examples/byzantine_client_rejected.py
| 1 | """ |
| 2 | Byzantine Client Rejected Example |
| 3 | |
| 4 | One honest client + one attacker who forges a signature (wrong key / bad bytes). |
| 5 | The aggregator detects the invalid signature and excludes the attacker |
| 6 | from the aggregation. The signed AggregationProof lists the exclusion |
| 7 | for auditors. |
| 8 | """ |
| 9 | |
| 10 | from quantumshield.identity.agent import AgentIdentity |
| 11 | |
| 12 | from pqc_federated_learning import ( |
| 13 | AggregationRound, |
| 14 | ClientUpdate, |
| 15 | ClientUpdateMetadata, |
| 16 | FedAvgAggregator, |
| 17 | FederatedAggregator, |
| 18 | GradientTensor, |
| 19 | UpdateSigner, |
| 20 | ) |
| 21 | |
| 22 | |
| 23 | def build_update( |
| 24 | identity: AgentIdentity, |
| 25 | round_id: str, |
| 26 | model_id: str, |
| 27 | num_samples: int, |
| 28 | ) -> ClientUpdate: |
| 29 | tensors = [ |
| 30 | GradientTensor(name="w", shape=(2,), values=(0.1, 0.2)), |
| 31 | ] |
| 32 | meta = ClientUpdateMetadata( |
| 33 | client_did=identity.did, |
| 34 | round_id=round_id, |
| 35 | model_id=model_id, |
| 36 | num_samples=num_samples, |
| 37 | ) |
| 38 | update = ClientUpdate.create(meta, tensors) |
| 39 | return UpdateSigner(identity).sign(update) |
| 40 | |
| 41 | |
| 42 | def main() -> None: |
| 43 | round_id = "round-42" |
| 44 | model_id = "fraud-detector-v1" |
| 45 | |
| 46 | honest = AgentIdentity.create("honest-bank") |
| 47 | attacker = AgentIdentity.create("evil-bank") |
| 48 | |
| 49 | honest_update = build_update(honest, round_id, model_id, 1000) |
| 50 | attacker_update = build_update(attacker, round_id, model_id, 5000) |
| 51 | |
| 52 | # Attacker corrupts their own signature bytes (simulates forgery / tampered transit). |
| 53 | attacker_update.signature = "00" * 64 |
| 54 | |
| 55 | aggregator_id = AgentIdentity.create("regulator-aggregator") |
| 56 | round_ = AggregationRound(round_id=round_id, model_id=model_id) |
| 57 | round_.add(honest_update) |
| 58 | round_.add(attacker_update) |
| 59 | |
| 60 | aggregator = FederatedAggregator( |
| 61 | identity=aggregator_id, |
| 62 | strategy=FedAvgAggregator(), |
| 63 | min_updates=1, |
| 64 | ) |
| 65 | result = aggregator.aggregate(round_) |
| 66 | |
| 67 | print("--- Round summary ---") |
| 68 | print(f" total submissions = {len(round_.updates)}") |
| 69 | print(f" included = {len(result.proof.included_client_dids)}") |
| 70 | print(f" excluded = {len(result.proof.excluded_reasons)}") |
| 71 | print() |
| 72 | for did, reason in result.proof.excluded_reasons.items(): |
| 73 | print(f" [EXCLUDED] {did}") |
| 74 | print(f" reason: {reason}") |
| 75 | |
| 76 | # The attacker must have been excluded |
| 77 | assert honest.did in result.proof.included_client_dids |
| 78 | assert attacker.did in result.proof.excluded_reasons |
| 79 | assert FederatedAggregator.verify_proof(result.proof) |
| 80 | |
| 81 | print("\n[OK] Attacker excluded. Aggregation proof is valid and PQ-signed.") |
| 82 | |
| 83 | |
| 84 | if __name__ == "__main__": |
| 85 | main() |
| 86 | |