tests/test_integration.py
2.6 KB · 84 lines · python Raw
1 """End-to-end integration tests."""
2
3 from __future__ import annotations
4
5 import pytest
6 from quantumshield.identity.agent import AgentIdentity
7
8 from pqc_gpu_driver import (
9 DriverAttestationError,
10 DriverAttestationVerifier,
11 DriverAttester,
12 DriverModule,
13 InMemoryBackend,
14 TensorMetadata,
15 establish_channel,
16 )
17
18
19 def test_full_flow_attest_channel_upload_download_decrypt(
20 attester: DriverAttester,
21 sample_module: DriverModule,
22 sample_module_bytes: bytes,
23 random_tensor_bytes: bytes,
24 ) -> None:
25 # 1. Attest the driver module and verify it against the allow-list.
26 attestation = attester.attest(sample_module)
27 verifier = DriverAttestationVerifier(trusted_signers={attester.identity.did})
28 verifier.verify_or_raise(attestation, actual_module_bytes=sample_module_bytes)
29
30 # 2. Establish encrypted CPU<->GPU channel.
31 cpu, gpu = establish_channel()
32
33 # 3. Encrypt tensor on the CPU side.
34 meta = TensorMetadata(
35 tensor_id="llama-layer-0-w",
36 name="layer_0.self_attn.q_proj.weight",
37 dtype="float32",
38 shape=(len(random_tensor_bytes) // 4,),
39 size_bytes=len(random_tensor_bytes),
40 transfer_direction="cpu_to_gpu",
41 )
42 enc = cpu.encrypt_tensor(random_tensor_bytes, meta)
43
44 # 4. Upload to the backend (encrypted bytes only).
45 backend = InMemoryBackend()
46 handle = backend.upload(enc)
47
48 # 5. Download from backend on GPU side.
49 pulled = backend.download(handle)
50
51 # 6. GPU side decrypts - must match original plaintext bit-for-bit.
52 decrypted = gpu.decrypt_tensor(pulled)
53 assert decrypted == random_tensor_bytes
54
55 backend.free(handle)
56
57
58 def test_byzantine_untrusted_signer_rejected(
59 untrusted_identity: AgentIdentity,
60 trusted_identity: AgentIdentity,
61 sample_module: DriverModule,
62 sample_module_bytes: bytes,
63 ) -> None:
64 # Attacker (untrusted) attempts to attest a module.
65 rogue_attester = DriverAttester(untrusted_identity)
66 rogue_attestation = rogue_attester.attest(sample_module)
67
68 # Verifier only trusts the vendor identity.
69 verifier = DriverAttestationVerifier(trusted_signers={trusted_identity.did})
70
71 result = verifier.verify(
72 rogue_attestation, actual_module_bytes=sample_module_bytes
73 )
74 # Signature verifies, hash matches - but signer is not trusted.
75 assert result.valid is False
76 assert result.trusted is False
77 assert "not in trusted set" in (result.error or "")
78
79 # verify_or_raise should also refuse.
80 with pytest.raises(DriverAttestationError):
81 verifier.verify_or_raise(
82 rogue_attestation, actual_module_bytes=sample_module_bytes
83 )
84