tests/test_channel.py
3.8 KB · 121 lines · python Raw
1 """Tests for ChannelSession and establish_channel()."""
2
3 from __future__ import annotations
4
5 import time
6
7 import pytest
8
9 from pqc_gpu_driver import (
10 ChannelExpiredError,
11 DecryptionError,
12 NonceReplayError,
13 TensorMetadata,
14 establish_channel,
15 )
16
17
18 def _meta(tid: str, size: int) -> TensorMetadata:
19 return TensorMetadata(
20 tensor_id=tid,
21 name="model.layer_0.weights",
22 dtype="float32",
23 shape=(size // 4,),
24 size_bytes=size,
25 transfer_direction="cpu_to_gpu",
26 )
27
28
29 def test_establish_channel_returns_matching_sessions() -> None:
30 cpu, gpu = establish_channel()
31 assert cpu.session_id == gpu.session_id
32 assert cpu.symmetric_key == gpu.symmetric_key
33 assert cpu.peer_label == "gpu"
34 assert gpu.peer_label == "cpu"
35 assert cpu.algorithm == "ML-KEM-768"
36 assert cpu.is_valid()
37 assert gpu.is_valid()
38
39
40 def test_encrypt_decrypt_roundtrip(random_tensor_bytes: bytes) -> None:
41 cpu, gpu = establish_channel()
42 meta = _meta("t-1", len(random_tensor_bytes))
43 enc = cpu.encrypt_tensor(random_tensor_bytes, meta)
44 pt = gpu.decrypt_tensor(enc)
45 assert pt == random_tensor_bytes
46 assert enc.sequence_number == 1
47 assert cpu.next_send_seq == 2
48
49
50 def test_decrypt_with_wrong_nonce_fails(random_tensor_bytes: bytes) -> None:
51 cpu, gpu = establish_channel()
52 meta = _meta("t-1", len(random_tensor_bytes))
53 enc = cpu.encrypt_tensor(random_tensor_bytes, meta)
54 # Flip a byte of the nonce.
55 flipped_nonce = ("0" if enc.nonce[0] != "0" else "1") + enc.nonce[1:]
56 enc.nonce = flipped_nonce
57 with pytest.raises(DecryptionError):
58 gpu.decrypt_tensor(enc)
59
60
61 def test_aad_tamper_detected(random_tensor_bytes: bytes) -> None:
62 cpu, gpu = establish_channel()
63 meta = _meta("t-1", len(random_tensor_bytes))
64 enc = cpu.encrypt_tensor(random_tensor_bytes, meta)
65
66 # Swap metadata post-encrypt - AAD mismatch should surface as DecryptionError.
67 enc.metadata = TensorMetadata(
68 tensor_id="t-999",
69 name="attacker.renamed",
70 dtype=meta.dtype,
71 shape=meta.shape,
72 size_bytes=meta.size_bytes,
73 transfer_direction=meta.transfer_direction,
74 )
75 with pytest.raises(DecryptionError):
76 gpu.decrypt_tensor(enc)
77
78
79 def test_replay_rejected(random_tensor_bytes: bytes) -> None:
80 cpu, gpu = establish_channel()
81 meta = _meta("t-1", len(random_tensor_bytes))
82 enc = cpu.encrypt_tensor(random_tensor_bytes, meta)
83 assert gpu.decrypt_tensor(enc) == random_tensor_bytes
84 # Replay same envelope.
85 with pytest.raises(NonceReplayError):
86 gpu.decrypt_tensor(enc)
87
88
89 def test_lower_sequence_rejected(random_tensor_bytes: bytes) -> None:
90 cpu, gpu = establish_channel()
91 meta = _meta("t-1", len(random_tensor_bytes))
92 enc1 = cpu.encrypt_tensor(random_tensor_bytes, meta)
93 enc2 = cpu.encrypt_tensor(random_tensor_bytes, meta)
94
95 gpu.decrypt_tensor(enc2)
96 with pytest.raises(NonceReplayError):
97 gpu.decrypt_tensor(enc1)
98
99
100 def test_expired_session_raises(random_tensor_bytes: bytes) -> None:
101 cpu, _gpu = establish_channel(ttl_seconds=0)
102 # Ensure current time drifts past expires_at even with clock granularity.
103 time.sleep(0.05)
104 meta = _meta("t-1", len(random_tensor_bytes))
105 with pytest.raises(ChannelExpiredError):
106 cpu.encrypt_tensor(random_tensor_bytes, meta)
107
108
109 def test_sequence_numbers_increment(random_tensor_bytes: bytes) -> None:
110 cpu, _gpu = establish_channel()
111 meta = _meta("t-1", len(random_tensor_bytes))
112 enc1 = cpu.encrypt_tensor(random_tensor_bytes, meta)
113 enc2 = cpu.encrypt_tensor(random_tensor_bytes, meta)
114 enc3 = cpu.encrypt_tensor(random_tensor_bytes, meta)
115 assert [enc1.sequence_number, enc2.sequence_number, enc3.sequence_number] == [
116 1,
117 2,
118 3,
119 ]
120 assert cpu.next_send_seq == 4
121