src/pqc_federated_learning/aggregators/fedtrimmed.py
1.8 KB · 48 lines · python Raw
1 """FedTrimmedMean: drops top/bottom fraction of values per element before averaging."""
2
3 from __future__ import annotations
4
5 from pqc_federated_learning.aggregators.base import Aggregator
6 from pqc_federated_learning.errors import InsufficientUpdatesError, ShapeMismatchError
7 from pqc_federated_learning.update import ClientUpdate, GradientTensor
8
9
10 class FedTrimmedMeanAggregator(Aggregator):
11 name = "fedtrimmedmean"
12
13 def __init__(self, trim_ratio: float = 0.1):
14 if not 0.0 <= trim_ratio < 0.5:
15 raise ValueError("trim_ratio must be in [0, 0.5)")
16 self.trim_ratio = trim_ratio
17
18 def aggregate(self, updates: list[ClientUpdate]) -> list[GradientTensor]:
19 if not updates:
20 raise InsufficientUpdatesError("FedTrimmedMean requires at least one update")
21
22 n = len(updates)
23 trim = int(n * self.trim_ratio)
24
25 names = [t.name for t in updates[0].tensors]
26 shapes = {t.name: t.shape for t in updates[0].tensors}
27 for u in updates[1:]:
28 if [t.name for t in u.tensors] != names:
29 raise ShapeMismatchError("tensor name mismatch across updates")
30
31 out: list[GradientTensor] = []
32 for name in names:
33 shape = shapes[name]
34 length = 1
35 for d in shape:
36 length *= d
37 agg: list[float] = []
38 for i in range(length):
39 vals = []
40 for u in updates:
41 tensor = next(t for t in u.tensors if t.name == name)
42 vals.append(tensor.values[i])
43 vals.sort()
44 kept = vals[trim : n - trim] if (n - 2 * trim) > 0 else vals
45 agg.append(sum(kept) / len(kept))
46 out.append(GradientTensor(name=name, shape=shape, values=tuple(agg)))
47 return out
48