examples/robust_median.py
| 1 | """ |
| 2 | Robust Median Aggregation Example |
| 3 | |
| 4 | Four honest clients plus one malicious client who signs their update |
| 5 | correctly but ships extreme values designed to bias the global model. |
| 6 | FedMedian absorbs the attack: the per-element median ignores the outlier. |
| 7 | """ |
| 8 | |
| 9 | from quantumshield.identity.agent import AgentIdentity |
| 10 | |
| 11 | from pqc_federated_learning import ( |
| 12 | AggregationRound, |
| 13 | ClientUpdate, |
| 14 | ClientUpdateMetadata, |
| 15 | FedMedianAggregator, |
| 16 | FederatedAggregator, |
| 17 | GradientTensor, |
| 18 | UpdateSigner, |
| 19 | ) |
| 20 | |
| 21 | |
| 22 | def build_signed_update( |
| 23 | identity: AgentIdentity, |
| 24 | round_id: str, |
| 25 | model_id: str, |
| 26 | scale: float, |
| 27 | ) -> ClientUpdate: |
| 28 | tensors = [ |
| 29 | GradientTensor( |
| 30 | name="dense.weights", |
| 31 | shape=(4,), |
| 32 | values=tuple(v * scale for v in (0.10, 0.20, 0.30, 0.40)), |
| 33 | ), |
| 34 | ] |
| 35 | meta = ClientUpdateMetadata( |
| 36 | client_did=identity.did, |
| 37 | round_id=round_id, |
| 38 | model_id=model_id, |
| 39 | num_samples=100, |
| 40 | ) |
| 41 | update = ClientUpdate.create(meta, tensors) |
| 42 | return UpdateSigner(identity).sign(update) |
| 43 | |
| 44 | |
| 45 | def main() -> None: |
| 46 | round_id = "round-robust-1" |
| 47 | model_id = "credit-model-v7" |
| 48 | |
| 49 | honest = [AgentIdentity.create(f"bank-{i}") for i in range(4)] |
| 50 | malicious = AgentIdentity.create("malicious-bank") |
| 51 | |
| 52 | updates = [build_signed_update(h, round_id, model_id, scale=1.0) for h in honest] |
| 53 | # Malicious client: values are 1000x - designed to bias the mean. |
| 54 | updates.append(build_signed_update(malicious, round_id, model_id, scale=1_000.0)) |
| 55 | |
| 56 | aggregator_id = AgentIdentity.create("regulator-aggregator") |
| 57 | round_ = AggregationRound(round_id=round_id, model_id=model_id) |
| 58 | for u in updates: |
| 59 | round_.add(u) |
| 60 | |
| 61 | aggregator = FederatedAggregator( |
| 62 | identity=aggregator_id, |
| 63 | strategy=FedMedianAggregator(), |
| 64 | trusted_clients={h.did for h in honest} | {malicious.did}, |
| 65 | min_updates=3, |
| 66 | ) |
| 67 | result = aggregator.aggregate(round_) |
| 68 | |
| 69 | print("--- Robust FedMedian aggregation ---") |
| 70 | for t in result.aggregated: |
| 71 | print(f" {t.name} shape={t.shape}") |
| 72 | print(f" values = {tuple(round(v, 4) for v in t.values)}") |
| 73 | |
| 74 | honest_baseline = (0.10, 0.20, 0.30, 0.40) |
| 75 | agg_values = next(t.values for t in result.aggregated if t.name == "dense.weights") |
| 76 | max_drift = max(abs(a - b) for a, b in zip(agg_values, honest_baseline)) |
| 77 | print(f"\n Max element drift from honest baseline: {max_drift:.6f}") |
| 78 | print(f" (Mean aggregation would have drifted by ~{(1000 - 1) * 0.4 / 5:.2f})") |
| 79 | |
| 80 | assert FederatedAggregator.verify_proof(result.proof) |
| 81 | assert max_drift < 1e-9, "Median should exactly match honest values" |
| 82 | |
| 83 | print("\n[OK] Median survived a 1000x malicious client. Proof PQ-signed.") |
| 84 | |
| 85 | |
| 86 | if __name__ == "__main__": |
| 87 | main() |
| 88 | |