src/pqc_federated_learning/aggregators/fedmedian.py
| 1 | """FedMedian: element-wise median of client tensors. Robust to outliers.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import statistics |
| 6 | |
| 7 | from pqc_federated_learning.aggregators.base import Aggregator |
| 8 | from pqc_federated_learning.errors import InsufficientUpdatesError, ShapeMismatchError |
| 9 | from pqc_federated_learning.update import ClientUpdate, GradientTensor |
| 10 | |
| 11 | |
| 12 | class FedMedianAggregator(Aggregator): |
| 13 | name = "fedmedian" |
| 14 | |
| 15 | def aggregate(self, updates: list[ClientUpdate]) -> list[GradientTensor]: |
| 16 | if not updates: |
| 17 | raise InsufficientUpdatesError("FedMedian requires at least one update") |
| 18 | |
| 19 | names = [t.name for t in updates[0].tensors] |
| 20 | shapes = {t.name: t.shape for t in updates[0].tensors} |
| 21 | for u in updates[1:]: |
| 22 | if [t.name for t in u.tensors] != names: |
| 23 | raise ShapeMismatchError("tensor name mismatch across updates") |
| 24 | |
| 25 | out: list[GradientTensor] = [] |
| 26 | for name in names: |
| 27 | shape = shapes[name] |
| 28 | length = 1 |
| 29 | for d in shape: |
| 30 | length *= d |
| 31 | agg: list[float] = [] |
| 32 | for i in range(length): |
| 33 | vals = [] |
| 34 | for u in updates: |
| 35 | tensor = next(t for t in u.tensors if t.name == name) |
| 36 | vals.append(tensor.values[i]) |
| 37 | agg.append(statistics.median(vals)) |
| 38 | out.append(GradientTensor(name=name, shape=shape, values=tuple(agg))) |
| 39 | return out |
| 40 | |