tests/test_integration.py
2.9 KB · 84 lines · python Raw
1 """End-to-end integration tests."""
2
3 from __future__ import annotations
4
5 from quantumshield.identity.agent import AgentIdentity
6
7 from pqc_federated_learning import (
8 AggregationRound,
9 FedAvgAggregator,
10 FedMedianAggregator,
11 FederatedAggregator,
12 )
13 from tests.conftest import make_signed_update
14
15
16 def test_end_to_end_fedavg_five_clients() -> None:
17 aggregator_id = AgentIdentity.create("aggregator")
18 client_ids = [AgentIdentity.create(f"client-{i}") for i in range(5)]
19
20 round_ = AggregationRound(round_id="r42", model_id="clf-v3")
21 for cid in client_ids:
22 update = make_signed_update(
23 cid, round_id="r42", model_id="clf-v3", num_samples=100
24 )
25 round_.add(update)
26
27 aggregator = FederatedAggregator(
28 identity=aggregator_id,
29 strategy=FedAvgAggregator(),
30 trusted_clients={c.did for c in client_ids},
31 )
32 result = aggregator.aggregate(round_)
33
34 assert len(result.aggregated) == 2
35 assert len(result.proof.included_client_dids) == 5
36 assert not result.proof.excluded_reasons
37 assert FederatedAggregator.verify_proof(result.proof) is True
38
39
40 def test_byzantine_bad_signature_excluded() -> None:
41 aggregator_id = AgentIdentity.create("aggregator")
42 honest = AgentIdentity.create("honest")
43 byz = AgentIdentity.create("byzantine")
44
45 good = make_signed_update(honest)
46 bad = make_signed_update(byz)
47 # Corrupt signature bytes
48 bad.signature = "00" * 64
49
50 round_ = AggregationRound(round_id="round-1", model_id="model-x")
51 round_.add(good)
52 round_.add(bad)
53
54 aggregator = FederatedAggregator(
55 identity=aggregator_id, strategy=FedAvgAggregator()
56 )
57 result = aggregator.aggregate(round_)
58 assert honest.did in result.proof.included_client_dids
59 assert byz.did in result.proof.excluded_reasons
60 assert FederatedAggregator.verify_proof(result.proof) is True
61
62
63 def test_fedmedian_survives_malicious_client() -> None:
64 aggregator_id = AgentIdentity.create("aggregator")
65 honest_ids = [AgentIdentity.create(f"honest-{i}") for i in range(4)]
66 malicious = AgentIdentity.create("malicious")
67
68 round_ = AggregationRound(round_id="r1", model_id="m1")
69 for cid in honest_ids:
70 round_.add(make_signed_update(cid, round_id="r1", model_id="m1", values_scale=1.0))
71 # Malicious client sends extreme values but signs correctly.
72 round_.add(make_signed_update(malicious, round_id="r1", model_id="m1", values_scale=1_000.0))
73
74 aggregator = FederatedAggregator(
75 identity=aggregator_id, strategy=FedMedianAggregator()
76 )
77 result = aggregator.aggregate(round_)
78 # Median should still reflect the honest 1.0-scale values.
79 w = next(t for t in result.aggregated if t.name == "dense_1.weights")
80 expected = (0.1, 0.2, 0.3, 0.4)
81 for got, want in zip(w.values, expected):
82 assert abs(got - want) < 1e-9
83 assert FederatedAggregator.verify_proof(result.proof) is True
84