src/pqc_federated_learning/aggregators/fedmedian.py
1.4 KB · 40 lines · python Raw
1 """FedMedian: element-wise median of client tensors. Robust to outliers."""
2
3 from __future__ import annotations
4
5 import statistics
6
7 from pqc_federated_learning.aggregators.base import Aggregator
8 from pqc_federated_learning.errors import InsufficientUpdatesError, ShapeMismatchError
9 from pqc_federated_learning.update import ClientUpdate, GradientTensor
10
11
12 class FedMedianAggregator(Aggregator):
13 name = "fedmedian"
14
15 def aggregate(self, updates: list[ClientUpdate]) -> list[GradientTensor]:
16 if not updates:
17 raise InsufficientUpdatesError("FedMedian requires at least one update")
18
19 names = [t.name for t in updates[0].tensors]
20 shapes = {t.name: t.shape for t in updates[0].tensors}
21 for u in updates[1:]:
22 if [t.name for t in u.tensors] != names:
23 raise ShapeMismatchError("tensor name mismatch across updates")
24
25 out: list[GradientTensor] = []
26 for name in names:
27 shape = shapes[name]
28 length = 1
29 for d in shape:
30 length *= d
31 agg: list[float] = []
32 for i in range(length):
33 vals = []
34 for u in updates:
35 tensor = next(t for t in u.tensors if t.name == name)
36 vals.append(tensor.values[i])
37 agg.append(statistics.median(vals))
38 out.append(GradientTensor(name=name, shape=shape, values=tuple(agg)))
39 return out
40