src/pqc_gpu_driver/backends/rocm.py
2.5 KB · 67 lines · python Raw
1 """ROCm backend (stub interface).
2
3 Real integration uses the AMD ROCm HIP runtime. This stub documents the
4 expected shape; users plug in their real syscalls.
5
6 A production implementation of this backend is expected to:
7
8 * Initialize a HIP context via ``hipInit`` / ``hipSetDevice`` for the target
9 AMD GPU (MI300X, MI325X, or future CDNA-class device).
10 * For :meth:`upload`, allocate device memory with ``hipMalloc`` and copy the
11 ciphertext bytes of the :class:`EncryptedTensor` from pinned host memory
12 with ``hipMemcpy`` (host-to-device). Register the pointer with HIP-IPC if
13 cross-process sharing is required.
14 * For :meth:`download`, issue ``hipMemcpy`` (device-to-host) from the device
15 buffer associated with ``device_handle`` back into a host buffer and return
16 it wrapped in an :class:`EncryptedTensor`.
17 * For :meth:`free`, call ``hipFree`` and drop the IPC handle.
18 * Keep tensor bytes encrypted at rest; plaintext exists only inside the
19 workload's trusted compute boundary.
20 """
21
22 from __future__ import annotations
23
24 from pqc_gpu_driver.backends.base import GPUBackend
25 from pqc_gpu_driver.errors import BackendError
26 from pqc_gpu_driver.tensor import EncryptedTensor
27
28
29 class ROCmBackend(GPUBackend):
30 """Stub AMD ROCm backend.
31
32 Raises :class:`BackendError` when invoked without real runtime wiring.
33 """
34
35 name = "rocm"
36 device_type = "rocm"
37
38 def __init__(self, device_index: int = 0) -> None:
39 self.device_index = device_index
40
41 def upload(self, tensor: EncryptedTensor) -> str:
42 raise BackendError(
43 "ROCmBackend.upload is a stub. A real implementation allocates "
44 f"device memory on HIP device {self.device_index} via hipMalloc "
45 "and copies the ciphertext bytes with hipMemcpy (HostToDevice)."
46 )
47
48 def download(self, device_handle: str) -> EncryptedTensor:
49 raise BackendError(
50 "ROCmBackend.download is a stub. A real implementation issues "
51 f"hipMemcpy (DeviceToHost) for handle {device_handle} to pull "
52 "ciphertext bytes back to host memory."
53 )
54
55 def free(self, device_handle: str) -> None:
56 raise BackendError(
57 "ROCmBackend.free is a stub. A real implementation calls hipFree "
58 f"on the device pointer for {device_handle} and drops any "
59 "HIP-IPC handles."
60 )
61
62 def device_info(self) -> dict:
63 raise BackendError(
64 "ROCmBackend.device_info is a stub. A real implementation reads "
65 f"device {self.device_index} properties via hipGetDeviceProperties."
66 )
67