src/pqc_hypervisor_attestation/signer.py
3.8 KB · 117 lines · python Raw
1 """Attester (signs reports) and AttestationVerifier (checks them)."""
2
3 from __future__ import annotations
4
5 import hashlib
6 from dataclasses import dataclass
7
8 from quantumshield.core.algorithms import SignatureAlgorithm
9 from quantumshield.core.signatures import sign, verify
10 from quantumshield.identity.agent import AgentIdentity
11
12 from pqc_hypervisor_attestation.claim import AttestationReport
13 from pqc_hypervisor_attestation.errors import (
14 AttestationVerificationError,
15 RegionDriftError,
16 )
17
18
19 @dataclass(frozen=True)
20 class VerificationResult:
21 """Outcome of verifying an AttestationReport."""
22
23 valid: bool
24 signature_valid: bool
25 not_expired: bool
26 drifts: list[str] # region_ids whose snapshot != expected
27 error: str | None = None
28
29
30 class Attester:
31 """Sign AttestationReports with an AgentIdentity."""
32
33 def __init__(self, identity: AgentIdentity):
34 self.identity = identity
35
36 def sign(self, report: AttestationReport) -> AttestationReport:
37 canonical = report.canonical_bytes()
38 digest = hashlib.sha3_256(canonical).digest()
39 sig = sign(digest, self.identity.signing_keypair)
40 report.signer_did = self.identity.did
41 report.algorithm = self.identity.signing_keypair.algorithm.value
42 report.signature = sig.hex()
43 report.public_key = self.identity.signing_keypair.public_key.hex()
44 return report
45
46
47 class AttestationVerifier:
48 """Independently verify AttestationReports.
49
50 Checks:
51 - ML-DSA signature over canonical bytes
52 - Report not expired
53 - Each claim's snapshot.content_hash matches its expected_hash (if set)
54 """
55
56 @staticmethod
57 def verify(
58 report: AttestationReport,
59 strict: bool = True,
60 ) -> VerificationResult:
61 # 1. Signature check
62 sig_ok = False
63 err: str | None = None
64 if not report.signature:
65 err = "missing signature"
66 else:
67 try:
68 algorithm = SignatureAlgorithm(report.algorithm)
69 digest = hashlib.sha3_256(report.canonical_bytes()).digest()
70 sig_ok = verify(
71 digest,
72 bytes.fromhex(report.signature),
73 bytes.fromhex(report.public_key),
74 algorithm,
75 )
76 if not sig_ok:
77 err = "invalid ML-DSA signature"
78 except ValueError:
79 err = f"unknown algorithm {report.algorithm}"
80 except Exception as exc: # noqa: BLE001 - surface backend failures uniformly
81 err = f"signature verify failed: {exc}"
82
83 # 2. Expiry check
84 fresh = not report.is_expired()
85 if sig_ok and not fresh:
86 err = "report expired"
87
88 # 3. Drift check: each claim's snapshot vs expected_hash
89 drifts: list[str] = []
90 for claim in report.claims:
91 if claim.expected_hash and claim.snapshot.content_hash != claim.expected_hash:
92 drifts.append(claim.region.region_id)
93 if sig_ok and fresh and drifts and strict:
94 err = (
95 f"memory drift detected in {len(drifts)} region(s): "
96 f"{', '.join(drifts)}"
97 )
98
99 valid = sig_ok and fresh and (not drifts or not strict)
100 return VerificationResult(
101 valid=valid,
102 signature_valid=sig_ok,
103 not_expired=fresh,
104 drifts=drifts,
105 error=err,
106 )
107
108 @staticmethod
109 def verify_or_raise(report: AttestationReport, strict: bool = True) -> None:
110 result = AttestationVerifier.verify(report, strict=strict)
111 if not result.valid:
112 if result.drifts:
113 raise RegionDriftError(result.error or "drift detected")
114 raise AttestationVerificationError(
115 result.error or "verification failed"
116 )
117