examples/simple_fedavg.py
3.4 KB · 109 lines · python Raw
1 """
2 Simple FedAvg Example
3
4 Three hospitals train a local model, sign their gradient updates with ML-DSA,
5 and send them to a central aggregator. The aggregator verifies every signature,
6 computes a weighted mean (FedAvg), and emits a signed aggregation proof.
7 """
8
9 from quantumshield.identity.agent import AgentIdentity
10
11 from pqc_federated_learning import (
12 AggregationRound,
13 ClientUpdate,
14 ClientUpdateMetadata,
15 FedAvgAggregator,
16 FederatedAggregator,
17 GradientTensor,
18 UpdateSigner,
19 )
20
21
22 def build_signed_update(
23 identity: AgentIdentity,
24 round_id: str,
25 model_id: str,
26 num_samples: int,
27 scale: float,
28 ) -> ClientUpdate:
29 tensors = [
30 GradientTensor(
31 name="conv1.weights",
32 shape=(2, 2),
33 values=tuple(v * scale for v in (0.1, 0.2, 0.3, 0.4)),
34 ),
35 GradientTensor(
36 name="conv1.bias",
37 shape=(2,),
38 values=tuple(v * scale for v in (0.01, 0.02)),
39 ),
40 ]
41 meta = ClientUpdateMetadata(
42 client_did=identity.did,
43 round_id=round_id,
44 model_id=model_id,
45 num_samples=num_samples,
46 epochs=3,
47 local_loss=0.42 / scale,
48 )
49 update = ClientUpdate.create(meta, tensors)
50 return UpdateSigner(identity).sign(update)
51
52
53 def main() -> None:
54 round_id = "round-1"
55 model_id = "pneumonia-detector-v2"
56
57 # Three clients
58 hospital_a = AgentIdentity.create("hospital-a")
59 hospital_b = AgentIdentity.create("hospital-b")
60 hospital_c = AgentIdentity.create("hospital-c")
61
62 print(f"Hospital A DID: {hospital_a.did}")
63 print(f"Hospital B DID: {hospital_b.did}")
64 print(f"Hospital C DID: {hospital_c.did}")
65
66 u_a = build_signed_update(hospital_a, round_id, model_id, num_samples=1024, scale=1.0)
67 u_b = build_signed_update(hospital_b, round_id, model_id, num_samples=512, scale=1.5)
68 u_c = build_signed_update(hospital_c, round_id, model_id, num_samples=2048, scale=0.8)
69
70 # Coordinator / aggregator
71 aggregator_id = AgentIdentity.create("central-aggregator")
72 print(f"\nAggregator DID: {aggregator_id.did}")
73
74 round_ = AggregationRound(round_id=round_id, model_id=model_id)
75 round_.add(u_a)
76 round_.add(u_b)
77 round_.add(u_c)
78
79 aggregator = FederatedAggregator(
80 identity=aggregator_id,
81 strategy=FedAvgAggregator(),
82 trusted_clients={hospital_a.did, hospital_b.did, hospital_c.did},
83 min_updates=1,
84 )
85
86 result = aggregator.aggregate(round_)
87
88 print("\n--- Aggregated tensors ---")
89 for t in result.aggregated:
90 preview = ", ".join(f"{v:.5f}" for v in t.values[:4])
91 print(f" {t.name} shape={t.shape} values=[{preview}]")
92
93 print("\n--- Aggregation proof ---")
94 print(f" round_id = {result.proof.round_id}")
95 print(f" model_id = {result.proof.model_id}")
96 print(f" aggregator_name = {result.proof.aggregator_name}")
97 print(f" num_tensors = {result.proof.num_tensors}")
98 print(f" result_hash = {result.proof.result_hash}")
99 print(f" included clients = {len(result.proof.included_client_dids)}")
100 print(f" excluded = {len(result.proof.excluded_reasons)}")
101 print(f" signature[:32] = {result.proof.signature[:32]}...")
102
103 ok = FederatedAggregator.verify_proof(result.proof)
104 print(f"\n[OK] Proof signature verifies: {ok}")
105
106
107 if __name__ == "__main__":
108 main()
109