src/pqc_bootloader/measured_boot.py
1.5 KB · 57 lines · python Raw
1 """MeasuredBoot - TPM/PCR-style chain of boot-stage hashes."""
2
3 from __future__ import annotations
4
5 import hashlib
6 from dataclasses import dataclass, field
7 from datetime import datetime, timezone
8 from enum import Enum
9
10
11 class BootStage(str, Enum):
12 ROM = "rom"
13 BOOTLOADER = "bootloader"
14 KERNEL = "kernel"
15 INITRD = "initrd"
16 USERSPACE = "userspace"
17 MODEL_WEIGHTS = "model-weights"
18
19
20 @dataclass(frozen=True)
21 class PCRMeasurement:
22 """One measurement: stage label + SHA3-256 of the measured bytes."""
23
24 stage: BootStage
25 measured_hash: str # hex
26 measured_at: str
27
28
29 @dataclass
30 class MeasuredBoot:
31 """TPM-like measurement chain.
32
33 PCR update formula: new_pcr = SHA3-256(old_pcr || measurement).
34 Every stage is extended into the PCR in order; tampering with any stage
35 produces a different final PCR value.
36 """
37
38 pcr_value: str = "0" * 64
39 measurements: list[PCRMeasurement] = field(default_factory=list)
40
41 def extend(self, stage: BootStage, content: bytes) -> str:
42 h = hashlib.sha3_256(content).hexdigest()
43 combined = bytes.fromhex(self.pcr_value) + bytes.fromhex(h)
44 self.pcr_value = hashlib.sha3_256(combined).hexdigest()
45 self.measurements.append(
46 PCRMeasurement(
47 stage=stage,
48 measured_hash=h,
49 measured_at=datetime.now(timezone.utc).isoformat(),
50 )
51 )
52 return self.pcr_value
53
54 def reset(self) -> None:
55 self.pcr_value = "0" * 64
56 self.measurements.clear()
57