tests/test_prover.py
2.7 KB · 86 lines · python Raw
1 """Tests for InclusionProver."""
2
3 from __future__ import annotations
4
5 from collections.abc import Callable
6
7 import pytest
8 from quantumshield.identity.agent import AgentIdentity
9
10 from pqc_audit_log_fs.appender import LogAppender, RotationPolicy
11 from pqc_audit_log_fs.errors import SegmentNotFoundError
12 from pqc_audit_log_fs.event import InferenceEvent
13 from pqc_audit_log_fs.prover import InclusionProver
14 from pqc_audit_log_fs.reader import LogReader
15
16
17 def _seed_segment(
18 log_dir: str,
19 signer: AgentIdentity,
20 factory: Callable[..., InferenceEvent],
21 n: int = 8,
22 ) -> list[InferenceEvent]:
23 events = [factory() for _ in range(n)]
24 app = LogAppender(
25 log_dir, signer,
26 rotation=RotationPolicy(max_events_per_segment=1000),
27 )
28 for e in events:
29 app.append(e)
30 app.close()
31 return events
32
33
34 def test_prove_event_returns_valid_proof(
35 signer_identity: AgentIdentity,
36 tmp_log_dir: str,
37 event_factory: Callable[..., InferenceEvent],
38 ) -> None:
39 events = _seed_segment(tmp_log_dir, signer_identity, event_factory, n=6)
40 reader = LogReader(tmp_log_dir)
41 prover = InclusionProver(reader)
42 proof = prover.prove_event(1, events[3].event_id)
43 assert proof.index == 3
44 assert proof.tree_size == 6
45
46
47 def test_verify_proof_passes(
48 signer_identity: AgentIdentity,
49 tmp_log_dir: str,
50 event_factory: Callable[..., InferenceEvent],
51 ) -> None:
52 events = _seed_segment(tmp_log_dir, signer_identity, event_factory, n=10)
53 reader = LogReader(tmp_log_dir)
54 prover = InclusionProver(reader)
55 target = events[7]
56 proof = prover.prove_event(1, target.event_id)
57 assert InclusionProver.verify_proof(target, proof) is True
58
59
60 def test_verify_proof_fails_for_tampered_event(
61 signer_identity: AgentIdentity,
62 tmp_log_dir: str,
63 event_factory: Callable[..., InferenceEvent],
64 ) -> None:
65 events = _seed_segment(tmp_log_dir, signer_identity, event_factory, n=5)
66 reader = LogReader(tmp_log_dir)
67 prover = InclusionProver(reader)
68 target = events[2]
69 proof = prover.prove_event(1, target.event_id)
70 # Mutate the event after the proof was generated
71 tampered = InferenceEvent.from_dict(target.to_dict())
72 tampered.decision_label = "FORGED"
73 assert InclusionProver.verify_proof(tampered, proof) is False
74
75
76 def test_missing_event_raises(
77 signer_identity: AgentIdentity,
78 tmp_log_dir: str,
79 event_factory: Callable[..., InferenceEvent],
80 ) -> None:
81 _seed_segment(tmp_log_dir, signer_identity, event_factory, n=3)
82 reader = LogReader(tmp_log_dir)
83 prover = InclusionProver(reader)
84 with pytest.raises(SegmentNotFoundError):
85 prover.prove_event(1, "urn:pqc-audit-evt:does-not-exist")
86