src/pqc_gpu_driver/driver_attest.py
| 1 | """Driver module attestation with ML-DSA. |
| 2 | |
| 3 | Every GPU driver / kernel module loaded into the AI inference system gets an |
| 4 | ML-DSA signature over its bytecode hash. At load time, the verifier checks the |
| 5 | signature against an allow-list of signers before permitting the module to load. |
| 6 | """ |
| 7 | |
| 8 | from __future__ import annotations |
| 9 | |
| 10 | import hashlib |
| 11 | import json |
| 12 | from dataclasses import asdict, dataclass |
| 13 | from datetime import datetime, timezone |
| 14 | from typing import Any |
| 15 | |
| 16 | from quantumshield.core.algorithms import SignatureAlgorithm |
| 17 | from quantumshield.core.signatures import sign, verify |
| 18 | from quantumshield.identity.agent import AgentIdentity |
| 19 | |
| 20 | from pqc_gpu_driver.errors import DriverAttestationError |
| 21 | |
| 22 | |
| 23 | @dataclass(frozen=True) |
| 24 | class DriverModule: |
| 25 | """A GPU driver module (e.g. nvidia.ko, amdgpu.ko) binary summary.""" |
| 26 | |
| 27 | name: str |
| 28 | version: str |
| 29 | module_hash: str # hex SHA3-256 of the .ko file |
| 30 | module_size: int |
| 31 | target: str = "linux" # "linux" | "windows" | ... |
| 32 | |
| 33 | def canonical_bytes(self) -> bytes: |
| 34 | payload = { |
| 35 | "name": self.name, |
| 36 | "version": self.version, |
| 37 | "module_hash": self.module_hash, |
| 38 | "module_size": self.module_size, |
| 39 | "target": self.target, |
| 40 | } |
| 41 | return json.dumps( |
| 42 | payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False |
| 43 | ).encode("utf-8") |
| 44 | |
| 45 | @staticmethod |
| 46 | def hash_module_bytes(data: bytes) -> str: |
| 47 | return hashlib.sha3_256(data).hexdigest() |
| 48 | |
| 49 | def to_dict(self) -> dict[str, Any]: |
| 50 | return asdict(self) |
| 51 | |
| 52 | |
| 53 | @dataclass |
| 54 | class DriverAttestation: |
| 55 | """A signed claim about a DriverModule being authorized to load.""" |
| 56 | |
| 57 | module: DriverModule |
| 58 | signer_did: str = "" |
| 59 | algorithm: str = "" |
| 60 | signature: str = "" # hex |
| 61 | public_key: str = "" # hex |
| 62 | signed_at: str = "" |
| 63 | |
| 64 | def canonical_bytes(self) -> bytes: |
| 65 | """Bytes covered by the signature (module only, no signature fields).""" |
| 66 | return self.module.canonical_bytes() |
| 67 | |
| 68 | def to_dict(self) -> dict[str, Any]: |
| 69 | return { |
| 70 | "module": self.module.to_dict(), |
| 71 | "signer_did": self.signer_did, |
| 72 | "algorithm": self.algorithm, |
| 73 | "signature": self.signature, |
| 74 | "public_key": self.public_key, |
| 75 | "signed_at": self.signed_at, |
| 76 | } |
| 77 | |
| 78 | @classmethod |
| 79 | def from_dict(cls, data: dict[str, Any]) -> DriverAttestation: |
| 80 | mod = data["module"] |
| 81 | return cls( |
| 82 | module=DriverModule( |
| 83 | name=mod["name"], |
| 84 | version=mod["version"], |
| 85 | module_hash=mod["module_hash"], |
| 86 | module_size=int(mod["module_size"]), |
| 87 | target=mod.get("target", "linux"), |
| 88 | ), |
| 89 | signer_did=data.get("signer_did", ""), |
| 90 | algorithm=data.get("algorithm", ""), |
| 91 | signature=data.get("signature", ""), |
| 92 | public_key=data.get("public_key", ""), |
| 93 | signed_at=data.get("signed_at", ""), |
| 94 | ) |
| 95 | |
| 96 | |
| 97 | class DriverAttester: |
| 98 | """Signs DriverModule attestations with an AgentIdentity.""" |
| 99 | |
| 100 | def __init__(self, identity: AgentIdentity): |
| 101 | self.identity = identity |
| 102 | |
| 103 | def attest(self, module: DriverModule) -> DriverAttestation: |
| 104 | att = DriverAttestation(module=module) |
| 105 | canonical = att.canonical_bytes() |
| 106 | digest = hashlib.sha3_256(canonical).digest() |
| 107 | sig = sign(digest, self.identity.signing_keypair) |
| 108 | att.signer_did = self.identity.did |
| 109 | att.algorithm = self.identity.signing_keypair.algorithm.value |
| 110 | att.signature = sig.hex() |
| 111 | att.public_key = self.identity.signing_keypair.public_key.hex() |
| 112 | att.signed_at = datetime.now(timezone.utc).isoformat() |
| 113 | return att |
| 114 | |
| 115 | |
| 116 | @dataclass(frozen=True) |
| 117 | class VerificationResult: |
| 118 | valid: bool |
| 119 | module_name: str |
| 120 | signer_did: str | None |
| 121 | trusted: bool |
| 122 | error: str | None = None |
| 123 | |
| 124 | |
| 125 | class DriverAttestationVerifier: |
| 126 | """Verify a DriverAttestation against an allow-list of trusted signer DIDs.""" |
| 127 | |
| 128 | def __init__(self, trusted_signers: set[str] | None = None): |
| 129 | self.trusted_signers = trusted_signers |
| 130 | |
| 131 | def verify( |
| 132 | self, |
| 133 | attestation: DriverAttestation, |
| 134 | actual_module_bytes: bytes | None = None, |
| 135 | ) -> VerificationResult: |
| 136 | # 1. Module hash must match declared hash when bytes supplied |
| 137 | if actual_module_bytes is not None: |
| 138 | actual_hash = DriverModule.hash_module_bytes(actual_module_bytes) |
| 139 | if actual_hash != attestation.module.module_hash: |
| 140 | return VerificationResult( |
| 141 | valid=False, |
| 142 | module_name=attestation.module.name, |
| 143 | signer_did=attestation.signer_did, |
| 144 | trusted=False, |
| 145 | error=( |
| 146 | f"module hash mismatch: " |
| 147 | f"declared={attestation.module.module_hash[:16]}..., " |
| 148 | f"actual={actual_hash[:16]}..." |
| 149 | ), |
| 150 | ) |
| 151 | |
| 152 | # 2. Signature must verify |
| 153 | if not attestation.signature or not attestation.algorithm: |
| 154 | return VerificationResult( |
| 155 | valid=False, |
| 156 | module_name=attestation.module.name, |
| 157 | signer_did=attestation.signer_did, |
| 158 | trusted=False, |
| 159 | error="missing signature fields", |
| 160 | ) |
| 161 | try: |
| 162 | algorithm = SignatureAlgorithm(attestation.algorithm) |
| 163 | except ValueError: |
| 164 | return VerificationResult( |
| 165 | valid=False, |
| 166 | module_name=attestation.module.name, |
| 167 | signer_did=attestation.signer_did, |
| 168 | trusted=False, |
| 169 | error=f"unknown algorithm {attestation.algorithm}", |
| 170 | ) |
| 171 | digest = hashlib.sha3_256(attestation.canonical_bytes()).digest() |
| 172 | try: |
| 173 | sig_ok = verify( |
| 174 | digest, |
| 175 | bytes.fromhex(attestation.signature), |
| 176 | bytes.fromhex(attestation.public_key), |
| 177 | algorithm, |
| 178 | ) |
| 179 | except Exception as exc: |
| 180 | return VerificationResult( |
| 181 | valid=False, |
| 182 | module_name=attestation.module.name, |
| 183 | signer_did=attestation.signer_did, |
| 184 | trusted=False, |
| 185 | error=f"signature verify failed: {exc}", |
| 186 | ) |
| 187 | if not sig_ok: |
| 188 | return VerificationResult( |
| 189 | valid=False, |
| 190 | module_name=attestation.module.name, |
| 191 | signer_did=attestation.signer_did, |
| 192 | trusted=False, |
| 193 | error="invalid ML-DSA signature", |
| 194 | ) |
| 195 | |
| 196 | # 3. Signer must be in the allow-list (if configured) |
| 197 | trusted = True |
| 198 | if self.trusted_signers is not None: |
| 199 | trusted = attestation.signer_did in self.trusted_signers |
| 200 | if not trusted: |
| 201 | return VerificationResult( |
| 202 | valid=False, |
| 203 | module_name=attestation.module.name, |
| 204 | signer_did=attestation.signer_did, |
| 205 | trusted=False, |
| 206 | error=f"signer {attestation.signer_did} not in trusted set", |
| 207 | ) |
| 208 | |
| 209 | return VerificationResult( |
| 210 | valid=True, |
| 211 | module_name=attestation.module.name, |
| 212 | signer_did=attestation.signer_did, |
| 213 | trusted=trusted, |
| 214 | error=None, |
| 215 | ) |
| 216 | |
| 217 | def verify_or_raise( |
| 218 | self, |
| 219 | attestation: DriverAttestation, |
| 220 | actual_module_bytes: bytes | None = None, |
| 221 | ) -> None: |
| 222 | result = self.verify(attestation, actual_module_bytes) |
| 223 | if not result.valid: |
| 224 | raise DriverAttestationError(result.error or "verification failed") |
| 225 | |