tests/test_verifier.py
3.2 KB · 99 lines · python Raw
1 """Tests for RetrievalVerifier."""
2
3 from __future__ import annotations
4
5 from dataclasses import replace
6
7 import pytest
8 from quantumshield.identity.agent import AgentIdentity
9
10 from pqc_rag_signing import (
11 ChunkMetadata,
12 ChunkSigner,
13 RetrievalVerifier,
14 TamperedChunkError,
15 )
16
17
18 def _signed_batch(identity: AgentIdentity, texts: list[str]) -> list:
19 signer = ChunkSigner(identity)
20 return signer.sign_chunks(texts, source="batch.txt")
21
22
23 def test_verify_retrieved_all_valid(ingest_identity: AgentIdentity) -> None:
24 chunks = _signed_batch(ingest_identity, ["alpha", "beta", "gamma"])
25 verifier = RetrievalVerifier()
26 result = verifier.verify_retrieved(chunks)
27 assert result.total == 3
28 assert result.verified_count == 3
29 assert result.failed_count == 0
30 assert result.all_verified
31 assert result.verified_texts() == ["alpha", "beta", "gamma"]
32
33
34 def test_verify_retrieved_detects_tampered_chunk(
35 ingest_identity: AgentIdentity,
36 ) -> None:
37 chunks = _signed_batch(ingest_identity, ["alpha", "beta"])
38 # Tamper with the text of the first chunk (hash will mismatch)
39 tampered = replace(chunks[0], text="EVIL")
40 bad_batch = [tampered, chunks[1]]
41 verifier = RetrievalVerifier()
42 result = verifier.verify_retrieved(bad_batch)
43 assert not result.all_verified
44 assert result.failed_count == 1
45 assert result.verified_count == 1
46 assert result.failed[0][0].chunk_id == tampered.chunk_id
47
48
49 def test_verify_retrieved_detects_untrusted_signer(
50 ingest_identity: AgentIdentity,
51 attacker_identity: AgentIdentity,
52 ) -> None:
53 good = _signed_batch(ingest_identity, ["safe content"])
54 evil = _signed_batch(attacker_identity, ["poisoned content"])
55 verifier = RetrievalVerifier(trusted_signers={ingest_identity.did})
56 result = verifier.verify_retrieved(good + evil)
57 assert result.verified_count == 1
58 assert result.failed_count == 1
59 assert result.failed[0][0].signer_did == attacker_identity.did
60 assert "allow-list" in (result.failed[0][1].error or "")
61
62
63 def test_verify_or_raise_success(ingest_identity: AgentIdentity) -> None:
64 chunks = _signed_batch(ingest_identity, ["alpha", "beta"])
65 verifier = RetrievalVerifier()
66 safe = verifier.verify_or_raise(chunks)
67 assert len(safe) == 2
68
69
70 def test_verify_or_raise_raises_on_tamper(
71 ingest_identity: AgentIdentity,
72 ) -> None:
73 chunks = _signed_batch(ingest_identity, ["alpha"])
74 chunks[0] = replace(chunks[0], text="TAMPERED")
75 verifier = RetrievalVerifier()
76 with pytest.raises(TamperedChunkError):
77 verifier.verify_or_raise(chunks)
78
79
80 def test_verified_texts_returns_only_safe_content(
81 ingest_identity: AgentIdentity,
82 attacker_identity: AgentIdentity,
83 ) -> None:
84 signer_good = ChunkSigner(ingest_identity)
85 good = signer_good.sign_chunk(
86 "TRUSTED",
87 ChunkMetadata(source="a.txt", chunk_index=0, total_chunks=1),
88 )
89 signer_evil = ChunkSigner(attacker_identity)
90 evil = signer_evil.sign_chunk(
91 "POISON",
92 ChunkMetadata(source="a.txt", chunk_index=1, total_chunks=2),
93 )
94 verifier = RetrievalVerifier(trusted_signers={ingest_identity.did})
95 result = verifier.verify_retrieved([good, evil])
96 texts = result.verified_texts()
97 assert "TRUSTED" in texts
98 assert "POISON" not in texts
99