src/pqc_rag_signing/verifier.py
3.5 KB · 115 lines · python Raw
1 """Retrieval-time verification wrapper for RAG pipelines."""
2
3 from __future__ import annotations
4
5 from dataclasses import dataclass, field
6 from datetime import datetime, timezone
7 from typing import Iterable
8
9 from pqc_rag_signing.chunk import SignedChunk
10 from pqc_rag_signing.errors import TamperedChunkError
11 from pqc_rag_signing.signer import ChunkSigner, VerificationResult
12
13
14 @dataclass
15 class RetrievalResult:
16 """Aggregate result of verifying a batch of retrieved chunks."""
17
18 total: int
19 verified: list[SignedChunk] = field(default_factory=list)
20 failed: list[tuple[SignedChunk, VerificationResult]] = field(default_factory=list)
21 verified_at: str = ""
22 trusted_signers: set[str] = field(default_factory=set)
23
24 @property
25 def all_verified(self) -> bool:
26 return len(self.failed) == 0
27
28 @property
29 def verified_count(self) -> int:
30 return len(self.verified)
31
32 @property
33 def failed_count(self) -> int:
34 return len(self.failed)
35
36 def verified_texts(self) -> list[str]:
37 """Return ONLY the text content of verified chunks - safe for LLM."""
38 return [c.text for c in self.verified]
39
40
41 class RetrievalVerifier:
42 """Verify chunks retrieved from a vector DB before passing to an LLM.
43
44 Supports optional allow-list of trusted signer DIDs. Chunks signed by
45 anyone NOT in the allow-list (if set) are rejected even if cryptographically
46 valid.
47
48 Usage:
49 verifier = RetrievalVerifier(
50 trusted_signers={"did:pqaid:abc123..."},
51 strict=True,
52 )
53 result = verifier.verify_retrieved(signed_chunks)
54 if not result.all_verified:
55 # handle failures
56 ...
57 safe_texts = result.verified_texts()
58 """
59
60 def __init__(
61 self,
62 trusted_signers: set[str] | None = None,
63 strict: bool = True,
64 ) -> None:
65 self.trusted_signers = trusted_signers
66 self.strict = strict
67
68 def verify_retrieved(
69 self,
70 chunks: Iterable[SignedChunk],
71 ) -> RetrievalResult:
72 """Verify each chunk, bucket into verified vs failed."""
73 result = RetrievalResult(
74 total=0,
75 verified_at=datetime.now(timezone.utc).isoformat(),
76 trusted_signers=self.trusted_signers or set(),
77 )
78 for chunk in chunks:
79 result.total += 1
80 verification = ChunkSigner.verify_chunk(chunk)
81
82 if (
83 verification.valid
84 and self.trusted_signers
85 and chunk.signer_did not in self.trusted_signers
86 ):
87 verification = VerificationResult(
88 valid=False,
89 chunk_id=chunk.chunk_id,
90 signer_did=chunk.signer_did,
91 algorithm=chunk.algorithm,
92 error=f"signer {chunk.signer_did} not in trusted allow-list",
93 )
94
95 if verification.valid:
96 result.verified.append(chunk)
97 else:
98 result.failed.append((chunk, verification))
99
100 return result
101
102 def verify_or_raise(
103 self,
104 chunks: Iterable[SignedChunk],
105 ) -> list[SignedChunk]:
106 """Like verify_retrieved, but raises on any failure."""
107 result = self.verify_retrieved(chunks)
108 if not result.all_verified:
109 first_fail = result.failed[0]
110 raise TamperedChunkError(
111 f"{result.failed_count}/{result.total} chunks failed "
112 f"verification. First failure: {first_fail[1].error}"
113 )
114 return result.verified
115