src/pqc_hypervisor_attestation/claim.py
5.1 KB · 160 lines · python Raw
1 """AttestationClaim and AttestationReport — signed memory-state claims."""
2
3 from __future__ import annotations
4
5 import json
6 import uuid
7 from dataclasses import dataclass, field
8 from datetime import datetime, timedelta, timezone
9 from typing import Any
10
11 from pqc_hypervisor_attestation.region import MemoryRegion, RegionSnapshot
12
13
14 @dataclass
15 class AttestationClaim:
16 """A single claim about a memory region's state."""
17
18 claim_id: str
19 region: MemoryRegion
20 snapshot: RegionSnapshot
21 expected_hash: str = "" # optional: the hash the claim CLAIMS to be
22 workload_id: str = "" # which AI workload this attests to
23 platform: str = "" # "amd-sev-snp" | "intel-tdx" | "in-memory" | ...
24 nonce: str = "" # random, server-supplied for freshness
25
26 @classmethod
27 def create(
28 cls,
29 region: MemoryRegion,
30 snapshot: RegionSnapshot,
31 expected_hash: str = "",
32 workload_id: str = "",
33 platform: str = "",
34 nonce: str = "",
35 ) -> AttestationClaim:
36 return cls(
37 claim_id=f"urn:pqc-att:{uuid.uuid4().hex}",
38 region=region,
39 snapshot=snapshot,
40 expected_hash=expected_hash,
41 workload_id=workload_id,
42 platform=platform,
43 nonce=nonce,
44 )
45
46 def to_dict(self) -> dict[str, Any]:
47 return {
48 "claim_id": self.claim_id,
49 "region": self.region.to_dict(),
50 "snapshot": self.snapshot.to_dict(),
51 "expected_hash": self.expected_hash,
52 "workload_id": self.workload_id,
53 "platform": self.platform,
54 "nonce": self.nonce,
55 }
56
57 @classmethod
58 def from_dict(cls, data: dict[str, Any]) -> AttestationClaim:
59 reg = data["region"]
60 snap = data["snapshot"]
61 return cls(
62 claim_id=data["claim_id"],
63 region=MemoryRegion(**reg),
64 snapshot=RegionSnapshot(**snap),
65 expected_hash=data.get("expected_hash", ""),
66 workload_id=data.get("workload_id", ""),
67 platform=data.get("platform", ""),
68 nonce=data.get("nonce", ""),
69 )
70
71
72 @dataclass
73 class AttestationReport:
74 """Bundle of claims signed with a single ML-DSA signature."""
75
76 report_id: str
77 claims: list[AttestationClaim] = field(default_factory=list)
78 attester_id: str = "" # ID of who made the attestation
79 platform: str = ""
80 issued_at: str = ""
81 expires_at: str = ""
82 signer_did: str = ""
83 algorithm: str = ""
84 signature: str = "" # hex
85 public_key: str = "" # hex
86
87 @classmethod
88 def create(
89 cls,
90 claims: list[AttestationClaim],
91 attester_id: str = "",
92 platform: str = "",
93 ttl_seconds: int = 300,
94 ) -> AttestationReport:
95 now = datetime.now(timezone.utc)
96 exp = now + timedelta(seconds=ttl_seconds)
97 return cls(
98 report_id=f"urn:pqc-attreport:{uuid.uuid4().hex}",
99 claims=list(claims),
100 attester_id=attester_id,
101 platform=platform,
102 issued_at=now.isoformat(),
103 expires_at=exp.isoformat(),
104 )
105
106 def canonical_bytes(self) -> bytes:
107 payload = {
108 "report_id": self.report_id,
109 "claims": [c.to_dict() for c in self.claims],
110 "attester_id": self.attester_id,
111 "platform": self.platform,
112 "issued_at": self.issued_at,
113 "expires_at": self.expires_at,
114 }
115 return json.dumps(
116 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
117 ).encode("utf-8")
118
119 def is_expired(self) -> bool:
120 if not self.expires_at:
121 return False
122 try:
123 exp = datetime.fromisoformat(self.expires_at)
124 now = datetime.now(timezone.utc)
125 return now > exp
126 except ValueError:
127 return False
128
129 def to_dict(self) -> dict[str, Any]:
130 return {
131 "report_id": self.report_id,
132 "claims": [c.to_dict() for c in self.claims],
133 "attester_id": self.attester_id,
134 "platform": self.platform,
135 "issued_at": self.issued_at,
136 "expires_at": self.expires_at,
137 "signer_did": self.signer_did,
138 "algorithm": self.algorithm,
139 "signature": self.signature,
140 "public_key": self.public_key,
141 }
142
143 def to_json(self) -> str:
144 return json.dumps(self.to_dict(), indent=2)
145
146 @classmethod
147 def from_dict(cls, data: dict[str, Any]) -> AttestationReport:
148 return cls(
149 report_id=data["report_id"],
150 claims=[AttestationClaim.from_dict(c) for c in data.get("claims", [])],
151 attester_id=data.get("attester_id", ""),
152 platform=data.get("platform", ""),
153 issued_at=data.get("issued_at", ""),
154 expires_at=data.get("expires_at", ""),
155 signer_did=data.get("signer_did", ""),
156 algorithm=data.get("algorithm", ""),
157 signature=data.get("signature", ""),
158 public_key=data.get("public_key", ""),
159 )
160