src/pqc_kv_cache/session.py
3.2 KB · 103 lines · python Raw
1 """TenantSession - per-tenant symmetric key derived from ML-KEM-768."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import uuid
7 from dataclasses import dataclass
8 from datetime import datetime, timedelta, timezone
9 from typing import Any
10
11 from quantumshield.core.algorithms import KEMAlgorithm
12 from quantumshield.core.keys import generate_kem_keypair
13
14 from pqc_kv_cache.errors import SessionExpiredError
15
16 SESSION_TTL_SECONDS = 900 # 15 minutes default
17
18
19 @dataclass(frozen=True)
20 class TenantIdentity:
21 """Identifier for a tenant (user/org/session)."""
22
23 tenant_id: str
24 display_name: str = ""
25
26
27 def establish_tenant_session(
28 tenant: TenantIdentity,
29 algorithm: KEMAlgorithm = KEMAlgorithm.ML_KEM_768,
30 ttl_seconds: int = SESSION_TTL_SECONDS,
31 ) -> TenantSession:
32 """Derive a fresh per-tenant session key via ML-KEM-768.
33
34 In production: the tenant supplies their KEM public key and the inference
35 server runs Encapsulate to derive the shared symmetric key. Here we
36 generate a fresh keypair and derive a 32-byte AES key deterministically
37 from the keypair so the pattern works under the Ed25519 fallback backend.
38 """
39 kp = generate_kem_keypair(algorithm)
40 # Deterministic derivation from the keypair's raw bytes
41 symmetric_key = hashlib.sha3_256(kp.private_key + kp.public_key).digest()
42 session_id = f"urn:pqc-kv-sess:{uuid.uuid4().hex}"
43 now = datetime.now(timezone.utc)
44 exp = now + timedelta(seconds=ttl_seconds)
45 return TenantSession(
46 session_id=session_id,
47 tenant=tenant,
48 symmetric_key=symmetric_key,
49 algorithm=algorithm.value,
50 created_at=now.isoformat(),
51 expires_at=exp.isoformat(),
52 )
53
54
55 @dataclass
56 class TenantSession:
57 """Per-tenant session holding the AES-256-GCM key + counters."""
58
59 session_id: str
60 tenant: TenantIdentity
61 symmetric_key: bytes
62 algorithm: str
63 created_at: str
64 expires_at: str
65 next_sequence: int = 1
66 entries_encrypted: int = 0
67
68 def is_valid(self) -> bool:
69 try:
70 exp = datetime.fromisoformat(self.expires_at)
71 return datetime.now(timezone.utc) <= exp
72 except ValueError:
73 return False
74
75 def check_valid(self) -> None:
76 if not self.is_valid():
77 raise SessionExpiredError(f"session {self.session_id} expired")
78
79 def consume_sequence(self) -> int:
80 seq = self.next_sequence
81 self.next_sequence += 1
82 self.entries_encrypted += 1
83 return seq
84
85 def rotate_key(self, new_key: bytes) -> None:
86 """Replace the symmetric key (used by KeyRotationPolicy)."""
87 self.symmetric_key = new_key
88 self.next_sequence = 1
89 self.entries_encrypted = 0
90
91 def to_public_dict(self) -> dict[str, Any]:
92 """Serialize without the symmetric key - safe for logs/telemetry."""
93 return {
94 "session_id": self.session_id,
95 "tenant_id": self.tenant.tenant_id,
96 "tenant_display": self.tenant.display_name,
97 "algorithm": self.algorithm,
98 "created_at": self.created_at,
99 "expires_at": self.expires_at,
100 "entries_encrypted": self.entries_encrypted,
101 "is_valid": self.is_valid(),
102 }
103