src/pqc_federated_learning/aggregators/fedavg.py
1.8 KB · 46 lines · python Raw
1 """FedAvg: num_samples-weighted mean of client gradient tensors."""
2
3 from __future__ import annotations
4
5 from pqc_federated_learning.aggregators.base import Aggregator
6 from pqc_federated_learning.errors import InsufficientUpdatesError, ShapeMismatchError
7 from pqc_federated_learning.update import ClientUpdate, GradientTensor
8
9
10 class FedAvgAggregator(Aggregator):
11 name = "fedavg"
12
13 def aggregate(self, updates: list[ClientUpdate]) -> list[GradientTensor]:
14 if not updates:
15 raise InsufficientUpdatesError("FedAvg requires at least one update")
16
17 # Build tensor-name -> list of (weight, values, shape)
18 tensor_names = [t.name for t in updates[0].tensors]
19 shapes = {t.name: t.shape for t in updates[0].tensors}
20 for u in updates[1:]:
21 if [t.name for t in u.tensors] != tensor_names:
22 raise ShapeMismatchError(
23 f"client {u.metadata.client_did} has different tensor names"
24 )
25 for t in u.tensors:
26 if t.shape != shapes[t.name]:
27 raise ShapeMismatchError(
28 f"tensor {t.name} shape mismatch: {t.shape} vs {shapes[t.name]}"
29 )
30
31 total_weight = sum(max(1, u.metadata.num_samples) for u in updates)
32 out: list[GradientTensor] = []
33 for tname in tensor_names:
34 shape = shapes[tname]
35 length = 1
36 for d in shape:
37 length *= d
38 agg = [0.0] * length
39 for u in updates:
40 weight = max(1, u.metadata.num_samples) / total_weight
41 tensor = next(t for t in u.tensors if t.name == tname)
42 for i, v in enumerate(tensor.values):
43 agg[i] += weight * v
44 out.append(GradientTensor(name=tname, shape=shape, values=tuple(agg)))
45 return out
46