examples/simple_fedavg.py
| 1 | """ |
| 2 | Simple FedAvg Example |
| 3 | |
| 4 | Three hospitals train a local model, sign their gradient updates with ML-DSA, |
| 5 | and send them to a central aggregator. The aggregator verifies every signature, |
| 6 | computes a weighted mean (FedAvg), and emits a signed aggregation proof. |
| 7 | """ |
| 8 | |
| 9 | from quantumshield.identity.agent import AgentIdentity |
| 10 | |
| 11 | from pqc_federated_learning import ( |
| 12 | AggregationRound, |
| 13 | ClientUpdate, |
| 14 | ClientUpdateMetadata, |
| 15 | FedAvgAggregator, |
| 16 | FederatedAggregator, |
| 17 | GradientTensor, |
| 18 | UpdateSigner, |
| 19 | ) |
| 20 | |
| 21 | |
| 22 | def build_signed_update( |
| 23 | identity: AgentIdentity, |
| 24 | round_id: str, |
| 25 | model_id: str, |
| 26 | num_samples: int, |
| 27 | scale: float, |
| 28 | ) -> ClientUpdate: |
| 29 | tensors = [ |
| 30 | GradientTensor( |
| 31 | name="conv1.weights", |
| 32 | shape=(2, 2), |
| 33 | values=tuple(v * scale for v in (0.1, 0.2, 0.3, 0.4)), |
| 34 | ), |
| 35 | GradientTensor( |
| 36 | name="conv1.bias", |
| 37 | shape=(2,), |
| 38 | values=tuple(v * scale for v in (0.01, 0.02)), |
| 39 | ), |
| 40 | ] |
| 41 | meta = ClientUpdateMetadata( |
| 42 | client_did=identity.did, |
| 43 | round_id=round_id, |
| 44 | model_id=model_id, |
| 45 | num_samples=num_samples, |
| 46 | epochs=3, |
| 47 | local_loss=0.42 / scale, |
| 48 | ) |
| 49 | update = ClientUpdate.create(meta, tensors) |
| 50 | return UpdateSigner(identity).sign(update) |
| 51 | |
| 52 | |
| 53 | def main() -> None: |
| 54 | round_id = "round-1" |
| 55 | model_id = "pneumonia-detector-v2" |
| 56 | |
| 57 | # Three clients |
| 58 | hospital_a = AgentIdentity.create("hospital-a") |
| 59 | hospital_b = AgentIdentity.create("hospital-b") |
| 60 | hospital_c = AgentIdentity.create("hospital-c") |
| 61 | |
| 62 | print(f"Hospital A DID: {hospital_a.did}") |
| 63 | print(f"Hospital B DID: {hospital_b.did}") |
| 64 | print(f"Hospital C DID: {hospital_c.did}") |
| 65 | |
| 66 | u_a = build_signed_update(hospital_a, round_id, model_id, num_samples=1024, scale=1.0) |
| 67 | u_b = build_signed_update(hospital_b, round_id, model_id, num_samples=512, scale=1.5) |
| 68 | u_c = build_signed_update(hospital_c, round_id, model_id, num_samples=2048, scale=0.8) |
| 69 | |
| 70 | # Coordinator / aggregator |
| 71 | aggregator_id = AgentIdentity.create("central-aggregator") |
| 72 | print(f"\nAggregator DID: {aggregator_id.did}") |
| 73 | |
| 74 | round_ = AggregationRound(round_id=round_id, model_id=model_id) |
| 75 | round_.add(u_a) |
| 76 | round_.add(u_b) |
| 77 | round_.add(u_c) |
| 78 | |
| 79 | aggregator = FederatedAggregator( |
| 80 | identity=aggregator_id, |
| 81 | strategy=FedAvgAggregator(), |
| 82 | trusted_clients={hospital_a.did, hospital_b.did, hospital_c.did}, |
| 83 | min_updates=1, |
| 84 | ) |
| 85 | |
| 86 | result = aggregator.aggregate(round_) |
| 87 | |
| 88 | print("\n--- Aggregated tensors ---") |
| 89 | for t in result.aggregated: |
| 90 | preview = ", ".join(f"{v:.5f}" for v in t.values[:4]) |
| 91 | print(f" {t.name} shape={t.shape} values=[{preview}]") |
| 92 | |
| 93 | print("\n--- Aggregation proof ---") |
| 94 | print(f" round_id = {result.proof.round_id}") |
| 95 | print(f" model_id = {result.proof.model_id}") |
| 96 | print(f" aggregator_name = {result.proof.aggregator_name}") |
| 97 | print(f" num_tensors = {result.proof.num_tensors}") |
| 98 | print(f" result_hash = {result.proof.result_hash}") |
| 99 | print(f" included clients = {len(result.proof.included_client_dids)}") |
| 100 | print(f" excluded = {len(result.proof.excluded_reasons)}") |
| 101 | print(f" signature[:32] = {result.proof.signature[:32]}...") |
| 102 | |
| 103 | ok = FederatedAggregator.verify_proof(result.proof) |
| 104 | print(f"\n[OK] Proof signature verifies: {ok}") |
| 105 | |
| 106 | |
| 107 | if __name__ == "__main__": |
| 108 | main() |
| 109 | |