The first major optimization technique for Federated Learning, which came out in 2017 was called FederatedAveraging. This algorithm combined local stochastic gradient descent (SGD) on each client with a server that performs model averaging. [2/n]
The paper by Google showed a number of experiments to demonstrate that this algorithm was robust to unbalanced and non-IID distributions, and also reduced the number of rounds of communication required for training, by orders of magnitude. [3/n]
Since SGD had shown great results in deep learning, the authors decided to base the Federated Learning training algorithm on SGD as well. SGD could be applied in this setting, where a single batch gradient calculation could be done per round of communication across clients. [4/n]
Before we get into the maths, I'll define some terms - [5/n]
I'll explain FederatedAveraging (FedAvg) by contrasting it with the baseline algorithm the authors presented, which was called the FederatedSGD (FedSGD). [6/n]
For FedSGD, the parameter C (explained above) which controls the global batch size is set to 1. This corresponds to a full-batch (non-stochastic) gradient descent. For the current global model w_t, the average gradient on its global model is calculated for each client k. [7/n]
The central server then aggregates these gradients and applies the update - [8/n]
This was FedSGD. Now let's make a small change to the above update. What this does is that now each client locally takes one step of gradient descent
on the current model using its local data, and the server then takes a weighted average of the resulting models. [9/n]
This way we can add more computation to each client by iterating the local update multiple times before doing the averaging step. This small modification results in the FederatedAveraging (FedAvg) algorithm. [10/n]
Also, why do we make this modification to allow more compute on the client?
The answer is here - [11/n]
The amount of computation is controlled by three parameters
C - Fraction of clients participating in that round
E - No. of training passes each client makes over its local dataset each round
B - Local minibatch size used for client updates
[12/n]
The pseudocode for the FedAvg algorithm is shown below. B = ꝏ (used in experiments) implies full local dataset is treated as the minibatch. So, setting B = ꝏ and E = 1 makes this the FedSGD algorithm.
- [13/n]
Okay, now let's look at some experimental results, although I would also suggest looking up the results from the original paper as well.
One experiment showed the number of rounds required to attain a target accuracy, in two tasks - MNIST and a character modelling task. [14/n]
The IID and non-IID here refer to the datasets that were artificially generated by the authors to represent two kinds of distributions - IID, in which there is in fact an IID distribution among the clients. And non-IID in which the data is not IID among the clients. [15/n]
From the results, it can be seen that in both the IID and non-IID settings, keeping a small mini-batch size and higher number of training passes on each client per round resulted in the model converging faster. [16/n]
For all model classes, FedAvg converges to a higher
level of test accuracy than the baseline FedSGD models. For the CNN, the B = ꝏ;E = 1 FedSGD model reaches 99.22% accuracy in 1200 rounds, while the B = 10;E = 20 FedAvg
model reaches an accuracy of 99.44% in 300 rounds. [17/n]
The authors also hypothesise that in addition to lowering communication costs, model averaging produces a regularization benefit similar to that achieved by dropout. [18/n]
That's the end for now!
This thread finishes my summary on the basics of Federated Learning and is also a concise version of the very famous paper "Communication-Efficient Learning of Deep Networks from Decentralized Data" by Google (3506 citations 🤯). arxiv.org/pdf/1602.05629…
I also have an annotated version of the paper on my Github.
Annotated paper - github.com/shreyansh26/An…
If the thread helps you or you have any questions, do let me know! 👋
[n/n]
• • •
Missing some Tweet in this thread? You can try to
force a refresh
I recently started reading about Privacy-preserving ML, as this has been a topic that has always interested me.
I hope to share my learnings here on Twitter.
I started with Federated Learning and here's a detailed thread that will give you a high-level idea of FL🧵
Modern mobile devices have an abundance of data - textual data, image data. Applying ML to these can improve the user experience. [1/n]
However, this data when combined over billions of devices is large in quantity and also privacy-sensitive in most cases. Storing this data in a data center can be both infeasible as well as have privacy concerns. [2/n]