tests/test_update.py
| 1 | """Tests for ClientUpdate / GradientTensor / ClientUpdateMetadata.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import pytest |
| 6 | |
| 7 | from pqc_federated_learning import ( |
| 8 | ClientUpdate, |
| 9 | ClientUpdateMetadata, |
| 10 | GradientTensor, |
| 11 | ) |
| 12 | |
| 13 | |
| 14 | def test_gradient_tensor_shape_validation() -> None: |
| 15 | # Mismatched shape/values should raise |
| 16 | with pytest.raises(ValueError): |
| 17 | GradientTensor(name="w", shape=(2, 2), values=(1.0, 2.0, 3.0)) |
| 18 | |
| 19 | |
| 20 | def test_gradient_tensor_roundtrip() -> None: |
| 21 | t = GradientTensor(name="w", shape=(2, 2), values=(1.0, 2.0, 3.0, 4.0)) |
| 22 | d = t.to_dict() |
| 23 | t2 = GradientTensor.from_dict(d) |
| 24 | assert t == t2 |
| 25 | |
| 26 | |
| 27 | def test_content_hash_is_deterministic() -> None: |
| 28 | meta = ClientUpdateMetadata( |
| 29 | client_did="did:pqaid:abc", |
| 30 | round_id="r1", |
| 31 | model_id="m1", |
| 32 | num_samples=10, |
| 33 | ) |
| 34 | tensors = [GradientTensor(name="w", shape=(2,), values=(1.0, 2.0))] |
| 35 | h1 = ClientUpdate.compute_content_hash(meta, tensors, "2026-01-01T00:00:00+00:00") |
| 36 | h2 = ClientUpdate.compute_content_hash(meta, tensors, "2026-01-01T00:00:00+00:00") |
| 37 | assert h1 == h2 |
| 38 | assert len(h1) == 64 # sha3-256 hex |
| 39 | |
| 40 | |
| 41 | def test_content_hash_changes_with_values() -> None: |
| 42 | meta = ClientUpdateMetadata( |
| 43 | client_did="did:pqaid:abc", |
| 44 | round_id="r1", |
| 45 | model_id="m1", |
| 46 | num_samples=10, |
| 47 | ) |
| 48 | t1 = [GradientTensor(name="w", shape=(2,), values=(1.0, 2.0))] |
| 49 | t2 = [GradientTensor(name="w", shape=(2,), values=(1.0, 3.0))] |
| 50 | h1 = ClientUpdate.compute_content_hash(meta, t1, "2026-01-01T00:00:00+00:00") |
| 51 | h2 = ClientUpdate.compute_content_hash(meta, t2, "2026-01-01T00:00:00+00:00") |
| 52 | assert h1 != h2 |
| 53 | |
| 54 | |
| 55 | def test_create_populates_content_hash() -> None: |
| 56 | meta = ClientUpdateMetadata( |
| 57 | client_did="did:pqaid:abc", |
| 58 | round_id="r1", |
| 59 | model_id="m1", |
| 60 | num_samples=10, |
| 61 | ) |
| 62 | tensors = [GradientTensor(name="w", shape=(2,), values=(1.0, 2.0))] |
| 63 | u = ClientUpdate.create(meta, tensors) |
| 64 | assert u.content_hash != "" |
| 65 | assert u.created_at != "" |
| 66 | # content hash matches computation |
| 67 | assert u.content_hash == ClientUpdate.compute_content_hash(meta, tensors, u.created_at) |
| 68 | |
| 69 | |
| 70 | def test_client_update_roundtrip() -> None: |
| 71 | meta = ClientUpdateMetadata( |
| 72 | client_did="did:pqaid:abc", |
| 73 | round_id="r1", |
| 74 | model_id="m1", |
| 75 | num_samples=10, |
| 76 | epochs=3, |
| 77 | local_loss=0.25, |
| 78 | ) |
| 79 | tensors = [GradientTensor(name="w", shape=(2,), values=(1.0, 2.0))] |
| 80 | u = ClientUpdate.create(meta, tensors) |
| 81 | u.signer_did = "did:pqaid:abc" |
| 82 | u.algorithm = "ML-DSA-65" |
| 83 | u.signature = "deadbeef" |
| 84 | u.public_key = "cafe" |
| 85 | u.signed_at = "2026-01-01T00:00:00+00:00" |
| 86 | |
| 87 | d = u.to_dict() |
| 88 | u2 = ClientUpdate.from_dict(d) |
| 89 | assert u2.metadata == u.metadata |
| 90 | assert u2.tensors == u.tensors |
| 91 | assert u2.content_hash == u.content_hash |
| 92 | assert u2.signature == u.signature |
| 93 | assert u2.algorithm == u.algorithm |
| 94 | |