tests/test_tensor.py
1.4 KB · 44 lines · python Raw
1 """Tests for the EncryptedTensor / TensorMetadata envelope."""
2
3 from __future__ import annotations
4
5 from pqc_gpu_driver import EncryptedTensor, TensorMetadata
6
7
8 def test_tensor_metadata_roundtrip() -> None:
9 meta = TensorMetadata(
10 tensor_id="t-1",
11 name="model.dense_1.weights",
12 dtype="float16",
13 shape=(768, 3072),
14 size_bytes=768 * 3072 * 2,
15 transfer_direction="cpu_to_gpu",
16 )
17 decoded = TensorMetadata.from_dict(meta.to_dict())
18 assert decoded == meta
19
20
21 def test_encrypted_tensor_roundtrip() -> None:
22 meta = TensorMetadata(tensor_id="t-2", shape=(4, 4), size_bytes=64)
23 enc = EncryptedTensor(
24 metadata=meta,
25 nonce="a" * 24,
26 ciphertext="deadbeef" * 8,
27 sequence_number=7,
28 )
29 decoded = EncryptedTensor.from_dict(enc.to_dict())
30 assert decoded.metadata == meta
31 assert decoded.nonce == enc.nonce
32 assert decoded.ciphertext == enc.ciphertext
33 assert decoded.sequence_number == 7
34
35
36 def test_tensor_metadata_preserves_shape_tuple() -> None:
37 meta = TensorMetadata(tensor_id="t-3", shape=(1, 2, 3, 4))
38 # to_dict downgrades shape to list (JSON-compatible)
39 assert meta.to_dict()["shape"] == [1, 2, 3, 4]
40 # from_dict restores to tuple
41 restored = TensorMetadata.from_dict(meta.to_dict())
42 assert isinstance(restored.shape, tuple)
43 assert restored.shape == (1, 2, 3, 4)
44