src/pqc_rag_signing/corpus.py
5.8 KB · 166 lines · python Raw
1 """Corpus manifest - proves an entire set of chunks is intact."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import json
7 import uuid
8 from dataclasses import asdict, dataclass
9 from datetime import datetime, timezone
10 from typing import Iterable
11
12 from quantumshield.core.algorithms import SignatureAlgorithm
13 from quantumshield.core.signatures import sign, verify
14 from quantumshield.identity.agent import AgentIdentity
15
16 from pqc_rag_signing.chunk import SignedChunk
17 from pqc_rag_signing.errors import CorpusIntegrityError
18 from pqc_rag_signing.signer import ChunkSigner
19
20
21 @dataclass
22 class CorpusManifest:
23 """Merkle-ish manifest committing to an entire set of signed chunks.
24
25 The manifest contains the sorted list of (chunk_id, content_hash) pairs,
26 hashed into a single root. Any change to any chunk (add/remove/modify)
27 changes the root, so the manifest is a compact proof of corpus integrity.
28 """
29
30 corpus_id: str
31 name: str
32 created_at: str
33 chunk_count: int
34 chunk_hashes: list[tuple[str, str]]
35 root: str
36 signer_did: str
37 algorithm: str
38 signature: str
39 public_key: str
40
41 @staticmethod
42 def compute_root(chunk_hashes: list[tuple[str, str]]) -> str:
43 """Compute the deterministic manifest root over chunk hashes.
44
45 We sort by chunk_id for determinism, concatenate, and SHA3-256.
46 """
47 sorted_pairs = sorted(chunk_hashes, key=lambda p: p[0])
48 concat = "|".join(f"{cid}:{ch}" for cid, ch in sorted_pairs)
49 return hashlib.sha3_256(concat.encode("utf-8")).hexdigest()
50
51 def to_dict(self) -> dict:
52 return asdict(self)
53
54 @classmethod
55 def from_dict(cls, data: dict) -> CorpusManifest:
56 return cls(
57 corpus_id=data["corpus_id"],
58 name=data["name"],
59 created_at=data["created_at"],
60 chunk_count=data["chunk_count"],
61 chunk_hashes=[tuple(p) for p in data["chunk_hashes"]],
62 root=data["root"],
63 signer_did=data["signer_did"],
64 algorithm=data["algorithm"],
65 signature=data["signature"],
66 public_key=data["public_key"],
67 )
68
69 def to_json(self) -> str:
70 return json.dumps(self.to_dict(), indent=2)
71
72
73 class Corpus:
74 """High-level wrapper: sign a whole corpus, produce a manifest.
75
76 Usage:
77 identity = AgentIdentity.create("rag-ingestion")
78 corpus = Corpus(name="company-docs", identity=identity)
79 corpus.add_document("handbook.pdf", chunks=["...", "...", "..."])
80 corpus.add_document("policies.pdf", chunks=["...", "..."])
81 signed_chunks = corpus.sign_all()
82 manifest = corpus.build_manifest()
83 # persist signed_chunks to vector DB and manifest.to_json() to disk/S3
84 """
85
86 def __init__(
87 self,
88 name: str,
89 identity: AgentIdentity,
90 corpus_id: str | None = None,
91 ) -> None:
92 self.corpus_id = corpus_id or f"corpus-{uuid.uuid4().hex[:12]}"
93 self.name = name
94 self.identity = identity
95 self._documents: list[tuple[str, list[str]]] = []
96 self._signed: list[SignedChunk] = []
97
98 def add_document(self, source: str, chunks: list[str]) -> None:
99 """Queue a document for signing."""
100 self._documents.append((source, chunks))
101
102 def sign_all(self) -> list[SignedChunk]:
103 """Sign every queued chunk. Returns the full list of signed envelopes."""
104 signer = ChunkSigner(self.identity, corpus_id=self.corpus_id)
105 out: list[SignedChunk] = []
106 for source, chunks in self._documents:
107 out.extend(signer.sign_chunks(chunks, source=source))
108 self._signed = out
109 return out
110
111 def build_manifest(self, chunks: list[SignedChunk] | None = None) -> CorpusManifest:
112 """Build a signed manifest committing to all chunks in the corpus."""
113 chunks = chunks if chunks is not None else self._signed
114 if not chunks:
115 raise CorpusIntegrityError("no chunks to build manifest from")
116
117 chunk_hashes = [(c.chunk_id, c.content_hash) for c in chunks]
118 root = CorpusManifest.compute_root(chunk_hashes)
119 sig = sign(bytes.fromhex(root), self.identity.signing_keypair)
120 return CorpusManifest(
121 corpus_id=self.corpus_id,
122 name=self.name,
123 created_at=datetime.now(timezone.utc).isoformat(),
124 chunk_count=len(chunks),
125 chunk_hashes=sorted(chunk_hashes, key=lambda p: p[0]),
126 root=root,
127 signer_did=self.identity.did,
128 algorithm=self.identity.signing_keypair.algorithm.value,
129 signature=sig.hex(),
130 public_key=self.identity.signing_keypair.public_key.hex(),
131 )
132
133 @staticmethod
134 def verify_manifest(manifest: CorpusManifest) -> bool:
135 """Verify the manifest root signature and recompute the root."""
136 expected = CorpusManifest.compute_root(manifest.chunk_hashes)
137 if expected != manifest.root:
138 return False
139 try:
140 algorithm = SignatureAlgorithm(manifest.algorithm)
141 except ValueError:
142 return False
143 return verify(
144 bytes.fromhex(manifest.root),
145 bytes.fromhex(manifest.signature),
146 bytes.fromhex(manifest.public_key),
147 algorithm,
148 )
149
150 @staticmethod
151 def verify_chunks_against_manifest(
152 chunks: Iterable[SignedChunk],
153 manifest: CorpusManifest,
154 ) -> tuple[bool, list[str]]:
155 """Check that every chunk's (chunk_id, content_hash) is in the manifest.
156
157 Returns (all_present, missing_chunk_ids).
158 """
159 manifest_pairs = {tuple(p) for p in manifest.chunk_hashes}
160 missing: list[str] = []
161 for c in chunks:
162 pair = (c.chunk_id, c.content_hash)
163 if pair not in manifest_pairs:
164 missing.append(c.chunk_id)
165 return len(missing) == 0, missing
166