examples/basic_kv_encryption.py
2.5 KB · 86 lines · python Raw
1 """Basic single-tenant KV cache encryption example.
2
3 Demonstrates:
4 1. Establishing a TenantSession for one tenant.
5 2. Encrypting 3 KV cache entries (simulating 3 token positions).
6 3. Decrypting them back and verifying round-trip.
7 4. Printing the audit log.
8 """
9
10 from __future__ import annotations
11
12 import os
13
14 from pqc_kv_cache import (
15 CacheDecryptor,
16 CacheEncryptor,
17 EntryMetadata,
18 KVAuditLog,
19 KVCacheEntry,
20 TenantIdentity,
21 establish_tenant_session,
22 )
23
24
25 def main() -> None:
26 tenant = TenantIdentity(tenant_id="tenant-alice", display_name="Alice Corp")
27 session = establish_tenant_session(tenant)
28 print(f"Session established: {session.session_id}")
29 print(f"Algorithm: {session.algorithm}")
30 print(f"Expires at: {session.expires_at}")
31
32 encryptor = CacheEncryptor(session)
33 decryptor = CacheDecryptor(session)
34 audit = KVAuditLog()
35
36 # Simulate encrypting K/V for 3 token positions in layer 0
37 encrypted_entries = []
38 originals: list[KVCacheEntry] = []
39 for pos in range(3):
40 meta = EntryMetadata(
41 tenant_id=tenant.tenant_id,
42 session_id=session.session_id,
43 layer_idx=0,
44 position=pos,
45 token_id=1000 + pos,
46 )
47 entry = KVCacheEntry(
48 metadata=meta,
49 key_tensor_bytes=os.urandom(64),
50 value_tensor_bytes=os.urandom(64),
51 )
52 originals.append(entry)
53 enc = encryptor.encrypt_entry(entry)
54 audit.log_encrypt(
55 tenant.tenant_id,
56 session.session_id,
57 meta.layer_idx,
58 meta.position,
59 enc.sequence_number,
60 )
61 encrypted_entries.append(enc)
62 print(f"Encrypted pos={pos} seq={enc.sequence_number} ct_bytes={len(enc.ciphertext) // 2}")
63
64 # Decrypt and verify
65 for orig, enc in zip(originals, encrypted_entries):
66 dec = decryptor.decrypt_entry(enc)
67 assert dec.key_tensor_bytes == orig.key_tensor_bytes
68 assert dec.value_tensor_bytes == orig.value_tensor_bytes
69 audit.log_decrypt(
70 tenant.tenant_id,
71 session.session_id,
72 orig.metadata.layer_idx,
73 orig.metadata.position,
74 enc.sequence_number,
75 success=True,
76 )
77 print(f"Decrypted pos={orig.metadata.position} OK")
78
79 print("\nAudit log entries:")
80 for entry in audit.entries(limit=10):
81 print(f" {entry.timestamp} {entry.operation:8s} seq={entry.sequence_number}")
82
83
84 if __name__ == "__main__":
85 main()
86