tests/test_signer.py
2.9 KB · 96 lines · python Raw
1 """Tests for Attester and AttestationVerifier."""
2
3 from __future__ import annotations
4
5 import time
6
7 import pytest
8
9 from pqc_hypervisor_attestation import (
10 AttestationClaim,
11 AttestationReport,
12 AttestationVerifier,
13 Attester,
14 MemoryRegion,
15 RegionSnapshot,
16 )
17 from pqc_hypervisor_attestation.errors import AttestationVerificationError
18
19
20 def _make_report() -> AttestationReport:
21 region = MemoryRegion(
22 region_id="w0",
23 description="weights",
24 address=0x1000,
25 size=4,
26 protection="RO",
27 )
28 snap = RegionSnapshot.create("w0", b"\x00\x01\x02\x03")
29 claim = AttestationClaim.create(region=region, snapshot=snap)
30 return AttestationReport.create(claims=[claim], ttl_seconds=60)
31
32
33 def test_sign_populates_fields(attester: Attester) -> None:
34 report = _make_report()
35 signed = attester.sign(report)
36 assert signed.signature
37 assert signed.public_key
38 assert signed.algorithm
39 assert signed.signer_did == attester.identity.did
40
41
42 def test_verify_success(attester: Attester) -> None:
43 report = attester.sign(_make_report())
44 result = AttestationVerifier.verify(report)
45 assert result.valid is True
46 assert result.signature_valid is True
47 assert result.not_expired is True
48 assert result.drifts == []
49
50
51 def test_signature_tamper_detected(attester: Attester) -> None:
52 report = attester.sign(_make_report())
53 # Flip the first hex char of the signature — that invalidates the sig
54 # even under Ed25519 transitional mode.
55 first = report.signature[0]
56 flipped = "f" if first != "f" else "0"
57 report.signature = flipped + report.signature[1:]
58 result = AttestationVerifier.verify(report)
59 assert result.signature_valid is False
60 assert result.valid is False
61
62
63 def test_expired_report_rejected(attester: Attester) -> None:
64 region = MemoryRegion(
65 region_id="w0",
66 description="weights",
67 address=0x1000,
68 size=4,
69 protection="RO",
70 )
71 snap = RegionSnapshot.create("w0", b"\x00\x01\x02\x03")
72 claim = AttestationClaim.create(region=region, snapshot=snap)
73 report = AttestationReport.create(claims=[claim], ttl_seconds=0)
74 signed = attester.sign(report)
75 time.sleep(0.1)
76 result = AttestationVerifier.verify(signed)
77 assert result.signature_valid is True
78 assert result.not_expired is False
79 assert result.valid is False
80
81
82 def test_verify_or_raise_raises_on_failure(attester: Attester) -> None:
83 report = attester.sign(_make_report())
84 # Tamper
85 report.signature = "00" * len(bytes.fromhex(report.signature))
86 with pytest.raises(AttestationVerificationError):
87 AttestationVerifier.verify_or_raise(report)
88
89
90 def test_missing_signature_rejected() -> None:
91 report = _make_report() # unsigned
92 result = AttestationVerifier.verify(report)
93 assert result.signature_valid is False
94 assert result.valid is False
95 assert result.error == "missing signature"
96