tests/test_aggregators.py
5.7 KB · 154 lines · python Raw
1 """Tests for aggregator strategies."""
2
3 from __future__ import annotations
4
5 import pytest
6 from quantumshield.identity.agent import AgentIdentity
7
8 from pqc_federated_learning import (
9 FedAvgAggregator,
10 FedMedianAggregator,
11 FedSumAggregator,
12 FedTrimmedMeanAggregator,
13 GradientTensor,
14 )
15 from pqc_federated_learning.errors import (
16 InsufficientUpdatesError,
17 ShapeMismatchError,
18 )
19 from tests.conftest import make_signed_update
20
21
22 def test_fedavg_weighted_mean(
23 client_a_identity: AgentIdentity, client_b_identity: AgentIdentity
24 ) -> None:
25 a = make_signed_update(client_a_identity, num_samples=100, values_scale=1.0)
26 b = make_signed_update(client_b_identity, num_samples=300, values_scale=2.0)
27 # tensors: (0.1, 0.2, 0.3, 0.4) with weight 0.25
28 # (0.2, 0.4, 0.6, 0.8) with weight 0.75
29 # mean: (0.175, 0.35, 0.525, 0.7)
30 result = FedAvgAggregator().aggregate([a, b])
31 assert len(result) == 2
32 w = next(t for t in result if t.name == "dense_1.weights")
33 expected = (0.175, 0.35, 0.525, 0.7)
34 for got, want in zip(w.values, expected):
35 assert abs(got - want) < 1e-9
36
37
38 def test_fedavg_equal_weight_when_num_samples_zero(
39 client_a_identity: AgentIdentity, client_b_identity: AgentIdentity
40 ) -> None:
41 # num_samples=0 is treated as weight 1 (min).
42 a = make_signed_update(client_a_identity, num_samples=0, values_scale=1.0)
43 b = make_signed_update(client_b_identity, num_samples=0, values_scale=3.0)
44 result = FedAvgAggregator().aggregate([a, b])
45 w = next(t for t in result if t.name == "dense_1.weights")
46 # mean of (0.1,..,0.4) and (0.3,..,1.2) => (0.2, 0.4, 0.6, 0.8)
47 expected = (0.2, 0.4, 0.6, 0.8)
48 for got, want in zip(w.values, expected):
49 assert abs(got - want) < 1e-9
50
51
52 def test_fedsum_plain_sum(
53 client_a_identity: AgentIdentity, client_b_identity: AgentIdentity
54 ) -> None:
55 a = make_signed_update(client_a_identity, values_scale=1.0)
56 b = make_signed_update(client_b_identity, values_scale=2.0)
57 result = FedSumAggregator().aggregate([a, b])
58 w = next(t for t in result if t.name == "dense_1.weights")
59 # (0.1+0.2, 0.2+0.4, 0.3+0.6, 0.4+0.8) = (0.3, 0.6, 0.9, 1.2)
60 expected = (0.3, 0.6, 0.9, 1.2)
61 for got, want in zip(w.values, expected):
62 assert abs(got - want) < 1e-9
63
64
65 def test_fedmedian_robust_to_outlier(
66 client_a_identity: AgentIdentity,
67 client_b_identity: AgentIdentity,
68 client_c_identity: AgentIdentity,
69 ) -> None:
70 a = make_signed_update(client_a_identity, values_scale=1.0)
71 b = make_signed_update(client_b_identity, values_scale=1.0)
72 # Outlier: 100x scale
73 c = make_signed_update(client_c_identity, values_scale=100.0)
74 result = FedMedianAggregator().aggregate([a, b, c])
75 w = next(t for t in result if t.name == "dense_1.weights")
76 # Median of (v, v, 100v) with v=(0.1..0.4) -> v
77 expected = (0.1, 0.2, 0.3, 0.4)
78 for got, want in zip(w.values, expected):
79 assert abs(got - want) < 1e-9
80
81
82 def test_fedtrimmedmean_drops_extremes(
83 client_a_identity: AgentIdentity,
84 client_b_identity: AgentIdentity,
85 client_c_identity: AgentIdentity,
86 ) -> None:
87 # 5 clients; trim 20% = drop 1 low, 1 high.
88 ids = [
89 client_a_identity,
90 client_b_identity,
91 client_c_identity,
92 AgentIdentity.create("d"),
93 AgentIdentity.create("e"),
94 ]
95 scales = [1.0, 1.0, 1.0, 1.0, 100.0] # last is outlier
96 updates = [make_signed_update(i, values_scale=s) for i, s in zip(ids, scales)]
97 agg = FedTrimmedMeanAggregator(trim_ratio=0.2).aggregate(updates)
98 w = next(t for t in agg if t.name == "dense_1.weights")
99 # After trim, only three 1.0-scale updates remain -> mean = base values.
100 expected = (0.1, 0.2, 0.3, 0.4)
101 for got, want in zip(w.values, expected):
102 assert abs(got - want) < 1e-9
103
104
105 def test_fedtrimmedmean_rejects_bad_ratio() -> None:
106 with pytest.raises(ValueError):
107 FedTrimmedMeanAggregator(trim_ratio=0.5)
108 with pytest.raises(ValueError):
109 FedTrimmedMeanAggregator(trim_ratio=-0.01)
110
111
112 def test_empty_updates_raises() -> None:
113 with pytest.raises(InsufficientUpdatesError):
114 FedAvgAggregator().aggregate([])
115 with pytest.raises(InsufficientUpdatesError):
116 FedSumAggregator().aggregate([])
117 with pytest.raises(InsufficientUpdatesError):
118 FedMedianAggregator().aggregate([])
119 with pytest.raises(InsufficientUpdatesError):
120 FedTrimmedMeanAggregator().aggregate([])
121
122
123 def test_shape_mismatch_raises(
124 client_a_identity: AgentIdentity, client_b_identity: AgentIdentity
125 ) -> None:
126 a = make_signed_update(client_a_identity)
127 # b uses a different shape for same-named tensor
128 bad_tensors = [
129 GradientTensor(name="dense_1.weights", shape=(4,), values=(0.1, 0.2, 0.3, 0.4)),
130 GradientTensor(name="dense_1.bias", shape=(2,), values=(0.01, 0.02)),
131 ]
132 b = make_signed_update(client_b_identity, tensors=bad_tensors)
133 with pytest.raises(ShapeMismatchError):
134 FedAvgAggregator().aggregate([a, b])
135 with pytest.raises(ShapeMismatchError):
136 FedSumAggregator().aggregate([a, b])
137
138
139 def test_tensor_name_mismatch_raises(
140 client_a_identity: AgentIdentity, client_b_identity: AgentIdentity
141 ) -> None:
142 a = make_signed_update(client_a_identity)
143 bad_tensors = [
144 GradientTensor(name="other.weights", shape=(2, 2), values=(0.1, 0.2, 0.3, 0.4)),
145 GradientTensor(name="other.bias", shape=(2,), values=(0.01, 0.02)),
146 ]
147 b = make_signed_update(client_b_identity, tensors=bad_tensors)
148 with pytest.raises(ShapeMismatchError):
149 FedAvgAggregator().aggregate([a, b])
150 with pytest.raises(ShapeMismatchError):
151 FedMedianAggregator().aggregate([a, b])
152 with pytest.raises(ShapeMismatchError):
153 FedTrimmedMeanAggregator().aggregate([a, b])
154