src/pqc_audit_log_fs/reader.py
4.8 KB · 122 lines · python Raw
1 """LogReader - read sealed segments back, verify signatures + chain."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import json
7 import os
8
9 from quantumshield.core.algorithms import SignatureAlgorithm
10 from quantumshield.core.signatures import verify
11
12 from pqc_audit_log_fs.errors import (
13 SegmentCorruptedError,
14 SegmentNotFoundError,
15 SignatureVerificationError,
16 )
17 from pqc_audit_log_fs.event import InferenceEvent
18 from pqc_audit_log_fs.merkle import compute_merkle_root
19 from pqc_audit_log_fs.segment import AuditSegment, SegmentHeader
20
21
22 class LogReader:
23 """Read-only access to a log directory."""
24
25 def __init__(self, log_dir: str) -> None:
26 if not os.path.isdir(log_dir):
27 raise SegmentNotFoundError(f"no directory {log_dir}")
28 self.log_dir = log_dir
29
30 def list_segments(self) -> list[int]:
31 nums: list[int] = []
32 for name in os.listdir(self.log_dir):
33 if name.startswith("segment-") and name.endswith(".sig.json"):
34 try:
35 n = int(name[len("segment-"): len("segment-") + 5])
36 nums.append(n)
37 except ValueError:
38 continue
39 return sorted(nums)
40
41 def read_header(self, segment_number: int) -> SegmentHeader:
42 path = os.path.join(self.log_dir, f"segment-{segment_number:05d}.sig.json")
43 if not os.path.exists(path):
44 raise SegmentNotFoundError(f"no sig file for segment {segment_number}")
45 with open(path, "r", encoding="utf-8") as f:
46 data = json.load(f)
47 return SegmentHeader.from_dict(data)
48
49 def read_segment(self, segment_number: int) -> AuditSegment:
50 header = self.read_header(segment_number)
51 jsonl = os.path.join(self.log_dir, f"segment-{segment_number:05d}.log")
52 if not os.path.exists(jsonl):
53 raise SegmentNotFoundError(f"no jsonl file for segment {segment_number}")
54 events: list[InferenceEvent] = []
55 with open(jsonl, "r", encoding="utf-8") as f:
56 for line in f:
57 line = line.strip()
58 if not line:
59 continue
60 try:
61 events.append(InferenceEvent.from_dict(json.loads(line)))
62 except (json.JSONDecodeError, TypeError) as exc:
63 raise SegmentCorruptedError(
64 f"malformed jsonl in segment {segment_number}: {exc}"
65 ) from exc
66 return AuditSegment(header=header, events=events)
67
68 def verify_segment(self, segment_number: int) -> bool:
69 """Verify (a) ML-DSA sig on header, (b) merkle_root matches recomputed root."""
70 segment = self.read_segment(segment_number)
71 header = segment.header
72 leaves = [e.leaf_hash() for e in segment.events]
73 recomputed = compute_merkle_root(leaves) if leaves else ""
74 if recomputed != header.merkle_root:
75 raise SegmentCorruptedError(
76 f"segment {segment_number} merkle_root mismatch: "
77 f"declared={header.merkle_root[:16]}..., "
78 f"recomputed={recomputed[:16]}..."
79 )
80 try:
81 algorithm = SignatureAlgorithm(header.algorithm)
82 except ValueError as exc:
83 raise SignatureVerificationError(
84 f"unknown algorithm {header.algorithm}"
85 ) from exc
86 canonical = hashlib.sha3_256(header.canonical_bytes()).digest()
87 if not verify(
88 canonical,
89 bytes.fromhex(header.signature),
90 bytes.fromhex(header.public_key),
91 algorithm,
92 ):
93 raise SignatureVerificationError(
94 f"segment {segment_number} ML-DSA signature invalid"
95 )
96 return True
97
98 def verify_chain(self) -> tuple[bool, list[str]]:
99 """Verify every segment's sig + chain link. Returns (ok, errors)."""
100 errors: list[str] = []
101 prev_root: str = ""
102 for n in self.list_segments():
103 try:
104 self.verify_segment(n)
105 header = self.read_header(n)
106 if header.previous_segment_root != prev_root:
107 errors.append(
108 f"segment {n} chain break: expected prev "
109 f"{prev_root[:16]}..., got "
110 f"{header.previous_segment_root[:16]}..."
111 )
112 prev_root = header.merkle_root
113 except (SegmentCorruptedError, SignatureVerificationError) as exc:
114 errors.append(f"segment {n}: {exc}")
115 # Even if verification fails, try to track chain if header readable
116 try:
117 header = self.read_header(n)
118 prev_root = header.merkle_root
119 except SegmentNotFoundError:
120 pass
121 return len(errors) == 0, errors
122