src/pqc_federated_learning/aggregators/fedsum.py
1.5 KB · 41 lines · python Raw
1 """FedSum: unweighted elementwise sum of client tensors."""
2
3 from __future__ import annotations
4
5 from pqc_federated_learning.aggregators.base import Aggregator
6 from pqc_federated_learning.errors import InsufficientUpdatesError, ShapeMismatchError
7 from pqc_federated_learning.update import ClientUpdate, GradientTensor
8
9
10 class FedSumAggregator(Aggregator):
11 name = "fedsum"
12
13 def aggregate(self, updates: list[ClientUpdate]) -> list[GradientTensor]:
14 if not updates:
15 raise InsufficientUpdatesError("FedSum requires at least one update")
16
17 names = [t.name for t in updates[0].tensors]
18 shapes = {t.name: t.shape for t in updates[0].tensors}
19 for u in updates[1:]:
20 if [t.name for t in u.tensors] != names:
21 raise ShapeMismatchError("tensor name mismatch across updates")
22 for t in u.tensors:
23 if t.shape != shapes[t.name]:
24 raise ShapeMismatchError(
25 f"{t.name} shape mismatch: {t.shape} vs {shapes[t.name]}"
26 )
27
28 out: list[GradientTensor] = []
29 for name in names:
30 shape = shapes[name]
31 length = 1
32 for d in shape:
33 length *= d
34 agg = [0.0] * length
35 for u in updates:
36 tensor = next(t for t in u.tensors if t.name == name)
37 for i, v in enumerate(tensor.values):
38 agg[i] += v
39 out.append(GradientTensor(name=name, shape=shape, values=tuple(agg)))
40 return out
41