src/pqc_gpu_driver/tensor.py
2.0 KB · 62 lines · python Raw
1 """Encrypted tensor envelope for CPU<->GPU transfers."""
2
3 from __future__ import annotations
4
5 from dataclasses import asdict, dataclass
6 from typing import Any
7
8
9 @dataclass(frozen=True)
10 class TensorMetadata:
11 """Non-secret metadata about a tensor transfer."""
12
13 tensor_id: str # stable id within a session
14 name: str = "" # e.g. "model.dense_1.weights"
15 dtype: str = "float32"
16 shape: tuple[int, ...] = ()
17 size_bytes: int = 0
18 transfer_direction: str = "cpu_to_gpu" # "cpu_to_gpu" | "gpu_to_cpu"
19
20 def to_dict(self) -> dict[str, Any]:
21 d = asdict(self)
22 d["shape"] = list(self.shape)
23 return d
24
25 @classmethod
26 def from_dict(cls, data: dict[str, Any]) -> TensorMetadata:
27 return cls(
28 tensor_id=data["tensor_id"],
29 name=data.get("name", ""),
30 dtype=data.get("dtype", "float32"),
31 shape=tuple(data.get("shape", [])),
32 size_bytes=int(data.get("size_bytes", 0)),
33 transfer_direction=data.get("transfer_direction", "cpu_to_gpu"),
34 )
35
36
37 @dataclass
38 class EncryptedTensor:
39 """AES-256-GCM encrypted tensor bytes + metadata authenticated via AAD."""
40
41 metadata: TensorMetadata
42 nonce: str # hex (12 bytes = 24 hex chars)
43 ciphertext: str # hex
44 sequence_number: int # strictly-increasing per session
45
46 def to_dict(self) -> dict[str, Any]:
47 return {
48 "metadata": self.metadata.to_dict(),
49 "nonce": self.nonce,
50 "ciphertext": self.ciphertext,
51 "sequence_number": self.sequence_number,
52 }
53
54 @classmethod
55 def from_dict(cls, data: dict[str, Any]) -> EncryptedTensor:
56 return cls(
57 metadata=TensorMetadata.from_dict(data["metadata"]),
58 nonce=data["nonce"],
59 ciphertext=data["ciphertext"],
60 sequence_number=int(data["sequence_number"]),
61 )
62