tests/test_tensor.py
| 1 | """Tests for the EncryptedTensor / TensorMetadata envelope.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from pqc_gpu_driver import EncryptedTensor, TensorMetadata |
| 6 | |
| 7 | |
| 8 | def test_tensor_metadata_roundtrip() -> None: |
| 9 | meta = TensorMetadata( |
| 10 | tensor_id="t-1", |
| 11 | name="model.dense_1.weights", |
| 12 | dtype="float16", |
| 13 | shape=(768, 3072), |
| 14 | size_bytes=768 * 3072 * 2, |
| 15 | transfer_direction="cpu_to_gpu", |
| 16 | ) |
| 17 | decoded = TensorMetadata.from_dict(meta.to_dict()) |
| 18 | assert decoded == meta |
| 19 | |
| 20 | |
| 21 | def test_encrypted_tensor_roundtrip() -> None: |
| 22 | meta = TensorMetadata(tensor_id="t-2", shape=(4, 4), size_bytes=64) |
| 23 | enc = EncryptedTensor( |
| 24 | metadata=meta, |
| 25 | nonce="a" * 24, |
| 26 | ciphertext="deadbeef" * 8, |
| 27 | sequence_number=7, |
| 28 | ) |
| 29 | decoded = EncryptedTensor.from_dict(enc.to_dict()) |
| 30 | assert decoded.metadata == meta |
| 31 | assert decoded.nonce == enc.nonce |
| 32 | assert decoded.ciphertext == enc.ciphertext |
| 33 | assert decoded.sequence_number == 7 |
| 34 | |
| 35 | |
| 36 | def test_tensor_metadata_preserves_shape_tuple() -> None: |
| 37 | meta = TensorMetadata(tensor_id="t-3", shape=(1, 2, 3, 4)) |
| 38 | # to_dict downgrades shape to list (JSON-compatible) |
| 39 | assert meta.to_dict()["shape"] == [1, 2, 3, 4] |
| 40 | # from_dict restores to tuple |
| 41 | restored = TensorMetadata.from_dict(meta.to_dict()) |
| 42 | assert isinstance(restored.shape, tuple) |
| 43 | assert restored.shape == (1, 2, 3, 4) |
| 44 | |