src/pqc_mcp_transport/server.py
9.2 KB · 274 lines · python Raw
1 """PQC-secured MCP server that verifies incoming signatures and signs responses."""
2
3 from __future__ import annotations
4
5 import asyncio
6 import json
7 import uuid
8 from dataclasses import dataclass
9 from typing import Any, Callable, Awaitable
10
11 from quantumshield.identity.agent import AgentIdentity
12
13 from pqc_mcp_transport.audit import AuditLog
14 from pqc_mcp_transport.errors import (
15 PeerNotAuthenticatedError,
16 PQCTransportError,
17 SessionExpiredError,
18 SignatureVerificationError,
19 )
20 from pqc_mcp_transport.handshake import HandshakeRequest, HandshakeResponse, PQCHandshake
21 from pqc_mcp_transport.session import PQCSession
22 from pqc_mcp_transport.signer import MessageSigner
23
24
25 @dataclass
26 class ToolHandler:
27 """A registered MCP tool handler."""
28
29 name: str
30 description: str
31 handler: Callable[..., Awaitable[Any]]
32
33
34 class PQCMCPServer:
35 """MCP server that verifies PQC signatures on incoming calls and signs all responses."""
36
37 def __init__(
38 self,
39 identity: AgentIdentity,
40 require_auth: bool = True,
41 ) -> None:
42 self.identity = identity
43 self.signer = MessageSigner(identity)
44 self._tools: dict[str, ToolHandler] = {}
45 self._sessions: dict[str, PQCSession] = {}
46 self._require_auth = require_auth
47 self.audit = AuditLog()
48
49 def tool(self, name: str, description: str = "") -> Callable:
50 """Decorator to register an async tool handler."""
51
52 def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]:
53 self._tools[name] = ToolHandler(
54 name=name, description=description, handler=func
55 )
56 return func
57
58 return decorator
59
60 def get_tool_list(self) -> list[dict]:
61 """Return a list of registered tools and their descriptions."""
62 return [
63 {"name": t.name, "description": t.description} for t in self._tools.values()
64 ]
65
66 async def handle_handshake(self, request_data: dict) -> dict:
67 """Handle a PQC handshake initiation and return the response dict."""
68 request = HandshakeRequest.from_dict(request_data)
69 response = PQCHandshake.respond(request, self.identity)
70
71 # Create a server-side session
72 from quantumshield.core.algorithms import SignatureAlgorithm
73
74 session = PQCSession(
75 session_id=response.session_id,
76 local_identity=self.identity,
77 peer_did=request.client_did,
78 peer_public_key=bytes.fromhex(request.client_public_key),
79 peer_algorithm=SignatureAlgorithm(request.algorithm),
80 )
81 self._sessions[response.session_id] = session
82
83 session.log_operation(
84 op_type="handshake",
85 method=None,
86 signer_did=request.client_did,
87 verified=True,
88 signature_hex=request.signature,
89 algorithm=request.algorithm,
90 details="Handshake accepted",
91 )
92
93 return response.to_dict()
94
95 async def handle_request(self, raw_message: dict) -> dict:
96 """Process an incoming MCP request with PQC verification.
97
98 Returns a signed JSON-RPC response dict.
99 """
100 # Check if it is a handshake request
101 if raw_message.get("type") == "pqc_handshake_request":
102 return await self.handle_handshake(raw_message)
103
104 # Verify PQC signature
105 if self._require_auth:
106 pqc = raw_message.get("_pqc")
107 if not pqc:
108 return self._error_response(
109 raw_message.get("id"),
110 -32600,
111 "Missing _pqc signature envelope",
112 )
113
114 vr = MessageSigner.verify_message(raw_message)
115 if not vr.valid:
116 raise SignatureVerificationError(
117 f"Request signature verification failed: {vr.error}"
118 )
119
120 # Check session
121 session_id = pqc.get("session_id")
122 session = self._sessions.get(session_id) if session_id else None
123 if session and not session.is_valid():
124 raise SessionExpiredError("Session has expired")
125
126 # Replay protection
127 if session and vr.nonce:
128 session.check_nonce(vr.nonce)
129
130 if session:
131 session.log_operation(
132 op_type="tool_call",
133 method=raw_message.get("method", ""),
134 signer_did=vr.signer_did or "unknown",
135 verified=True,
136 signature_hex=pqc.get("signature", ""),
137 algorithm=vr.algorithm or "",
138 )
139
140 # Strip PQC envelope for processing
141 clean = MessageSigner.strip_pqc(raw_message)
142 method = clean.get("method", "")
143 msg_id = clean.get("id")
144 params = clean.get("params", {})
145
146 # Handle tools/list
147 if method == "tools/list":
148 result = {"tools": self.get_tool_list()}
149 response = {
150 "jsonrpc": "2.0",
151 "id": msg_id,
152 "result": result,
153 }
154 return self.signer.sign_message(response)
155
156 # Handle tools/call
157 if method == "tools/call":
158 tool_name = params.get("name", "")
159 arguments = params.get("arguments", {})
160 handler = self._tools.get(tool_name)
161
162 if handler is None:
163 return self.signer.sign_message(
164 self._error_response(msg_id, -32601, f"Unknown tool: {tool_name}")
165 )
166
167 try:
168 result = await handler.handler(**arguments)
169 response = {
170 "jsonrpc": "2.0",
171 "id": msg_id,
172 "result": {"content": result},
173 }
174 except Exception as exc:
175 response = self._error_response(
176 msg_id, -32000, f"Tool error: {exc}"
177 )
178
179 return self.signer.sign_message(response)
180
181 return self.signer.sign_message(
182 self._error_response(msg_id, -32601, f"Unknown method: {method}")
183 )
184
185 @staticmethod
186 def _error_response(msg_id: Any, code: int, message: str) -> dict:
187 return {
188 "jsonrpc": "2.0",
189 "id": msg_id,
190 "error": {"code": code, "message": message},
191 }
192
193 async def run(self, host: str = "0.0.0.0", port: int = 8080) -> None:
194 """Run a simple async HTTP server.
195
196 This is a minimal server suitable for development and examples.
197 For production, use :class:`PQCMiddleware` with an ASGI framework.
198 """
199 server = await asyncio.start_server(
200 lambda r, w: self._handle_connection(r, w),
201 host,
202 port,
203 )
204 async with server:
205 await server.serve_forever()
206
207 async def _handle_connection(
208 self,
209 reader: asyncio.StreamReader,
210 writer: asyncio.StreamWriter,
211 ) -> None:
212 """Handle a single HTTP connection (minimal HTTP/1.1 parser)."""
213 try:
214 # Read request line and headers
215 request_line = await reader.readline()
216 if not request_line:
217 writer.close()
218 return
219
220 headers: dict[str, str] = {}
221 while True:
222 line = await reader.readline()
223 if line in (b"\r\n", b"\n", b""):
224 break
225 if b":" in line:
226 key, value = line.decode("utf-8").split(":", 1)
227 headers[key.strip().lower()] = value.strip()
228
229 # Read body
230 content_length = int(headers.get("content-length", "0"))
231 body = await reader.read(content_length) if content_length > 0 else b""
232
233 # Parse path
234 parts = request_line.decode("utf-8").split()
235 method = parts[0] if parts else "GET"
236 path = parts[1] if len(parts) > 1 else "/"
237
238 if method == "POST" and body:
239 request_data = json.loads(body)
240
241 if path == "/handshake":
242 response_data = await self.handle_handshake(request_data)
243 else:
244 response_data = await self.handle_request(request_data)
245
246 response_body = json.dumps(response_data).encode("utf-8")
247 status = "200 OK"
248 else:
249 response_body = b'{"status": "PQC MCP Server running"}'
250 status = "200 OK"
251
252 # Write HTTP response
253 response = (
254 f"HTTP/1.1 {status}\r\n"
255 f"Content-Type: application/json\r\n"
256 f"Content-Length: {len(response_body)}\r\n"
257 f"\r\n"
258 ).encode("utf-8") + response_body
259
260 writer.write(response)
261 await writer.drain()
262 except Exception:
263 error_body = b'{"error": "Internal server error"}'
264 error_response = (
265 f"HTTP/1.1 500 Internal Server Error\r\n"
266 f"Content-Type: application/json\r\n"
267 f"Content-Length: {len(error_body)}\r\n"
268 f"\r\n"
269 ).encode("utf-8") + error_body
270 writer.write(error_response)
271 await writer.drain()
272 finally:
273 writer.close()
274