src/pqc_enclave_sdk/vault.py
7.8 KB · 233 lines · python Raw
1 """EnclaveVault - high-level API over an EnclaveBackend."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import json
7 import os
8 import uuid
9 from dataclasses import dataclass, field
10 from datetime import datetime, timedelta, timezone
11
12 from cryptography.hazmat.primitives.ciphers.aead import AESGCM
13 from quantumshield.core.algorithms import KEMAlgorithm
14 from quantumshield.core.keys import generate_kem_keypair
15
16 from pqc_enclave_sdk.artifact import (
17 ArtifactKind,
18 ArtifactMetadata,
19 EnclaveArtifact,
20 EncryptedArtifact,
21 )
22 from pqc_enclave_sdk.audit import EnclaveAuditLog
23 from pqc_enclave_sdk.backends.base import EnclaveBackend
24 from pqc_enclave_sdk.errors import (
25 DecryptionError,
26 EnclaveLockedError,
27 UnknownArtifactError,
28 )
29
30 NONCE_SIZE = 12
31 DEFAULT_SESSION_TTL = 3600
32
33
34 def establish_enclave_session(
35 backend: EnclaveBackend,
36 algorithm: KEMAlgorithm = KEMAlgorithm.ML_KEM_768,
37 ttl_seconds: int = DEFAULT_SESSION_TTL,
38 ) -> tuple[bytes, str, str]:
39 """Derive a 32-byte AES key from a fresh ML-KEM-768 keypair bound to the device.
40
41 In production: the enclave runs Decapsulate on a ciphertext encrypted to
42 the enclave's KEM public key. Here we generate the KEM keypair in-process
43 and derive the symmetric key deterministically so tests work with the
44 Ed25519/stub backend.
45
46 Returns: (symmetric_key, key_id, expires_at_iso)
47 """
48 kp = generate_kem_keypair(algorithm)
49 symmetric = hashlib.sha3_256(kp.private_key + kp.public_key).digest()
50 key_id = f"urn:pqc-enclave-key:{uuid.uuid4().hex}"
51 exp = (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat()
52 backend.store_session_key(key_id, symmetric, exp)
53 return symmetric, key_id, exp
54
55
56 @dataclass
57 class EnclaveVault:
58 """High-level vault backed by an EnclaveBackend + AES-256-GCM per artifact.
59
60 Usage:
61 backend = InMemoryEnclaveBackend(device_id="iphone-alice")
62 vault = EnclaveVault(backend)
63 vault.unlock()
64 vault.put_artifact(
65 name="llama-3.2-1b-int4",
66 kind=ArtifactKind.MODEL_WEIGHTS,
67 content=weights_bytes,
68 )
69 vault.save()
70 # later...
71 vault.unlock()
72 weights = vault.get_artifact("llama-3.2-1b-int4").content
73 """
74
75 backend: EnclaveBackend
76 audit: EnclaveAuditLog = field(default_factory=EnclaveAuditLog)
77 _symmetric_key: bytes | None = None
78 _key_id: str = ""
79 _expires_at: str = ""
80 _store: dict[str, EncryptedArtifact] = field(default_factory=dict)
81
82 @property
83 def is_unlocked(self) -> bool:
84 if self._symmetric_key is None:
85 return False
86 try:
87 exp = datetime.fromisoformat(self._expires_at)
88 return datetime.now(timezone.utc) <= exp
89 except ValueError:
90 return False
91
92 def _require_unlocked(self) -> None:
93 if not self.is_unlocked:
94 raise EnclaveLockedError("enclave vault is locked; call unlock() first")
95
96 # -- lifecycle ---------------------------------------------------------
97
98 def unlock(self, ttl_seconds: int = DEFAULT_SESSION_TTL) -> None:
99 key, key_id, exp = establish_enclave_session(
100 self.backend, ttl_seconds=ttl_seconds
101 )
102 self._symmetric_key = key
103 self._key_id = key_id
104 self._expires_at = exp
105 self._store = dict(self.backend.load_artifacts())
106 self.audit.log_unlock(self.backend.device_id, key_id)
107
108 def lock(self) -> None:
109 self._symmetric_key = None
110 self._key_id = ""
111 self._expires_at = ""
112 self.audit.log_lock(self.backend.device_id)
113
114 def save(self) -> None:
115 """Persist the encrypted store to the backend."""
116 self.backend.save_artifacts(dict(self._store))
117
118 # -- AAD ---------------------------------------------------------------
119
120 @staticmethod
121 def _aad(metadata: ArtifactMetadata, content_hash: str, key_id: str) -> bytes:
122 payload = {
123 "metadata": metadata.to_dict(),
124 "content_hash": content_hash,
125 "key_id": key_id,
126 }
127 return json.dumps(
128 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
129 ).encode("utf-8")
130
131 # -- CRUD --------------------------------------------------------------
132
133 def put_artifact(
134 self,
135 name: str,
136 kind: ArtifactKind,
137 content: bytes,
138 version: str = "",
139 app_bundle_id: str = "",
140 model_did: str = "",
141 tags: tuple[str, ...] = (),
142 description: str = "",
143 ) -> EncryptedArtifact:
144 self._require_unlocked()
145 assert self._symmetric_key is not None
146 artifact_id = f"urn:pqc-enclave-art:{uuid.uuid4().hex}"
147 metadata = ArtifactMetadata(
148 artifact_id=artifact_id,
149 name=name,
150 kind=kind,
151 version=version,
152 app_bundle_id=app_bundle_id,
153 size_bytes=len(content),
154 created_at=datetime.now(timezone.utc).isoformat(),
155 device_id=self.backend.device_id,
156 model_did=model_did,
157 tags=tuple(tags),
158 description=description,
159 )
160 content_hash = EnclaveArtifact.content_hash(content)
161 nonce = os.urandom(NONCE_SIZE)
162 aes = AESGCM(self._symmetric_key)
163 ct = aes.encrypt(
164 nonce, content, self._aad(metadata, content_hash, self._key_id)
165 )
166 enc = EncryptedArtifact(
167 metadata=metadata,
168 nonce=nonce.hex(),
169 ciphertext=ct.hex(),
170 content_hash=content_hash,
171 key_id=self._key_id,
172 )
173 self._store[artifact_id] = enc
174 self._store[f"name:{name}"] = enc
175 self.audit.log_put(self.backend.device_id, artifact_id, name, kind.value)
176 return enc
177
178 def get_artifact(self, name_or_id: str) -> EnclaveArtifact:
179 self._require_unlocked()
180 assert self._symmetric_key is not None
181 key = name_or_id
182 if key not in self._store:
183 key = f"name:{name_or_id}"
184 if key not in self._store:
185 self.audit.log_get(
186 self.backend.device_id, name_or_id, success=False, details="not found"
187 )
188 raise UnknownArtifactError(f"no artifact '{name_or_id}'")
189 enc = self._store[key]
190 aes = AESGCM(self._symmetric_key)
191 aad = self._aad(enc.metadata, enc.content_hash, enc.key_id)
192 try:
193 content = aes.decrypt(
194 bytes.fromhex(enc.nonce), bytes.fromhex(enc.ciphertext), aad
195 )
196 except Exception as exc:
197 raise DecryptionError(f"AES-GCM decrypt failed: {exc}") from exc
198 self.audit.log_get(
199 self.backend.device_id, enc.metadata.artifact_id, success=True
200 )
201 return EnclaveArtifact(metadata=enc.metadata, content=content)
202
203 def delete_artifact(self, name_or_id: str) -> None:
204 self._require_unlocked()
205 key = name_or_id
206 if key not in self._store:
207 key = f"name:{name_or_id}"
208 if key not in self._store:
209 raise UnknownArtifactError(f"no artifact '{name_or_id}'")
210 enc = self._store.pop(key)
211 for k in (enc.metadata.artifact_id, f"name:{enc.metadata.name}"):
212 self._store.pop(k, None)
213 self.audit.log_delete(
214 self.backend.device_id, enc.metadata.artifact_id, enc.metadata.name
215 )
216
217 def list_artifacts(self) -> list[ArtifactMetadata]:
218 self._require_unlocked()
219 seen: dict[str, ArtifactMetadata] = {}
220 for _k, enc in self._store.items():
221 seen[enc.metadata.artifact_id] = enc.metadata
222 return list(seen.values())
223
224 # -- context manager ---------------------------------------------------
225
226 def __enter__(self) -> EnclaveVault:
227 if not self.is_unlocked:
228 self.unlock()
229 return self
230
231 def __exit__(self, exc_type, exc, tb) -> None:
232 self.lock()
233