tests/test_component.py
2.2 KB · 64 lines · python Raw
1 """Tests for ModelComponent canonical hashing and serialization."""
2
3 from __future__ import annotations
4
5 from pqc_mbom import ComponentReference, ComponentType, LicenseInfo, ModelComponent
6
7
8 def test_hash_is_deterministic(weights_component: ModelComponent) -> None:
9 first = weights_component.hash()
10 second = weights_component.hash()
11 assert first == second
12 assert len(first) == 64
13
14
15 def test_hash_changes_with_field_change(weights_component: ModelComponent) -> None:
16 original = weights_component.hash()
17 weights_component.content_hash = "ff" * 32
18 assert weights_component.hash() != original
19
20
21 def test_roundtrip_to_dict_from_dict(weights_component: ModelComponent) -> None:
22 weights_component.properties = {"framework": "pytorch"}
23 weights_component.references = [
24 ComponentReference(component_id="base-llama3-0001", relationship="derived-from"),
25 ]
26 weights_component.license = LicenseInfo(
27 spdx_id="apache-2.0",
28 name="Apache License 2.0",
29 commercial_use=True,
30 )
31 data = weights_component.to_dict()
32 restored = ModelComponent.from_dict(data)
33 assert restored.hash() == weights_component.hash()
34 assert restored.properties == {"framework": "pytorch"}
35 assert restored.license.spdx_id == "apache-2.0"
36 assert restored.license.commercial_use is True
37 assert restored.references[0].relationship == "derived-from"
38
39
40 def test_license_defaults() -> None:
41 lic = LicenseInfo()
42 assert lic.spdx_id == ""
43 assert lic.commercial_use is False
44 assert lic.attribution_required is True
45
46
47 def test_hash_content_sha3_256() -> None:
48 digest = ModelComponent.hash_content(b"deterministic-content")
49 assert len(digest) == 64
50 assert digest == ModelComponent.hash_content(b"deterministic-content")
51 assert digest != ModelComponent.hash_content(b"other-content")
52
53
54 def test_component_type_enum_coverage() -> None:
55 # Sanity check: every declared component type is usable in canonical bytes.
56 for ctype in ComponentType:
57 c = ModelComponent(
58 component_id=f"id-{ctype.value}",
59 component_type=ctype,
60 name=f"comp-{ctype.value}",
61 )
62 # canonical_bytes should not raise, and hash should be 64 hex.
63 assert len(c.hash()) == 64
64