tests/test_aggregator.py
5.5 KB · 180 lines · python Raw
1 """Tests for FederatedAggregator, AggregationRound, and AggregationProof."""
2
3 from __future__ import annotations
4
5 import pytest
6 from quantumshield.identity.agent import AgentIdentity
7
8 from pqc_federated_learning import (
9 AggregationRound,
10 FedAvgAggregator,
11 FederatedAggregator,
12 )
13 from pqc_federated_learning.errors import (
14 AggregationError,
15 InsufficientUpdatesError,
16 )
17 from tests.conftest import make_signed_update
18
19
20 def test_verified_updates_aggregate(
21 aggregator_identity: AgentIdentity,
22 client_a_identity: AgentIdentity,
23 client_b_identity: AgentIdentity,
24 ) -> None:
25 a = make_signed_update(client_a_identity)
26 b = make_signed_update(client_b_identity)
27 round_ = AggregationRound(round_id="round-1", model_id="model-x")
28 round_.add(a)
29 round_.add(b)
30
31 agg = FederatedAggregator(
32 identity=aggregator_identity, strategy=FedAvgAggregator()
33 )
34 result = agg.aggregate(round_)
35 assert len(result.aggregated) == 2
36 assert result.proof.num_tensors == 2
37 assert result.proof.signature != ""
38 assert set(result.proof.included_client_dids) == {
39 client_a_identity.did,
40 client_b_identity.did,
41 }
42 assert not result.proof.excluded_reasons
43
44
45 def test_unsigned_update_excluded(
46 aggregator_identity: AgentIdentity,
47 client_a_identity: AgentIdentity,
48 client_b_identity: AgentIdentity,
49 ) -> None:
50 good = make_signed_update(client_a_identity)
51 bad = make_signed_update(client_b_identity)
52 # Strip signature from bad
53 bad.signature = ""
54
55 round_ = AggregationRound(round_id="round-1", model_id="model-x")
56 round_.add(good)
57 round_.add(bad)
58
59 agg = FederatedAggregator(
60 identity=aggregator_identity, strategy=FedAvgAggregator()
61 )
62 result = agg.aggregate(round_)
63 assert client_a_identity.did in result.proof.included_client_dids
64 assert client_b_identity.did in result.proof.excluded_reasons
65
66
67 def test_untrusted_client_excluded(
68 aggregator_identity: AgentIdentity,
69 client_a_identity: AgentIdentity,
70 attacker_identity: AgentIdentity,
71 ) -> None:
72 good = make_signed_update(client_a_identity)
73 attacker = make_signed_update(attacker_identity)
74
75 round_ = AggregationRound(round_id="round-1", model_id="model-x")
76 round_.add(good)
77 round_.add(attacker)
78
79 agg = FederatedAggregator(
80 identity=aggregator_identity,
81 strategy=FedAvgAggregator(),
82 trusted_clients={client_a_identity.did},
83 )
84 result = agg.aggregate(round_)
85 assert client_a_identity.did in result.proof.included_client_dids
86 assert attacker_identity.did in result.proof.excluded_reasons
87 assert result.proof.excluded_reasons[attacker_identity.did] == "client not in trusted set"
88
89
90 def test_min_updates_enforced(
91 aggregator_identity: AgentIdentity,
92 client_a_identity: AgentIdentity,
93 ) -> None:
94 a = make_signed_update(client_a_identity)
95 round_ = AggregationRound(round_id="round-1", model_id="model-x")
96 round_.add(a)
97
98 agg = FederatedAggregator(
99 identity=aggregator_identity,
100 strategy=FedAvgAggregator(),
101 min_updates=2,
102 )
103 with pytest.raises(InsufficientUpdatesError):
104 agg.aggregate(round_)
105
106
107 def test_proof_signed_and_verifiable(
108 aggregator_identity: AgentIdentity,
109 client_a_identity: AgentIdentity,
110 client_b_identity: AgentIdentity,
111 ) -> None:
112 a = make_signed_update(client_a_identity)
113 b = make_signed_update(client_b_identity)
114 round_ = AggregationRound(round_id="round-1", model_id="model-x")
115 round_.add(a)
116 round_.add(b)
117
118 agg = FederatedAggregator(
119 identity=aggregator_identity, strategy=FedAvgAggregator()
120 )
121 result = agg.aggregate(round_)
122 assert FederatedAggregator.verify_proof(result.proof) is True
123 assert result.proof.signer_did == aggregator_identity.did
124
125
126 def test_tampered_proof_rejected(
127 aggregator_identity: AgentIdentity,
128 client_a_identity: AgentIdentity,
129 client_b_identity: AgentIdentity,
130 ) -> None:
131 a = make_signed_update(client_a_identity)
132 b = make_signed_update(client_b_identity)
133 round_ = AggregationRound(round_id="round-1", model_id="model-x")
134 round_.add(a)
135 round_.add(b)
136
137 agg = FederatedAggregator(
138 identity=aggregator_identity, strategy=FedAvgAggregator()
139 )
140 result = agg.aggregate(round_)
141 # Flip result hash
142 result.proof.result_hash = "0" * 64
143 assert FederatedAggregator.verify_proof(result.proof) is False
144
145
146 def test_round_add_validates_round_and_model(
147 client_a_identity: AgentIdentity,
148 ) -> None:
149 round_ = AggregationRound(round_id="round-1", model_id="model-x")
150
151 wrong_round = make_signed_update(client_a_identity, round_id="round-2", model_id="model-x")
152 with pytest.raises(AggregationError):
153 round_.add(wrong_round)
154
155 wrong_model = make_signed_update(client_a_identity, round_id="round-1", model_id="model-y")
156 with pytest.raises(AggregationError):
157 round_.add(wrong_model)
158
159
160 def test_proof_roundtrip_to_dict(
161 aggregator_identity: AgentIdentity,
162 client_a_identity: AgentIdentity,
163 client_b_identity: AgentIdentity,
164 ) -> None:
165 a = make_signed_update(client_a_identity)
166 b = make_signed_update(client_b_identity)
167 round_ = AggregationRound(round_id="round-1", model_id="model-x")
168 round_.add(a)
169 round_.add(b)
170
171 agg = FederatedAggregator(
172 identity=aggregator_identity, strategy=FedAvgAggregator()
173 )
174 result = agg.aggregate(round_)
175 from pqc_federated_learning import AggregationProof
176
177 d = result.proof.to_dict()
178 proof2 = AggregationProof.from_dict(d)
179 assert FederatedAggregator.verify_proof(proof2) is True
180