tests/test_update.py
2.9 KB · 94 lines · python Raw
1 """Tests for ClientUpdate / GradientTensor / ClientUpdateMetadata."""
2
3 from __future__ import annotations
4
5 import pytest
6
7 from pqc_federated_learning import (
8 ClientUpdate,
9 ClientUpdateMetadata,
10 GradientTensor,
11 )
12
13
14 def test_gradient_tensor_shape_validation() -> None:
15 # Mismatched shape/values should raise
16 with pytest.raises(ValueError):
17 GradientTensor(name="w", shape=(2, 2), values=(1.0, 2.0, 3.0))
18
19
20 def test_gradient_tensor_roundtrip() -> None:
21 t = GradientTensor(name="w", shape=(2, 2), values=(1.0, 2.0, 3.0, 4.0))
22 d = t.to_dict()
23 t2 = GradientTensor.from_dict(d)
24 assert t == t2
25
26
27 def test_content_hash_is_deterministic() -> None:
28 meta = ClientUpdateMetadata(
29 client_did="did:pqaid:abc",
30 round_id="r1",
31 model_id="m1",
32 num_samples=10,
33 )
34 tensors = [GradientTensor(name="w", shape=(2,), values=(1.0, 2.0))]
35 h1 = ClientUpdate.compute_content_hash(meta, tensors, "2026-01-01T00:00:00+00:00")
36 h2 = ClientUpdate.compute_content_hash(meta, tensors, "2026-01-01T00:00:00+00:00")
37 assert h1 == h2
38 assert len(h1) == 64 # sha3-256 hex
39
40
41 def test_content_hash_changes_with_values() -> None:
42 meta = ClientUpdateMetadata(
43 client_did="did:pqaid:abc",
44 round_id="r1",
45 model_id="m1",
46 num_samples=10,
47 )
48 t1 = [GradientTensor(name="w", shape=(2,), values=(1.0, 2.0))]
49 t2 = [GradientTensor(name="w", shape=(2,), values=(1.0, 3.0))]
50 h1 = ClientUpdate.compute_content_hash(meta, t1, "2026-01-01T00:00:00+00:00")
51 h2 = ClientUpdate.compute_content_hash(meta, t2, "2026-01-01T00:00:00+00:00")
52 assert h1 != h2
53
54
55 def test_create_populates_content_hash() -> None:
56 meta = ClientUpdateMetadata(
57 client_did="did:pqaid:abc",
58 round_id="r1",
59 model_id="m1",
60 num_samples=10,
61 )
62 tensors = [GradientTensor(name="w", shape=(2,), values=(1.0, 2.0))]
63 u = ClientUpdate.create(meta, tensors)
64 assert u.content_hash != ""
65 assert u.created_at != ""
66 # content hash matches computation
67 assert u.content_hash == ClientUpdate.compute_content_hash(meta, tensors, u.created_at)
68
69
70 def test_client_update_roundtrip() -> None:
71 meta = ClientUpdateMetadata(
72 client_did="did:pqaid:abc",
73 round_id="r1",
74 model_id="m1",
75 num_samples=10,
76 epochs=3,
77 local_loss=0.25,
78 )
79 tensors = [GradientTensor(name="w", shape=(2,), values=(1.0, 2.0))]
80 u = ClientUpdate.create(meta, tensors)
81 u.signer_did = "did:pqaid:abc"
82 u.algorithm = "ML-DSA-65"
83 u.signature = "deadbeef"
84 u.public_key = "cafe"
85 u.signed_at = "2026-01-01T00:00:00+00:00"
86
87 d = u.to_dict()
88 u2 = ClientUpdate.from_dict(d)
89 assert u2.metadata == u.metadata
90 assert u2.tensors == u.tensors
91 assert u2.content_hash == u.content_hash
92 assert u2.signature == u.signature
93 assert u2.algorithm == u.algorithm
94