Hands-on Tutorials

Federated Learning: A Simple Implementation of FedAvg (Federated Averaging) with PyTorch

Ece Işık Polat
Towards Data Science
6 min readSep 24, 2020

--

Photo by Jason Dent on Unsplash

Mobile devices such as phones, tablets, and smartwatches are now primary computing devices and have become integral for many people. These devices host huge amounts of valuable and private data thanks to the combination of rich user interactions and powerful sensors. Models trained on such data could significantly improve the usability and power of intelligent applications. However, the sensitive nature of this data means there are also some risks and responsibilities [1]. At this point, the Federated Learning (FL) concept comes into play.

In FL, each client trains its model decentrally. In other words, the model training process is carried out separately for each client. Only learned model parameters are sent to a trusted center to combine and feed the aggregated main model. Then the trusted center sent back the aggregated main model back to these clients, and this process is circulated [2].

In this context, I prepared a simple implementation with IID (independent and identically distributed) data to show how the parameters of hundreds of different models that are running on different nodes can be combined with the FedAvg method and whether this model will give a reasonable result. This implementation was carried out on the MNIST Data set. The MNIST data set contains 28 * 28 pixel grayscale images of numbers from 0 to 9 [3].

Handwritten Digits from the MNIST dataset (Image by Author*)
  • The MNIST data set does not contain each label equally. Therefore, to fulfill the IID requirement, the dataset was grouped, shuffled, and then distributed so that each node contains an equal number of each label.
  • A simple 2-layer model was created for the classification process.
  • Functions to be used for FedAvg are defined.

Here, an iteration is completed as follows.

  1. Since the parameters of the main model and parameters of all local models in the nodes are randomly initialized, all these parameters will be different from each other. For this reason, the main model sends its parameters to the nodes before the training of local models in the nodes begins.
  2. Nodes start to train their local models over their own data by using these parameters.
  3. Each node updates its parameters while training its own model. After the training process is completed, each node sends its parameters to the main model.
  4. The main model takes the average of these parameters and sets them as its new weight parameters and passes them back to the nodes for the next iteration.

The above flow is for one iteration. This iteration can be repeated over and over to improve the performance of the main model.

Note: The purpose here is not to increase the performance of the classification algorithm, but to compare the performance of the model obtained with federated learning with a centralized model. If you wish, you can use more complicated models or adjust hyperparameters to improve performance.

And if you are ready, here we go!

GIF by Europapark on giphy

Function Explanations

Functions for data distribution

  • split_and_shuffle_labels(y_data, seed, amount): The data set does not contain an equal number of each label. In order to distribute the data to the nodes as IID, an equal number of them must be taken. This function groups them as much as the amount given from each label and shuffles the order within itself. Please be aware that, what shuffled here is the indexes of the data, we will use them when retrieving the data in the future. But these indexes need to be reset to avoid key errors. Therefore, a new column has been defined and shuffled indexes are kept there.
  • get_iid_subsamples_indices(label_dict, number_of_samples, amount): This function divides the indexes in each node with an equal number of each label. (Here the indexes are still distributed, not data)
  • create_iid_subsamples(sample_dict, x_data, y_data, x_name, y_name): This function distributes x and y data to nodes in dictionary.

