src/pqc_reasoning_ledger/merkle.py
| 1 | """SHA3-256 Merkle tree with domain-separated leaves/internal nodes. |
| 2 | |
| 3 | Leaves: SHA3-256(0x00 || leaf_value) |
| 4 | Internal node: SHA3-256(0x01 || left || right) |
| 5 | |
| 6 | For odd levels, the last node is promoted (duplicated) - RFC6962-style. |
| 7 | """ |
| 8 | |
| 9 | from __future__ import annotations |
| 10 | |
| 11 | import hashlib |
| 12 | from dataclasses import asdict, dataclass |
| 13 | from typing import Any |
| 14 | |
| 15 | from pqc_reasoning_ledger.errors import ReasoningLedgerError |
| 16 | |
| 17 | |
| 18 | def _pair_hash(left: bytes, right: bytes) -> bytes: |
| 19 | """Internal-node hash: SHA3-256(0x01 || left || right).""" |
| 20 | return hashlib.sha3_256(b"\x01" + left + right).digest() |
| 21 | |
| 22 | |
| 23 | def _leaf_hash_bytes(leaf_bytes: bytes) -> bytes: |
| 24 | """Leaf-node hash: SHA3-256(0x00 || leaf_bytes).""" |
| 25 | return hashlib.sha3_256(b"\x00" + leaf_bytes).digest() |
| 26 | |
| 27 | |
| 28 | @dataclass(frozen=True) |
| 29 | class InclusionProof: |
| 30 | """Proof that a leaf is in a Merkle tree. |
| 31 | |
| 32 | `siblings` is the list of sibling hashes walking from leaf to root. |
| 33 | `directions` are 'L' or 'R' indicating which side each sibling is on. |
| 34 | `index` is the 0-based leaf index; `tree_size` is total leaf count. |
| 35 | """ |
| 36 | |
| 37 | leaf_hash: str # hex (the ORIGINAL leaf hash, BEFORE 0x00 prefix) |
| 38 | index: int |
| 39 | tree_size: int |
| 40 | root: str # hex |
| 41 | siblings: list[str] # hex, each a sibling node hash |
| 42 | directions: list[str] # 'L' or 'R' |
| 43 | |
| 44 | def to_dict(self) -> dict[str, Any]: |
| 45 | return asdict(self) |
| 46 | |
| 47 | @classmethod |
| 48 | def from_dict(cls, data: dict[str, Any]) -> InclusionProof: |
| 49 | return cls(**data) |
| 50 | |
| 51 | |
| 52 | def compute_merkle_root(leaves: list[str]) -> str: |
| 53 | """Compute the SHA3-256 Merkle root over a list of hex-encoded leaf hashes. |
| 54 | |
| 55 | Raises ReasoningLedgerError if `leaves` is empty. |
| 56 | """ |
| 57 | if not leaves: |
| 58 | raise ReasoningLedgerError("cannot compute root of empty tree") |
| 59 | |
| 60 | level: list[bytes] = [_leaf_hash_bytes(bytes.fromhex(h)) for h in leaves] |
| 61 | while len(level) > 1: |
| 62 | next_level: list[bytes] = [] |
| 63 | for i in range(0, len(level), 2): |
| 64 | left = level[i] |
| 65 | right = level[i + 1] if i + 1 < len(level) else left |
| 66 | next_level.append(_pair_hash(left, right)) |
| 67 | level = next_level |
| 68 | return level[0].hex() |
| 69 | |
| 70 | |
| 71 | def build_proof( |
| 72 | leaves: list[str], index: int, root: str | None = None |
| 73 | ) -> InclusionProof: |
| 74 | """Build an inclusion proof for `leaves[index]`. |
| 75 | |
| 76 | Raises ReasoningLedgerError if leaves is empty or index out of range. |
| 77 | If `root` is provided, it is stored on the proof; otherwise it is |
| 78 | re-computed from `leaves`. |
| 79 | """ |
| 80 | if not leaves: |
| 81 | raise ReasoningLedgerError("cannot build proof for empty tree") |
| 82 | if index < 0 or index >= len(leaves): |
| 83 | raise ReasoningLedgerError( |
| 84 | f"index {index} out of range [0, {len(leaves) - 1}]" |
| 85 | ) |
| 86 | |
| 87 | level: list[bytes] = [_leaf_hash_bytes(bytes.fromhex(h)) for h in leaves] |
| 88 | siblings: list[str] = [] |
| 89 | directions: list[str] = [] |
| 90 | idx = index |
| 91 | |
| 92 | while len(level) > 1: |
| 93 | next_level: list[bytes] = [] |
| 94 | for i in range(0, len(level), 2): |
| 95 | left = level[i] |
| 96 | right = level[i + 1] if i + 1 < len(level) else left |
| 97 | next_level.append(_pair_hash(left, right)) |
| 98 | |
| 99 | sib_index = idx ^ 1 |
| 100 | if sib_index >= len(level): |
| 101 | # odd-node case: sibling is a duplicate of ourselves |
| 102 | sib = level[idx] |
| 103 | direction = "L" if idx % 2 == 1 else "R" |
| 104 | else: |
| 105 | sib = level[sib_index] |
| 106 | direction = "L" if sib_index < idx else "R" |
| 107 | siblings.append(sib.hex()) |
| 108 | directions.append(direction) |
| 109 | |
| 110 | idx //= 2 |
| 111 | level = next_level |
| 112 | |
| 113 | computed_root = level[0].hex() |
| 114 | return InclusionProof( |
| 115 | leaf_hash=leaves[index], |
| 116 | index=index, |
| 117 | tree_size=len(leaves), |
| 118 | root=root if root is not None else computed_root, |
| 119 | siblings=siblings, |
| 120 | directions=directions, |
| 121 | ) |
| 122 | |
| 123 | |
| 124 | def verify_inclusion(proof: InclusionProof) -> bool: |
| 125 | """Independently verify an inclusion proof. Returns True iff valid.""" |
| 126 | current = _leaf_hash_bytes(bytes.fromhex(proof.leaf_hash)) |
| 127 | for sib_hex, direction in zip(proof.siblings, proof.directions): |
| 128 | sib = bytes.fromhex(sib_hex) |
| 129 | if direction == "L": |
| 130 | current = _pair_hash(sib, current) |
| 131 | elif direction == "R": |
| 132 | current = _pair_hash(current, sib) |
| 133 | else: |
| 134 | return False |
| 135 | return current.hex() == proof.root |
| 136 | |