examples/tensor_tamper_detection.py
1.4 KB · 50 lines · python Raw
1 """Show that a single-bit flip in an encrypted tensor is detected.
2
3 AES-256-GCM ties ciphertext and AAD (metadata + sequence number) into one
4 authentication tag. Any bit flip in ciphertext or tampering with metadata
5 causes decrypt_tensor() to raise DecryptionError.
6
7 Run:
8
9 python examples/tensor_tamper_detection.py
10 """
11
12 from __future__ import annotations
13
14 import os
15
16 from pqc_gpu_driver import DecryptionError, TensorMetadata, establish_channel
17
18
19 def main() -> None:
20 cpu, gpu = establish_channel()
21 tensor = os.urandom(512)
22 meta = TensorMetadata(
23 tensor_id="t-tamper",
24 name="model.fc.weight",
25 dtype="float32",
26 shape=(128,),
27 size_bytes=len(tensor),
28 )
29
30 print("[*] Encrypting tensor on CPU side ...")
31 enc = cpu.encrypt_tensor(tensor, meta)
32 print(f" original ciphertext prefix = {enc.ciphertext[:32]}...")
33
34 print("\n[*] Attacker flips one byte of ciphertext over PCIe ...")
35 ct = bytearray(bytes.fromhex(enc.ciphertext))
36 ct[0] ^= 0xFF
37 enc.ciphertext = bytes(ct).hex()
38 print(f" tampered ciphertext prefix = {enc.ciphertext[:32]}...")
39
40 print("\n[*] GPU side attempts decryption ...")
41 try:
42 gpu.decrypt_tensor(enc)
43 print("[-] FAIL: tamper went undetected.")
44 except DecryptionError as exc:
45 print(f"[+] Tamper detected. DecryptionError: {exc}")
46
47
48 if __name__ == "__main__":
49 main()
50