tests/test_merkle.py
3.7 KB · 116 lines · python Raw
1 """Tests for the Merkle tree."""
2
3 import pytest
4
5 from pqc_training_data import DataRecord, MerkleTree
6 from pqc_training_data.errors import (
7 EmptyTreeError,
8 IndexOutOfRangeError,
9 )
10
11
12 def _tree_from_records(records: list[DataRecord]) -> MerkleTree:
13 t = MerkleTree()
14 t.add_many([r.leaf_hash() for r in records])
15 return t
16
17
18 def test_empty_tree_root_raises() -> None:
19 tree = MerkleTree()
20 with pytest.raises(EmptyTreeError):
21 tree.root()
22
23
24 def test_single_leaf_root(single_record: DataRecord) -> None:
25 tree = _tree_from_records([single_record])
26 root = tree.root()
27 # With one leaf, root is just H(0x00 || leaf_hash)
28 assert len(root) == 64
29
30
31 def test_root_deterministic_same_inputs(sample_records: list[DataRecord]) -> None:
32 t1 = _tree_from_records(sample_records)
33 t2 = _tree_from_records(sample_records)
34 assert t1.root() == t2.root()
35
36
37 def test_root_changes_with_leaf_change(sample_records: list[DataRecord]) -> None:
38 t1 = _tree_from_records(sample_records)
39 modified = list(sample_records)
40 modified[2] = DataRecord(content=b"CHANGED", metadata={"doc_id": 2})
41 t2 = _tree_from_records(modified)
42 assert t1.root() != t2.root()
43
44
45 def test_inclusion_proof_even_count(sample_records: list[DataRecord]) -> None:
46 # 5 records is odd - let's use 4 here for a balanced tree
47 records = sample_records[:4]
48 tree = _tree_from_records(records)
49 root = tree.root()
50 for i in range(4):
51 proof = tree.inclusion_proof(i)
52 assert proof.root == root
53 assert proof.tree_size == 4
54 assert proof.leaf_hash == records[i].leaf_hash().hex
55 assert MerkleTree.verify_inclusion(proof)
56
57
58 def test_inclusion_proof_odd_count(odd_records: list[DataRecord]) -> None:
59 # 7 leaves - exercises odd-level promotion in multiple places
60 tree = _tree_from_records(odd_records)
61 root = tree.root()
62 for i in range(len(odd_records)):
63 proof = tree.inclusion_proof(i)
64 assert proof.root == root
65 assert proof.tree_size == 7
66 assert MerkleTree.verify_inclusion(proof)
67
68
69 def test_verify_inclusion_success(sample_records: list[DataRecord]) -> None:
70 tree = _tree_from_records(sample_records)
71 proof = tree.inclusion_proof(2)
72 assert MerkleTree.verify_inclusion(proof) is True
73
74
75 def test_verify_inclusion_wrong_leaf_fails(sample_records: list[DataRecord]) -> None:
76 tree = _tree_from_records(sample_records)
77 proof = tree.inclusion_proof(2)
78 # Replace leaf_hash with something else
79 wrong_leaf = "0" * 64
80 bad_proof = type(proof)(
81 leaf_hash=wrong_leaf,
82 index=proof.index,
83 tree_size=proof.tree_size,
84 root=proof.root,
85 siblings=list(proof.siblings),
86 directions=list(proof.directions),
87 )
88 assert MerkleTree.verify_inclusion(bad_proof) is False
89
90
91 def test_verify_inclusion_tampered_sibling_fails(sample_records: list[DataRecord]) -> None:
92 tree = _tree_from_records(sample_records)
93 proof = tree.inclusion_proof(1)
94 tampered_siblings = list(proof.siblings)
95 # Flip one character of the first sibling (still valid hex)
96 first = tampered_siblings[0]
97 swap = "f" if first[0] != "f" else "0"
98 tampered_siblings[0] = swap + first[1:]
99 bad_proof = type(proof)(
100 leaf_hash=proof.leaf_hash,
101 index=proof.index,
102 tree_size=proof.tree_size,
103 root=proof.root,
104 siblings=tampered_siblings,
105 directions=list(proof.directions),
106 )
107 assert MerkleTree.verify_inclusion(bad_proof) is False
108
109
110 def test_index_out_of_range_raises(sample_records: list[DataRecord]) -> None:
111 tree = _tree_from_records(sample_records)
112 with pytest.raises(IndexOutOfRangeError):
113 tree.inclusion_proof(99)
114 with pytest.raises(IndexOutOfRangeError):
115 tree.inclusion_proof(-1)
116