tests/test_verifier.py
| 1 | """Tests for RetrievalVerifier.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from dataclasses import replace |
| 6 | |
| 7 | import pytest |
| 8 | from quantumshield.identity.agent import AgentIdentity |
| 9 | |
| 10 | from pqc_rag_signing import ( |
| 11 | ChunkMetadata, |
| 12 | ChunkSigner, |
| 13 | RetrievalVerifier, |
| 14 | TamperedChunkError, |
| 15 | ) |
| 16 | |
| 17 | |
| 18 | def _signed_batch(identity: AgentIdentity, texts: list[str]) -> list: |
| 19 | signer = ChunkSigner(identity) |
| 20 | return signer.sign_chunks(texts, source="batch.txt") |
| 21 | |
| 22 | |
| 23 | def test_verify_retrieved_all_valid(ingest_identity: AgentIdentity) -> None: |
| 24 | chunks = _signed_batch(ingest_identity, ["alpha", "beta", "gamma"]) |
| 25 | verifier = RetrievalVerifier() |
| 26 | result = verifier.verify_retrieved(chunks) |
| 27 | assert result.total == 3 |
| 28 | assert result.verified_count == 3 |
| 29 | assert result.failed_count == 0 |
| 30 | assert result.all_verified |
| 31 | assert result.verified_texts() == ["alpha", "beta", "gamma"] |
| 32 | |
| 33 | |
| 34 | def test_verify_retrieved_detects_tampered_chunk( |
| 35 | ingest_identity: AgentIdentity, |
| 36 | ) -> None: |
| 37 | chunks = _signed_batch(ingest_identity, ["alpha", "beta"]) |
| 38 | # Tamper with the text of the first chunk (hash will mismatch) |
| 39 | tampered = replace(chunks[0], text="EVIL") |
| 40 | bad_batch = [tampered, chunks[1]] |
| 41 | verifier = RetrievalVerifier() |
| 42 | result = verifier.verify_retrieved(bad_batch) |
| 43 | assert not result.all_verified |
| 44 | assert result.failed_count == 1 |
| 45 | assert result.verified_count == 1 |
| 46 | assert result.failed[0][0].chunk_id == tampered.chunk_id |
| 47 | |
| 48 | |
| 49 | def test_verify_retrieved_detects_untrusted_signer( |
| 50 | ingest_identity: AgentIdentity, |
| 51 | attacker_identity: AgentIdentity, |
| 52 | ) -> None: |
| 53 | good = _signed_batch(ingest_identity, ["safe content"]) |
| 54 | evil = _signed_batch(attacker_identity, ["poisoned content"]) |
| 55 | verifier = RetrievalVerifier(trusted_signers={ingest_identity.did}) |
| 56 | result = verifier.verify_retrieved(good + evil) |
| 57 | assert result.verified_count == 1 |
| 58 | assert result.failed_count == 1 |
| 59 | assert result.failed[0][0].signer_did == attacker_identity.did |
| 60 | assert "allow-list" in (result.failed[0][1].error or "") |
| 61 | |
| 62 | |
| 63 | def test_verify_or_raise_success(ingest_identity: AgentIdentity) -> None: |
| 64 | chunks = _signed_batch(ingest_identity, ["alpha", "beta"]) |
| 65 | verifier = RetrievalVerifier() |
| 66 | safe = verifier.verify_or_raise(chunks) |
| 67 | assert len(safe) == 2 |
| 68 | |
| 69 | |
| 70 | def test_verify_or_raise_raises_on_tamper( |
| 71 | ingest_identity: AgentIdentity, |
| 72 | ) -> None: |
| 73 | chunks = _signed_batch(ingest_identity, ["alpha"]) |
| 74 | chunks[0] = replace(chunks[0], text="TAMPERED") |
| 75 | verifier = RetrievalVerifier() |
| 76 | with pytest.raises(TamperedChunkError): |
| 77 | verifier.verify_or_raise(chunks) |
| 78 | |
| 79 | |
| 80 | def test_verified_texts_returns_only_safe_content( |
| 81 | ingest_identity: AgentIdentity, |
| 82 | attacker_identity: AgentIdentity, |
| 83 | ) -> None: |
| 84 | signer_good = ChunkSigner(ingest_identity) |
| 85 | good = signer_good.sign_chunk( |
| 86 | "TRUSTED", |
| 87 | ChunkMetadata(source="a.txt", chunk_index=0, total_chunks=1), |
| 88 | ) |
| 89 | signer_evil = ChunkSigner(attacker_identity) |
| 90 | evil = signer_evil.sign_chunk( |
| 91 | "POISON", |
| 92 | ChunkMetadata(source="a.txt", chunk_index=1, total_chunks=2), |
| 93 | ) |
| 94 | verifier = RetrievalVerifier(trusted_signers={ingest_identity.did}) |
| 95 | result = verifier.verify_retrieved([good, evil]) |
| 96 | texts = result.verified_texts() |
| 97 | assert "TRUSTED" in texts |
| 98 | assert "POISON" not in texts |
| 99 | |