The FederatedAveraging Algorithm#
The \(\texttt{FederatedSGD}\) Algorithm#
The \(\texttt{FederatedAveraging}\) Algorithm#
(Federated Averaging)
Input: \(K\) clients, the number of local epochs \(E\), the learning rate \(\alpha\), the batch size \(B\), fraction \(C\) of clients to sample, and the number of communication rounds \(T\)
Output: output a global model \(w\)
Server executes:
Initialize global model \(w_0\)
for each round \(t = 1, 2, \dots, T\) do
\(m \leftarrow \max(C \cdot K, 1)\)
\(S_t \leftarrow\) (random set of \(m\) clients)
for each client \(k \in S_t\) do
\(w^k_{t+1} \leftarrow \text{ClientUpdate}(k, w_t)\)
\(n \leftarrow \sum_{k \in S_t} n_k\)
\(w_{t+1} \leftarrow \sum_{k \in S_t} \frac{n_k}{n} w^k_{t+1}\)
ClientUpdate(\(k, w\)):
\(\mathcal{B} \leftarrow\) (split \(\mathcal{P}_k\) into batches of size \(B\))
for \(e = 1, 2, \dots, E\) do
for batch \(b \in \mathcal{B}\) do
\(w \leftarrow w - \alpha \nabla \ell(w; b)\)
return \(w\)