tests/test_continuous.py
| 1 | """Tests for ContinuousAttester.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from pqc_hypervisor_attestation import ( |
| 6 | Attester, |
| 7 | ContinuousAttester, |
| 8 | InMemoryBackend, |
| 9 | ) |
| 10 | |
| 11 | WORKLOAD_ID = "model-serving-1" |
| 12 | |
| 13 | |
| 14 | def test_attest_once_covers_all_regions( |
| 15 | attester: Attester, backend: InMemoryBackend |
| 16 | ) -> None: |
| 17 | loop = ContinuousAttester( |
| 18 | attester=attester, backend=backend, workload_id=WORKLOAD_ID |
| 19 | ) |
| 20 | report = loop.attest_once() |
| 21 | assert report.signature |
| 22 | region_ids = {c.region.region_id for c in report.claims} |
| 23 | assert region_ids == {"model-weights-0", "activation-cache"} |
| 24 | assert report.platform == "in-memory" |
| 25 | |
| 26 | |
| 27 | def test_run_for_returns_expected_count( |
| 28 | attester: Attester, backend: InMemoryBackend |
| 29 | ) -> None: |
| 30 | loop = ContinuousAttester( |
| 31 | attester=attester, backend=backend, workload_id=WORKLOAD_ID |
| 32 | ) |
| 33 | # seconds=2, interval=1.0 -> expect ~2 reports; allow 1..3 for clock jitter. |
| 34 | reports = loop.run_for(seconds=2, interval=1.0) |
| 35 | assert 1 <= len(reports) <= 3 |
| 36 | for r in reports: |
| 37 | assert r.signature |
| 38 | |
| 39 | |
| 40 | def test_drift_between_calls_changes_snapshot_hash( |
| 41 | attester: Attester, backend: InMemoryBackend |
| 42 | ) -> None: |
| 43 | loop = ContinuousAttester( |
| 44 | attester=attester, backend=backend, workload_id=WORKLOAD_ID |
| 45 | ) |
| 46 | first = loop.attest_once() |
| 47 | # Mutate the weights to simulate tampering. |
| 48 | backend.update("model-weights-0", b"\xcc" * 128) |
| 49 | second = loop.attest_once() |
| 50 | |
| 51 | def weights_hash(report) -> str: |
| 52 | for claim in report.claims: |
| 53 | if claim.region.region_id == "model-weights-0": |
| 54 | return claim.snapshot.content_hash |
| 55 | raise AssertionError("weights region not in report") |
| 56 | |
| 57 | assert weights_hash(first) != weights_hash(second) |
| 58 | |