tests/test_encryptor.py
3.9 KB · 122 lines · python Raw
1 """Tests for CacheEncryptor / CacheDecryptor."""
2
3 from __future__ import annotations
4
5 import dataclasses
6
7 import pytest
8
9 from pqc_kv_cache.encryptor import CacheDecryptor, CacheEncryptor
10 from pqc_kv_cache.entry import EncryptedEntry, EntryMetadata, KVCacheEntry
11 from pqc_kv_cache.errors import (
12 DecryptionError,
13 NonceReplayError,
14 TenantIsolationError,
15 )
16
17
18 def test_encrypt_returns_encrypted_entry(
19 session_alice, sample_entry_factory
20 ) -> None:
21 entry = sample_entry_factory(session_alice, 0, 0)
22 enc = CacheEncryptor(session_alice).encrypt_entry(entry)
23 assert isinstance(enc, EncryptedEntry)
24 assert len(bytes.fromhex(enc.nonce)) == 12
25 assert len(bytes.fromhex(enc.ciphertext)) > 0
26 assert enc.sequence_number == 1
27
28
29 def test_encrypt_then_decrypt_roundtrip(
30 session_alice, sample_entry_factory
31 ) -> None:
32 entry = sample_entry_factory(session_alice, 2, 5)
33 enc = CacheEncryptor(session_alice).encrypt_entry(entry)
34 dec = CacheDecryptor(session_alice).decrypt_entry(enc)
35 assert dec.metadata == entry.metadata
36 assert dec.key_tensor_bytes == entry.key_tensor_bytes
37 assert dec.value_tensor_bytes == entry.value_tensor_bytes
38
39
40 def test_tenant_id_mismatch_on_encrypt_raises(
41 session_alice, tenant_bob
42 ) -> None:
43 foreign_meta = EntryMetadata(
44 tenant_id=tenant_bob.tenant_id,
45 session_id=session_alice.session_id,
46 layer_idx=0,
47 position=0,
48 )
49 entry = KVCacheEntry(
50 metadata=foreign_meta,
51 key_tensor_bytes=b"\x00" * 32,
52 value_tensor_bytes=b"\x00" * 32,
53 )
54 with pytest.raises(TenantIsolationError):
55 CacheEncryptor(session_alice).encrypt_entry(entry)
56
57
58 def test_wrong_tenant_session_decrypt_raises(
59 session_alice, session_bob, sample_entry_factory
60 ) -> None:
61 entry = sample_entry_factory(session_alice, 0, 0)
62 enc = CacheEncryptor(session_alice).encrypt_entry(entry)
63 with pytest.raises(TenantIsolationError):
64 CacheDecryptor(session_bob).decrypt_entry(enc)
65
66
67 def test_sequence_counter_increments(
68 session_alice, sample_entry_factory
69 ) -> None:
70 enc1 = CacheEncryptor(session_alice).encrypt_entry(
71 sample_entry_factory(session_alice, 0, 0)
72 )
73 enc2 = CacheEncryptor(session_alice).encrypt_entry(
74 sample_entry_factory(session_alice, 0, 1)
75 )
76 enc3 = CacheEncryptor(session_alice).encrypt_entry(
77 sample_entry_factory(session_alice, 0, 2)
78 )
79 assert (enc1.sequence_number, enc2.sequence_number, enc3.sequence_number) == (1, 2, 3)
80
81
82 def test_aad_tamper_detected(session_alice, sample_entry_factory) -> None:
83 entry = sample_entry_factory(session_alice, 0, 0)
84 enc = CacheEncryptor(session_alice).encrypt_entry(entry)
85 tampered_meta = dataclasses.replace(enc.metadata, layer_idx=99)
86 tampered = EncryptedEntry(
87 metadata=tampered_meta,
88 nonce=enc.nonce,
89 ciphertext=enc.ciphertext,
90 key_len=enc.key_len,
91 sequence_number=enc.sequence_number,
92 )
93 with pytest.raises(DecryptionError):
94 CacheDecryptor(session_alice).decrypt_entry(tampered)
95
96
97 def test_ciphertext_tamper_detected(
98 session_alice, sample_entry_factory
99 ) -> None:
100 entry = sample_entry_factory(session_alice, 0, 0)
101 enc = CacheEncryptor(session_alice).encrypt_entry(entry)
102 flipped = bytearray(bytes.fromhex(enc.ciphertext))
103 flipped[0] ^= 0x01
104 tampered = EncryptedEntry(
105 metadata=enc.metadata,
106 nonce=enc.nonce,
107 ciphertext=flipped.hex(),
108 key_len=enc.key_len,
109 sequence_number=enc.sequence_number,
110 )
111 with pytest.raises(DecryptionError):
112 CacheDecryptor(session_alice).decrypt_entry(tampered)
113
114
115 def test_replay_detected(session_alice, sample_entry_factory) -> None:
116 entry = sample_entry_factory(session_alice, 0, 0)
117 enc = CacheEncryptor(session_alice).encrypt_entry(entry)
118 dec = CacheDecryptor(session_alice)
119 dec.decrypt_entry(enc)
120 with pytest.raises(NonceReplayError):
121 dec.decrypt_entry(enc)
122