src/pqc_audit_log_fs/reader.py
| 1 | """LogReader - read sealed segments back, verify signatures + chain.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | import json |
| 7 | import os |
| 8 | |
| 9 | from quantumshield.core.algorithms import SignatureAlgorithm |
| 10 | from quantumshield.core.signatures import verify |
| 11 | |
| 12 | from pqc_audit_log_fs.errors import ( |
| 13 | SegmentCorruptedError, |
| 14 | SegmentNotFoundError, |
| 15 | SignatureVerificationError, |
| 16 | ) |
| 17 | from pqc_audit_log_fs.event import InferenceEvent |
| 18 | from pqc_audit_log_fs.merkle import compute_merkle_root |
| 19 | from pqc_audit_log_fs.segment import AuditSegment, SegmentHeader |
| 20 | |
| 21 | |
| 22 | class LogReader: |
| 23 | """Read-only access to a log directory.""" |
| 24 | |
| 25 | def __init__(self, log_dir: str) -> None: |
| 26 | if not os.path.isdir(log_dir): |
| 27 | raise SegmentNotFoundError(f"no directory {log_dir}") |
| 28 | self.log_dir = log_dir |
| 29 | |
| 30 | def list_segments(self) -> list[int]: |
| 31 | nums: list[int] = [] |
| 32 | for name in os.listdir(self.log_dir): |
| 33 | if name.startswith("segment-") and name.endswith(".sig.json"): |
| 34 | try: |
| 35 | n = int(name[len("segment-"): len("segment-") + 5]) |
| 36 | nums.append(n) |
| 37 | except ValueError: |
| 38 | continue |
| 39 | return sorted(nums) |
| 40 | |
| 41 | def read_header(self, segment_number: int) -> SegmentHeader: |
| 42 | path = os.path.join(self.log_dir, f"segment-{segment_number:05d}.sig.json") |
| 43 | if not os.path.exists(path): |
| 44 | raise SegmentNotFoundError(f"no sig file for segment {segment_number}") |
| 45 | with open(path, "r", encoding="utf-8") as f: |
| 46 | data = json.load(f) |
| 47 | return SegmentHeader.from_dict(data) |
| 48 | |
| 49 | def read_segment(self, segment_number: int) -> AuditSegment: |
| 50 | header = self.read_header(segment_number) |
| 51 | jsonl = os.path.join(self.log_dir, f"segment-{segment_number:05d}.log") |
| 52 | if not os.path.exists(jsonl): |
| 53 | raise SegmentNotFoundError(f"no jsonl file for segment {segment_number}") |
| 54 | events: list[InferenceEvent] = [] |
| 55 | with open(jsonl, "r", encoding="utf-8") as f: |
| 56 | for line in f: |
| 57 | line = line.strip() |
| 58 | if not line: |
| 59 | continue |
| 60 | try: |
| 61 | events.append(InferenceEvent.from_dict(json.loads(line))) |
| 62 | except (json.JSONDecodeError, TypeError) as exc: |
| 63 | raise SegmentCorruptedError( |
| 64 | f"malformed jsonl in segment {segment_number}: {exc}" |
| 65 | ) from exc |
| 66 | return AuditSegment(header=header, events=events) |
| 67 | |
| 68 | def verify_segment(self, segment_number: int) -> bool: |
| 69 | """Verify (a) ML-DSA sig on header, (b) merkle_root matches recomputed root.""" |
| 70 | segment = self.read_segment(segment_number) |
| 71 | header = segment.header |
| 72 | leaves = [e.leaf_hash() for e in segment.events] |
| 73 | recomputed = compute_merkle_root(leaves) if leaves else "" |
| 74 | if recomputed != header.merkle_root: |
| 75 | raise SegmentCorruptedError( |
| 76 | f"segment {segment_number} merkle_root mismatch: " |
| 77 | f"declared={header.merkle_root[:16]}..., " |
| 78 | f"recomputed={recomputed[:16]}..." |
| 79 | ) |
| 80 | try: |
| 81 | algorithm = SignatureAlgorithm(header.algorithm) |
| 82 | except ValueError as exc: |
| 83 | raise SignatureVerificationError( |
| 84 | f"unknown algorithm {header.algorithm}" |
| 85 | ) from exc |
| 86 | canonical = hashlib.sha3_256(header.canonical_bytes()).digest() |
| 87 | if not verify( |
| 88 | canonical, |
| 89 | bytes.fromhex(header.signature), |
| 90 | bytes.fromhex(header.public_key), |
| 91 | algorithm, |
| 92 | ): |
| 93 | raise SignatureVerificationError( |
| 94 | f"segment {segment_number} ML-DSA signature invalid" |
| 95 | ) |
| 96 | return True |
| 97 | |
| 98 | def verify_chain(self) -> tuple[bool, list[str]]: |
| 99 | """Verify every segment's sig + chain link. Returns (ok, errors).""" |
| 100 | errors: list[str] = [] |
| 101 | prev_root: str = "" |
| 102 | for n in self.list_segments(): |
| 103 | try: |
| 104 | self.verify_segment(n) |
| 105 | header = self.read_header(n) |
| 106 | if header.previous_segment_root != prev_root: |
| 107 | errors.append( |
| 108 | f"segment {n} chain break: expected prev " |
| 109 | f"{prev_root[:16]}..., got " |
| 110 | f"{header.previous_segment_root[:16]}..." |
| 111 | ) |
| 112 | prev_root = header.merkle_root |
| 113 | except (SegmentCorruptedError, SignatureVerificationError) as exc: |
| 114 | errors.append(f"segment {n}: {exc}") |
| 115 | # Even if verification fails, try to track chain if header readable |
| 116 | try: |
| 117 | header = self.read_header(n) |
| 118 | prev_root = header.merkle_root |
| 119 | except SegmentNotFoundError: |
| 120 | pass |
| 121 | return len(errors) == 0, errors |
| 122 | |