src/pqc_kv_cache/isolation.py
| 1 | """TenantIsolationManager - supervises multiple TenantSessions simultaneously.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from dataclasses import dataclass, field |
| 6 | |
| 7 | from pqc_kv_cache.encryptor import CacheDecryptor, CacheEncryptor |
| 8 | from pqc_kv_cache.entry import EncryptedEntry, KVCacheEntry |
| 9 | from pqc_kv_cache.errors import TenantIsolationError, UnknownTenantError |
| 10 | from pqc_kv_cache.session import ( |
| 11 | TenantIdentity, |
| 12 | TenantSession, |
| 13 | establish_tenant_session, |
| 14 | ) |
| 15 | |
| 16 | |
| 17 | @dataclass |
| 18 | class TenantIsolationManager: |
| 19 | """Manages multiple TenantSessions and enforces strict isolation.""" |
| 20 | |
| 21 | sessions: dict[str, TenantSession] = field(default_factory=dict) |
| 22 | |
| 23 | def create_session(self, tenant: TenantIdentity) -> TenantSession: |
| 24 | if ( |
| 25 | tenant.tenant_id in self.sessions |
| 26 | and self.sessions[tenant.tenant_id].is_valid() |
| 27 | ): |
| 28 | return self.sessions[tenant.tenant_id] |
| 29 | session = establish_tenant_session(tenant) |
| 30 | self.sessions[tenant.tenant_id] = session |
| 31 | return session |
| 32 | |
| 33 | def get_session(self, tenant_id: str) -> TenantSession: |
| 34 | if tenant_id not in self.sessions: |
| 35 | raise UnknownTenantError(f"no session for tenant {tenant_id}") |
| 36 | return self.sessions[tenant_id] |
| 37 | |
| 38 | def encrypt(self, tenant_id: str, entry: KVCacheEntry) -> EncryptedEntry: |
| 39 | session = self.get_session(tenant_id) |
| 40 | if entry.metadata.tenant_id != tenant_id: |
| 41 | raise TenantIsolationError( |
| 42 | f"entry tenant {entry.metadata.tenant_id} != provided tenant {tenant_id}" |
| 43 | ) |
| 44 | return CacheEncryptor(session).encrypt_entry(entry) |
| 45 | |
| 46 | def decrypt(self, tenant_id: str, enc: EncryptedEntry) -> KVCacheEntry: |
| 47 | session = self.get_session(tenant_id) |
| 48 | return CacheDecryptor(session).decrypt_entry(enc) |
| 49 | |
| 50 | def close_session(self, tenant_id: str) -> None: |
| 51 | self.sessions.pop(tenant_id, None) |
| 52 | |
| 53 | def list_active_tenants(self) -> list[str]: |
| 54 | return [tid for tid, s in self.sessions.items() if s.is_valid()] |
| 55 | |