tests/test_integration.py
| 1 | """End-to-end integration tests.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from quantumshield.identity.agent import AgentIdentity |
| 6 | |
| 7 | from pqc_federated_learning import ( |
| 8 | AggregationRound, |
| 9 | FedAvgAggregator, |
| 10 | FedMedianAggregator, |
| 11 | FederatedAggregator, |
| 12 | ) |
| 13 | from tests.conftest import make_signed_update |
| 14 | |
| 15 | |
| 16 | def test_end_to_end_fedavg_five_clients() -> None: |
| 17 | aggregator_id = AgentIdentity.create("aggregator") |
| 18 | client_ids = [AgentIdentity.create(f"client-{i}") for i in range(5)] |
| 19 | |
| 20 | round_ = AggregationRound(round_id="r42", model_id="clf-v3") |
| 21 | for cid in client_ids: |
| 22 | update = make_signed_update( |
| 23 | cid, round_id="r42", model_id="clf-v3", num_samples=100 |
| 24 | ) |
| 25 | round_.add(update) |
| 26 | |
| 27 | aggregator = FederatedAggregator( |
| 28 | identity=aggregator_id, |
| 29 | strategy=FedAvgAggregator(), |
| 30 | trusted_clients={c.did for c in client_ids}, |
| 31 | ) |
| 32 | result = aggregator.aggregate(round_) |
| 33 | |
| 34 | assert len(result.aggregated) == 2 |
| 35 | assert len(result.proof.included_client_dids) == 5 |
| 36 | assert not result.proof.excluded_reasons |
| 37 | assert FederatedAggregator.verify_proof(result.proof) is True |
| 38 | |
| 39 | |
| 40 | def test_byzantine_bad_signature_excluded() -> None: |
| 41 | aggregator_id = AgentIdentity.create("aggregator") |
| 42 | honest = AgentIdentity.create("honest") |
| 43 | byz = AgentIdentity.create("byzantine") |
| 44 | |
| 45 | good = make_signed_update(honest) |
| 46 | bad = make_signed_update(byz) |
| 47 | # Corrupt signature bytes |
| 48 | bad.signature = "00" * 64 |
| 49 | |
| 50 | round_ = AggregationRound(round_id="round-1", model_id="model-x") |
| 51 | round_.add(good) |
| 52 | round_.add(bad) |
| 53 | |
| 54 | aggregator = FederatedAggregator( |
| 55 | identity=aggregator_id, strategy=FedAvgAggregator() |
| 56 | ) |
| 57 | result = aggregator.aggregate(round_) |
| 58 | assert honest.did in result.proof.included_client_dids |
| 59 | assert byz.did in result.proof.excluded_reasons |
| 60 | assert FederatedAggregator.verify_proof(result.proof) is True |
| 61 | |
| 62 | |
| 63 | def test_fedmedian_survives_malicious_client() -> None: |
| 64 | aggregator_id = AgentIdentity.create("aggregator") |
| 65 | honest_ids = [AgentIdentity.create(f"honest-{i}") for i in range(4)] |
| 66 | malicious = AgentIdentity.create("malicious") |
| 67 | |
| 68 | round_ = AggregationRound(round_id="r1", model_id="m1") |
| 69 | for cid in honest_ids: |
| 70 | round_.add(make_signed_update(cid, round_id="r1", model_id="m1", values_scale=1.0)) |
| 71 | # Malicious client sends extreme values but signs correctly. |
| 72 | round_.add(make_signed_update(malicious, round_id="r1", model_id="m1", values_scale=1_000.0)) |
| 73 | |
| 74 | aggregator = FederatedAggregator( |
| 75 | identity=aggregator_id, strategy=FedMedianAggregator() |
| 76 | ) |
| 77 | result = aggregator.aggregate(round_) |
| 78 | # Median should still reflect the honest 1.0-scale values. |
| 79 | w = next(t for t in result.aggregated if t.name == "dense_1.weights") |
| 80 | expected = (0.1, 0.2, 0.3, 0.4) |
| 81 | for got, want in zip(w.values, expected): |
| 82 | assert abs(got - want) < 1e-9 |
| 83 | assert FederatedAggregator.verify_proof(result.proof) is True |
| 84 | |