src/pqc_gpu_driver/channel.py
5.1 KB · 155 lines · python Raw
1 """ChannelSession - encrypted CPU<->GPU channel using ML-KEM-derived AES-256-GCM key."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import json
7 import os
8 import uuid
9 from dataclasses import dataclass, field
10 from datetime import datetime, timedelta, timezone
11
12 from cryptography.hazmat.primitives.ciphers.aead import AESGCM
13 from quantumshield.core.algorithms import KEMAlgorithm
14 from quantumshield.core.keys import generate_kem_keypair
15
16 from pqc_gpu_driver.errors import (
17 ChannelExpiredError,
18 DecryptionError,
19 NonceReplayError,
20 )
21 from pqc_gpu_driver.tensor import EncryptedTensor, TensorMetadata
22
23 NONCE_SIZE = 12
24 SESSION_TTL_SECONDS = 3600
25
26
27 def establish_channel(
28 cpu_side_label: str = "cpu",
29 gpu_side_label: str = "gpu",
30 algorithm: KEMAlgorithm = KEMAlgorithm.ML_KEM_768,
31 ttl_seconds: int = SESSION_TTL_SECONDS,
32 ) -> tuple[ChannelSession, ChannelSession]:
33 """Produce two matching ChannelSessions sharing the same symmetric key.
34
35 In production the CPU side runs ML-KEM encapsulation to the GPU's public
36 key. Here we derive a symmetric key from a fresh ML-KEM keypair so tests
37 work without liboqs. The two sides can encrypt/decrypt each other.
38 """
39 kp = generate_kem_keypair(algorithm)
40 # Derive 32-byte symmetric key from the KEM keypair deterministically.
41 # In a real deployment, this is the output of ML-KEM.Decapsulate on both sides.
42 shared = hashlib.sha3_256(kp.private_key + kp.public_key).digest()
43
44 session_id = f"urn:pqc-gpu-sess:{uuid.uuid4().hex}"
45 now = datetime.now(timezone.utc)
46 expires = now + timedelta(seconds=ttl_seconds)
47
48 cpu_session = ChannelSession(
49 session_id=session_id,
50 peer_label=gpu_side_label,
51 symmetric_key=shared,
52 algorithm=algorithm.value,
53 created_at=now.isoformat(),
54 expires_at=expires.isoformat(),
55 )
56 gpu_session = ChannelSession(
57 session_id=session_id,
58 peer_label=cpu_side_label,
59 symmetric_key=shared,
60 algorithm=algorithm.value,
61 created_at=now.isoformat(),
62 expires_at=expires.isoformat(),
63 )
64 return cpu_session, gpu_session
65
66
67 @dataclass
68 class ChannelSession:
69 """One side of an encrypted CPU<->GPU channel.
70
71 Encryption:
72 - AES-256-GCM per tensor transfer
73 - Nonce = 12 bytes random (unique per message); stored with ciphertext
74 - AAD = canonical bytes of TensorMetadata + sequence_number
75 (binds metadata + ordering)
76 - Sequence number enforced monotonically on recv side (replay protection)
77 """
78
79 session_id: str
80 peer_label: str
81 symmetric_key: bytes
82 algorithm: str
83 created_at: str
84 expires_at: str
85 next_send_seq: int = 1
86 last_recv_seq: int = 0
87 _used_nonces_recent: list[str] = field(default_factory=list)
88
89 def is_valid(self) -> bool:
90 try:
91 exp = datetime.fromisoformat(self.expires_at)
92 return datetime.now(timezone.utc) <= exp
93 except ValueError:
94 return False
95
96 def _check_valid(self) -> None:
97 if not self.is_valid():
98 raise ChannelExpiredError(f"session {self.session_id} is expired")
99
100 @staticmethod
101 def _aad(metadata: TensorMetadata, sequence_number: int) -> bytes:
102 payload = {
103 "metadata": metadata.to_dict(),
104 "sequence_number": sequence_number,
105 }
106 return json.dumps(
107 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
108 ).encode("utf-8")
109
110 def encrypt_tensor(
111 self, tensor_bytes: bytes, metadata: TensorMetadata
112 ) -> EncryptedTensor:
113 """Encrypt tensor bytes for transmission across PCIe."""
114 self._check_valid()
115 nonce = os.urandom(NONCE_SIZE)
116 seq = self.next_send_seq
117 self.next_send_seq += 1
118 aes = AESGCM(self.symmetric_key)
119 aad = self._aad(metadata, seq)
120 ct = aes.encrypt(nonce, tensor_bytes, aad)
121 return EncryptedTensor(
122 metadata=metadata,
123 nonce=nonce.hex(),
124 ciphertext=ct.hex(),
125 sequence_number=seq,
126 )
127
128 def decrypt_tensor(self, enc: EncryptedTensor) -> bytes:
129 """Decrypt a tensor received over the channel.
130
131 Enforces strict monotonicity of sequence numbers to prevent replay.
132 """
133 self._check_valid()
134 if enc.sequence_number <= self.last_recv_seq:
135 raise NonceReplayError(
136 f"sequence {enc.sequence_number} <= last_recv {self.last_recv_seq}"
137 )
138 if enc.nonce in self._used_nonces_recent:
139 raise NonceReplayError(f"nonce {enc.nonce} already seen")
140
141 aes = AESGCM(self.symmetric_key)
142 aad = self._aad(enc.metadata, enc.sequence_number)
143 try:
144 pt = aes.decrypt(
145 bytes.fromhex(enc.nonce), bytes.fromhex(enc.ciphertext), aad
146 )
147 except Exception as exc:
148 raise DecryptionError(f"AES-GCM decrypt failed: {exc}") from exc
149
150 self.last_recv_seq = enc.sequence_number
151 self._used_nonces_recent.append(enc.nonce)
152 if len(self._used_nonces_recent) > 1024:
153 self._used_nonces_recent = self._used_nonces_recent[-1024:]
154 return pt
155