src/pqc_agent_wallet/vault.py
15.1 KB · 413 lines · python Raw
1 """Wallet - encrypted credential vault with ML-DSA integrity and optional ML-KEM key encapsulation."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import json
7 import os
8 from dataclasses import dataclass
9 from datetime import datetime, timezone
10 from typing import Any
11
12 from cryptography.hazmat.primitives.ciphers.aead import AESGCM
13 from quantumshield.core.algorithms import KEMAlgorithm, SignatureAlgorithm
14 from quantumshield.core.signatures import sign, verify
15 from quantumshield.identity.agent import AgentIdentity
16
17 from pqc_agent_wallet.audit import WalletAuditLog
18 from pqc_agent_wallet.credential import Credential, CredentialMetadata
19 from pqc_agent_wallet.errors import (
20 CredentialNotFoundError,
21 InvalidPassphraseError,
22 TamperedWalletError,
23 WalletFormatError,
24 WalletLockedError,
25 )
26 from pqc_agent_wallet.kdf import DEFAULT_ITERATIONS, derive_key_from_passphrase
27
28 WALLET_FORMAT_VERSION = "1.0"
29 NONCE_LENGTH = 12
30 SALT_LENGTH = 16
31
32
33 @dataclass
34 class _EncryptedCredential:
35 nonce: str # hex
36 ciphertext: str # hex
37 metadata: CredentialMetadata
38
39
40 class Wallet:
41 """Encrypted credential store for an AI agent.
42
43 Two unlock modes:
44 1. Passphrase: `Wallet.create_with_passphrase(path, passphrase, owner)`
45 2. KEM-encapsulated: `Wallet.create_with_kem(path, recipient_kem_public_key, owner)`
46 - the wallet is created with an ephemeral symmetric key encapsulated to
47 the recipient's ML-KEM-768 public key. To unlock, the recipient uses
48 their private key to decapsulate.
49
50 Usage:
51 owner = AgentIdentity.create("my-agent")
52 wallet = Wallet.create_with_passphrase("agent.wallet", "hunter2", owner)
53 wallet.put("openai_api_key", "sk-...", service="openai")
54 wallet.save()
55 wallet.lock()
56
57 # Later...
58 wallet = Wallet.load("agent.wallet", owner)
59 wallet.unlock_with_passphrase("hunter2")
60 key = wallet.get("openai_api_key")
61 """
62
63 def __init__(
64 self,
65 path: str,
66 owner: AgentIdentity,
67 salt: bytes = b"",
68 iterations: int = DEFAULT_ITERATIONS,
69 kem_encapsulation: dict | None = None,
70 encrypted_credentials: dict[str, _EncryptedCredential] | None = None,
71 created_at: str = "",
72 audit_log: WalletAuditLog | None = None,
73 ):
74 self.path = path
75 self.owner = owner
76 self.salt = salt or os.urandom(SALT_LENGTH)
77 self.iterations = iterations
78 self.kem_encapsulation = kem_encapsulation # or None
79 self._encrypted: dict[str, _EncryptedCredential] = encrypted_credentials or {}
80 self._unlock_key: bytes | None = None
81 self.created_at = created_at or datetime.now(timezone.utc).isoformat()
82 self.audit = audit_log or WalletAuditLog()
83
84 # ------------------------------------------------------------------
85 # Factories
86 # ------------------------------------------------------------------
87
88 @classmethod
89 def create_with_passphrase(
90 cls,
91 path: str,
92 passphrase: str,
93 owner: AgentIdentity,
94 ) -> Wallet:
95 w = cls(path=path, owner=owner)
96 w._unlock_key = derive_key_from_passphrase(passphrase, w.salt, w.iterations)
97 return w
98
99 @classmethod
100 def create_with_kem(
101 cls,
102 path: str,
103 recipient_kem_public_key: bytes,
104 recipient_algorithm: KEMAlgorithm,
105 owner: AgentIdentity,
106 ) -> Wallet:
107 """Create a wallet whose unlock key is encapsulated to a KEM public key.
108
109 The caller is the issuer (who knows the ephemeral symmetric key briefly
110 and then throws it away). The recipient who holds the matching KEM
111 private key can decapsulate to unlock.
112
113 NOTE: quantumshield's `encapsulate()` API may not be available in the
114 Ed25519 fallback backend - in that case we derive a 32-byte key from
115 the ephemeral bytes using SHA3-256 so the test flow still works.
116 """
117 w = cls(path=path, owner=owner)
118
119 # Use quantumshield's encapsulate if available; else fall back to a
120 # deterministic-from-random derivation for dev/testing.
121 from quantumshield.core import keys as _qk
122
123 symmetric_key: bytes
124 ciphertext: bytes
125 if hasattr(_qk, "encapsulate"):
126 symmetric_key, ciphertext = _qk.encapsulate(
127 recipient_kem_public_key, recipient_algorithm
128 )
129 else:
130 # Dev fallback: generate a 32-byte symmetric key, "encapsulate" by
131 # SHA3-256'ing it with the recipient pubkey. Real liboqs integration
132 # replaces this path.
133 symmetric_key = os.urandom(32)
134 ciphertext = hashlib.sha3_256(
135 symmetric_key + recipient_kem_public_key
136 ).digest()
137
138 w._unlock_key = symmetric_key[:32]
139 w.kem_encapsulation = {
140 "algorithm": recipient_algorithm.value,
141 "ciphertext": ciphertext.hex(),
142 "recipient_pubkey": recipient_kem_public_key.hex(),
143 }
144 return w
145
146 # ------------------------------------------------------------------
147 # Unlock
148 # ------------------------------------------------------------------
149
150 @property
151 def is_unlocked(self) -> bool:
152 return self._unlock_key is not None
153
154 def unlock_with_passphrase(self, passphrase: str) -> None:
155 candidate = derive_key_from_passphrase(passphrase, self.salt, self.iterations)
156 # Validate by attempting to decrypt the first credential if any; if no
157 # credentials yet, accept (fresh wallet).
158 if self._encrypted:
159 _name, enc = next(iter(self._encrypted.items()))
160 try:
161 self._decrypt_value(enc, candidate)
162 except Exception as exc:
163 raise InvalidPassphraseError(
164 "passphrase failed to unlock wallet"
165 ) from exc
166 self._unlock_key = candidate
167 self.audit.log("unlock", self.owner, "", True, details="passphrase")
168
169 def unlock_with_kem_private_key(
170 self,
171 recipient_kem_private_key: bytes,
172 algorithm: KEMAlgorithm,
173 ) -> None:
174 if not self.kem_encapsulation:
175 raise WalletFormatError("wallet was not created with KEM encapsulation")
176 from quantumshield.core import keys as _qk
177
178 ciphertext = bytes.fromhex(self.kem_encapsulation["ciphertext"])
179 if hasattr(_qk, "decapsulate"):
180 symmetric_key = _qk.decapsulate(
181 ciphertext, recipient_kem_private_key, algorithm
182 )
183 else:
184 # Dev fallback can't truly decapsulate; callers must pass the key
185 # they used to create. Accept a private-key-as-symmetric for tests.
186 symmetric_key = recipient_kem_private_key[:32]
187 self._unlock_key = symmetric_key[:32]
188 self.audit.log("unlock", self.owner, "", True, details="kem")
189
190 def lock(self) -> None:
191 self._unlock_key = None
192 self.audit.log("lock", self.owner, "", True)
193
194 # ------------------------------------------------------------------
195 # CRUD
196 # ------------------------------------------------------------------
197
198 def put(
199 self,
200 name: str,
201 value: str,
202 service: str = "",
203 description: str = "",
204 scheme: str = "api-key",
205 tags: list[str] | None = None,
206 expires_at: str = "",
207 ) -> None:
208 if not self.is_unlocked:
209 raise WalletLockedError("wallet must be unlocked before put()")
210 now = datetime.now(timezone.utc).isoformat()
211 existing = self._encrypted.get(name)
212 metadata = CredentialMetadata(
213 name=name,
214 scheme=scheme,
215 service=service,
216 description=description,
217 created_at=existing.metadata.created_at if existing else now,
218 rotated_at=now if existing else "",
219 expires_at=expires_at,
220 tags=list(tags or []),
221 )
222 nonce = os.urandom(NONCE_LENGTH)
223 ciphertext = self._encrypt_value(value, nonce)
224 self._encrypted[name] = _EncryptedCredential(
225 nonce=nonce.hex(),
226 ciphertext=ciphertext.hex(),
227 metadata=metadata,
228 )
229 self.audit.log("put", self.owner, name, True, details=f"service={service}")
230
231 def get(self, name: str) -> str:
232 if not self.is_unlocked:
233 raise WalletLockedError("wallet must be unlocked before get()")
234 if name not in self._encrypted:
235 self.audit.log("get", self.owner, name, False, details="not found")
236 raise CredentialNotFoundError(f"no credential named '{name}'")
237 value = self._decrypt_value(self._encrypted[name], self._unlock_key or b"")
238 self.audit.log("get", self.owner, name, True)
239 return value
240
241 def get_credential(self, name: str) -> Credential:
242 value = self.get(name)
243 meta = self._encrypted[name].metadata
244 return Credential(metadata=meta, value=value)
245
246 def delete(self, name: str) -> None:
247 if not self.is_unlocked:
248 raise WalletLockedError("wallet must be unlocked before delete()")
249 if name not in self._encrypted:
250 self.audit.log("delete", self.owner, name, False, details="not found")
251 raise CredentialNotFoundError(f"no credential named '{name}'")
252 del self._encrypted[name]
253 self.audit.log("delete", self.owner, name, True)
254
255 def list_names(self) -> list[str]:
256 return sorted(self._encrypted.keys())
257
258 def list_metadata(self) -> list[CredentialMetadata]:
259 return [e.metadata for e in self._encrypted.values()]
260
261 def rotate(self, name: str, new_value: str) -> None:
262 """Replace a credential's value while keeping its metadata creation date."""
263 if name not in self._encrypted:
264 raise CredentialNotFoundError(name)
265 meta = self._encrypted[name].metadata
266 self.put(
267 name=name,
268 value=new_value,
269 service=meta.service,
270 description=meta.description,
271 scheme=meta.scheme,
272 tags=meta.tags,
273 expires_at=meta.expires_at,
274 )
275
276 # ------------------------------------------------------------------
277 # Persistence
278 # ------------------------------------------------------------------
279
280 def _build_payload(self) -> dict[str, Any]:
281 return {
282 "version": WALLET_FORMAT_VERSION,
283 "created_at": self.created_at,
284 "owner_did": self.owner.did,
285 "kdf": {
286 "algorithm": "PBKDF2-HMAC-SHA256",
287 "salt": self.salt.hex(),
288 "iterations": self.iterations,
289 },
290 "kem_encapsulation": self.kem_encapsulation,
291 "encrypted_credentials": {
292 name: {
293 "nonce": enc.nonce,
294 "ciphertext": enc.ciphertext,
295 "metadata": enc.metadata.to_dict(),
296 }
297 for name, enc in self._encrypted.items()
298 },
299 }
300
301 def _canonical_payload_bytes(self) -> bytes:
302 """Deterministic serialization of the payload used for signing."""
303 payload = self._build_payload()
304 return json.dumps(
305 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
306 ).encode("utf-8")
307
308 def save(self) -> None:
309 """Write wallet to disk with ML-DSA signature over the payload."""
310 canonical = self._canonical_payload_bytes()
311 digest = hashlib.sha3_256(canonical).digest()
312 signature = sign(digest, self.owner.signing_keypair)
313
314 envelope = {
315 **self._build_payload(),
316 "owner_public_key": self.owner.signing_keypair.public_key.hex(),
317 "signature": signature.hex(),
318 "signature_algorithm": self.owner.signing_keypair.algorithm.value,
319 }
320 with open(self.path, "w", encoding="utf-8") as f:
321 json.dump(envelope, f, indent=2)
322
323 @classmethod
324 def load(cls, path: str, owner: AgentIdentity) -> Wallet:
325 """Load a wallet from disk; verifies the ML-DSA signature."""
326 with open(path, encoding="utf-8") as f:
327 envelope = json.load(f)
328
329 if envelope.get("version") != WALLET_FORMAT_VERSION:
330 raise WalletFormatError(
331 f"unsupported wallet version: {envelope.get('version')}"
332 )
333
334 # Verify owner signature over the payload
335 sig_hex = envelope.get("signature")
336 sig_alg = envelope.get("signature_algorithm")
337 owner_pk_hex = envelope.get("owner_public_key")
338 if not sig_hex or not sig_alg or not owner_pk_hex:
339 raise WalletFormatError("wallet missing signature fields")
340
341 payload = {
342 k: envelope[k]
343 for k in envelope
344 if k not in ("signature", "signature_algorithm", "owner_public_key")
345 }
346 canonical = json.dumps(
347 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
348 ).encode("utf-8")
349 digest = hashlib.sha3_256(canonical).digest()
350
351 try:
352 algorithm = SignatureAlgorithm(sig_alg)
353 except ValueError as exc:
354 raise WalletFormatError(f"unknown signature algorithm: {sig_alg}") from exc
355
356 valid = verify(
357 digest,
358 bytes.fromhex(sig_hex),
359 bytes.fromhex(owner_pk_hex),
360 algorithm,
361 )
362 if not valid:
363 raise TamperedWalletError("wallet signature failed verification")
364
365 # Reconstruct
366 kdf = envelope.get("kdf", {})
367 encrypted_raw = envelope.get("encrypted_credentials", {})
368 encrypted: dict[str, _EncryptedCredential] = {}
369 for name, raw in encrypted_raw.items():
370 encrypted[name] = _EncryptedCredential(
371 nonce=raw["nonce"],
372 ciphertext=raw["ciphertext"],
373 metadata=CredentialMetadata.from_dict(raw["metadata"]),
374 )
375
376 return cls(
377 path=path,
378 owner=owner,
379 salt=bytes.fromhex(kdf.get("salt", "")),
380 iterations=int(kdf.get("iterations", DEFAULT_ITERATIONS)),
381 kem_encapsulation=envelope.get("kem_encapsulation"),
382 encrypted_credentials=encrypted,
383 created_at=envelope.get("created_at", ""),
384 )
385
386 # ------------------------------------------------------------------
387 # Internal crypto
388 # ------------------------------------------------------------------
389
390 def _encrypt_value(self, value: str, nonce: bytes) -> bytes:
391 if not self._unlock_key:
392 raise WalletLockedError("wallet must be unlocked")
393 aes = AESGCM(self._unlock_key)
394 return aes.encrypt(nonce, value.encode("utf-8"), associated_data=None)
395
396 def _decrypt_value(self, enc: _EncryptedCredential, key: bytes) -> str:
397 aes = AESGCM(key)
398 return aes.decrypt(
399 bytes.fromhex(enc.nonce),
400 bytes.fromhex(enc.ciphertext),
401 associated_data=None,
402 ).decode("utf-8")
403
404 # ------------------------------------------------------------------
405 # Context manager sugar
406 # ------------------------------------------------------------------
407
408 def __enter__(self) -> Wallet:
409 return self
410
411 def __exit__(self, exc_type, exc, tb) -> None:
412 self.lock()
413