The world’s leading publication for data science, AI, and ML professionals.

Preserving Data Privacy in Deep Learning | Part 1

Understanding the basics of Federated Learning and its implementation using PyTorch

Link to part 2 (Distribution of CIFAR10 into real-world/non-IID dataset): https://towardsdatascience.com/preserving-data-privacy-in-deep-learning-part-2-6c2e9494398b

Link to part 3 (Implementation of Federated Learning with non-IID dataset): https://towardsdatascience.com/preserving-data-privacy-in-deep-learning-part-3-ae2103c40c22

Photo by Harsh Yadav
Photo by Harsh Yadav

Many thanks to renowned data scientist Mr. Akshay Kulkarni for his inspiration and guidance on this tutorial.

In the world revolutionized by data and digitalization, more and more personal information is shared and stored, which opens a new field of preserving data privacy. But, what is data privacy, and why is there a need for preserving it?

Photo by Google AI
Photo by Google AI

Data Privacy defines how a particular piece of information/data should be handled or who has authorized access based on its relative importance. With the introduction to AI (Machine Learning and Deep Learning), a lot of personal information can be extracted from these models, which can cause irreparable damage to the people whose personal data has been exposed. So, here comes the need to preserve this data while implementing various machine learning models.

In this series of tutorials, the major concern is to preserve the data-privacy in Deep Learning models. You will be exploring different methods like Federated Learning, Differential Privacy, and Homomorphic Encryption.

In this tutorial, you will discover how to preserve data privacy using federated learning on machine learning models. After completing this tutorial, you will know:

  1. The basics of Federated Learning
  2. Dividing data amongst the clients (for federated learning)
  3. The model architecture
  4. Aggregation of the Decentralized Weights in Global Weights

Introduction

Federated Learning, also known as collaborative learning, is a deep learning technique where the training takes place across multiple decentralized edge devices (clients) or servers on their personal data, without sharing the data with other clients, thus keeping the data private. It aims at training a machine learning algorithm, say, deep neural networks on multiple devices (clients) having local datasets without explicitly exchanging the data samples.

Photo by Google AI
Photo by Google AI

This training happens simultaneously on other devices, hundreds, and thousands of them. After training the same model on different devices with different data, their weights (summary of training) are sent to the global server, where aggregation of these weights takes place. Different aggregation techniques are used to get the most out of the weights learned on the clients’ devices. After aggregation, the global weights are again sent to the clients, and the training continues on the client’s device. This entire process is called a communication round in federated learning, and this is how many communication rounds take place to further improve the accuracy of the model.

Photo by Google AI
Photo by Google AI

Federated learning makes it possible for AI algorithms to gain experience from a vast range of data located at different sites. This approach enables several organizations to collaborate on the development of models but without having to share sensitive data with each other. Over the course of several communication rounds, the shared model (global model) gets exposed to a significantly wider range of data than what any single client possesses, and at the same time, it learns the client-specific features as well.

There are two types of Federated Learning:

  1. Centralized federated learning: In this setting, a central server is used to orchestrate the different steps of algorithms and coordinate all the participating nodes during the learning process. The server is responsible for the nodes selection at the beginning of the training process and for the aggregation of the received model updates (weights). It is the bottleneck of the system.
  2. Decentralized federated learning: In this type, nodes are able to coordinate themselves to obtain the global model. This setting prevents single point failures as the model updates are exchanged only between interconnected nodes.

Now that we are clear on what Federated Learning is, let’s move on to building one from scratch in PyTorch and training it on the CIFAR10 dataset. With the help of this tutorial, we will understand the basic flow of federated learning. We will neither go into the details of how server-client communication works in the real-world nor the aggregation techniques, that would be discussed in the upcoming tutorial. In this study, the dataset is randomly divided into the clients and all local models will be trained on the same machine. Let’s go.

1. Import all relevant packages

Photo by Google AI
Photo by Google AI

2. Hyper-parameters

  1. num_clients: Total number of clients, further this is used to divide the dataset into num_clients with every client, having the same amount of images.
  2. num_selected: Number of randomly selected clients from num_clients during communication round. To be used in the training section. Generally, num_selected is around 30% of the num_clients.
  3. num_rounds: Total number of communication rounds. In each communication round, num_clients are randomly selected, training on client’s devices takes place, which is followed by aggregation of the individual model weights into one global model.
  4. epochs: Total number of local training rounds on each selected client’s device.
  5. batch_size: Loading the data into the data loader by batches.

3. Loading and dividing CIFAR 10 into clients

CIFAR10 dataset is used in this tutorial. It consists of 60,000 color images of 32×32 pixels in 10 classes. There are 50,000 training images and 10,000 test images. In the training batch, there are 5,000 images from each class, which makes 50,000 in total. In PyTorch, CIFAR 10 is available to use with the help of the torchvision module.

