examples/byzantine_client_rejected.py
2.5 KB · 86 lines · python Raw
1 """
2 Byzantine Client Rejected Example
3
4 One honest client + one attacker who forges a signature (wrong key / bad bytes).
5 The aggregator detects the invalid signature and excludes the attacker
6 from the aggregation. The signed AggregationProof lists the exclusion
7 for auditors.
8 """
9
10 from quantumshield.identity.agent import AgentIdentity
11
12 from pqc_federated_learning import (
13 AggregationRound,
14 ClientUpdate,
15 ClientUpdateMetadata,
16 FedAvgAggregator,
17 FederatedAggregator,
18 GradientTensor,
19 UpdateSigner,
20 )
21
22
23 def build_update(
24 identity: AgentIdentity,
25 round_id: str,
26 model_id: str,
27 num_samples: int,
28 ) -> ClientUpdate:
29 tensors = [
30 GradientTensor(name="w", shape=(2,), values=(0.1, 0.2)),
31 ]
32 meta = ClientUpdateMetadata(
33 client_did=identity.did,
34 round_id=round_id,
35 model_id=model_id,
36 num_samples=num_samples,
37 )
38 update = ClientUpdate.create(meta, tensors)
39 return UpdateSigner(identity).sign(update)
40
41
42 def main() -> None:
43 round_id = "round-42"
44 model_id = "fraud-detector-v1"
45
46 honest = AgentIdentity.create("honest-bank")
47 attacker = AgentIdentity.create("evil-bank")
48
49 honest_update = build_update(honest, round_id, model_id, 1000)
50 attacker_update = build_update(attacker, round_id, model_id, 5000)
51
52 # Attacker corrupts their own signature bytes (simulates forgery / tampered transit).
53 attacker_update.signature = "00" * 64
54
55 aggregator_id = AgentIdentity.create("regulator-aggregator")
56 round_ = AggregationRound(round_id=round_id, model_id=model_id)
57 round_.add(honest_update)
58 round_.add(attacker_update)
59
60 aggregator = FederatedAggregator(
61 identity=aggregator_id,
62 strategy=FedAvgAggregator(),
63 min_updates=1,
64 )
65 result = aggregator.aggregate(round_)
66
67 print("--- Round summary ---")
68 print(f" total submissions = {len(round_.updates)}")
69 print(f" included = {len(result.proof.included_client_dids)}")
70 print(f" excluded = {len(result.proof.excluded_reasons)}")
71 print()
72 for did, reason in result.proof.excluded_reasons.items():
73 print(f" [EXCLUDED] {did}")
74 print(f" reason: {reason}")
75
76 # The attacker must have been excluded
77 assert honest.did in result.proof.included_client_dids
78 assert attacker.did in result.proof.excluded_reasons
79 assert FederatedAggregator.verify_proof(result.proof)
80
81 print("\n[OK] Attacker excluded. Aggregation proof is valid and PQ-signed.")
82
83
84 if __name__ == "__main__":
85 main()
86