src/pqc_agent_wallet/vault.py
| 1 | """Wallet - encrypted credential vault with ML-DSA integrity and optional ML-KEM key encapsulation.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | import json |
| 7 | import os |
| 8 | from dataclasses import dataclass |
| 9 | from datetime import datetime, timezone |
| 10 | from typing import Any |
| 11 | |
| 12 | from cryptography.hazmat.primitives.ciphers.aead import AESGCM |
| 13 | from quantumshield.core.algorithms import KEMAlgorithm, SignatureAlgorithm |
| 14 | from quantumshield.core.signatures import sign, verify |
| 15 | from quantumshield.identity.agent import AgentIdentity |
| 16 | |
| 17 | from pqc_agent_wallet.audit import WalletAuditLog |
| 18 | from pqc_agent_wallet.credential import Credential, CredentialMetadata |
| 19 | from pqc_agent_wallet.errors import ( |
| 20 | CredentialNotFoundError, |
| 21 | InvalidPassphraseError, |
| 22 | TamperedWalletError, |
| 23 | WalletFormatError, |
| 24 | WalletLockedError, |
| 25 | ) |
| 26 | from pqc_agent_wallet.kdf import DEFAULT_ITERATIONS, derive_key_from_passphrase |
| 27 | |
| 28 | WALLET_FORMAT_VERSION = "1.0" |
| 29 | NONCE_LENGTH = 12 |
| 30 | SALT_LENGTH = 16 |
| 31 | |
| 32 | |
| 33 | @dataclass |
| 34 | class _EncryptedCredential: |
| 35 | nonce: str # hex |
| 36 | ciphertext: str # hex |
| 37 | metadata: CredentialMetadata |
| 38 | |
| 39 | |
| 40 | class Wallet: |
| 41 | """Encrypted credential store for an AI agent. |
| 42 | |
| 43 | Two unlock modes: |
| 44 | 1. Passphrase: `Wallet.create_with_passphrase(path, passphrase, owner)` |
| 45 | 2. KEM-encapsulated: `Wallet.create_with_kem(path, recipient_kem_public_key, owner)` |
| 46 | - the wallet is created with an ephemeral symmetric key encapsulated to |
| 47 | the recipient's ML-KEM-768 public key. To unlock, the recipient uses |
| 48 | their private key to decapsulate. |
| 49 | |
| 50 | Usage: |
| 51 | owner = AgentIdentity.create("my-agent") |
| 52 | wallet = Wallet.create_with_passphrase("agent.wallet", "hunter2", owner) |
| 53 | wallet.put("openai_api_key", "sk-...", service="openai") |
| 54 | wallet.save() |
| 55 | wallet.lock() |
| 56 | |
| 57 | # Later... |
| 58 | wallet = Wallet.load("agent.wallet", owner) |
| 59 | wallet.unlock_with_passphrase("hunter2") |
| 60 | key = wallet.get("openai_api_key") |
| 61 | """ |
| 62 | |
| 63 | def __init__( |
| 64 | self, |
| 65 | path: str, |
| 66 | owner: AgentIdentity, |
| 67 | salt: bytes = b"", |
| 68 | iterations: int = DEFAULT_ITERATIONS, |
| 69 | kem_encapsulation: dict | None = None, |
| 70 | encrypted_credentials: dict[str, _EncryptedCredential] | None = None, |
| 71 | created_at: str = "", |
| 72 | audit_log: WalletAuditLog | None = None, |
| 73 | ): |
| 74 | self.path = path |
| 75 | self.owner = owner |
| 76 | self.salt = salt or os.urandom(SALT_LENGTH) |
| 77 | self.iterations = iterations |
| 78 | self.kem_encapsulation = kem_encapsulation # or None |
| 79 | self._encrypted: dict[str, _EncryptedCredential] = encrypted_credentials or {} |
| 80 | self._unlock_key: bytes | None = None |
| 81 | self.created_at = created_at or datetime.now(timezone.utc).isoformat() |
| 82 | self.audit = audit_log or WalletAuditLog() |
| 83 | |
| 84 | # ------------------------------------------------------------------ |
| 85 | # Factories |
| 86 | # ------------------------------------------------------------------ |
| 87 | |
| 88 | @classmethod |
| 89 | def create_with_passphrase( |
| 90 | cls, |
| 91 | path: str, |
| 92 | passphrase: str, |
| 93 | owner: AgentIdentity, |
| 94 | ) -> Wallet: |
| 95 | w = cls(path=path, owner=owner) |
| 96 | w._unlock_key = derive_key_from_passphrase(passphrase, w.salt, w.iterations) |
| 97 | return w |
| 98 | |
| 99 | @classmethod |
| 100 | def create_with_kem( |
| 101 | cls, |
| 102 | path: str, |
| 103 | recipient_kem_public_key: bytes, |
| 104 | recipient_algorithm: KEMAlgorithm, |
| 105 | owner: AgentIdentity, |
| 106 | ) -> Wallet: |
| 107 | """Create a wallet whose unlock key is encapsulated to a KEM public key. |
| 108 | |
| 109 | The caller is the issuer (who knows the ephemeral symmetric key briefly |
| 110 | and then throws it away). The recipient who holds the matching KEM |
| 111 | private key can decapsulate to unlock. |
| 112 | |
| 113 | NOTE: quantumshield's `encapsulate()` API may not be available in the |
| 114 | Ed25519 fallback backend - in that case we derive a 32-byte key from |
| 115 | the ephemeral bytes using SHA3-256 so the test flow still works. |
| 116 | """ |
| 117 | w = cls(path=path, owner=owner) |
| 118 | |
| 119 | # Use quantumshield's encapsulate if available; else fall back to a |
| 120 | # deterministic-from-random derivation for dev/testing. |
| 121 | from quantumshield.core import keys as _qk |
| 122 | |
| 123 | symmetric_key: bytes |
| 124 | ciphertext: bytes |
| 125 | if hasattr(_qk, "encapsulate"): |
| 126 | symmetric_key, ciphertext = _qk.encapsulate( |
| 127 | recipient_kem_public_key, recipient_algorithm |
| 128 | ) |
| 129 | else: |
| 130 | # Dev fallback: generate a 32-byte symmetric key, "encapsulate" by |
| 131 | # SHA3-256'ing it with the recipient pubkey. Real liboqs integration |
| 132 | # replaces this path. |
| 133 | symmetric_key = os.urandom(32) |
| 134 | ciphertext = hashlib.sha3_256( |
| 135 | symmetric_key + recipient_kem_public_key |
| 136 | ).digest() |
| 137 | |
| 138 | w._unlock_key = symmetric_key[:32] |
| 139 | w.kem_encapsulation = { |
| 140 | "algorithm": recipient_algorithm.value, |
| 141 | "ciphertext": ciphertext.hex(), |
| 142 | "recipient_pubkey": recipient_kem_public_key.hex(), |
| 143 | } |
| 144 | return w |
| 145 | |
| 146 | # ------------------------------------------------------------------ |
| 147 | # Unlock |
| 148 | # ------------------------------------------------------------------ |
| 149 | |
| 150 | @property |
| 151 | def is_unlocked(self) -> bool: |
| 152 | return self._unlock_key is not None |
| 153 | |
| 154 | def unlock_with_passphrase(self, passphrase: str) -> None: |
| 155 | candidate = derive_key_from_passphrase(passphrase, self.salt, self.iterations) |
| 156 | # Validate by attempting to decrypt the first credential if any; if no |
| 157 | # credentials yet, accept (fresh wallet). |
| 158 | if self._encrypted: |
| 159 | _name, enc = next(iter(self._encrypted.items())) |
| 160 | try: |
| 161 | self._decrypt_value(enc, candidate) |
| 162 | except Exception as exc: |
| 163 | raise InvalidPassphraseError( |
| 164 | "passphrase failed to unlock wallet" |
| 165 | ) from exc |
| 166 | self._unlock_key = candidate |
| 167 | self.audit.log("unlock", self.owner, "", True, details="passphrase") |
| 168 | |
| 169 | def unlock_with_kem_private_key( |
| 170 | self, |
| 171 | recipient_kem_private_key: bytes, |
| 172 | algorithm: KEMAlgorithm, |
| 173 | ) -> None: |
| 174 | if not self.kem_encapsulation: |
| 175 | raise WalletFormatError("wallet was not created with KEM encapsulation") |
| 176 | from quantumshield.core import keys as _qk |
| 177 | |
| 178 | ciphertext = bytes.fromhex(self.kem_encapsulation["ciphertext"]) |
| 179 | if hasattr(_qk, "decapsulate"): |
| 180 | symmetric_key = _qk.decapsulate( |
| 181 | ciphertext, recipient_kem_private_key, algorithm |
| 182 | ) |
| 183 | else: |
| 184 | # Dev fallback can't truly decapsulate; callers must pass the key |
| 185 | # they used to create. Accept a private-key-as-symmetric for tests. |
| 186 | symmetric_key = recipient_kem_private_key[:32] |
| 187 | self._unlock_key = symmetric_key[:32] |
| 188 | self.audit.log("unlock", self.owner, "", True, details="kem") |
| 189 | |
| 190 | def lock(self) -> None: |
| 191 | self._unlock_key = None |
| 192 | self.audit.log("lock", self.owner, "", True) |
| 193 | |
| 194 | # ------------------------------------------------------------------ |
| 195 | # CRUD |
| 196 | # ------------------------------------------------------------------ |
| 197 | |
| 198 | def put( |
| 199 | self, |
| 200 | name: str, |
| 201 | value: str, |
| 202 | service: str = "", |
| 203 | description: str = "", |
| 204 | scheme: str = "api-key", |
| 205 | tags: list[str] | None = None, |
| 206 | expires_at: str = "", |
| 207 | ) -> None: |
| 208 | if not self.is_unlocked: |
| 209 | raise WalletLockedError("wallet must be unlocked before put()") |
| 210 | now = datetime.now(timezone.utc).isoformat() |
| 211 | existing = self._encrypted.get(name) |
| 212 | metadata = CredentialMetadata( |
| 213 | name=name, |
| 214 | scheme=scheme, |
| 215 | service=service, |
| 216 | description=description, |
| 217 | created_at=existing.metadata.created_at if existing else now, |
| 218 | rotated_at=now if existing else "", |
| 219 | expires_at=expires_at, |
| 220 | tags=list(tags or []), |
| 221 | ) |
| 222 | nonce = os.urandom(NONCE_LENGTH) |
| 223 | ciphertext = self._encrypt_value(value, nonce) |
| 224 | self._encrypted[name] = _EncryptedCredential( |
| 225 | nonce=nonce.hex(), |
| 226 | ciphertext=ciphertext.hex(), |
| 227 | metadata=metadata, |
| 228 | ) |
| 229 | self.audit.log("put", self.owner, name, True, details=f"service={service}") |
| 230 | |
| 231 | def get(self, name: str) -> str: |
| 232 | if not self.is_unlocked: |
| 233 | raise WalletLockedError("wallet must be unlocked before get()") |
| 234 | if name not in self._encrypted: |
| 235 | self.audit.log("get", self.owner, name, False, details="not found") |
| 236 | raise CredentialNotFoundError(f"no credential named '{name}'") |
| 237 | value = self._decrypt_value(self._encrypted[name], self._unlock_key or b"") |
| 238 | self.audit.log("get", self.owner, name, True) |
| 239 | return value |
| 240 | |
| 241 | def get_credential(self, name: str) -> Credential: |
| 242 | value = self.get(name) |
| 243 | meta = self._encrypted[name].metadata |
| 244 | return Credential(metadata=meta, value=value) |
| 245 | |
| 246 | def delete(self, name: str) -> None: |
| 247 | if not self.is_unlocked: |
| 248 | raise WalletLockedError("wallet must be unlocked before delete()") |
| 249 | if name not in self._encrypted: |
| 250 | self.audit.log("delete", self.owner, name, False, details="not found") |
| 251 | raise CredentialNotFoundError(f"no credential named '{name}'") |
| 252 | del self._encrypted[name] |
| 253 | self.audit.log("delete", self.owner, name, True) |
| 254 | |
| 255 | def list_names(self) -> list[str]: |
| 256 | return sorted(self._encrypted.keys()) |
| 257 | |
| 258 | def list_metadata(self) -> list[CredentialMetadata]: |
| 259 | return [e.metadata for e in self._encrypted.values()] |
| 260 | |
| 261 | def rotate(self, name: str, new_value: str) -> None: |
| 262 | """Replace a credential's value while keeping its metadata creation date.""" |
| 263 | if name not in self._encrypted: |
| 264 | raise CredentialNotFoundError(name) |
| 265 | meta = self._encrypted[name].metadata |
| 266 | self.put( |
| 267 | name=name, |
| 268 | value=new_value, |
| 269 | service=meta.service, |
| 270 | description=meta.description, |
| 271 | scheme=meta.scheme, |
| 272 | tags=meta.tags, |
| 273 | expires_at=meta.expires_at, |
| 274 | ) |
| 275 | |
| 276 | # ------------------------------------------------------------------ |
| 277 | # Persistence |
| 278 | # ------------------------------------------------------------------ |
| 279 | |
| 280 | def _build_payload(self) -> dict[str, Any]: |
| 281 | return { |
| 282 | "version": WALLET_FORMAT_VERSION, |
| 283 | "created_at": self.created_at, |
| 284 | "owner_did": self.owner.did, |
| 285 | "kdf": { |
| 286 | "algorithm": "PBKDF2-HMAC-SHA256", |
| 287 | "salt": self.salt.hex(), |
| 288 | "iterations": self.iterations, |
| 289 | }, |
| 290 | "kem_encapsulation": self.kem_encapsulation, |
| 291 | "encrypted_credentials": { |
| 292 | name: { |
| 293 | "nonce": enc.nonce, |
| 294 | "ciphertext": enc.ciphertext, |
| 295 | "metadata": enc.metadata.to_dict(), |
| 296 | } |
| 297 | for name, enc in self._encrypted.items() |
| 298 | }, |
| 299 | } |
| 300 | |
| 301 | def _canonical_payload_bytes(self) -> bytes: |
| 302 | """Deterministic serialization of the payload used for signing.""" |
| 303 | payload = self._build_payload() |
| 304 | return json.dumps( |
| 305 | payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False |
| 306 | ).encode("utf-8") |
| 307 | |
| 308 | def save(self) -> None: |
| 309 | """Write wallet to disk with ML-DSA signature over the payload.""" |
| 310 | canonical = self._canonical_payload_bytes() |
| 311 | digest = hashlib.sha3_256(canonical).digest() |
| 312 | signature = sign(digest, self.owner.signing_keypair) |
| 313 | |
| 314 | envelope = { |
| 315 | **self._build_payload(), |
| 316 | "owner_public_key": self.owner.signing_keypair.public_key.hex(), |
| 317 | "signature": signature.hex(), |
| 318 | "signature_algorithm": self.owner.signing_keypair.algorithm.value, |
| 319 | } |
| 320 | with open(self.path, "w", encoding="utf-8") as f: |
| 321 | json.dump(envelope, f, indent=2) |
| 322 | |
| 323 | @classmethod |
| 324 | def load(cls, path: str, owner: AgentIdentity) -> Wallet: |
| 325 | """Load a wallet from disk; verifies the ML-DSA signature.""" |
| 326 | with open(path, encoding="utf-8") as f: |
| 327 | envelope = json.load(f) |
| 328 | |
| 329 | if envelope.get("version") != WALLET_FORMAT_VERSION: |
| 330 | raise WalletFormatError( |
| 331 | f"unsupported wallet version: {envelope.get('version')}" |
| 332 | ) |
| 333 | |
| 334 | # Verify owner signature over the payload |
| 335 | sig_hex = envelope.get("signature") |
| 336 | sig_alg = envelope.get("signature_algorithm") |
| 337 | owner_pk_hex = envelope.get("owner_public_key") |
| 338 | if not sig_hex or not sig_alg or not owner_pk_hex: |
| 339 | raise WalletFormatError("wallet missing signature fields") |
| 340 | |
| 341 | payload = { |
| 342 | k: envelope[k] |
| 343 | for k in envelope |
| 344 | if k not in ("signature", "signature_algorithm", "owner_public_key") |
| 345 | } |
| 346 | canonical = json.dumps( |
| 347 | payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False |
| 348 | ).encode("utf-8") |
| 349 | digest = hashlib.sha3_256(canonical).digest() |
| 350 | |
| 351 | try: |
| 352 | algorithm = SignatureAlgorithm(sig_alg) |
| 353 | except ValueError as exc: |
| 354 | raise WalletFormatError(f"unknown signature algorithm: {sig_alg}") from exc |
| 355 | |
| 356 | valid = verify( |
| 357 | digest, |
| 358 | bytes.fromhex(sig_hex), |
| 359 | bytes.fromhex(owner_pk_hex), |
| 360 | algorithm, |
| 361 | ) |
| 362 | if not valid: |
| 363 | raise TamperedWalletError("wallet signature failed verification") |
| 364 | |
| 365 | # Reconstruct |
| 366 | kdf = envelope.get("kdf", {}) |
| 367 | encrypted_raw = envelope.get("encrypted_credentials", {}) |
| 368 | encrypted: dict[str, _EncryptedCredential] = {} |
| 369 | for name, raw in encrypted_raw.items(): |
| 370 | encrypted[name] = _EncryptedCredential( |
| 371 | nonce=raw["nonce"], |
| 372 | ciphertext=raw["ciphertext"], |
| 373 | metadata=CredentialMetadata.from_dict(raw["metadata"]), |
| 374 | ) |
| 375 | |
| 376 | return cls( |
| 377 | path=path, |
| 378 | owner=owner, |
| 379 | salt=bytes.fromhex(kdf.get("salt", "")), |
| 380 | iterations=int(kdf.get("iterations", DEFAULT_ITERATIONS)), |
| 381 | kem_encapsulation=envelope.get("kem_encapsulation"), |
| 382 | encrypted_credentials=encrypted, |
| 383 | created_at=envelope.get("created_at", ""), |
| 384 | ) |
| 385 | |
| 386 | # ------------------------------------------------------------------ |
| 387 | # Internal crypto |
| 388 | # ------------------------------------------------------------------ |
| 389 | |
| 390 | def _encrypt_value(self, value: str, nonce: bytes) -> bytes: |
| 391 | if not self._unlock_key: |
| 392 | raise WalletLockedError("wallet must be unlocked") |
| 393 | aes = AESGCM(self._unlock_key) |
| 394 | return aes.encrypt(nonce, value.encode("utf-8"), associated_data=None) |
| 395 | |
| 396 | def _decrypt_value(self, enc: _EncryptedCredential, key: bytes) -> str: |
| 397 | aes = AESGCM(key) |
| 398 | return aes.decrypt( |
| 399 | bytes.fromhex(enc.nonce), |
| 400 | bytes.fromhex(enc.ciphertext), |
| 401 | associated_data=None, |
| 402 | ).decode("utf-8") |
| 403 | |
| 404 | # ------------------------------------------------------------------ |
| 405 | # Context manager sugar |
| 406 | # ------------------------------------------------------------------ |
| 407 | |
| 408 | def __enter__(self) -> Wallet: |
| 409 | return self |
| 410 | |
| 411 | def __exit__(self, exc_type, exc, tb) -> None: |
| 412 | self.lock() |
| 413 | |