examples/basic_attestation.py
2.3 KB · 82 lines · python Raw
1 """Basic attestation example.
2
3 Registers two memory regions (model weights + activation cache) in the
4 in-memory backend, signs an AttestationReport with ML-DSA, and verifies
5 it from scratch — the minimum working flow.
6 """
7
8 from __future__ import annotations
9
10 from quantumshield.identity.agent import AgentIdentity
11
12 from pqc_hypervisor_attestation import (
13 AttestationVerifier,
14 Attester,
15 ContinuousAttester,
16 InMemoryBackend,
17 MemoryRegion,
18 RegionSnapshot,
19 )
20
21 WORKLOAD_ID = "model-serving-1"
22
23
24 def main() -> None:
25 # 1. Build an attester identity with an ML-DSA keypair.
26 identity = AgentIdentity.create(
27 name="llama-host-attester",
28 capabilities=["attest"],
29 )
30 attester = Attester(identity)
31
32 # 2. Register two in-memory regions with content.
33 backend = InMemoryBackend()
34 weights = MemoryRegion(
35 region_id="model-weights-0",
36 description="Llama weight shard 0",
37 address=0x1000,
38 size=128,
39 protection="RO",
40 )
41 cache = MemoryRegion(
42 region_id="activation-cache",
43 description="KV cache for in-flight request",
44 address=0x2000,
45 size=64,
46 protection="RW",
47 )
48 weights_bytes = b"\xaa" * 128
49 cache_bytes = b"\xbb" * 64
50 backend.register(WORKLOAD_ID, weights, weights_bytes)
51 backend.register(WORKLOAD_ID, cache, cache_bytes)
52
53 # 3. Pin expected hashes computed at VM boot.
54 expected = {
55 weights.region_id: RegionSnapshot.hash_bytes(weights_bytes),
56 cache.region_id: RegionSnapshot.hash_bytes(cache_bytes),
57 }
58
59 # 4. Attest once.
60 loop = ContinuousAttester(
61 attester=attester,
62 backend=backend,
63 workload_id=WORKLOAD_ID,
64 expected_hashes=expected,
65 )
66 report = loop.attest_once()
67 print(f"signed report : {report.report_id}")
68 print(f"attester did : {report.signer_did}")
69 print(f"algorithm : {report.algorithm}")
70 print(f"claims : {len(report.claims)}")
71
72 # 5. Verify.
73 result = AttestationVerifier.verify(report, strict=True)
74 print(f"valid : {result.valid}")
75 print(f"signature_valid : {result.signature_valid}")
76 print(f"not_expired : {result.not_expired}")
77 print(f"drifts : {result.drifts}")
78
79
80 if __name__ == "__main__":
81 main()
82