src/pqc_bootloader/measured_boot.py
| 1 | """MeasuredBoot - TPM/PCR-style chain of boot-stage hashes.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | from dataclasses import dataclass, field |
| 7 | from datetime import datetime, timezone |
| 8 | from enum import Enum |
| 9 | |
| 10 | |
| 11 | class BootStage(str, Enum): |
| 12 | ROM = "rom" |
| 13 | BOOTLOADER = "bootloader" |
| 14 | KERNEL = "kernel" |
| 15 | INITRD = "initrd" |
| 16 | USERSPACE = "userspace" |
| 17 | MODEL_WEIGHTS = "model-weights" |
| 18 | |
| 19 | |
| 20 | @dataclass(frozen=True) |
| 21 | class PCRMeasurement: |
| 22 | """One measurement: stage label + SHA3-256 of the measured bytes.""" |
| 23 | |
| 24 | stage: BootStage |
| 25 | measured_hash: str # hex |
| 26 | measured_at: str |
| 27 | |
| 28 | |
| 29 | @dataclass |
| 30 | class MeasuredBoot: |
| 31 | """TPM-like measurement chain. |
| 32 | |
| 33 | PCR update formula: new_pcr = SHA3-256(old_pcr || measurement). |
| 34 | Every stage is extended into the PCR in order; tampering with any stage |
| 35 | produces a different final PCR value. |
| 36 | """ |
| 37 | |
| 38 | pcr_value: str = "0" * 64 |
| 39 | measurements: list[PCRMeasurement] = field(default_factory=list) |
| 40 | |
| 41 | def extend(self, stage: BootStage, content: bytes) -> str: |
| 42 | h = hashlib.sha3_256(content).hexdigest() |
| 43 | combined = bytes.fromhex(self.pcr_value) + bytes.fromhex(h) |
| 44 | self.pcr_value = hashlib.sha3_256(combined).hexdigest() |
| 45 | self.measurements.append( |
| 46 | PCRMeasurement( |
| 47 | stage=stage, |
| 48 | measured_hash=h, |
| 49 | measured_at=datetime.now(timezone.utc).isoformat(), |
| 50 | ) |
| 51 | ) |
| 52 | return self.pcr_value |
| 53 | |
| 54 | def reset(self) -> None: |
| 55 | self.pcr_value = "0" * 64 |
| 56 | self.measurements.clear() |
| 57 | |