src/pqc_kv_cache/session.py
| 1 | """TenantSession - per-tenant symmetric key derived from ML-KEM-768.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | import uuid |
| 7 | from dataclasses import dataclass |
| 8 | from datetime import datetime, timedelta, timezone |
| 9 | from typing import Any |
| 10 | |
| 11 | from quantumshield.core.algorithms import KEMAlgorithm |
| 12 | from quantumshield.core.keys import generate_kem_keypair |
| 13 | |
| 14 | from pqc_kv_cache.errors import SessionExpiredError |
| 15 | |
| 16 | SESSION_TTL_SECONDS = 900 # 15 minutes default |
| 17 | |
| 18 | |
| 19 | @dataclass(frozen=True) |
| 20 | class TenantIdentity: |
| 21 | """Identifier for a tenant (user/org/session).""" |
| 22 | |
| 23 | tenant_id: str |
| 24 | display_name: str = "" |
| 25 | |
| 26 | |
| 27 | def establish_tenant_session( |
| 28 | tenant: TenantIdentity, |
| 29 | algorithm: KEMAlgorithm = KEMAlgorithm.ML_KEM_768, |
| 30 | ttl_seconds: int = SESSION_TTL_SECONDS, |
| 31 | ) -> TenantSession: |
| 32 | """Derive a fresh per-tenant session key via ML-KEM-768. |
| 33 | |
| 34 | In production: the tenant supplies their KEM public key and the inference |
| 35 | server runs Encapsulate to derive the shared symmetric key. Here we |
| 36 | generate a fresh keypair and derive a 32-byte AES key deterministically |
| 37 | from the keypair so the pattern works under the Ed25519 fallback backend. |
| 38 | """ |
| 39 | kp = generate_kem_keypair(algorithm) |
| 40 | # Deterministic derivation from the keypair's raw bytes |
| 41 | symmetric_key = hashlib.sha3_256(kp.private_key + kp.public_key).digest() |
| 42 | session_id = f"urn:pqc-kv-sess:{uuid.uuid4().hex}" |
| 43 | now = datetime.now(timezone.utc) |
| 44 | exp = now + timedelta(seconds=ttl_seconds) |
| 45 | return TenantSession( |
| 46 | session_id=session_id, |
| 47 | tenant=tenant, |
| 48 | symmetric_key=symmetric_key, |
| 49 | algorithm=algorithm.value, |
| 50 | created_at=now.isoformat(), |
| 51 | expires_at=exp.isoformat(), |
| 52 | ) |
| 53 | |
| 54 | |
| 55 | @dataclass |
| 56 | class TenantSession: |
| 57 | """Per-tenant session holding the AES-256-GCM key + counters.""" |
| 58 | |
| 59 | session_id: str |
| 60 | tenant: TenantIdentity |
| 61 | symmetric_key: bytes |
| 62 | algorithm: str |
| 63 | created_at: str |
| 64 | expires_at: str |
| 65 | next_sequence: int = 1 |
| 66 | entries_encrypted: int = 0 |
| 67 | |
| 68 | def is_valid(self) -> bool: |
| 69 | try: |
| 70 | exp = datetime.fromisoformat(self.expires_at) |
| 71 | return datetime.now(timezone.utc) <= exp |
| 72 | except ValueError: |
| 73 | return False |
| 74 | |
| 75 | def check_valid(self) -> None: |
| 76 | if not self.is_valid(): |
| 77 | raise SessionExpiredError(f"session {self.session_id} expired") |
| 78 | |
| 79 | def consume_sequence(self) -> int: |
| 80 | seq = self.next_sequence |
| 81 | self.next_sequence += 1 |
| 82 | self.entries_encrypted += 1 |
| 83 | return seq |
| 84 | |
| 85 | def rotate_key(self, new_key: bytes) -> None: |
| 86 | """Replace the symmetric key (used by KeyRotationPolicy).""" |
| 87 | self.symmetric_key = new_key |
| 88 | self.next_sequence = 1 |
| 89 | self.entries_encrypted = 0 |
| 90 | |
| 91 | def to_public_dict(self) -> dict[str, Any]: |
| 92 | """Serialize without the symmetric key - safe for logs/telemetry.""" |
| 93 | return { |
| 94 | "session_id": self.session_id, |
| 95 | "tenant_id": self.tenant.tenant_id, |
| 96 | "tenant_display": self.tenant.display_name, |
| 97 | "algorithm": self.algorithm, |
| 98 | "created_at": self.created_at, |
| 99 | "expires_at": self.expires_at, |
| 100 | "entries_encrypted": self.entries_encrypted, |
| 101 | "is_valid": self.is_valid(), |
| 102 | } |
| 103 | |