tests/test_drift.py
2.0 KB · 67 lines · python Raw
1 """Tests for expected_hash drift detection."""
2
3 from __future__ import annotations
4
5 import pytest
6
7 from pqc_hypervisor_attestation import (
8 AttestationClaim,
9 AttestationReport,
10 AttestationVerifier,
11 Attester,
12 MemoryRegion,
13 RegionSnapshot,
14 )
15 from pqc_hypervisor_attestation.errors import RegionDriftError
16
17
18 def _signed_report(
19 attester: Attester,
20 content: bytes,
21 expected: str,
22 ) -> AttestationReport:
23 region = MemoryRegion(
24 region_id="w0",
25 description="weights",
26 address=0x1000,
27 size=len(content),
28 protection="RO",
29 )
30 snap = RegionSnapshot.create("w0", content)
31 claim = AttestationClaim.create(
32 region=region,
33 snapshot=snap,
34 expected_hash=expected,
35 )
36 report = AttestationReport.create(claims=[claim], ttl_seconds=60)
37 return attester.sign(report)
38
39
40 def test_matching_expected_hash_verifies(attester: Attester) -> None:
41 content = b"trusted-weights"
42 expected = RegionSnapshot.hash_bytes(content)
43 report = _signed_report(attester, content, expected)
44 result = AttestationVerifier.verify(report)
45 assert result.valid is True
46 assert result.drifts == []
47
48
49 def test_drift_fails_in_strict_mode(attester: Attester) -> None:
50 expected = RegionSnapshot.hash_bytes(b"trusted-weights")
51 # Real snapshot taken over tampered bytes.
52 report = _signed_report(attester, b"TAMPERED-weights", expected)
53 result = AttestationVerifier.verify(report, strict=True)
54 assert result.signature_valid is True
55 assert result.drifts == ["w0"]
56 assert result.valid is False
57 with pytest.raises(RegionDriftError):
58 AttestationVerifier.verify_or_raise(report, strict=True)
59
60
61 def test_drift_allowed_in_non_strict_mode(attester: Attester) -> None:
62 expected = RegionSnapshot.hash_bytes(b"trusted-weights")
63 report = _signed_report(attester, b"TAMPERED-weights", expected)
64 result = AttestationVerifier.verify(report, strict=False)
65 assert result.valid is True
66 assert result.drifts == ["w0"]
67