src/pqc_rag_signing/signer.py
5.5 KB · 169 lines · python Raw
1 """Chunk signing and verification using ML-DSA."""
2
3 from __future__ import annotations
4
5 import os
6 import uuid
7 from dataclasses import dataclass
8 from datetime import datetime, timezone
9 from typing import Iterable
10
11 from quantumshield.core.algorithms import SignatureAlgorithm
12 from quantumshield.core.signatures import sign, verify
13 from quantumshield.identity.agent import AgentIdentity
14
15 from pqc_rag_signing.chunk import ChunkMetadata, SignedChunk
16 from pqc_rag_signing.errors import ChunkVerificationError
17
18
19 @dataclass(frozen=True)
20 class VerificationResult:
21 """Outcome of verifying a SignedChunk."""
22
23 valid: bool
24 chunk_id: str
25 signer_did: str | None
26 algorithm: str | None
27 error: str | None = None
28
29 def raise_if_invalid(self) -> None:
30 """Raise ChunkVerificationError if this result is invalid."""
31 if not self.valid:
32 raise ChunkVerificationError(
33 f"Chunk {self.chunk_id} failed verification: {self.error}"
34 )
35
36
37 class ChunkSigner:
38 """Signs RAG chunks with a fixed AgentIdentity.
39
40 Usage:
41 identity = AgentIdentity.create("my-ingest-pipeline")
42 signer = ChunkSigner(identity)
43 signed = signer.sign_chunk("text content", metadata)
44 # store signed.to_dict() in vector DB
45 """
46
47 def __init__(self, identity: AgentIdentity, corpus_id: str | None = None) -> None:
48 self.identity = identity
49 self.corpus_id = corpus_id
50
51 # -- signing ------------------------------------------------------------
52
53 def sign_chunk(
54 self,
55 text: str,
56 metadata: ChunkMetadata,
57 chunk_id: str | None = None,
58 ) -> SignedChunk:
59 """Sign a single chunk. Returns the full signed envelope."""
60 chunk_id = chunk_id or f"chunk-{uuid.uuid4().hex[:16]}"
61 nonce = os.urandom(8).hex()
62 content_hash = SignedChunk.compute_content_hash(text, metadata, nonce)
63 sig = sign(bytes.fromhex(content_hash), self.identity.signing_keypair)
64 return SignedChunk(
65 chunk_id=chunk_id,
66 text=text,
67 metadata=metadata,
68 content_hash=content_hash,
69 signer_did=self.identity.did,
70 algorithm=self.identity.signing_keypair.algorithm.value,
71 signature=sig.hex(),
72 public_key=self.identity.signing_keypair.public_key.hex(),
73 signed_at=datetime.now(timezone.utc).isoformat(),
74 corpus_id=self.corpus_id,
75 nonce=nonce,
76 )
77
78 def sign_chunks(
79 self,
80 texts: Iterable[str],
81 source: str,
82 ) -> list[SignedChunk]:
83 """Sign a batch of chunks from a single source document.
84
85 Metadata (chunk_index, total_chunks, offsets) is auto-computed.
86 """
87 text_list = list(texts)
88 total = len(text_list)
89 offset = 0
90 signed: list[SignedChunk] = []
91 for i, text in enumerate(text_list):
92 meta = ChunkMetadata(
93 source=source,
94 chunk_index=i,
95 total_chunks=total,
96 start_offset=offset,
97 end_offset=offset + len(text),
98 )
99 offset += len(text)
100 signed.append(self.sign_chunk(text, meta))
101 return signed
102
103 # -- verification -------------------------------------------------------
104
105 @staticmethod
106 def verify_chunk(chunk: SignedChunk) -> VerificationResult:
107 """Verify a SignedChunk's content hash and ML-DSA signature."""
108 expected_hash = SignedChunk.compute_content_hash(
109 chunk.text, chunk.metadata, chunk.nonce
110 )
111 if expected_hash != chunk.content_hash:
112 return VerificationResult(
113 valid=False,
114 chunk_id=chunk.chunk_id,
115 signer_did=chunk.signer_did,
116 algorithm=chunk.algorithm,
117 error=(
118 f"content hash mismatch (expected {expected_hash[:16]}, "
119 f"got {chunk.content_hash[:16]})"
120 ),
121 )
122
123 try:
124 algorithm = SignatureAlgorithm(chunk.algorithm)
125 except ValueError:
126 return VerificationResult(
127 valid=False,
128 chunk_id=chunk.chunk_id,
129 signer_did=chunk.signer_did,
130 algorithm=chunk.algorithm,
131 error=f"unknown algorithm {chunk.algorithm}",
132 )
133
134 try:
135 sig_valid = verify(
136 bytes.fromhex(chunk.content_hash),
137 bytes.fromhex(chunk.signature),
138 bytes.fromhex(chunk.public_key),
139 algorithm,
140 )
141 except Exception as exc:
142 return VerificationResult(
143 valid=False,
144 chunk_id=chunk.chunk_id,
145 signer_did=chunk.signer_did,
146 algorithm=chunk.algorithm,
147 error=f"signature verify failed: {exc}",
148 )
149
150 if not sig_valid:
151 return VerificationResult(
152 valid=False,
153 chunk_id=chunk.chunk_id,
154 signer_did=chunk.signer_did,
155 algorithm=chunk.algorithm,
156 error="invalid ML-DSA signature",
157 )
158
159 return VerificationResult(
160 valid=True,
161 chunk_id=chunk.chunk_id,
162 signer_did=chunk.signer_did,
163 algorithm=chunk.algorithm,
164 )
165
166 @staticmethod
167 def verify_chunks(chunks: Iterable[SignedChunk]) -> list[VerificationResult]:
168 return [ChunkSigner.verify_chunk(c) for c in chunks]
169