src/pqc_training_data/merkle.py
| 1 | """SHA3-256 Merkle tree over a list of leaf hashes.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | from dataclasses import asdict, dataclass, field |
| 7 | from typing import Any |
| 8 | |
| 9 | from pqc_training_data.errors import ( |
| 10 | EmptyTreeError, |
| 11 | InclusionProofError, |
| 12 | IndexOutOfRangeError, |
| 13 | ) |
| 14 | from pqc_training_data.record import RecordHash |
| 15 | |
| 16 | |
| 17 | def _pair_hash(left: bytes, right: bytes) -> bytes: |
| 18 | """Internal-node hash: SHA3-256(0x01 || left || right). |
| 19 | |
| 20 | The 0x01 prefix domain-separates internal from leaf hashes |
| 21 | (leaves are prefixed 0x00 upstream). |
| 22 | """ |
| 23 | return hashlib.sha3_256(b"\x01" + left + right).digest() |
| 24 | |
| 25 | |
| 26 | def _leaf_hash_bytes(leaf_bytes: bytes) -> bytes: |
| 27 | """Wrap a leaf hash with 0x00 domain separator.""" |
| 28 | return hashlib.sha3_256(b"\x00" + leaf_bytes).digest() |
| 29 | |
| 30 | |
| 31 | @dataclass(frozen=True) |
| 32 | class InclusionProof: |
| 33 | """Proof that a leaf is in a Merkle tree. |
| 34 | |
| 35 | `siblings` is the list of sibling hashes walking from leaf to root. |
| 36 | `directions` are 'L' or 'R' indicating which side each sibling is on. |
| 37 | `index` is the 0-based leaf index; `tree_size` is total leaf count. |
| 38 | """ |
| 39 | |
| 40 | leaf_hash: str # hex (the ORIGINAL leaf_hash, BEFORE 0x00 prefix) |
| 41 | index: int |
| 42 | tree_size: int |
| 43 | root: str # hex |
| 44 | siblings: list[str] # hex, each a sibling hash (the raw node hash) |
| 45 | directions: list[str] # 'L' or 'R' (whether sibling is on left or right of our path) |
| 46 | |
| 47 | def to_dict(self) -> dict[str, Any]: |
| 48 | return asdict(self) |
| 49 | |
| 50 | |
| 51 | @dataclass |
| 52 | class MerkleTree: |
| 53 | """SHA3-256 Merkle tree. Works for any number of leaves (>= 1). |
| 54 | |
| 55 | For odd levels, the last node is promoted (duplicated) - standard |
| 56 | RFC6962-style handling. Leaves are SHA3-256(0x00 || leaf_value); |
| 57 | internal nodes are SHA3-256(0x01 || left || right). |
| 58 | """ |
| 59 | |
| 60 | leaves: list[RecordHash] = field(default_factory=list) |
| 61 | |
| 62 | def add(self, leaf_hash: RecordHash) -> None: |
| 63 | self.leaves.append(leaf_hash) |
| 64 | |
| 65 | def add_many(self, leaf_hashes: list[RecordHash]) -> None: |
| 66 | self.leaves.extend(leaf_hashes) |
| 67 | |
| 68 | @property |
| 69 | def size(self) -> int: |
| 70 | return len(self.leaves) |
| 71 | |
| 72 | def root(self) -> str: |
| 73 | """Compute the Merkle root hex. Raises EmptyTreeError if no leaves.""" |
| 74 | if not self.leaves: |
| 75 | raise EmptyTreeError("cannot compute root of empty tree") |
| 76 | |
| 77 | level: list[bytes] = [_leaf_hash_bytes(leaf.bytes) for leaf in self.leaves] |
| 78 | while len(level) > 1: |
| 79 | next_level: list[bytes] = [] |
| 80 | for i in range(0, len(level), 2): |
| 81 | left = level[i] |
| 82 | right = level[i + 1] if i + 1 < len(level) else left |
| 83 | next_level.append(_pair_hash(left, right)) |
| 84 | level = next_level |
| 85 | return level[0].hex() |
| 86 | |
| 87 | def inclusion_proof(self, index: int) -> InclusionProof: |
| 88 | """Generate an inclusion proof for the leaf at `index`.""" |
| 89 | if not self.leaves: |
| 90 | raise EmptyTreeError("empty tree has no proofs") |
| 91 | if index < 0 or index >= len(self.leaves): |
| 92 | raise IndexOutOfRangeError( |
| 93 | f"index {index} out of range [0, {len(self.leaves) - 1}]" |
| 94 | ) |
| 95 | |
| 96 | level: list[bytes] = [_leaf_hash_bytes(leaf.bytes) for leaf in self.leaves] |
| 97 | siblings: list[str] = [] |
| 98 | directions: list[str] = [] |
| 99 | idx = index |
| 100 | |
| 101 | while len(level) > 1: |
| 102 | next_level: list[bytes] = [] |
| 103 | for i in range(0, len(level), 2): |
| 104 | left = level[i] |
| 105 | right = level[i + 1] if i + 1 < len(level) else left |
| 106 | next_level.append(_pair_hash(left, right)) |
| 107 | |
| 108 | # Which side is our sibling on? |
| 109 | sib_index = idx ^ 1 # XOR 1 flips bottom bit |
| 110 | if sib_index >= len(level): |
| 111 | # odd-node case: sibling is a duplicate of ourselves |
| 112 | sib = level[idx] |
| 113 | direction = "L" if idx % 2 == 1 else "R" |
| 114 | else: |
| 115 | sib = level[sib_index] |
| 116 | direction = "L" if sib_index < idx else "R" |
| 117 | siblings.append(sib.hex()) |
| 118 | directions.append(direction) |
| 119 | |
| 120 | idx //= 2 |
| 121 | level = next_level |
| 122 | |
| 123 | return InclusionProof( |
| 124 | leaf_hash=self.leaves[index].hex, |
| 125 | index=index, |
| 126 | tree_size=self.size, |
| 127 | root=level[0].hex(), |
| 128 | siblings=siblings, |
| 129 | directions=directions, |
| 130 | ) |
| 131 | |
| 132 | @staticmethod |
| 133 | def verify_inclusion(proof: InclusionProof) -> bool: |
| 134 | """Independently verify an inclusion proof. Returns True iff valid.""" |
| 135 | current = _leaf_hash_bytes(bytes.fromhex(proof.leaf_hash)) |
| 136 | for sib_hex, direction in zip(proof.siblings, proof.directions): |
| 137 | sib = bytes.fromhex(sib_hex) |
| 138 | if direction == "L": |
| 139 | current = _pair_hash(sib, current) |
| 140 | elif direction == "R": |
| 141 | current = _pair_hash(current, sib) |
| 142 | else: |
| 143 | raise InclusionProofError(f"invalid direction: {direction}") |
| 144 | return current.hex() == proof.root |
| 145 | |