src/pqc_gpu_driver/channel.py
| 1 | """ChannelSession - encrypted CPU<->GPU channel using ML-KEM-derived AES-256-GCM key.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | import json |
| 7 | import os |
| 8 | import uuid |
| 9 | from dataclasses import dataclass, field |
| 10 | from datetime import datetime, timedelta, timezone |
| 11 | |
| 12 | from cryptography.hazmat.primitives.ciphers.aead import AESGCM |
| 13 | from quantumshield.core.algorithms import KEMAlgorithm |
| 14 | from quantumshield.core.keys import generate_kem_keypair |
| 15 | |
| 16 | from pqc_gpu_driver.errors import ( |
| 17 | ChannelExpiredError, |
| 18 | DecryptionError, |
| 19 | NonceReplayError, |
| 20 | ) |
| 21 | from pqc_gpu_driver.tensor import EncryptedTensor, TensorMetadata |
| 22 | |
| 23 | NONCE_SIZE = 12 |
| 24 | SESSION_TTL_SECONDS = 3600 |
| 25 | |
| 26 | |
| 27 | def establish_channel( |
| 28 | cpu_side_label: str = "cpu", |
| 29 | gpu_side_label: str = "gpu", |
| 30 | algorithm: KEMAlgorithm = KEMAlgorithm.ML_KEM_768, |
| 31 | ttl_seconds: int = SESSION_TTL_SECONDS, |
| 32 | ) -> tuple[ChannelSession, ChannelSession]: |
| 33 | """Produce two matching ChannelSessions sharing the same symmetric key. |
| 34 | |
| 35 | In production the CPU side runs ML-KEM encapsulation to the GPU's public |
| 36 | key. Here we derive a symmetric key from a fresh ML-KEM keypair so tests |
| 37 | work without liboqs. The two sides can encrypt/decrypt each other. |
| 38 | """ |
| 39 | kp = generate_kem_keypair(algorithm) |
| 40 | # Derive 32-byte symmetric key from the KEM keypair deterministically. |
| 41 | # In a real deployment, this is the output of ML-KEM.Decapsulate on both sides. |
| 42 | shared = hashlib.sha3_256(kp.private_key + kp.public_key).digest() |
| 43 | |
| 44 | session_id = f"urn:pqc-gpu-sess:{uuid.uuid4().hex}" |
| 45 | now = datetime.now(timezone.utc) |
| 46 | expires = now + timedelta(seconds=ttl_seconds) |
| 47 | |
| 48 | cpu_session = ChannelSession( |
| 49 | session_id=session_id, |
| 50 | peer_label=gpu_side_label, |
| 51 | symmetric_key=shared, |
| 52 | algorithm=algorithm.value, |
| 53 | created_at=now.isoformat(), |
| 54 | expires_at=expires.isoformat(), |
| 55 | ) |
| 56 | gpu_session = ChannelSession( |
| 57 | session_id=session_id, |
| 58 | peer_label=cpu_side_label, |
| 59 | symmetric_key=shared, |
| 60 | algorithm=algorithm.value, |
| 61 | created_at=now.isoformat(), |
| 62 | expires_at=expires.isoformat(), |
| 63 | ) |
| 64 | return cpu_session, gpu_session |
| 65 | |
| 66 | |
| 67 | @dataclass |
| 68 | class ChannelSession: |
| 69 | """One side of an encrypted CPU<->GPU channel. |
| 70 | |
| 71 | Encryption: |
| 72 | - AES-256-GCM per tensor transfer |
| 73 | - Nonce = 12 bytes random (unique per message); stored with ciphertext |
| 74 | - AAD = canonical bytes of TensorMetadata + sequence_number |
| 75 | (binds metadata + ordering) |
| 76 | - Sequence number enforced monotonically on recv side (replay protection) |
| 77 | """ |
| 78 | |
| 79 | session_id: str |
| 80 | peer_label: str |
| 81 | symmetric_key: bytes |
| 82 | algorithm: str |
| 83 | created_at: str |
| 84 | expires_at: str |
| 85 | next_send_seq: int = 1 |
| 86 | last_recv_seq: int = 0 |
| 87 | _used_nonces_recent: list[str] = field(default_factory=list) |
| 88 | |
| 89 | def is_valid(self) -> bool: |
| 90 | try: |
| 91 | exp = datetime.fromisoformat(self.expires_at) |
| 92 | return datetime.now(timezone.utc) <= exp |
| 93 | except ValueError: |
| 94 | return False |
| 95 | |
| 96 | def _check_valid(self) -> None: |
| 97 | if not self.is_valid(): |
| 98 | raise ChannelExpiredError(f"session {self.session_id} is expired") |
| 99 | |
| 100 | @staticmethod |
| 101 | def _aad(metadata: TensorMetadata, sequence_number: int) -> bytes: |
| 102 | payload = { |
| 103 | "metadata": metadata.to_dict(), |
| 104 | "sequence_number": sequence_number, |
| 105 | } |
| 106 | return json.dumps( |
| 107 | payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False |
| 108 | ).encode("utf-8") |
| 109 | |
| 110 | def encrypt_tensor( |
| 111 | self, tensor_bytes: bytes, metadata: TensorMetadata |
| 112 | ) -> EncryptedTensor: |
| 113 | """Encrypt tensor bytes for transmission across PCIe.""" |
| 114 | self._check_valid() |
| 115 | nonce = os.urandom(NONCE_SIZE) |
| 116 | seq = self.next_send_seq |
| 117 | self.next_send_seq += 1 |
| 118 | aes = AESGCM(self.symmetric_key) |
| 119 | aad = self._aad(metadata, seq) |
| 120 | ct = aes.encrypt(nonce, tensor_bytes, aad) |
| 121 | return EncryptedTensor( |
| 122 | metadata=metadata, |
| 123 | nonce=nonce.hex(), |
| 124 | ciphertext=ct.hex(), |
| 125 | sequence_number=seq, |
| 126 | ) |
| 127 | |
| 128 | def decrypt_tensor(self, enc: EncryptedTensor) -> bytes: |
| 129 | """Decrypt a tensor received over the channel. |
| 130 | |
| 131 | Enforces strict monotonicity of sequence numbers to prevent replay. |
| 132 | """ |
| 133 | self._check_valid() |
| 134 | if enc.sequence_number <= self.last_recv_seq: |
| 135 | raise NonceReplayError( |
| 136 | f"sequence {enc.sequence_number} <= last_recv {self.last_recv_seq}" |
| 137 | ) |
| 138 | if enc.nonce in self._used_nonces_recent: |
| 139 | raise NonceReplayError(f"nonce {enc.nonce} already seen") |
| 140 | |
| 141 | aes = AESGCM(self.symmetric_key) |
| 142 | aad = self._aad(enc.metadata, enc.sequence_number) |
| 143 | try: |
| 144 | pt = aes.decrypt( |
| 145 | bytes.fromhex(enc.nonce), bytes.fromhex(enc.ciphertext), aad |
| 146 | ) |
| 147 | except Exception as exc: |
| 148 | raise DecryptionError(f"AES-GCM decrypt failed: {exc}") from exc |
| 149 | |
| 150 | self.last_recv_seq = enc.sequence_number |
| 151 | self._used_nonces_recent.append(enc.nonce) |
| 152 | if len(self._used_nonces_recent) > 1024: |
| 153 | self._used_nonces_recent = self._used_nonces_recent[-1024:] |
| 154 | return pt |
| 155 | |