src/pqc_mbom/component.py
5.4 KB · 139 lines · python Raw
1 """AI model component types and data structures."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import json
7 from dataclasses import dataclass, field, asdict
8 from enum import Enum
9 from typing import Any
10
11
12 class ComponentType(str, Enum):
13 """Categories of AI model components tracked by an MBOM."""
14 BASE_ARCHITECTURE = "base-architecture" # e.g. Llama-3 8B architecture
15 WEIGHTS = "weights" # serialized weights file
16 TRAINING_DATA = "training-data" # raw training dataset
17 FINE_TUNING_DATA = "fine-tuning-data"
18 RLHF_DATA = "rlhf-data" # human feedback dataset
19 EVALUATION_BENCHMARK = "evaluation-benchmark"
20 TOKENIZER = "tokenizer"
21 QUANTIZATION_METHOD = "quantization-method"
22 CODE = "code" # training/inference code
23 CONFIG = "config" # config files (JSON/YAML)
24 ADAPTER = "adapter" # LoRA/QLoRA adapter weights
25 SAFETY_MODEL = "safety-model" # content filter / classifier
26 OTHER = "other"
27
28
29 @dataclass(frozen=True)
30 class LicenseInfo:
31 """License declaration for a component."""
32 spdx_id: str = "" # e.g. "apache-2.0", "cc-by-4.0"
33 name: str = "" # human-readable name
34 url: str = "" # link to license text
35 commercial_use: bool = False
36 attribution_required: bool = True
37
38 def to_dict(self) -> dict[str, Any]:
39 return asdict(self)
40
41
42 @dataclass(frozen=True)
43 class ComponentReference:
44 """A reference/link from one component to another (dependency)."""
45 component_id: str
46 relationship: str # "depends-on" | "derived-from" | "contains"
47
48 def to_dict(self) -> dict[str, Any]:
49 return asdict(self)
50
51
52 @dataclass
53 class ModelComponent:
54 """One entry in the MBOM.
55
56 content_hash is SHA3-256 over the bytes the component represents (weights,
57 data files, code, etc.). For pointer-only components (e.g. a published
58 dataset referenced by URL), you can supply content_hash of the declared
59 manifest and set external_url.
60 """
61 component_id: str # stable UUID or slug
62 component_type: ComponentType
63 name: str
64 version: str = ""
65 content_hash: str = "" # hex SHA3-256
66 content_size: int = 0 # bytes; 0 = unknown
67 supplier: str = "" # organization
68 author: str = "" # person
69 external_url: str = "" # where to fetch (optional)
70 license: LicenseInfo = field(default_factory=LicenseInfo)
71 references: list[ComponentReference] = field(default_factory=list)
72 properties: dict[str, str] = field(default_factory=dict) # arbitrary extras
73
74 @staticmethod
75 def hash_content(content: bytes) -> str:
76 return hashlib.sha3_256(content).hexdigest()
77
78 def canonical_bytes(self) -> bytes:
79 payload = {
80 "component_id": self.component_id,
81 "component_type": self.component_type.value,
82 "name": self.name,
83 "version": self.version,
84 "content_hash": self.content_hash,
85 "content_size": self.content_size,
86 "supplier": self.supplier,
87 "author": self.author,
88 "external_url": self.external_url,
89 "license": self.license.to_dict(),
90 "references": [r.to_dict() for r in self.references],
91 "properties": self.properties,
92 }
93 return json.dumps(
94 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
95 ).encode("utf-8")
96
97 def hash(self) -> str:
98 return hashlib.sha3_256(self.canonical_bytes()).hexdigest()
99
100 def to_dict(self) -> dict[str, Any]:
101 return {
102 "component_id": self.component_id,
103 "component_type": self.component_type.value,
104 "name": self.name,
105 "version": self.version,
106 "content_hash": self.content_hash,
107 "content_size": self.content_size,
108 "supplier": self.supplier,
109 "author": self.author,
110 "external_url": self.external_url,
111 "license": self.license.to_dict(),
112 "references": [r.to_dict() for r in self.references],
113 "properties": dict(self.properties),
114 }
115
116 @classmethod
117 def from_dict(cls, data: dict[str, Any]) -> ModelComponent:
118 lic = data.get("license", {})
119 return cls(
120 component_id=data["component_id"],
121 component_type=ComponentType(data["component_type"]),
122 name=data["name"],
123 version=data.get("version", ""),
124 content_hash=data.get("content_hash", ""),
125 content_size=int(data.get("content_size", 0)),
126 supplier=data.get("supplier", ""),
127 author=data.get("author", ""),
128 external_url=data.get("external_url", ""),
129 license=LicenseInfo(
130 spdx_id=lic.get("spdx_id", ""),
131 name=lic.get("name", ""),
132 url=lic.get("url", ""),
133 commercial_use=bool(lic.get("commercial_use", False)),
134 attribution_required=bool(lic.get("attribution_required", True)),
135 ),
136 references=[ComponentReference(**r) for r in data.get("references", [])],
137 properties=dict(data.get("properties", {})),
138 )
139