tests/test_vault.py
4.2 KB · 141 lines · python Raw
1 """Tests for EnclaveVault lifecycle and CRUD."""
2
3 from __future__ import annotations
4
5 import pytest
6
7 from pqc_enclave_sdk import (
8 ArtifactKind,
9 EnclaveLockedError,
10 EnclaveVault,
11 InMemoryEnclaveBackend,
12 )
13
14
15 def test_unlock_sets_is_unlocked(backend: InMemoryEnclaveBackend) -> None:
16 v = EnclaveVault(backend=backend)
17 assert not v.is_unlocked
18 v.unlock()
19 assert v.is_unlocked
20
21
22 def test_lock_clears_key(vault: EnclaveVault) -> None:
23 assert vault.is_unlocked
24 vault.lock()
25 assert not vault.is_unlocked
26
27
28 def test_put_requires_unlock(backend: InMemoryEnclaveBackend) -> None:
29 v = EnclaveVault(backend=backend)
30 with pytest.raises(EnclaveLockedError):
31 v.put_artifact(
32 name="foo", kind=ArtifactKind.CREDENTIAL, content=b"secret"
33 )
34
35
36 def test_get_requires_unlock(
37 vault: EnclaveVault, api_credential: bytes
38 ) -> None:
39 vault.put_artifact(
40 name="openai", kind=ArtifactKind.CREDENTIAL, content=api_credential
41 )
42 vault.lock()
43 with pytest.raises(EnclaveLockedError):
44 vault.get_artifact("openai")
45
46
47 def test_put_get_roundtrip_preserves_content(
48 vault: EnclaveVault, small_weights: bytes
49 ) -> None:
50 vault.put_artifact(
51 name="tiny-weights",
52 kind=ArtifactKind.MODEL_WEIGHTS,
53 content=small_weights,
54 )
55 got = vault.get_artifact("tiny-weights")
56 assert got.content == small_weights
57 assert got.metadata.kind == ArtifactKind.MODEL_WEIGHTS
58
59
60 def test_put_by_kind_tagged_correctly(vault: EnclaveVault) -> None:
61 enc = vault.put_artifact(
62 name="lora-x",
63 kind=ArtifactKind.LORA_ADAPTER,
64 content=b"adapter-bytes",
65 )
66 assert enc.metadata.kind == ArtifactKind.LORA_ADAPTER
67 assert enc.metadata.name == "lora-x"
68 assert enc.metadata.size_bytes == len(b"adapter-bytes")
69
70
71 def test_get_by_name_works(
72 vault: EnclaveVault, api_credential: bytes
73 ) -> None:
74 enc = vault.put_artifact(
75 name="stripe-key",
76 kind=ArtifactKind.CREDENTIAL,
77 content=api_credential,
78 )
79 got_by_name = vault.get_artifact("stripe-key")
80 got_by_id = vault.get_artifact(enc.metadata.artifact_id)
81 assert got_by_name.content == api_credential
82 assert got_by_id.content == api_credential
83
84
85 def test_delete_removes_both_id_and_name_entries(
86 vault: EnclaveVault,
87 ) -> None:
88 enc = vault.put_artifact(
89 name="temp", kind=ArtifactKind.OTHER, content=b"temp-bytes"
90 )
91 aid = enc.metadata.artifact_id
92 vault.delete_artifact("temp")
93 # Neither name nor id should resolve now.
94 from pqc_enclave_sdk import UnknownArtifactError
95
96 with pytest.raises(UnknownArtifactError):
97 vault.get_artifact("temp")
98 with pytest.raises(UnknownArtifactError):
99 vault.get_artifact(aid)
100
101
102 def test_list_artifacts_returns_unique_metadata(vault: EnclaveVault) -> None:
103 vault.put_artifact(name="a", kind=ArtifactKind.CREDENTIAL, content=b"1")
104 vault.put_artifact(name="b", kind=ArtifactKind.CREDENTIAL, content=b"2")
105 vault.put_artifact(name="c", kind=ArtifactKind.TOKENIZER, content=b"3")
106 metas = vault.list_artifacts()
107 names = {m.name for m in metas}
108 assert names == {"a", "b", "c"}
109 # Ensure no duplicates even though the internal store double-indexes.
110 ids = [m.artifact_id for m in metas]
111 assert len(ids) == len(set(ids))
112
113
114 def test_save_and_reload_via_backend_preserves_artifacts(
115 backend: InMemoryEnclaveBackend, small_weights: bytes
116 ) -> None:
117 v1 = EnclaveVault(backend=backend)
118 v1.unlock()
119 v1.put_artifact(
120 name="weights-1",
121 kind=ArtifactKind.MODEL_WEIGHTS,
122 content=small_weights,
123 )
124 v1.save()
125
126 # Fresh vault over the same backend - should load the persisted artifacts.
127 v2 = EnclaveVault(backend=backend)
128 # A brand new unlock yields a new session key - content decryption
129 # therefore requires the original key. Demonstrate the save/load
130 # pipeline shuttles EncryptedArtifact objects intact.
131 loaded = backend.load_artifacts()
132 assert any(
133 enc.metadata.name == "weights-1" for enc in loaded.values()
134 )
135 # And v2.unlock() populates its own store from the backend.
136 v2.unlock()
137 names_in_store = {
138 enc.metadata.name for enc in v2._store.values()
139 }
140 assert "weights-1" in names_in_store
141