src/pqc_bootloader/firmware.py
5.0 KB · 155 lines · python Raw
1 """Firmware image data structures."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import json
7 from dataclasses import asdict, dataclass
8 from enum import Enum
9 from typing import Any
10
11
12 class TargetDevice(str, Enum):
13 """Appliance families this firmware targets."""
14
15 AI_INFERENCE_APPLIANCE = "ai-inference-appliance"
16 MEDICAL_DIAGNOSTIC = "medical-diagnostic"
17 INDUSTRIAL_CONTROL = "industrial-control"
18 EDGE_GATEWAY = "edge-gateway"
19 MILITARY_EMBEDDED = "military-embedded"
20 UNKNOWN = "unknown"
21
22
23 @dataclass(frozen=True)
24 class FirmwareMetadata:
25 """Non-binary metadata describing a firmware image."""
26
27 name: str
28 version: str
29 target: TargetDevice
30 kernel_version: str = ""
31 architecture: str = "x86_64" # x86_64 | arm64 | riscv64 | ...
32 build_id: str = "" # git SHA, CI build id, etc.
33 release_notes_url: str = ""
34 min_hardware_revision: str = ""
35 security_level: str = "production" # production | development | debug
36
37 def to_dict(self) -> dict[str, Any]:
38 d = asdict(self)
39 d["target"] = self.target.value
40 return d
41
42
43 @dataclass
44 class FirmwareImage:
45 """Raw firmware bytes + metadata + SHA3-256 hash."""
46
47 metadata: FirmwareMetadata
48 image_bytes: bytes
49 image_hash: str = "" # hex SHA3-256
50 image_size: int = 0
51
52 @staticmethod
53 def hash_bytes(data: bytes) -> str:
54 return hashlib.sha3_256(data).hexdigest()
55
56 @classmethod
57 def from_bytes(cls, metadata: FirmwareMetadata, data: bytes) -> FirmwareImage:
58 return cls(
59 metadata=metadata,
60 image_bytes=data,
61 image_hash=cls.hash_bytes(data),
62 image_size=len(data),
63 )
64
65 @classmethod
66 def from_file(cls, metadata: FirmwareMetadata, path: str) -> FirmwareImage:
67 with open(path, "rb") as f:
68 data = f.read()
69 return cls.from_bytes(metadata, data)
70
71 def canonical_manifest_bytes(self) -> bytes:
72 """Bytes signed by the manufacturer (metadata + hash, NOT the image)."""
73 payload = {
74 "metadata": self.metadata.to_dict(),
75 "image_hash": self.image_hash,
76 "image_size": self.image_size,
77 }
78 return json.dumps(
79 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
80 ).encode("utf-8")
81
82 def to_dict(self, include_image: bool = False) -> dict[str, Any]:
83 d: dict[str, Any] = {
84 "metadata": self.metadata.to_dict(),
85 "image_hash": self.image_hash,
86 "image_size": self.image_size,
87 }
88 if include_image:
89 import base64
90
91 d["image_base64"] = base64.b64encode(self.image_bytes).decode("ascii")
92 return d
93
94
95 @dataclass
96 class SignedFirmware:
97 """Firmware image + ML-DSA signature envelope."""
98
99 firmware: FirmwareImage
100 manufacturer_key_id: str # fingerprint of the manufacturer public key
101 signer_did: str
102 algorithm: str
103 signature: str # hex
104 public_key: str # hex
105 signed_at: str
106 previous_firmware_hash: str = "" # for update-chain continuity
107
108 def to_dict(self, include_image: bool = True) -> dict[str, Any]:
109 return {
110 "firmware": self.firmware.to_dict(include_image=include_image),
111 "manufacturer_key_id": self.manufacturer_key_id,
112 "signer_did": self.signer_did,
113 "algorithm": self.algorithm,
114 "signature": self.signature,
115 "public_key": self.public_key,
116 "signed_at": self.signed_at,
117 "previous_firmware_hash": self.previous_firmware_hash,
118 }
119
120 @classmethod
121 def from_dict(cls, data: dict[str, Any]) -> SignedFirmware:
122 import base64
123
124 fw = data["firmware"]
125 meta = fw["metadata"]
126 image_bytes = b""
127 if "image_base64" in fw:
128 image_bytes = base64.b64decode(fw["image_base64"])
129 firmware = FirmwareImage(
130 metadata=FirmwareMetadata(
131 name=meta["name"],
132 version=meta["version"],
133 target=TargetDevice(meta.get("target", "unknown")),
134 kernel_version=meta.get("kernel_version", ""),
135 architecture=meta.get("architecture", "x86_64"),
136 build_id=meta.get("build_id", ""),
137 release_notes_url=meta.get("release_notes_url", ""),
138 min_hardware_revision=meta.get("min_hardware_revision", ""),
139 security_level=meta.get("security_level", "production"),
140 ),
141 image_bytes=image_bytes,
142 image_hash=fw.get("image_hash", ""),
143 image_size=int(fw.get("image_size", 0)),
144 )
145 return cls(
146 firmware=firmware,
147 manufacturer_key_id=data["manufacturer_key_id"],
148 signer_did=data["signer_did"],
149 algorithm=data["algorithm"],
150 signature=data["signature"],
151 public_key=data["public_key"],
152 signed_at=data["signed_at"],
153 previous_firmware_hash=data.get("previous_firmware_hash", ""),
154 )
155