tests/test_update_chain.py
2.4 KB · 79 lines · python Raw
1 """Tests for UpdateChain."""
2
3 from __future__ import annotations
4
5 import pytest
6
7 from pqc_bootloader.errors import FirmwareRollbackError, UpdateChainError
8 from pqc_bootloader.firmware import FirmwareImage, FirmwareMetadata, TargetDevice
9 from pqc_bootloader.signer import FirmwareSigner
10 from pqc_bootloader.update_chain import UpdateChain
11
12
13 def _fw(version: str, payload: bytes, name: str = "acme-inference-os") -> FirmwareImage:
14 meta = FirmwareMetadata(
15 name=name,
16 version=version,
17 target=TargetDevice.AI_INFERENCE_APPLIANCE,
18 )
19 return FirmwareImage.from_bytes(meta, payload)
20
21
22 def test_add_first_link_ok(firmware_signer: FirmwareSigner) -> None:
23 chain = UpdateChain()
24 fw = _fw("1.0.0", b"v1 payload")
25 signed = firmware_signer.sign(fw)
26 chain.add(signed)
27 assert chain.current() is signed
28 ok, errors = chain.verify_chain()
29 assert ok and errors == []
30
31
32 def test_second_link_verifies_previous_hash(firmware_signer: FirmwareSigner) -> None:
33 chain = UpdateChain()
34 v1 = firmware_signer.sign(_fw("1.0.0", b"v1 payload"))
35 v2 = firmware_signer.sign(
36 _fw("1.1.0", b"v2 payload"),
37 previous_firmware_hash=v1.firmware.image_hash,
38 )
39 chain.add(v1)
40 chain.add(v2)
41 ok, errors = chain.verify_chain()
42 assert ok and errors == []
43
44
45 def test_mismatched_previous_hash_raises(firmware_signer: FirmwareSigner) -> None:
46 chain = UpdateChain()
47 v1 = firmware_signer.sign(_fw("1.0.0", b"v1 payload"))
48 v2 = firmware_signer.sign(
49 _fw("1.1.0", b"v2 payload"),
50 previous_firmware_hash="de" * 32, # wrong
51 )
52 chain.add(v1)
53 with pytest.raises(UpdateChainError):
54 chain.add(v2)
55
56
57 def test_rollback_blocked_by_default(firmware_signer: FirmwareSigner) -> None:
58 chain = UpdateChain()
59 v1 = firmware_signer.sign(_fw("1.0.0", b"v1 payload"))
60 v0 = firmware_signer.sign(
61 _fw("0.9.0", b"v0 payload"),
62 previous_firmware_hash=v1.firmware.image_hash,
63 )
64 chain.add(v1)
65 with pytest.raises(FirmwareRollbackError):
66 chain.add(v0)
67
68
69 def test_rollback_allowed_when_flag_set(firmware_signer: FirmwareSigner) -> None:
70 chain = UpdateChain()
71 v1 = firmware_signer.sign(_fw("1.0.0", b"v1 payload"))
72 v0 = firmware_signer.sign(
73 _fw("0.9.0", b"v0 payload"),
74 previous_firmware_hash=v1.firmware.image_hash,
75 )
76 chain.add(v1)
77 chain.add(v0, allow_rollback=True)
78 assert chain.current() is v0
79