Functions for FedAvg

  • create_model_optimizer_criterion_dict(number_of_samples): This function creates a model, optimizer and loss function for each node.
  • get_averaged_weights(model_dict, number_of_samples): This function takes the average of the weights in individual nodes.
  • set_averaged_weights_as_main_model_weights_and_update_main_model(main_model,model_dict, number_of_samples): This function sends the averaged weights of individual nodes to the main model and sets them as the new weights of the main model. ( calls def get_averaged_weights(model_dict, number_of_samples))
  • compare_local_and_merged_model_performance(number_of_samples: This function compares the accuracy of the main model and the local model running on each node.
  • send_main_model_to_nodes_and_update_model_dict(main_model, model_dict, number_of_samples): This function sends the parameters of the main model to the nodes.
  • start_train_end_node_process_without_print(): This function trains individual local models in nodes.

What would the performance of a centralized model that is based on all training data be?

First, let’s examine what would the performance of the centralized model be if the data were not distributed to nodes at all?

 — — — Centralized Model — — — 
epoch: 1 | train accuracy: 0.8743 | test accuracy: 0.9437
epoch: 2 | train accuracy: 0.9567 | test accuracy: 0.9654
epoch: 3 | train accuracy: 0.9712 | test accuracy: 0.9701
epoch: 4 | train accuracy: 0.9785 | test accuracy: 0.9738
epoch: 5 | train accuracy: 0.9834 | test accuracy: 0.9713
epoch: 6 | train accuracy: 0.9864 | test accuracy: 0.9768
epoch: 7 | train accuracy: 0.9898 | test accuracy: 0.9763
epoch: 8 | train accuracy: 0.9923 | test accuracy: 0.9804
epoch: 9 | train accuracy: 0.9941 | test accuracy: 0.9784
epoch: 10 | train accuracy: 0.9959 | test accuracy: 0.9792
— — — Training finished — — -

The model used in this example is very simple, different improvements can be performed to increase model performance, such as using more complex models, increasing epoch or hyperparameter tuning. However, the purpose here is to compare the performance of the main model that is formed by combining the parameters of the local models trained on their own data with a centralized model that trained on all training data. In this way, we can gain insight into the capacity of federated learning.

Then, start our first iteration

Data is distributed to nodes

The main model is created

Models, optimizers, and loss functions in nodes are defined

Keys of dicts are being made iterable

Parameters of the main model are sent to nodes
Since the parameters of the main model and parameters of all local models in the nodes are randomly initialized, all these parameters will be different from each other. For this reason, the main model sends its parameters to the nodes before the training of local models in the nodes begins. You can check the weights below.

Models in the nodes are trained

Let’s compare the performance of the federated main model and centralized model

Federated main model vs centralized model before 1st iteration (on all test data)
Since the main model is randomly initialized and no action taken on it yet, before first iteration its performance is very poor. After first iteration, the accuracy of main model increased to %85.

Before 1st iteration main model accuracy on all test data: 0.1180
After 1st iteration main model accuracy on all test data: 0.8529
Centralized model accuracy on all test data: 0.9790

This is a single iteration, we can send the parameters of the main model back to the nodes and repeat the above steps. Now let’s check how the performance of the main model improves when we repeat the iteration 10 more times.

Iteration 2 : main_model accuracy on all test data:  0.8928
Iteration 3 : main_model accuracy on all test data: 0.9073
Iteration 4 : main_model accuracy on all test data: 0.9150
Iteration 5 : main_model accuracy on all test data: 0.9209
Iteration 6 : main_model accuracy on all test data: 0.9273
Iteration 7 : main_model accuracy on all test data: 0.9321
Iteration 8 : main_model accuracy on all test data: 0.9358
Iteration 9 : main_model accuracy on all test data: 0.9382
Iteration 10 : main_model accuracy on all test data: 0.9411
Iteration 11 : main_model accuracy on all test data: 0.9431

The accuracy of the centralized model was calculated as approximately 98%. The accuracy of the main model obtained by FedAvg method started from 85% and improved to 94%. In this case, we can say that although the main model obtained by FedAvg method was trained without seeing the data, its performance cannot be underestimated.

You can visit the https://github.com/eceisik/fl_public/blob/master/fedavg_mnist_iid.ipynb to see the full implementation.

*You can visit the github page.

[1] J. Konečný, H. B. McMahan, D. Ramage, and P. Richtárik, “Federated Optimization: Distributed Machine Learning for On-Device Intelligence,” pp. 1–38, 2016.

[2] H. B. Mcmahan and D. Ramage, “Communication-Efficient Learning of Deep Networks from Decentralized Data,” vol. 54, 2017.

[3] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. “Gradient-based learning applied to document recognition.” Proceedings of the IEEE, 86(11):2278–2324, November 1998.

--

--