src/pqc_training_data/merkle.py
4.9 KB · 145 lines · python Raw
1 """SHA3-256 Merkle tree over a list of leaf hashes."""
2
3 from __future__ import annotations
4
5 import hashlib
6 from dataclasses import asdict, dataclass, field
7 from typing import Any
8
9 from pqc_training_data.errors import (
10 EmptyTreeError,
11 InclusionProofError,
12 IndexOutOfRangeError,
13 )
14 from pqc_training_data.record import RecordHash
15
16
17 def _pair_hash(left: bytes, right: bytes) -> bytes:
18 """Internal-node hash: SHA3-256(0x01 || left || right).
19
20 The 0x01 prefix domain-separates internal from leaf hashes
21 (leaves are prefixed 0x00 upstream).
22 """
23 return hashlib.sha3_256(b"\x01" + left + right).digest()
24
25
26 def _leaf_hash_bytes(leaf_bytes: bytes) -> bytes:
27 """Wrap a leaf hash with 0x00 domain separator."""
28 return hashlib.sha3_256(b"\x00" + leaf_bytes).digest()
29
30
31 @dataclass(frozen=True)
32 class InclusionProof:
33 """Proof that a leaf is in a Merkle tree.
34
35 `siblings` is the list of sibling hashes walking from leaf to root.
36 `directions` are 'L' or 'R' indicating which side each sibling is on.
37 `index` is the 0-based leaf index; `tree_size` is total leaf count.
38 """
39
40 leaf_hash: str # hex (the ORIGINAL leaf_hash, BEFORE 0x00 prefix)
41 index: int
42 tree_size: int
43 root: str # hex
44 siblings: list[str] # hex, each a sibling hash (the raw node hash)
45 directions: list[str] # 'L' or 'R' (whether sibling is on left or right of our path)
46
47 def to_dict(self) -> dict[str, Any]:
48 return asdict(self)
49
50
51 @dataclass
52 class MerkleTree:
53 """SHA3-256 Merkle tree. Works for any number of leaves (>= 1).
54
55 For odd levels, the last node is promoted (duplicated) - standard
56 RFC6962-style handling. Leaves are SHA3-256(0x00 || leaf_value);
57 internal nodes are SHA3-256(0x01 || left || right).
58 """
59
60 leaves: list[RecordHash] = field(default_factory=list)
61
62 def add(self, leaf_hash: RecordHash) -> None:
63 self.leaves.append(leaf_hash)
64
65 def add_many(self, leaf_hashes: list[RecordHash]) -> None:
66 self.leaves.extend(leaf_hashes)
67
68 @property
69 def size(self) -> int:
70 return len(self.leaves)
71
72 def root(self) -> str:
73 """Compute the Merkle root hex. Raises EmptyTreeError if no leaves."""
74 if not self.leaves:
75 raise EmptyTreeError("cannot compute root of empty tree")
76
77 level: list[bytes] = [_leaf_hash_bytes(leaf.bytes) for leaf in self.leaves]
78 while len(level) > 1:
79 next_level: list[bytes] = []
80 for i in range(0, len(level), 2):
81 left = level[i]
82 right = level[i + 1] if i + 1 < len(level) else left
83 next_level.append(_pair_hash(left, right))
84 level = next_level
85 return level[0].hex()
86
87 def inclusion_proof(self, index: int) -> InclusionProof:
88 """Generate an inclusion proof for the leaf at `index`."""
89 if not self.leaves:
90 raise EmptyTreeError("empty tree has no proofs")
91 if index < 0 or index >= len(self.leaves):
92 raise IndexOutOfRangeError(
93 f"index {index} out of range [0, {len(self.leaves) - 1}]"
94 )
95
96 level: list[bytes] = [_leaf_hash_bytes(leaf.bytes) for leaf in self.leaves]
97 siblings: list[str] = []
98 directions: list[str] = []
99 idx = index
100
101 while len(level) > 1:
102 next_level: list[bytes] = []
103 for i in range(0, len(level), 2):
104 left = level[i]
105 right = level[i + 1] if i + 1 < len(level) else left
106 next_level.append(_pair_hash(left, right))
107
108 # Which side is our sibling on?
109 sib_index = idx ^ 1 # XOR 1 flips bottom bit
110 if sib_index >= len(level):
111 # odd-node case: sibling is a duplicate of ourselves
112 sib = level[idx]
113 direction = "L" if idx % 2 == 1 else "R"
114 else:
115 sib = level[sib_index]
116 direction = "L" if sib_index < idx else "R"
117 siblings.append(sib.hex())
118 directions.append(direction)
119
120 idx //= 2
121 level = next_level
122
123 return InclusionProof(
124 leaf_hash=self.leaves[index].hex,
125 index=index,
126 tree_size=self.size,
127 root=level[0].hex(),
128 siblings=siblings,
129 directions=directions,
130 )
131
132 @staticmethod
133 def verify_inclusion(proof: InclusionProof) -> bool:
134 """Independently verify an inclusion proof. Returns True iff valid."""
135 current = _leaf_hash_bytes(bytes.fromhex(proof.leaf_hash))
136 for sib_hex, direction in zip(proof.siblings, proof.directions):
137 sib = bytes.fromhex(sib_hex)
138 if direction == "L":
139 current = _pair_hash(sib, current)
140 elif direction == "R":
141 current = _pair_hash(current, sib)
142 else:
143 raise InclusionProofError(f"invalid direction: {direction}")
144 return current.hex() == proof.root
145