src/pqc_gpu_driver/driver_attest.py
7.6 KB · 225 lines · python Raw
1 """Driver module attestation with ML-DSA.
2
3 Every GPU driver / kernel module loaded into the AI inference system gets an
4 ML-DSA signature over its bytecode hash. At load time, the verifier checks the
5 signature against an allow-list of signers before permitting the module to load.
6 """
7
8 from __future__ import annotations
9
10 import hashlib
11 import json
12 from dataclasses import asdict, dataclass
13 from datetime import datetime, timezone
14 from typing import Any
15
16 from quantumshield.core.algorithms import SignatureAlgorithm
17 from quantumshield.core.signatures import sign, verify
18 from quantumshield.identity.agent import AgentIdentity
19
20 from pqc_gpu_driver.errors import DriverAttestationError
21
22
23 @dataclass(frozen=True)
24 class DriverModule:
25 """A GPU driver module (e.g. nvidia.ko, amdgpu.ko) binary summary."""
26
27 name: str
28 version: str
29 module_hash: str # hex SHA3-256 of the .ko file
30 module_size: int
31 target: str = "linux" # "linux" | "windows" | ...
32
33 def canonical_bytes(self) -> bytes:
34 payload = {
35 "name": self.name,
36 "version": self.version,
37 "module_hash": self.module_hash,
38 "module_size": self.module_size,
39 "target": self.target,
40 }
41 return json.dumps(
42 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
43 ).encode("utf-8")
44
45 @staticmethod
46 def hash_module_bytes(data: bytes) -> str:
47 return hashlib.sha3_256(data).hexdigest()
48
49 def to_dict(self) -> dict[str, Any]:
50 return asdict(self)
51
52
53 @dataclass
54 class DriverAttestation:
55 """A signed claim about a DriverModule being authorized to load."""
56
57 module: DriverModule
58 signer_did: str = ""
59 algorithm: str = ""
60 signature: str = "" # hex
61 public_key: str = "" # hex
62 signed_at: str = ""
63
64 def canonical_bytes(self) -> bytes:
65 """Bytes covered by the signature (module only, no signature fields)."""
66 return self.module.canonical_bytes()
67
68 def to_dict(self) -> dict[str, Any]:
69 return {
70 "module": self.module.to_dict(),
71 "signer_did": self.signer_did,
72 "algorithm": self.algorithm,
73 "signature": self.signature,
74 "public_key": self.public_key,
75 "signed_at": self.signed_at,
76 }
77
78 @classmethod
79 def from_dict(cls, data: dict[str, Any]) -> DriverAttestation:
80 mod = data["module"]
81 return cls(
82 module=DriverModule(
83 name=mod["name"],
84 version=mod["version"],
85 module_hash=mod["module_hash"],
86 module_size=int(mod["module_size"]),
87 target=mod.get("target", "linux"),
88 ),
89 signer_did=data.get("signer_did", ""),
90 algorithm=data.get("algorithm", ""),
91 signature=data.get("signature", ""),
92 public_key=data.get("public_key", ""),
93 signed_at=data.get("signed_at", ""),
94 )
95
96
97 class DriverAttester:
98 """Signs DriverModule attestations with an AgentIdentity."""
99
100 def __init__(self, identity: AgentIdentity):
101 self.identity = identity
102
103 def attest(self, module: DriverModule) -> DriverAttestation:
104 att = DriverAttestation(module=module)
105 canonical = att.canonical_bytes()
106 digest = hashlib.sha3_256(canonical).digest()
107 sig = sign(digest, self.identity.signing_keypair)
108 att.signer_did = self.identity.did
109 att.algorithm = self.identity.signing_keypair.algorithm.value
110 att.signature = sig.hex()
111 att.public_key = self.identity.signing_keypair.public_key.hex()
112 att.signed_at = datetime.now(timezone.utc).isoformat()
113 return att
114
115
116 @dataclass(frozen=True)
117 class VerificationResult:
118 valid: bool
119 module_name: str
120 signer_did: str | None
121 trusted: bool
122 error: str | None = None
123
124
125 class DriverAttestationVerifier:
126 """Verify a DriverAttestation against an allow-list of trusted signer DIDs."""
127
128 def __init__(self, trusted_signers: set[str] | None = None):
129 self.trusted_signers = trusted_signers
130
131 def verify(
132 self,
133 attestation: DriverAttestation,
134 actual_module_bytes: bytes | None = None,
135 ) -> VerificationResult:
136 # 1. Module hash must match declared hash when bytes supplied
137 if actual_module_bytes is not None:
138 actual_hash = DriverModule.hash_module_bytes(actual_module_bytes)
139 if actual_hash != attestation.module.module_hash:
140 return VerificationResult(
141 valid=False,
142 module_name=attestation.module.name,
143 signer_did=attestation.signer_did,
144 trusted=False,
145 error=(
146 f"module hash mismatch: "
147 f"declared={attestation.module.module_hash[:16]}..., "
148 f"actual={actual_hash[:16]}..."
149 ),
150 )
151
152 # 2. Signature must verify
153 if not attestation.signature or not attestation.algorithm:
154 return VerificationResult(
155 valid=False,
156 module_name=attestation.module.name,
157 signer_did=attestation.signer_did,
158 trusted=False,
159 error="missing signature fields",
160 )
161 try:
162 algorithm = SignatureAlgorithm(attestation.algorithm)
163 except ValueError:
164 return VerificationResult(
165 valid=False,
166 module_name=attestation.module.name,
167 signer_did=attestation.signer_did,
168 trusted=False,
169 error=f"unknown algorithm {attestation.algorithm}",
170 )
171 digest = hashlib.sha3_256(attestation.canonical_bytes()).digest()
172 try:
173 sig_ok = verify(
174 digest,
175 bytes.fromhex(attestation.signature),
176 bytes.fromhex(attestation.public_key),
177 algorithm,
178 )
179 except Exception as exc:
180 return VerificationResult(
181 valid=False,
182 module_name=attestation.module.name,
183 signer_did=attestation.signer_did,
184 trusted=False,
185 error=f"signature verify failed: {exc}",
186 )
187 if not sig_ok:
188 return VerificationResult(
189 valid=False,
190 module_name=attestation.module.name,
191 signer_did=attestation.signer_did,
192 trusted=False,
193 error="invalid ML-DSA signature",
194 )
195
196 # 3. Signer must be in the allow-list (if configured)
197 trusted = True
198 if self.trusted_signers is not None:
199 trusted = attestation.signer_did in self.trusted_signers
200 if not trusted:
201 return VerificationResult(
202 valid=False,
203 module_name=attestation.module.name,
204 signer_did=attestation.signer_did,
205 trusted=False,
206 error=f"signer {attestation.signer_did} not in trusted set",
207 )
208
209 return VerificationResult(
210 valid=True,
211 module_name=attestation.module.name,
212 signer_did=attestation.signer_did,
213 trusted=trusted,
214 error=None,
215 )
216
217 def verify_or_raise(
218 self,
219 attestation: DriverAttestation,
220 actual_module_bytes: bytes | None = None,
221 ) -> None:
222 result = self.verify(attestation, actual_module_bytes)
223 if not result.valid:
224 raise DriverAttestationError(result.error or "verification failed")
225