src/pqc_federated_learning/aggregators/fedavg.py
| 1 | """FedAvg: num_samples-weighted mean of client gradient tensors.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from pqc_federated_learning.aggregators.base import Aggregator |
| 6 | from pqc_federated_learning.errors import InsufficientUpdatesError, ShapeMismatchError |
| 7 | from pqc_federated_learning.update import ClientUpdate, GradientTensor |
| 8 | |
| 9 | |
| 10 | class FedAvgAggregator(Aggregator): |
| 11 | name = "fedavg" |
| 12 | |
| 13 | def aggregate(self, updates: list[ClientUpdate]) -> list[GradientTensor]: |
| 14 | if not updates: |
| 15 | raise InsufficientUpdatesError("FedAvg requires at least one update") |
| 16 | |
| 17 | # Build tensor-name -> list of (weight, values, shape) |
| 18 | tensor_names = [t.name for t in updates[0].tensors] |
| 19 | shapes = {t.name: t.shape for t in updates[0].tensors} |
| 20 | for u in updates[1:]: |
| 21 | if [t.name for t in u.tensors] != tensor_names: |
| 22 | raise ShapeMismatchError( |
| 23 | f"client {u.metadata.client_did} has different tensor names" |
| 24 | ) |
| 25 | for t in u.tensors: |
| 26 | if t.shape != shapes[t.name]: |
| 27 | raise ShapeMismatchError( |
| 28 | f"tensor {t.name} shape mismatch: {t.shape} vs {shapes[t.name]}" |
| 29 | ) |
| 30 | |
| 31 | total_weight = sum(max(1, u.metadata.num_samples) for u in updates) |
| 32 | out: list[GradientTensor] = [] |
| 33 | for tname in tensor_names: |
| 34 | shape = shapes[tname] |
| 35 | length = 1 |
| 36 | for d in shape: |
| 37 | length *= d |
| 38 | agg = [0.0] * length |
| 39 | for u in updates: |
| 40 | weight = max(1, u.metadata.num_samples) / total_weight |
| 41 | tensor = next(t for t in u.tensors if t.name == tname) |
| 42 | for i, v in enumerate(tensor.values): |
| 43 | agg[i] += weight * v |
| 44 | out.append(GradientTensor(name=tname, shape=shape, values=tuple(agg))) |
| 45 | return out |
| 46 | |