examples/robust_median.py
2.7 KB · 88 lines · python Raw
1 """
2 Robust Median Aggregation Example
3
4 Four honest clients plus one malicious client who signs their update
5 correctly but ships extreme values designed to bias the global model.
6 FedMedian absorbs the attack: the per-element median ignores the outlier.
7 """
8
9 from quantumshield.identity.agent import AgentIdentity
10
11 from pqc_federated_learning import (
12 AggregationRound,
13 ClientUpdate,
14 ClientUpdateMetadata,
15 FedMedianAggregator,
16 FederatedAggregator,
17 GradientTensor,
18 UpdateSigner,
19 )
20
21
22 def build_signed_update(
23 identity: AgentIdentity,
24 round_id: str,
25 model_id: str,
26 scale: float,
27 ) -> ClientUpdate:
28 tensors = [
29 GradientTensor(
30 name="dense.weights",
31 shape=(4,),
32 values=tuple(v * scale for v in (0.10, 0.20, 0.30, 0.40)),
33 ),
34 ]
35 meta = ClientUpdateMetadata(
36 client_did=identity.did,
37 round_id=round_id,
38 model_id=model_id,
39 num_samples=100,
40 )
41 update = ClientUpdate.create(meta, tensors)
42 return UpdateSigner(identity).sign(update)
43
44
45 def main() -> None:
46 round_id = "round-robust-1"
47 model_id = "credit-model-v7"
48
49 honest = [AgentIdentity.create(f"bank-{i}") for i in range(4)]
50 malicious = AgentIdentity.create("malicious-bank")
51
52 updates = [build_signed_update(h, round_id, model_id, scale=1.0) for h in honest]
53 # Malicious client: values are 1000x - designed to bias the mean.
54 updates.append(build_signed_update(malicious, round_id, model_id, scale=1_000.0))
55
56 aggregator_id = AgentIdentity.create("regulator-aggregator")
57 round_ = AggregationRound(round_id=round_id, model_id=model_id)
58 for u in updates:
59 round_.add(u)
60
61 aggregator = FederatedAggregator(
62 identity=aggregator_id,
63 strategy=FedMedianAggregator(),
64 trusted_clients={h.did for h in honest} | {malicious.did},
65 min_updates=3,
66 )
67 result = aggregator.aggregate(round_)
68
69 print("--- Robust FedMedian aggregation ---")
70 for t in result.aggregated:
71 print(f" {t.name} shape={t.shape}")
72 print(f" values = {tuple(round(v, 4) for v in t.values)}")
73
74 honest_baseline = (0.10, 0.20, 0.30, 0.40)
75 agg_values = next(t.values for t in result.aggregated if t.name == "dense.weights")
76 max_drift = max(abs(a - b) for a, b in zip(agg_values, honest_baseline))
77 print(f"\n Max element drift from honest baseline: {max_drift:.6f}")
78 print(f" (Mean aggregation would have drifted by ~{(1000 - 1) * 0.4 / 5:.2f})")
79
80 assert FederatedAggregator.verify_proof(result.proof)
81 assert max_drift < 1e-9, "Median should exactly match honest values"
82
83 print("\n[OK] Median survived a 1000x malicious client. Proof PQ-signed.")
84
85
86 if __name__ == "__main__":
87 main()
88