src/pqc_reasoning_ledger/merkle.py
4.3 KB · 136 lines · python Raw
1 """SHA3-256 Merkle tree with domain-separated leaves/internal nodes.
2
3 Leaves: SHA3-256(0x00 || leaf_value)
4 Internal node: SHA3-256(0x01 || left || right)
5
6 For odd levels, the last node is promoted (duplicated) - RFC6962-style.
7 """
8
9 from __future__ import annotations
10
11 import hashlib
12 from dataclasses import asdict, dataclass
13 from typing import Any
14
15 from pqc_reasoning_ledger.errors import ReasoningLedgerError
16
17
18 def _pair_hash(left: bytes, right: bytes) -> bytes:
19 """Internal-node hash: SHA3-256(0x01 || left || right)."""
20 return hashlib.sha3_256(b"\x01" + left + right).digest()
21
22
23 def _leaf_hash_bytes(leaf_bytes: bytes) -> bytes:
24 """Leaf-node hash: SHA3-256(0x00 || leaf_bytes)."""
25 return hashlib.sha3_256(b"\x00" + leaf_bytes).digest()
26
27
28 @dataclass(frozen=True)
29 class InclusionProof:
30 """Proof that a leaf is in a Merkle tree.
31
32 `siblings` is the list of sibling hashes walking from leaf to root.
33 `directions` are 'L' or 'R' indicating which side each sibling is on.
34 `index` is the 0-based leaf index; `tree_size` is total leaf count.
35 """
36
37 leaf_hash: str # hex (the ORIGINAL leaf hash, BEFORE 0x00 prefix)
38 index: int
39 tree_size: int
40 root: str # hex
41 siblings: list[str] # hex, each a sibling node hash
42 directions: list[str] # 'L' or 'R'
43
44 def to_dict(self) -> dict[str, Any]:
45 return asdict(self)
46
47 @classmethod
48 def from_dict(cls, data: dict[str, Any]) -> InclusionProof:
49 return cls(**data)
50
51
52 def compute_merkle_root(leaves: list[str]) -> str:
53 """Compute the SHA3-256 Merkle root over a list of hex-encoded leaf hashes.
54
55 Raises ReasoningLedgerError if `leaves` is empty.
56 """
57 if not leaves:
58 raise ReasoningLedgerError("cannot compute root of empty tree")
59
60 level: list[bytes] = [_leaf_hash_bytes(bytes.fromhex(h)) for h in leaves]
61 while len(level) > 1:
62 next_level: list[bytes] = []
63 for i in range(0, len(level), 2):
64 left = level[i]
65 right = level[i + 1] if i + 1 < len(level) else left
66 next_level.append(_pair_hash(left, right))
67 level = next_level
68 return level[0].hex()
69
70
71 def build_proof(
72 leaves: list[str], index: int, root: str | None = None
73 ) -> InclusionProof:
74 """Build an inclusion proof for `leaves[index]`.
75
76 Raises ReasoningLedgerError if leaves is empty or index out of range.
77 If `root` is provided, it is stored on the proof; otherwise it is
78 re-computed from `leaves`.
79 """
80 if not leaves:
81 raise ReasoningLedgerError("cannot build proof for empty tree")
82 if index < 0 or index >= len(leaves):
83 raise ReasoningLedgerError(
84 f"index {index} out of range [0, {len(leaves) - 1}]"
85 )
86
87 level: list[bytes] = [_leaf_hash_bytes(bytes.fromhex(h)) for h in leaves]
88 siblings: list[str] = []
89 directions: list[str] = []
90 idx = index
91
92 while len(level) > 1:
93 next_level: list[bytes] = []
94 for i in range(0, len(level), 2):
95 left = level[i]
96 right = level[i + 1] if i + 1 < len(level) else left
97 next_level.append(_pair_hash(left, right))
98
99 sib_index = idx ^ 1
100 if sib_index >= len(level):
101 # odd-node case: sibling is a duplicate of ourselves
102 sib = level[idx]
103 direction = "L" if idx % 2 == 1 else "R"
104 else:
105 sib = level[sib_index]
106 direction = "L" if sib_index < idx else "R"
107 siblings.append(sib.hex())
108 directions.append(direction)
109
110 idx //= 2
111 level = next_level
112
113 computed_root = level[0].hex()
114 return InclusionProof(
115 leaf_hash=leaves[index],
116 index=index,
117 tree_size=len(leaves),
118 root=root if root is not None else computed_root,
119 siblings=siblings,
120 directions=directions,
121 )
122
123
124 def verify_inclusion(proof: InclusionProof) -> bool:
125 """Independently verify an inclusion proof. Returns True iff valid."""
126 current = _leaf_hash_bytes(bytes.fromhex(proof.leaf_hash))
127 for sib_hex, direction in zip(proof.siblings, proof.directions):
128 sib = bytes.fromhex(sib_hex)
129 if direction == "L":
130 current = _pair_hash(sib, current)
131 elif direction == "R":
132 current = _pair_hash(current, sib)
133 else:
134 return False
135 return current.hex() == proof.root
136