tests/test_continuous.py
1.7 KB · 58 lines · python Raw
1 """Tests for ContinuousAttester."""
2
3 from __future__ import annotations
4
5 from pqc_hypervisor_attestation import (
6 Attester,
7 ContinuousAttester,
8 InMemoryBackend,
9 )
10
11 WORKLOAD_ID = "model-serving-1"
12
13
14 def test_attest_once_covers_all_regions(
15 attester: Attester, backend: InMemoryBackend
16 ) -> None:
17 loop = ContinuousAttester(
18 attester=attester, backend=backend, workload_id=WORKLOAD_ID
19 )
20 report = loop.attest_once()
21 assert report.signature
22 region_ids = {c.region.region_id for c in report.claims}
23 assert region_ids == {"model-weights-0", "activation-cache"}
24 assert report.platform == "in-memory"
25
26
27 def test_run_for_returns_expected_count(
28 attester: Attester, backend: InMemoryBackend
29 ) -> None:
30 loop = ContinuousAttester(
31 attester=attester, backend=backend, workload_id=WORKLOAD_ID
32 )
33 # seconds=2, interval=1.0 -> expect ~2 reports; allow 1..3 for clock jitter.
34 reports = loop.run_for(seconds=2, interval=1.0)
35 assert 1 <= len(reports) <= 3
36 for r in reports:
37 assert r.signature
38
39
40 def test_drift_between_calls_changes_snapshot_hash(
41 attester: Attester, backend: InMemoryBackend
42 ) -> None:
43 loop = ContinuousAttester(
44 attester=attester, backend=backend, workload_id=WORKLOAD_ID
45 )
46 first = loop.attest_once()
47 # Mutate the weights to simulate tampering.
48 backend.update("model-weights-0", b"\xcc" * 128)
49 second = loop.attest_once()
50
51 def weights_hash(report) -> str:
52 for claim in report.claims:
53 if claim.region.region_id == "model-weights-0":
54 return claim.snapshot.content_hash
55 raise AssertionError("weights region not in report")
56
57 assert weights_hash(first) != weights_hash(second)
58