tests/conftest.py
2.1 KB · 84 lines · python Raw
1 """Pytest fixtures for pqc-federated-learning tests."""
2
3 from __future__ import annotations
4
5 import pytest
6 from quantumshield.identity.agent import AgentIdentity
7
8 from pqc_federated_learning import (
9 ClientUpdate,
10 ClientUpdateMetadata,
11 GradientTensor,
12 UpdateSigner,
13 )
14
15
16 @pytest.fixture
17 def aggregator_identity() -> AgentIdentity:
18 return AgentIdentity.create("test-aggregator")
19
20
21 @pytest.fixture
22 def client_a_identity() -> AgentIdentity:
23 return AgentIdentity.create("client-a")
24
25
26 @pytest.fixture
27 def client_b_identity() -> AgentIdentity:
28 return AgentIdentity.create("client-b")
29
30
31 @pytest.fixture
32 def client_c_identity() -> AgentIdentity:
33 return AgentIdentity.create("client-c")
34
35
36 @pytest.fixture
37 def attacker_identity() -> AgentIdentity:
38 return AgentIdentity.create("evil-attacker")
39
40
41 @pytest.fixture
42 def sample_tensors() -> list[GradientTensor]:
43 return [
44 GradientTensor(name="dense_1.weights", shape=(2, 2), values=(0.1, 0.2, 0.3, 0.4)),
45 GradientTensor(name="dense_1.bias", shape=(2,), values=(0.01, 0.02)),
46 ]
47
48
49 def make_signed_update(
50 identity: AgentIdentity,
51 round_id: str = "round-1",
52 model_id: str = "model-x",
53 num_samples: int = 100,
54 values_scale: float = 1.0,
55 tensors: list[GradientTensor] | None = None,
56 ) -> ClientUpdate:
57 """Factory for a signed ClientUpdate."""
58 if tensors is None:
59 tensors = [
60 GradientTensor(
61 name="dense_1.weights",
62 shape=(2, 2),
63 values=tuple(v * values_scale for v in (0.1, 0.2, 0.3, 0.4)),
64 ),
65 GradientTensor(
66 name="dense_1.bias",
67 shape=(2,),
68 values=tuple(v * values_scale for v in (0.01, 0.02)),
69 ),
70 ]
71 meta = ClientUpdateMetadata(
72 client_did=identity.did,
73 round_id=round_id,
74 model_id=model_id,
75 num_samples=num_samples,
76 )
77 update = ClientUpdate.create(meta, tensors)
78 return UpdateSigner(identity).sign(update)
79
80
81 @pytest.fixture
82 def signed_update_factory():
83 return make_signed_update
84