src/pqc_audit_log_fs/merkle.py
4.3 KB · 134 lines · python Raw
1 """SHA3-256 Merkle tree with domain-separated leaves/internal nodes.
2
3 Leaves: SHA3-256(0x00 || leaf_value)
4 Internal node: SHA3-256(0x01 || left || right)
5
6 For odd levels, the last node is promoted (duplicated) - RFC6962-style.
7 """
8
9 from __future__ import annotations
10
11 import hashlib
12 from dataclasses import asdict, dataclass
13 from typing import Any
14
15 from pqc_audit_log_fs.errors import AuditLogError
16
17
18 def _pair_hash(left: bytes, right: bytes) -> bytes:
19 """Internal-node hash: SHA3-256(0x01 || left || right)."""
20 return hashlib.sha3_256(b"\x01" + left + right).digest()
21
22
23 def _leaf_hash_bytes(leaf_bytes: bytes) -> bytes:
24 """Leaf-node hash: SHA3-256(0x00 || leaf_bytes)."""
25 return hashlib.sha3_256(b"\x00" + leaf_bytes).digest()
26
27
28 @dataclass(frozen=True)
29 class InclusionProof:
30 """Proof that a leaf is in a Merkle tree.
31
32 `siblings` is the list of sibling hashes walking from leaf to root.
33 `directions` are 'L' or 'R' indicating which side each sibling is on.
34 `index` is the 0-based leaf index; `tree_size` is total leaf count.
35 """
36
37 leaf_hash: str # hex (the ORIGINAL leaf hash, BEFORE 0x00 prefix)
38 index: int
39 tree_size: int
40 root: str # hex
41 siblings: list[str] # hex, each a sibling node hash
42 directions: list[str] # 'L' or 'R'
43
44 def to_dict(self) -> dict[str, Any]:
45 return asdict(self)
46
47 @classmethod
48 def from_dict(cls, data: dict[str, Any]) -> InclusionProof:
49 return cls(**data)
50
51
52 def compute_merkle_root(leaves: list[str]) -> str:
53 """Compute the SHA3-256 Merkle root over a list of hex-encoded leaf hashes.
54
55 Raises AuditLogError if `leaves` is empty.
56 """
57 if not leaves:
58 raise AuditLogError("cannot compute root of empty tree")
59
60 level: list[bytes] = [_leaf_hash_bytes(bytes.fromhex(h)) for h in leaves]
61 while len(level) > 1:
62 next_level: list[bytes] = []
63 for i in range(0, len(level), 2):
64 left = level[i]
65 right = level[i + 1] if i + 1 < len(level) else left
66 next_level.append(_pair_hash(left, right))
67 level = next_level
68 return level[0].hex()
69
70
71 def build_merkle_proof(
72 leaves: list[str], index: int, root: str | None = None
73 ) -> InclusionProof:
74 """Build an inclusion proof for `leaves[index]`.
75
76 Raises AuditLogError if leaves is empty or index out of range.
77 If `root` is provided, it is stored on the proof; otherwise it is
78 re-computed from `leaves`.
79 """
80 if not leaves:
81 raise AuditLogError("cannot build proof for empty tree")
82 if index < 0 or index >= len(leaves):
83 raise AuditLogError(f"index {index} out of range [0, {len(leaves) - 1}]")
84
85 level: list[bytes] = [_leaf_hash_bytes(bytes.fromhex(h)) for h in leaves]
86 siblings: list[str] = []
87 directions: list[str] = []
88 idx = index
89
90 while len(level) > 1:
91 next_level: list[bytes] = []
92 for i in range(0, len(level), 2):
93 left = level[i]
94 right = level[i + 1] if i + 1 < len(level) else left
95 next_level.append(_pair_hash(left, right))
96
97 sib_index = idx ^ 1
98 if sib_index >= len(level):
99 # odd-node case: sibling is a duplicate of ourselves
100 sib = level[idx]
101 direction = "L" if idx % 2 == 1 else "R"
102 else:
103 sib = level[sib_index]
104 direction = "L" if sib_index < idx else "R"
105 siblings.append(sib.hex())
106 directions.append(direction)
107
108 idx //= 2
109 level = next_level
110
111 computed_root = level[0].hex()
112 return InclusionProof(
113 leaf_hash=leaves[index],
114 index=index,
115 tree_size=len(leaves),
116 root=root if root is not None else computed_root,
117 siblings=siblings,
118 directions=directions,
119 )
120
121
122 def verify_inclusion(proof: InclusionProof) -> bool:
123 """Independently verify an inclusion proof. Returns True iff valid."""
124 current = _leaf_hash_bytes(bytes.fromhex(proof.leaf_hash))
125 for sib_hex, direction in zip(proof.siblings, proof.directions):
126 sib = bytes.fromhex(sib_hex)
127 if direction == "L":
128 current = _pair_hash(sib, current)
129 elif direction == "R":
130 current = _pair_hash(current, sib)
131 else:
132 return False
133 return current.hex() == proof.root
134