src/pqc_mbom/component.py
| 1 | """AI model component types and data structures.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | import json |
| 7 | from dataclasses import dataclass, field, asdict |
| 8 | from enum import Enum |
| 9 | from typing import Any |
| 10 | |
| 11 | |
| 12 | class ComponentType(str, Enum): |
| 13 | """Categories of AI model components tracked by an MBOM.""" |
| 14 | BASE_ARCHITECTURE = "base-architecture" # e.g. Llama-3 8B architecture |
| 15 | WEIGHTS = "weights" # serialized weights file |
| 16 | TRAINING_DATA = "training-data" # raw training dataset |
| 17 | FINE_TUNING_DATA = "fine-tuning-data" |
| 18 | RLHF_DATA = "rlhf-data" # human feedback dataset |
| 19 | EVALUATION_BENCHMARK = "evaluation-benchmark" |
| 20 | TOKENIZER = "tokenizer" |
| 21 | QUANTIZATION_METHOD = "quantization-method" |
| 22 | CODE = "code" # training/inference code |
| 23 | CONFIG = "config" # config files (JSON/YAML) |
| 24 | ADAPTER = "adapter" # LoRA/QLoRA adapter weights |
| 25 | SAFETY_MODEL = "safety-model" # content filter / classifier |
| 26 | OTHER = "other" |
| 27 | |
| 28 | |
| 29 | @dataclass(frozen=True) |
| 30 | class LicenseInfo: |
| 31 | """License declaration for a component.""" |
| 32 | spdx_id: str = "" # e.g. "apache-2.0", "cc-by-4.0" |
| 33 | name: str = "" # human-readable name |
| 34 | url: str = "" # link to license text |
| 35 | commercial_use: bool = False |
| 36 | attribution_required: bool = True |
| 37 | |
| 38 | def to_dict(self) -> dict[str, Any]: |
| 39 | return asdict(self) |
| 40 | |
| 41 | |
| 42 | @dataclass(frozen=True) |
| 43 | class ComponentReference: |
| 44 | """A reference/link from one component to another (dependency).""" |
| 45 | component_id: str |
| 46 | relationship: str # "depends-on" | "derived-from" | "contains" |
| 47 | |
| 48 | def to_dict(self) -> dict[str, Any]: |
| 49 | return asdict(self) |
| 50 | |
| 51 | |
| 52 | @dataclass |
| 53 | class ModelComponent: |
| 54 | """One entry in the MBOM. |
| 55 | |
| 56 | content_hash is SHA3-256 over the bytes the component represents (weights, |
| 57 | data files, code, etc.). For pointer-only components (e.g. a published |
| 58 | dataset referenced by URL), you can supply content_hash of the declared |
| 59 | manifest and set external_url. |
| 60 | """ |
| 61 | component_id: str # stable UUID or slug |
| 62 | component_type: ComponentType |
| 63 | name: str |
| 64 | version: str = "" |
| 65 | content_hash: str = "" # hex SHA3-256 |
| 66 | content_size: int = 0 # bytes; 0 = unknown |
| 67 | supplier: str = "" # organization |
| 68 | author: str = "" # person |
| 69 | external_url: str = "" # where to fetch (optional) |
| 70 | license: LicenseInfo = field(default_factory=LicenseInfo) |
| 71 | references: list[ComponentReference] = field(default_factory=list) |
| 72 | properties: dict[str, str] = field(default_factory=dict) # arbitrary extras |
| 73 | |
| 74 | @staticmethod |
| 75 | def hash_content(content: bytes) -> str: |
| 76 | return hashlib.sha3_256(content).hexdigest() |
| 77 | |
| 78 | def canonical_bytes(self) -> bytes: |
| 79 | payload = { |
| 80 | "component_id": self.component_id, |
| 81 | "component_type": self.component_type.value, |
| 82 | "name": self.name, |
| 83 | "version": self.version, |
| 84 | "content_hash": self.content_hash, |
| 85 | "content_size": self.content_size, |
| 86 | "supplier": self.supplier, |
| 87 | "author": self.author, |
| 88 | "external_url": self.external_url, |
| 89 | "license": self.license.to_dict(), |
| 90 | "references": [r.to_dict() for r in self.references], |
| 91 | "properties": self.properties, |
| 92 | } |
| 93 | return json.dumps( |
| 94 | payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False |
| 95 | ).encode("utf-8") |
| 96 | |
| 97 | def hash(self) -> str: |
| 98 | return hashlib.sha3_256(self.canonical_bytes()).hexdigest() |
| 99 | |
| 100 | def to_dict(self) -> dict[str, Any]: |
| 101 | return { |
| 102 | "component_id": self.component_id, |
| 103 | "component_type": self.component_type.value, |
| 104 | "name": self.name, |
| 105 | "version": self.version, |
| 106 | "content_hash": self.content_hash, |
| 107 | "content_size": self.content_size, |
| 108 | "supplier": self.supplier, |
| 109 | "author": self.author, |
| 110 | "external_url": self.external_url, |
| 111 | "license": self.license.to_dict(), |
| 112 | "references": [r.to_dict() for r in self.references], |
| 113 | "properties": dict(self.properties), |
| 114 | } |
| 115 | |
| 116 | @classmethod |
| 117 | def from_dict(cls, data: dict[str, Any]) -> ModelComponent: |
| 118 | lic = data.get("license", {}) |
| 119 | return cls( |
| 120 | component_id=data["component_id"], |
| 121 | component_type=ComponentType(data["component_type"]), |
| 122 | name=data["name"], |
| 123 | version=data.get("version", ""), |
| 124 | content_hash=data.get("content_hash", ""), |
| 125 | content_size=int(data.get("content_size", 0)), |
| 126 | supplier=data.get("supplier", ""), |
| 127 | author=data.get("author", ""), |
| 128 | external_url=data.get("external_url", ""), |
| 129 | license=LicenseInfo( |
| 130 | spdx_id=lic.get("spdx_id", ""), |
| 131 | name=lic.get("name", ""), |
| 132 | url=lic.get("url", ""), |
| 133 | commercial_use=bool(lic.get("commercial_use", False)), |
| 134 | attribution_required=bool(lic.get("attribution_required", True)), |
| 135 | ), |
| 136 | references=[ComponentReference(**r) for r in data.get("references", [])], |
| 137 | properties=dict(data.get("properties", {})), |
| 138 | ) |
| 139 | |