src/pqc_enclave_sdk/vault.py
| 1 | """EnclaveVault - high-level API over an EnclaveBackend.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | import json |
| 7 | import os |
| 8 | import uuid |
| 9 | from dataclasses import dataclass, field |
| 10 | from datetime import datetime, timedelta, timezone |
| 11 | |
| 12 | from cryptography.hazmat.primitives.ciphers.aead import AESGCM |
| 13 | from quantumshield.core.algorithms import KEMAlgorithm |
| 14 | from quantumshield.core.keys import generate_kem_keypair |
| 15 | |
| 16 | from pqc_enclave_sdk.artifact import ( |
| 17 | ArtifactKind, |
| 18 | ArtifactMetadata, |
| 19 | EnclaveArtifact, |
| 20 | EncryptedArtifact, |
| 21 | ) |
| 22 | from pqc_enclave_sdk.audit import EnclaveAuditLog |
| 23 | from pqc_enclave_sdk.backends.base import EnclaveBackend |
| 24 | from pqc_enclave_sdk.errors import ( |
| 25 | DecryptionError, |
| 26 | EnclaveLockedError, |
| 27 | UnknownArtifactError, |
| 28 | ) |
| 29 | |
| 30 | NONCE_SIZE = 12 |
| 31 | DEFAULT_SESSION_TTL = 3600 |
| 32 | |
| 33 | |
| 34 | def establish_enclave_session( |
| 35 | backend: EnclaveBackend, |
| 36 | algorithm: KEMAlgorithm = KEMAlgorithm.ML_KEM_768, |
| 37 | ttl_seconds: int = DEFAULT_SESSION_TTL, |
| 38 | ) -> tuple[bytes, str, str]: |
| 39 | """Derive a 32-byte AES key from a fresh ML-KEM-768 keypair bound to the device. |
| 40 | |
| 41 | In production: the enclave runs Decapsulate on a ciphertext encrypted to |
| 42 | the enclave's KEM public key. Here we generate the KEM keypair in-process |
| 43 | and derive the symmetric key deterministically so tests work with the |
| 44 | Ed25519/stub backend. |
| 45 | |
| 46 | Returns: (symmetric_key, key_id, expires_at_iso) |
| 47 | """ |
| 48 | kp = generate_kem_keypair(algorithm) |
| 49 | symmetric = hashlib.sha3_256(kp.private_key + kp.public_key).digest() |
| 50 | key_id = f"urn:pqc-enclave-key:{uuid.uuid4().hex}" |
| 51 | exp = (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat() |
| 52 | backend.store_session_key(key_id, symmetric, exp) |
| 53 | return symmetric, key_id, exp |
| 54 | |
| 55 | |
| 56 | @dataclass |
| 57 | class EnclaveVault: |
| 58 | """High-level vault backed by an EnclaveBackend + AES-256-GCM per artifact. |
| 59 | |
| 60 | Usage: |
| 61 | backend = InMemoryEnclaveBackend(device_id="iphone-alice") |
| 62 | vault = EnclaveVault(backend) |
| 63 | vault.unlock() |
| 64 | vault.put_artifact( |
| 65 | name="llama-3.2-1b-int4", |
| 66 | kind=ArtifactKind.MODEL_WEIGHTS, |
| 67 | content=weights_bytes, |
| 68 | ) |
| 69 | vault.save() |
| 70 | # later... |
| 71 | vault.unlock() |
| 72 | weights = vault.get_artifact("llama-3.2-1b-int4").content |
| 73 | """ |
| 74 | |
| 75 | backend: EnclaveBackend |
| 76 | audit: EnclaveAuditLog = field(default_factory=EnclaveAuditLog) |
| 77 | _symmetric_key: bytes | None = None |
| 78 | _key_id: str = "" |
| 79 | _expires_at: str = "" |
| 80 | _store: dict[str, EncryptedArtifact] = field(default_factory=dict) |
| 81 | |
| 82 | @property |
| 83 | def is_unlocked(self) -> bool: |
| 84 | if self._symmetric_key is None: |
| 85 | return False |
| 86 | try: |
| 87 | exp = datetime.fromisoformat(self._expires_at) |
| 88 | return datetime.now(timezone.utc) <= exp |
| 89 | except ValueError: |
| 90 | return False |
| 91 | |
| 92 | def _require_unlocked(self) -> None: |
| 93 | if not self.is_unlocked: |
| 94 | raise EnclaveLockedError("enclave vault is locked; call unlock() first") |
| 95 | |
| 96 | # -- lifecycle --------------------------------------------------------- |
| 97 | |
| 98 | def unlock(self, ttl_seconds: int = DEFAULT_SESSION_TTL) -> None: |
| 99 | key, key_id, exp = establish_enclave_session( |
| 100 | self.backend, ttl_seconds=ttl_seconds |
| 101 | ) |
| 102 | self._symmetric_key = key |
| 103 | self._key_id = key_id |
| 104 | self._expires_at = exp |
| 105 | self._store = dict(self.backend.load_artifacts()) |
| 106 | self.audit.log_unlock(self.backend.device_id, key_id) |
| 107 | |
| 108 | def lock(self) -> None: |
| 109 | self._symmetric_key = None |
| 110 | self._key_id = "" |
| 111 | self._expires_at = "" |
| 112 | self.audit.log_lock(self.backend.device_id) |
| 113 | |
| 114 | def save(self) -> None: |
| 115 | """Persist the encrypted store to the backend.""" |
| 116 | self.backend.save_artifacts(dict(self._store)) |
| 117 | |
| 118 | # -- AAD --------------------------------------------------------------- |
| 119 | |
| 120 | @staticmethod |
| 121 | def _aad(metadata: ArtifactMetadata, content_hash: str, key_id: str) -> bytes: |
| 122 | payload = { |
| 123 | "metadata": metadata.to_dict(), |
| 124 | "content_hash": content_hash, |
| 125 | "key_id": key_id, |
| 126 | } |
| 127 | return json.dumps( |
| 128 | payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False |
| 129 | ).encode("utf-8") |
| 130 | |
| 131 | # -- CRUD -------------------------------------------------------------- |
| 132 | |
| 133 | def put_artifact( |
| 134 | self, |
| 135 | name: str, |
| 136 | kind: ArtifactKind, |
| 137 | content: bytes, |
| 138 | version: str = "", |
| 139 | app_bundle_id: str = "", |
| 140 | model_did: str = "", |
| 141 | tags: tuple[str, ...] = (), |
| 142 | description: str = "", |
| 143 | ) -> EncryptedArtifact: |
| 144 | self._require_unlocked() |
| 145 | assert self._symmetric_key is not None |
| 146 | artifact_id = f"urn:pqc-enclave-art:{uuid.uuid4().hex}" |
| 147 | metadata = ArtifactMetadata( |
| 148 | artifact_id=artifact_id, |
| 149 | name=name, |
| 150 | kind=kind, |
| 151 | version=version, |
| 152 | app_bundle_id=app_bundle_id, |
| 153 | size_bytes=len(content), |
| 154 | created_at=datetime.now(timezone.utc).isoformat(), |
| 155 | device_id=self.backend.device_id, |
| 156 | model_did=model_did, |
| 157 | tags=tuple(tags), |
| 158 | description=description, |
| 159 | ) |
| 160 | content_hash = EnclaveArtifact.content_hash(content) |
| 161 | nonce = os.urandom(NONCE_SIZE) |
| 162 | aes = AESGCM(self._symmetric_key) |
| 163 | ct = aes.encrypt( |
| 164 | nonce, content, self._aad(metadata, content_hash, self._key_id) |
| 165 | ) |
| 166 | enc = EncryptedArtifact( |
| 167 | metadata=metadata, |
| 168 | nonce=nonce.hex(), |
| 169 | ciphertext=ct.hex(), |
| 170 | content_hash=content_hash, |
| 171 | key_id=self._key_id, |
| 172 | ) |
| 173 | self._store[artifact_id] = enc |
| 174 | self._store[f"name:{name}"] = enc |
| 175 | self.audit.log_put(self.backend.device_id, artifact_id, name, kind.value) |
| 176 | return enc |
| 177 | |
| 178 | def get_artifact(self, name_or_id: str) -> EnclaveArtifact: |
| 179 | self._require_unlocked() |
| 180 | assert self._symmetric_key is not None |
| 181 | key = name_or_id |
| 182 | if key not in self._store: |
| 183 | key = f"name:{name_or_id}" |
| 184 | if key not in self._store: |
| 185 | self.audit.log_get( |
| 186 | self.backend.device_id, name_or_id, success=False, details="not found" |
| 187 | ) |
| 188 | raise UnknownArtifactError(f"no artifact '{name_or_id}'") |
| 189 | enc = self._store[key] |
| 190 | aes = AESGCM(self._symmetric_key) |
| 191 | aad = self._aad(enc.metadata, enc.content_hash, enc.key_id) |
| 192 | try: |
| 193 | content = aes.decrypt( |
| 194 | bytes.fromhex(enc.nonce), bytes.fromhex(enc.ciphertext), aad |
| 195 | ) |
| 196 | except Exception as exc: |
| 197 | raise DecryptionError(f"AES-GCM decrypt failed: {exc}") from exc |
| 198 | self.audit.log_get( |
| 199 | self.backend.device_id, enc.metadata.artifact_id, success=True |
| 200 | ) |
| 201 | return EnclaveArtifact(metadata=enc.metadata, content=content) |
| 202 | |
| 203 | def delete_artifact(self, name_or_id: str) -> None: |
| 204 | self._require_unlocked() |
| 205 | key = name_or_id |
| 206 | if key not in self._store: |
| 207 | key = f"name:{name_or_id}" |
| 208 | if key not in self._store: |
| 209 | raise UnknownArtifactError(f"no artifact '{name_or_id}'") |
| 210 | enc = self._store.pop(key) |
| 211 | for k in (enc.metadata.artifact_id, f"name:{enc.metadata.name}"): |
| 212 | self._store.pop(k, None) |
| 213 | self.audit.log_delete( |
| 214 | self.backend.device_id, enc.metadata.artifact_id, enc.metadata.name |
| 215 | ) |
| 216 | |
| 217 | def list_artifacts(self) -> list[ArtifactMetadata]: |
| 218 | self._require_unlocked() |
| 219 | seen: dict[str, ArtifactMetadata] = {} |
| 220 | for _k, enc in self._store.items(): |
| 221 | seen[enc.metadata.artifact_id] = enc.metadata |
| 222 | return list(seen.values()) |
| 223 | |
| 224 | # -- context manager --------------------------------------------------- |
| 225 | |
| 226 | def __enter__(self) -> EnclaveVault: |
| 227 | if not self.is_unlocked: |
| 228 | self.unlock() |
| 229 | return self |
| 230 | |
| 231 | def __exit__(self, exc_type, exc, tb) -> None: |
| 232 | self.lock() |
| 233 | |