tests/test_mbom.py
3.1 KB · 88 lines · python Raw
1 """Tests for MBOM and MBOMBuilder."""
2
3 from __future__ import annotations
4
5 import pytest
6
7 from pqc_mbom import (
8 ComponentType,
9 MBOM,
10 MBOMBuilder,
11 ModelComponent,
12 )
13 from pqc_mbom.errors import MissingComponentError
14
15
16 def test_builder_populates_fields() -> None:
17 builder = MBOMBuilder("Mixtral-8x22B", "1.0", supplier="Mistral")
18 builder.set_description("MoE model")
19 builder.add_base_architecture("mixtral", version="1.0", content_hash="a" * 64)
20 builder.add_weights("model.safetensors", content_hash="b" * 64, content_size=100)
21 mbom = builder.build()
22 assert mbom.model_name == "Mixtral-8x22B"
23 assert mbom.model_version == "1.0"
24 assert mbom.supplier == "Mistral"
25 assert mbom.description == "MoE model"
26 assert len(mbom.components) == 2
27 assert mbom.components[0].component_type == ComponentType.BASE_ARCHITECTURE
28 assert mbom.components[1].component_type == ComponentType.WEIGHTS
29 assert mbom.mbom_id.startswith("urn:pqc-mbom:")
30 assert mbom.components_root_hash # recomputed on build
31 assert mbom.created_at
32
33
34 def test_components_root_hash_is_deterministic(sample_mbom: MBOM) -> None:
35 first = sample_mbom.components_root_hash
36 second = sample_mbom.recompute_root()
37 assert first == second
38
39
40 def test_components_root_hash_changes_with_component_changes(sample_mbom: MBOM) -> None:
41 original_root = sample_mbom.components_root_hash
42 sample_mbom.components[0].version = "3.1"
43 sample_mbom.recompute_root()
44 assert sample_mbom.components_root_hash != original_root
45
46
47 def test_get_component_missing_raises(sample_mbom: MBOM) -> None:
48 # Existing should return component
49 got = sample_mbom.get_component("weights-0001")
50 assert got.name == "llama3-8b.safetensors"
51
52 with pytest.raises(MissingComponentError):
53 sample_mbom.get_component("not-in-mbom")
54
55
56 def test_components_by_type_filters(sample_mbom: MBOM) -> None:
57 weights = sample_mbom.components_by_type(ComponentType.WEIGHTS)
58 training = sample_mbom.components_by_type(ComponentType.TRAINING_DATA)
59 safety = sample_mbom.components_by_type(ComponentType.SAFETY_MODEL)
60 assert len(weights) == 1 and weights[0].name == "llama3-8b.safetensors"
61 assert len(training) == 1
62 assert safety == []
63
64
65 def test_to_json_from_json_roundtrip(sample_mbom: MBOM) -> None:
66 blob = sample_mbom.to_json()
67 restored = MBOM.from_json(blob)
68 assert restored.mbom_id == sample_mbom.mbom_id
69 assert restored.model_name == sample_mbom.model_name
70 assert len(restored.components) == len(sample_mbom.components)
71 assert restored.components_root_hash == sample_mbom.components_root_hash
72 # component-level integrity preserved
73 for orig, new in zip(sample_mbom.components, restored.components):
74 assert orig.hash() == new.hash()
75
76
77 def test_add_component_accepts_custom_component() -> None:
78 builder = MBOMBuilder("Custom", "1")
79 custom = ModelComponent(
80 component_id="custom-1",
81 component_type=ComponentType.ADAPTER,
82 name="lora-adapter",
83 content_hash="d" * 64,
84 )
85 builder.add_component(custom)
86 mbom = builder.build()
87 assert mbom.components[0].component_type == ComponentType.ADAPTER
88