src/pqc_hypervisor_attestation/claim.py
| 1 | """AttestationClaim and AttestationReport — signed memory-state claims.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import json |
| 6 | import uuid |
| 7 | from dataclasses import dataclass, field |
| 8 | from datetime import datetime, timedelta, timezone |
| 9 | from typing import Any |
| 10 | |
| 11 | from pqc_hypervisor_attestation.region import MemoryRegion, RegionSnapshot |
| 12 | |
| 13 | |
| 14 | @dataclass |
| 15 | class AttestationClaim: |
| 16 | """A single claim about a memory region's state.""" |
| 17 | |
| 18 | claim_id: str |
| 19 | region: MemoryRegion |
| 20 | snapshot: RegionSnapshot |
| 21 | expected_hash: str = "" # optional: the hash the claim CLAIMS to be |
| 22 | workload_id: str = "" # which AI workload this attests to |
| 23 | platform: str = "" # "amd-sev-snp" | "intel-tdx" | "in-memory" | ... |
| 24 | nonce: str = "" # random, server-supplied for freshness |
| 25 | |
| 26 | @classmethod |
| 27 | def create( |
| 28 | cls, |
| 29 | region: MemoryRegion, |
| 30 | snapshot: RegionSnapshot, |
| 31 | expected_hash: str = "", |
| 32 | workload_id: str = "", |
| 33 | platform: str = "", |
| 34 | nonce: str = "", |
| 35 | ) -> AttestationClaim: |
| 36 | return cls( |
| 37 | claim_id=f"urn:pqc-att:{uuid.uuid4().hex}", |
| 38 | region=region, |
| 39 | snapshot=snapshot, |
| 40 | expected_hash=expected_hash, |
| 41 | workload_id=workload_id, |
| 42 | platform=platform, |
| 43 | nonce=nonce, |
| 44 | ) |
| 45 | |
| 46 | def to_dict(self) -> dict[str, Any]: |
| 47 | return { |
| 48 | "claim_id": self.claim_id, |
| 49 | "region": self.region.to_dict(), |
| 50 | "snapshot": self.snapshot.to_dict(), |
| 51 | "expected_hash": self.expected_hash, |
| 52 | "workload_id": self.workload_id, |
| 53 | "platform": self.platform, |
| 54 | "nonce": self.nonce, |
| 55 | } |
| 56 | |
| 57 | @classmethod |
| 58 | def from_dict(cls, data: dict[str, Any]) -> AttestationClaim: |
| 59 | reg = data["region"] |
| 60 | snap = data["snapshot"] |
| 61 | return cls( |
| 62 | claim_id=data["claim_id"], |
| 63 | region=MemoryRegion(**reg), |
| 64 | snapshot=RegionSnapshot(**snap), |
| 65 | expected_hash=data.get("expected_hash", ""), |
| 66 | workload_id=data.get("workload_id", ""), |
| 67 | platform=data.get("platform", ""), |
| 68 | nonce=data.get("nonce", ""), |
| 69 | ) |
| 70 | |
| 71 | |
| 72 | @dataclass |
| 73 | class AttestationReport: |
| 74 | """Bundle of claims signed with a single ML-DSA signature.""" |
| 75 | |
| 76 | report_id: str |
| 77 | claims: list[AttestationClaim] = field(default_factory=list) |
| 78 | attester_id: str = "" # ID of who made the attestation |
| 79 | platform: str = "" |
| 80 | issued_at: str = "" |
| 81 | expires_at: str = "" |
| 82 | signer_did: str = "" |
| 83 | algorithm: str = "" |
| 84 | signature: str = "" # hex |
| 85 | public_key: str = "" # hex |
| 86 | |
| 87 | @classmethod |
| 88 | def create( |
| 89 | cls, |
| 90 | claims: list[AttestationClaim], |
| 91 | attester_id: str = "", |
| 92 | platform: str = "", |
| 93 | ttl_seconds: int = 300, |
| 94 | ) -> AttestationReport: |
| 95 | now = datetime.now(timezone.utc) |
| 96 | exp = now + timedelta(seconds=ttl_seconds) |
| 97 | return cls( |
| 98 | report_id=f"urn:pqc-attreport:{uuid.uuid4().hex}", |
| 99 | claims=list(claims), |
| 100 | attester_id=attester_id, |
| 101 | platform=platform, |
| 102 | issued_at=now.isoformat(), |
| 103 | expires_at=exp.isoformat(), |
| 104 | ) |
| 105 | |
| 106 | def canonical_bytes(self) -> bytes: |
| 107 | payload = { |
| 108 | "report_id": self.report_id, |
| 109 | "claims": [c.to_dict() for c in self.claims], |
| 110 | "attester_id": self.attester_id, |
| 111 | "platform": self.platform, |
| 112 | "issued_at": self.issued_at, |
| 113 | "expires_at": self.expires_at, |
| 114 | } |
| 115 | return json.dumps( |
| 116 | payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False |
| 117 | ).encode("utf-8") |
| 118 | |
| 119 | def is_expired(self) -> bool: |
| 120 | if not self.expires_at: |
| 121 | return False |
| 122 | try: |
| 123 | exp = datetime.fromisoformat(self.expires_at) |
| 124 | now = datetime.now(timezone.utc) |
| 125 | return now > exp |
| 126 | except ValueError: |
| 127 | return False |
| 128 | |
| 129 | def to_dict(self) -> dict[str, Any]: |
| 130 | return { |
| 131 | "report_id": self.report_id, |
| 132 | "claims": [c.to_dict() for c in self.claims], |
| 133 | "attester_id": self.attester_id, |
| 134 | "platform": self.platform, |
| 135 | "issued_at": self.issued_at, |
| 136 | "expires_at": self.expires_at, |
| 137 | "signer_did": self.signer_did, |
| 138 | "algorithm": self.algorithm, |
| 139 | "signature": self.signature, |
| 140 | "public_key": self.public_key, |
| 141 | } |
| 142 | |
| 143 | def to_json(self) -> str: |
| 144 | return json.dumps(self.to_dict(), indent=2) |
| 145 | |
| 146 | @classmethod |
| 147 | def from_dict(cls, data: dict[str, Any]) -> AttestationReport: |
| 148 | return cls( |
| 149 | report_id=data["report_id"], |
| 150 | claims=[AttestationClaim.from_dict(c) for c in data.get("claims", [])], |
| 151 | attester_id=data.get("attester_id", ""), |
| 152 | platform=data.get("platform", ""), |
| 153 | issued_at=data.get("issued_at", ""), |
| 154 | expires_at=data.get("expires_at", ""), |
| 155 | signer_did=data.get("signer_did", ""), |
| 156 | algorithm=data.get("algorithm", ""), |
| 157 | signature=data.get("signature", ""), |
| 158 | public_key=data.get("public_key", ""), |
| 159 | ) |
| 160 | |