examples/basic_channel.py
1.6 KB · 50 lines · python Raw
1 """Establish an encrypted CPU<->GPU channel and round-trip a synthetic tensor.
2
3 Run:
4
5 python examples/basic_channel.py
6 """
7
8 from __future__ import annotations
9
10 import os
11
12 from pqc_gpu_driver import TensorMetadata, establish_channel
13
14
15 def main() -> None:
16 print("[*] Establishing ML-KEM-768 channel between CPU and GPU ...")
17 cpu, gpu = establish_channel(cpu_side_label="inference-host", gpu_side_label="h100-0")
18 print(f" session_id = {cpu.session_id}")
19 print(f" algorithm = {cpu.algorithm}")
20 print(f" key_bytes = {len(cpu.symmetric_key)} (AES-256-GCM)")
21 print(f" expires_at = {cpu.expires_at}")
22
23 # Simulate a weight tensor shipped to the GPU.
24 tensor = os.urandom(2048)
25 meta = TensorMetadata(
26 tensor_id="layer_0.q_proj",
27 name="model.layers.0.self_attn.q_proj.weight",
28 dtype="float32",
29 shape=(512,),
30 size_bytes=len(tensor),
31 transfer_direction="cpu_to_gpu",
32 )
33
34 print("\n[*] CPU side encrypting tensor with AES-256-GCM ...")
35 enc = cpu.encrypt_tensor(tensor, meta)
36 print(f" sequence_number = {enc.sequence_number}")
37 print(f" nonce = {enc.nonce}")
38 print(f" ciphertext_len = {len(enc.ciphertext) // 2} bytes")
39
40 print("\n[*] GPU side decrypting tensor ...")
41 pt = gpu.decrypt_tensor(enc)
42 assert pt == tensor, "round-trip failed!"
43 print(f" decrypted_len = {len(pt)} bytes")
44 print(f" last_recv_seq = {gpu.last_recv_seq}")
45 print("[+] Round-trip OK. Tensor integrity + confidentiality preserved.")
46
47
48 if __name__ == "__main__":
49 main()
50