src/pqc_mbom/mbom.py
9.3 KB · 252 lines · python Raw
1 """MBOM - the signed bill of materials for an AI model."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import json
7 import uuid
8 from dataclasses import dataclass, field
9 from datetime import datetime, timezone
10 from typing import Any
11
12 from pqc_mbom.component import ModelComponent, ComponentType
13 from pqc_mbom.errors import InvalidMBOMError, MissingComponentError
14
15
16 SCHEMA_VERSION = "1.0"
17
18
19 @dataclass
20 class MBOM:
21 """A Model Bill of Materials.
22
23 Contains the model's own identity (name, version, supplier) and an
24 enumeration of ModelComponents with hashes. The MBOM as a whole is
25 signed separately via MBOMSigner.
26 """
27 mbom_id: str
28 schema_version: str
29 model_name: str
30 model_version: str
31 supplier: str = ""
32 description: str = ""
33 components: list[ModelComponent] = field(default_factory=list)
34 created_at: str = ""
35 components_root_hash: str = "" # SHA3-256 over sorted component hashes
36
37 # Set by MBOMSigner.sign
38 signer_did: str = ""
39 algorithm: str = ""
40 signature: str = ""
41 public_key: str = ""
42 signed_at: str = ""
43
44 @classmethod
45 def create(
46 cls,
47 model_name: str,
48 model_version: str,
49 supplier: str = "",
50 description: str = "",
51 components: list[ModelComponent] | None = None,
52 ) -> MBOM:
53 m = cls(
54 mbom_id=f"urn:pqc-mbom:{uuid.uuid4().hex}",
55 schema_version=SCHEMA_VERSION,
56 model_name=model_name,
57 model_version=model_version,
58 supplier=supplier,
59 description=description,
60 components=list(components or []),
61 created_at=datetime.now(timezone.utc).isoformat(),
62 )
63 m.recompute_root()
64 return m
65
66 def recompute_root(self) -> str:
67 component_hashes = sorted(c.hash() for c in self.components)
68 concat = "|".join(component_hashes).encode("utf-8")
69 self.components_root_hash = hashlib.sha3_256(concat).hexdigest()
70 return self.components_root_hash
71
72 def get_component(self, component_id: str) -> ModelComponent:
73 for c in self.components:
74 if c.component_id == component_id:
75 return c
76 raise MissingComponentError(f"no component with id '{component_id}'")
77
78 def components_by_type(self, ctype: ComponentType) -> list[ModelComponent]:
79 return [c for c in self.components if c.component_type == ctype]
80
81 def canonical_bytes(self) -> bytes:
82 payload = {
83 "mbom_id": self.mbom_id,
84 "schema_version": self.schema_version,
85 "model_name": self.model_name,
86 "model_version": self.model_version,
87 "supplier": self.supplier,
88 "description": self.description,
89 "components": [c.to_dict() for c in self.components],
90 "created_at": self.created_at,
91 "components_root_hash": self.components_root_hash,
92 }
93 return json.dumps(
94 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
95 ).encode("utf-8")
96
97 def to_dict(self) -> dict[str, Any]:
98 return {
99 "mbom_id": self.mbom_id,
100 "schema_version": self.schema_version,
101 "model_name": self.model_name,
102 "model_version": self.model_version,
103 "supplier": self.supplier,
104 "description": self.description,
105 "components": [c.to_dict() for c in self.components],
106 "created_at": self.created_at,
107 "components_root_hash": self.components_root_hash,
108 "signer_did": self.signer_did,
109 "algorithm": self.algorithm,
110 "signature": self.signature,
111 "public_key": self.public_key,
112 "signed_at": self.signed_at,
113 }
114
115 def to_json(self) -> str:
116 return json.dumps(self.to_dict(), indent=2)
117
118 @classmethod
119 def from_dict(cls, data: dict[str, Any]) -> MBOM:
120 try:
121 return cls(
122 mbom_id=data["mbom_id"],
123 schema_version=data.get("schema_version", SCHEMA_VERSION),
124 model_name=data["model_name"],
125 model_version=data["model_version"],
126 supplier=data.get("supplier", ""),
127 description=data.get("description", ""),
128 components=[ModelComponent.from_dict(c) for c in data.get("components", [])],
129 created_at=data.get("created_at", ""),
130 components_root_hash=data.get("components_root_hash", ""),
131 signer_did=data.get("signer_did", ""),
132 algorithm=data.get("algorithm", ""),
133 signature=data.get("signature", ""),
134 public_key=data.get("public_key", ""),
135 signed_at=data.get("signed_at", ""),
136 )
137 except KeyError as e:
138 raise InvalidMBOMError(f"missing required field: {e}") from e
139
140 @classmethod
141 def from_json(cls, blob: str) -> MBOM:
142 try:
143 return cls.from_dict(json.loads(blob))
144 except json.JSONDecodeError as e:
145 raise InvalidMBOMError(f"invalid JSON: {e}") from e
146
147
148 class MBOMBuilder:
149 """Fluent builder for MBOMs.
150
151 Usage:
152 builder = MBOMBuilder("Llama-3-8B-Instruct", "1.0", supplier="Meta")
153 builder.add_base_architecture("Llama-3", version="3.0", content_hash=...)
154 builder.add_training_data("common-crawl-2024", content_hash=..., size=1_000_000_000_000)
155 builder.add_fine_tuning_data("instruct-v1", content_hash=...)
156 builder.add_rlhf_data("hh-rlhf", content_hash=...)
157 builder.add_tokenizer("Llama-3-tokenizer", content_hash=...)
158 builder.add_weights("model.safetensors", content_hash=..., size=16_000_000_000)
159 mbom = builder.build()
160 """
161
162 def __init__(self, model_name: str, model_version: str, supplier: str = ""):
163 self.model_name = model_name
164 self.model_version = model_version
165 self.supplier = supplier
166 self.description = ""
167 self.components: list[ModelComponent] = []
168
169 def _component_id(self, name: str) -> str:
170 return f"{name.lower().replace(' ', '-')}-{uuid.uuid4().hex[:8]}"
171
172 def add_component(self, component: ModelComponent) -> MBOMBuilder:
173 self.components.append(component)
174 return self
175
176 def add_base_architecture(
177 self, name: str, version: str = "", content_hash: str = "", **kwargs: Any
178 ) -> MBOMBuilder:
179 return self.add_component(ModelComponent(
180 component_id=self._component_id(name),
181 component_type=ComponentType.BASE_ARCHITECTURE,
182 name=name, version=version, content_hash=content_hash, **kwargs,
183 ))
184
185 def add_weights(
186 self, name: str, content_hash: str = "", content_size: int = 0, **kwargs: Any
187 ) -> MBOMBuilder:
188 return self.add_component(ModelComponent(
189 component_id=self._component_id(name),
190 component_type=ComponentType.WEIGHTS,
191 name=name, content_hash=content_hash, content_size=content_size, **kwargs,
192 ))
193
194 def add_training_data(
195 self, name: str, content_hash: str = "", content_size: int = 0, **kwargs: Any
196 ) -> MBOMBuilder:
197 return self.add_component(ModelComponent(
198 component_id=self._component_id(name),
199 component_type=ComponentType.TRAINING_DATA,
200 name=name, content_hash=content_hash, content_size=content_size, **kwargs,
201 ))
202
203 def add_fine_tuning_data(
204 self, name: str, content_hash: str = "", **kwargs: Any
205 ) -> MBOMBuilder:
206 return self.add_component(ModelComponent(
207 component_id=self._component_id(name),
208 component_type=ComponentType.FINE_TUNING_DATA,
209 name=name, content_hash=content_hash, **kwargs,
210 ))
211
212 def add_rlhf_data(self, name: str, content_hash: str = "", **kwargs: Any) -> MBOMBuilder:
213 return self.add_component(ModelComponent(
214 component_id=self._component_id(name),
215 component_type=ComponentType.RLHF_DATA,
216 name=name, content_hash=content_hash, **kwargs,
217 ))
218
219 def add_tokenizer(self, name: str, content_hash: str = "", **kwargs: Any) -> MBOMBuilder:
220 return self.add_component(ModelComponent(
221 component_id=self._component_id(name),
222 component_type=ComponentType.TOKENIZER,
223 name=name, content_hash=content_hash, **kwargs,
224 ))
225
226 def add_quantization(self, name: str, **kwargs: Any) -> MBOMBuilder:
227 return self.add_component(ModelComponent(
228 component_id=self._component_id(name),
229 component_type=ComponentType.QUANTIZATION_METHOD,
230 name=name, **kwargs,
231 ))
232
233 def add_evaluation(self, name: str, content_hash: str = "", **kwargs: Any) -> MBOMBuilder:
234 return self.add_component(ModelComponent(
235 component_id=self._component_id(name),
236 component_type=ComponentType.EVALUATION_BENCHMARK,
237 name=name, content_hash=content_hash, **kwargs,
238 ))
239
240 def set_description(self, description: str) -> MBOMBuilder:
241 self.description = description
242 return self
243
244 def build(self) -> MBOM:
245 return MBOM.create(
246 model_name=self.model_name,
247 model_version=self.model_version,
248 supplier=self.supplier,
249 description=self.description,
250 components=list(self.components),
251 )
252