tests/test_backends.py
2.2 KB · 81 lines · python Raw
1 """Tests for the GPU backends."""
2
3 from __future__ import annotations
4
5 import pytest
6
7 from pqc_gpu_driver import (
8 BackendError,
9 CUDABackend,
10 InMemoryBackend,
11 ROCmBackend,
12 TensorMetadata,
13 establish_channel,
14 )
15
16
17 def _fake_tensor(random_tensor_bytes: bytes):
18 cpu, _gpu = establish_channel()
19 meta = TensorMetadata(
20 tensor_id="t-backend",
21 name="some.weights",
22 dtype="float32",
23 shape=(len(random_tensor_bytes) // 4,),
24 size_bytes=len(random_tensor_bytes),
25 )
26 return cpu.encrypt_tensor(random_tensor_bytes, meta)
27
28
29 def test_in_memory_backend_upload_download_roundtrip(
30 random_tensor_bytes: bytes,
31 ) -> None:
32 be = InMemoryBackend()
33 enc = _fake_tensor(random_tensor_bytes)
34 handle = be.upload(enc)
35 assert handle.startswith("mem:")
36 pulled = be.download(handle)
37 assert pulled.ciphertext == enc.ciphertext
38 assert pulled.nonce == enc.nonce
39 assert pulled.sequence_number == enc.sequence_number
40 info = be.device_info()
41 assert info["device_type"] == "in-memory"
42 assert info["live_handles"] == 1
43
44
45 def test_in_memory_backend_free_removes_handle(
46 random_tensor_bytes: bytes,
47 ) -> None:
48 be = InMemoryBackend()
49 handle = be.upload(_fake_tensor(random_tensor_bytes))
50 be.free(handle)
51 with pytest.raises(BackendError):
52 be.download(handle)
53 with pytest.raises(BackendError):
54 be.free(handle)
55
56
57 def test_cuda_backend_raises_backend_error(random_tensor_bytes: bytes) -> None:
58 be = CUDABackend(device_index=0)
59 enc = _fake_tensor(random_tensor_bytes)
60 with pytest.raises(BackendError):
61 be.upload(enc)
62 with pytest.raises(BackendError):
63 be.download("cuda:fake")
64 with pytest.raises(BackendError):
65 be.free("cuda:fake")
66 with pytest.raises(BackendError):
67 be.device_info()
68
69
70 def test_rocm_backend_raises_backend_error(random_tensor_bytes: bytes) -> None:
71 be = ROCmBackend(device_index=0)
72 enc = _fake_tensor(random_tensor_bytes)
73 with pytest.raises(BackendError):
74 be.upload(enc)
75 with pytest.raises(BackendError):
76 be.download("hip:fake")
77 with pytest.raises(BackendError):
78 be.free("hip:fake")
79 with pytest.raises(BackendError):
80 be.device_info()
81