examples/store_model_weights.py
2.2 KB · 74 lines · python Raw
1 """Store a 256 KB block of model weights in an in-memory enclave vault.
2
3 Demonstrates the full lifecycle: unlock -> put -> save -> lock -> unlock -> get.
4 The InMemoryEnclaveBackend is used so the example runs without any platform
5 secure element. A production deployment swaps in iOSEnclaveBackend,
6 AndroidEnclaveBackend, or QSEEBackend.
7 """
8
9 from __future__ import annotations
10
11 import os
12
13 from pqc_enclave_sdk import (
14 ArtifactKind,
15 EnclaveVault,
16 InMemoryEnclaveBackend,
17 )
18
19
20 def main() -> None:
21 weights = os.urandom(256 * 1024) # 256 KB simulated INT4 weights
22 backend = InMemoryEnclaveBackend(
23 device_id="iphone-alice-demo",
24 device_model="iphone-15-pro",
25 )
26
27 vault = EnclaveVault(backend=backend)
28 vault.unlock()
29 print(f"[unlock] vault unlocked. key_id={vault._key_id}")
30
31 enc = vault.put_artifact(
32 name="llama-3.2-1b-int4",
33 kind=ArtifactKind.MODEL_WEIGHTS,
34 content=weights,
35 version="1.0.0",
36 app_bundle_id="com.example.localllm",
37 tags=("prod", "int4"),
38 description="Llama 3.2 1B INT4 weights for on-device inference.",
39 )
40 print(
41 f"[put] artifact_id={enc.metadata.artifact_id} "
42 f"size={enc.metadata.size_bytes} bytes "
43 f"sha3={enc.content_hash[:16]}..."
44 )
45
46 vault.save()
47 print("[save] encrypted store persisted to backend")
48
49 # Preserve the session key so a fresh vault unlock can still decrypt
50 # the persisted artifact. In a real deployment the enclave holds the
51 # wrapping KEK and re-derives the same session key on next unlock.
52 saved_key = vault._symmetric_key
53 saved_key_id = vault._key_id
54 saved_exp = vault._expires_at
55
56 vault.lock()
57 print("[lock] vault sealed")
58
59 vault2 = EnclaveVault(backend=backend)
60 vault2.unlock()
61 vault2._symmetric_key = saved_key
62 vault2._key_id = saved_key_id
63 vault2._expires_at = saved_exp
64 vault2._store = backend.load_artifacts()
65 print("[unlock] fresh vault over same backend")
66
67 art = vault2.get_artifact("llama-3.2-1b-int4")
68 assert art.content == weights
69 print(f"[get] decrypted {len(art.content)} bytes, match={art.content == weights}")
70
71
72 if __name__ == "__main__":
73 main()
74