src/pqc_federated_learning/aggregators/fedtrimmed.py
| 1 | """FedTrimmedMean: drops top/bottom fraction of values per element before averaging.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from pqc_federated_learning.aggregators.base import Aggregator |
| 6 | from pqc_federated_learning.errors import InsufficientUpdatesError, ShapeMismatchError |
| 7 | from pqc_federated_learning.update import ClientUpdate, GradientTensor |
| 8 | |
| 9 | |
| 10 | class FedTrimmedMeanAggregator(Aggregator): |
| 11 | name = "fedtrimmedmean" |
| 12 | |
| 13 | def __init__(self, trim_ratio: float = 0.1): |
| 14 | if not 0.0 <= trim_ratio < 0.5: |
| 15 | raise ValueError("trim_ratio must be in [0, 0.5)") |
| 16 | self.trim_ratio = trim_ratio |
| 17 | |
| 18 | def aggregate(self, updates: list[ClientUpdate]) -> list[GradientTensor]: |
| 19 | if not updates: |
| 20 | raise InsufficientUpdatesError("FedTrimmedMean requires at least one update") |
| 21 | |
| 22 | n = len(updates) |
| 23 | trim = int(n * self.trim_ratio) |
| 24 | |
| 25 | names = [t.name for t in updates[0].tensors] |
| 26 | shapes = {t.name: t.shape for t in updates[0].tensors} |
| 27 | for u in updates[1:]: |
| 28 | if [t.name for t in u.tensors] != names: |
| 29 | raise ShapeMismatchError("tensor name mismatch across updates") |
| 30 | |
| 31 | out: list[GradientTensor] = [] |
| 32 | for name in names: |
| 33 | shape = shapes[name] |
| 34 | length = 1 |
| 35 | for d in shape: |
| 36 | length *= d |
| 37 | agg: list[float] = [] |
| 38 | for i in range(length): |
| 39 | vals = [] |
| 40 | for u in updates: |
| 41 | tensor = next(t for t in u.tensors if t.name == name) |
| 42 | vals.append(tensor.values[i]) |
| 43 | vals.sort() |
| 44 | kept = vals[trim : n - trim] if (n - 2 * trim) > 0 else vals |
| 45 | agg.append(sum(kept) / len(kept)) |
| 46 | out.append(GradientTensor(name=name, shape=shape, values=tuple(agg))) |
| 47 | return out |
| 48 | |