In this tutorial, images are equally divided into clients, thus representing the balanced (IID) case.

Lines 6–11 define the image augmentation and normalization method for the training data to be used while loading the images. Lines 14–15 load the training data with the given augmentation. A traindata_split is made, line 18, which splits the training data into num_clients, i.e. 20 in our case. Finally, a train_loader, line 21, is made to feed the images into the neural network for the training of the model.

Lines 24–27 define the normalization method for the test images. Lines 30–32 define a test_loader, which is used to generate the test results of the given model.

4. Building the Neural Network (Model Architecture)

VGG19 (16 convolution layers, 3 Fully Connected layers, 5 MaxPool layers, and 1 SoftMax layer) are used in this tutorial. There are other variants of VGG like VGG11, VGG13, and VGG16.

5. Helper functions for Federated training

The client_update function train the client model on private client data. This is the local training round that takes place at num_selected clients, i.e. 6 in our case.

The server_aggregate function aggregates the model weights received from every client and updates the global model with the updated weights. In this tutorial, the mean of the weights is taken and aggregated into the global weights.

The test function is the standard function, which takes the global model along with the test loader as the input and returns the test loss and accuracy.

6. Training the model

One global model, along with the individual client_models is initialized with VGG19 on a GPU. In this tutorial, SGD is used as an optimizer for all the client models.

Instead of VGG19, one can also use VGG11, VGG13, and VGG16. Other optimizers are also available and one can check the link for more details.

Lines 2–5 create a list for keeping a track of the loss and accuracy for the train and test dataset. Training of individual clients starts with line 8, i.e. the communication round. Initially, num_selected clients are selected from num_clients (line 10), i.e. 6 clients are randomly selected from 20 available clients. Training takes place for every selected client (lines 12–14) using the client_update function. Now, the aggregation of the weights takes place using the server_aggregate (line 18) function defined above. This updates the global model, which is the final model that is used for prediction. After updating the global model, this global model is used to test the training (line 20) with the help of the test function defined above.

This process continues for num_rounds, i.e. 150 communication rounds in our case. This is the example of the most basic level of federated learning, which can be further applied to a real client-based system.

7. Results

With 6 selected clients, each running 5 local epochs on the top of 150 communication rounds, below is the truncated test result.

So, this model has achieved around 89.8 % accuracy using federated learning after 150 communication rounds. But how does it compare to a standard SGD model trained on the same dataset? For comparative study, the VGG19 model is trained using a standard SGD optimizer on the training data without federated learning. To ensure an equal playing ground, all the hyperparameters are the same, and it is trained for 200 epochs.

The standard SGD model gives around 94% test accuracy after 200 epochs. So, our federated learning is performing decently with around 90% accuracy when compared to the original model.


SUMMARY

In this tutorial, you discovered how to preserve data privacy in deep learning models using federated learning with PyTorch.

Specifically, you learned:

  1. What is data privacy and why is there a need to preserve data privacy?
  2. The basics of federated learning and its implementation on a balanced CIFAR-10 dataset. A similar model can be made for any image classification problem.

Below is the flow diagram for a quick revision of the entire process.


CONCLUSION

Federated Learning is one of the best methods for preserving data privacy in machine learning models. The safety of client data is ensured by only sending the updated weights of the model, not the data. At the same time, the global model can learn from client-specific features. But do not get too excited about this model as the above results are not likely in the real world scenario. Real-world federated data held by clients are mostly NON-independent and identically distributed (non-IID). So, to tackle the problem of real-world/non-IID dataset, stay tuned for the next part of this series.

In the upcoming tutorials, you will not only get to learn about tackling the non-IID dataset in federated learning but also different aggregation techniques in federated learning, homomorphic encryption of the model weights, Differential Privacy and its hybrid with federated learning, and a few more topics helping in preserving the data privacy.

Link to part 2 (Distribution of CIFAR10 into real-world/non-IID dataset): https://towardsdatascience.com/preserving-data-privacy-in-deep-learning-part-2-6c2e9494398b

Link to part 3 (Implementation of Federated Learning with non-IID dataset): https://towardsdatascience.com/preserving-data-privacy-in-deep-learning-part-3-ae2103c40c22

REFERENCES

[1] Felix Sattler, Robust and Communication-Efficient Federated Learning from Non-IID Data, arXiv:1903.02891

[2] H.Brendan McMahan, Communication-Efficient Learning of Deep Networks from Decentralized Data, Proceedings of the 20th International Conference on Artificial Intelligence and Statistics (AISTATS) 2017. JMLR: W&CP volume 54


Related Articles