src/pqc_mcp_transport/server.py
| 1 | """PQC-secured MCP server that verifies incoming signatures and signs responses.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import asyncio |
| 6 | import json |
| 7 | import uuid |
| 8 | from dataclasses import dataclass |
| 9 | from typing import Any, Callable, Awaitable |
| 10 | |
| 11 | from quantumshield.identity.agent import AgentIdentity |
| 12 | |
| 13 | from pqc_mcp_transport.audit import AuditLog |
| 14 | from pqc_mcp_transport.errors import ( |
| 15 | PeerNotAuthenticatedError, |
| 16 | PQCTransportError, |
| 17 | SessionExpiredError, |
| 18 | SignatureVerificationError, |
| 19 | ) |
| 20 | from pqc_mcp_transport.handshake import HandshakeRequest, HandshakeResponse, PQCHandshake |
| 21 | from pqc_mcp_transport.session import PQCSession |
| 22 | from pqc_mcp_transport.signer import MessageSigner |
| 23 | |
| 24 | |
| 25 | @dataclass |
| 26 | class ToolHandler: |
| 27 | """A registered MCP tool handler.""" |
| 28 | |
| 29 | name: str |
| 30 | description: str |
| 31 | handler: Callable[..., Awaitable[Any]] |
| 32 | |
| 33 | |
| 34 | class PQCMCPServer: |
| 35 | """MCP server that verifies PQC signatures on incoming calls and signs all responses.""" |
| 36 | |
| 37 | def __init__( |
| 38 | self, |
| 39 | identity: AgentIdentity, |
| 40 | require_auth: bool = True, |
| 41 | ) -> None: |
| 42 | self.identity = identity |
| 43 | self.signer = MessageSigner(identity) |
| 44 | self._tools: dict[str, ToolHandler] = {} |
| 45 | self._sessions: dict[str, PQCSession] = {} |
| 46 | self._require_auth = require_auth |
| 47 | self.audit = AuditLog() |
| 48 | |
| 49 | def tool(self, name: str, description: str = "") -> Callable: |
| 50 | """Decorator to register an async tool handler.""" |
| 51 | |
| 52 | def decorator(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: |
| 53 | self._tools[name] = ToolHandler( |
| 54 | name=name, description=description, handler=func |
| 55 | ) |
| 56 | return func |
| 57 | |
| 58 | return decorator |
| 59 | |
| 60 | def get_tool_list(self) -> list[dict]: |
| 61 | """Return a list of registered tools and their descriptions.""" |
| 62 | return [ |
| 63 | {"name": t.name, "description": t.description} for t in self._tools.values() |
| 64 | ] |
| 65 | |
| 66 | async def handle_handshake(self, request_data: dict) -> dict: |
| 67 | """Handle a PQC handshake initiation and return the response dict.""" |
| 68 | request = HandshakeRequest.from_dict(request_data) |
| 69 | response = PQCHandshake.respond(request, self.identity) |
| 70 | |
| 71 | # Create a server-side session |
| 72 | from quantumshield.core.algorithms import SignatureAlgorithm |
| 73 | |
| 74 | session = PQCSession( |
| 75 | session_id=response.session_id, |
| 76 | local_identity=self.identity, |
| 77 | peer_did=request.client_did, |
| 78 | peer_public_key=bytes.fromhex(request.client_public_key), |
| 79 | peer_algorithm=SignatureAlgorithm(request.algorithm), |
| 80 | ) |
| 81 | self._sessions[response.session_id] = session |
| 82 | |
| 83 | session.log_operation( |
| 84 | op_type="handshake", |
| 85 | method=None, |
| 86 | signer_did=request.client_did, |
| 87 | verified=True, |
| 88 | signature_hex=request.signature, |
| 89 | algorithm=request.algorithm, |
| 90 | details="Handshake accepted", |
| 91 | ) |
| 92 | |
| 93 | return response.to_dict() |
| 94 | |
| 95 | async def handle_request(self, raw_message: dict) -> dict: |
| 96 | """Process an incoming MCP request with PQC verification. |
| 97 | |
| 98 | Returns a signed JSON-RPC response dict. |
| 99 | """ |
| 100 | # Check if it is a handshake request |
| 101 | if raw_message.get("type") == "pqc_handshake_request": |
| 102 | return await self.handle_handshake(raw_message) |
| 103 | |
| 104 | # Verify PQC signature |
| 105 | if self._require_auth: |
| 106 | pqc = raw_message.get("_pqc") |
| 107 | if not pqc: |
| 108 | return self._error_response( |
| 109 | raw_message.get("id"), |
| 110 | -32600, |
| 111 | "Missing _pqc signature envelope", |
| 112 | ) |
| 113 | |
| 114 | vr = MessageSigner.verify_message(raw_message) |
| 115 | if not vr.valid: |
| 116 | raise SignatureVerificationError( |
| 117 | f"Request signature verification failed: {vr.error}" |
| 118 | ) |
| 119 | |
| 120 | # Check session |
| 121 | session_id = pqc.get("session_id") |
| 122 | session = self._sessions.get(session_id) if session_id else None |
| 123 | if session and not session.is_valid(): |
| 124 | raise SessionExpiredError("Session has expired") |
| 125 | |
| 126 | # Replay protection |
| 127 | if session and vr.nonce: |
| 128 | session.check_nonce(vr.nonce) |
| 129 | |
| 130 | if session: |
| 131 | session.log_operation( |
| 132 | op_type="tool_call", |
| 133 | method=raw_message.get("method", ""), |
| 134 | signer_did=vr.signer_did or "unknown", |
| 135 | verified=True, |
| 136 | signature_hex=pqc.get("signature", ""), |
| 137 | algorithm=vr.algorithm or "", |
| 138 | ) |
| 139 | |
| 140 | # Strip PQC envelope for processing |
| 141 | clean = MessageSigner.strip_pqc(raw_message) |
| 142 | method = clean.get("method", "") |
| 143 | msg_id = clean.get("id") |
| 144 | params = clean.get("params", {}) |
| 145 | |
| 146 | # Handle tools/list |
| 147 | if method == "tools/list": |
| 148 | result = {"tools": self.get_tool_list()} |
| 149 | response = { |
| 150 | "jsonrpc": "2.0", |
| 151 | "id": msg_id, |
| 152 | "result": result, |
| 153 | } |
| 154 | return self.signer.sign_message(response) |
| 155 | |
| 156 | # Handle tools/call |
| 157 | if method == "tools/call": |
| 158 | tool_name = params.get("name", "") |
| 159 | arguments = params.get("arguments", {}) |
| 160 | handler = self._tools.get(tool_name) |
| 161 | |
| 162 | if handler is None: |
| 163 | return self.signer.sign_message( |
| 164 | self._error_response(msg_id, -32601, f"Unknown tool: {tool_name}") |
| 165 | ) |
| 166 | |
| 167 | try: |
| 168 | result = await handler.handler(**arguments) |
| 169 | response = { |
| 170 | "jsonrpc": "2.0", |
| 171 | "id": msg_id, |
| 172 | "result": {"content": result}, |
| 173 | } |
| 174 | except Exception as exc: |
| 175 | response = self._error_response( |
| 176 | msg_id, -32000, f"Tool error: {exc}" |
| 177 | ) |
| 178 | |
| 179 | return self.signer.sign_message(response) |
| 180 | |
| 181 | return self.signer.sign_message( |
| 182 | self._error_response(msg_id, -32601, f"Unknown method: {method}") |
| 183 | ) |
| 184 | |
| 185 | @staticmethod |
| 186 | def _error_response(msg_id: Any, code: int, message: str) -> dict: |
| 187 | return { |
| 188 | "jsonrpc": "2.0", |
| 189 | "id": msg_id, |
| 190 | "error": {"code": code, "message": message}, |
| 191 | } |
| 192 | |
| 193 | async def run(self, host: str = "0.0.0.0", port: int = 8080) -> None: |
| 194 | """Run a simple async HTTP server. |
| 195 | |
| 196 | This is a minimal server suitable for development and examples. |
| 197 | For production, use :class:`PQCMiddleware` with an ASGI framework. |
| 198 | """ |
| 199 | server = await asyncio.start_server( |
| 200 | lambda r, w: self._handle_connection(r, w), |
| 201 | host, |
| 202 | port, |
| 203 | ) |
| 204 | async with server: |
| 205 | await server.serve_forever() |
| 206 | |
| 207 | async def _handle_connection( |
| 208 | self, |
| 209 | reader: asyncio.StreamReader, |
| 210 | writer: asyncio.StreamWriter, |
| 211 | ) -> None: |
| 212 | """Handle a single HTTP connection (minimal HTTP/1.1 parser).""" |
| 213 | try: |
| 214 | # Read request line and headers |
| 215 | request_line = await reader.readline() |
| 216 | if not request_line: |
| 217 | writer.close() |
| 218 | return |
| 219 | |
| 220 | headers: dict[str, str] = {} |
| 221 | while True: |
| 222 | line = await reader.readline() |
| 223 | if line in (b"\r\n", b"\n", b""): |
| 224 | break |
| 225 | if b":" in line: |
| 226 | key, value = line.decode("utf-8").split(":", 1) |
| 227 | headers[key.strip().lower()] = value.strip() |
| 228 | |
| 229 | # Read body |
| 230 | content_length = int(headers.get("content-length", "0")) |
| 231 | body = await reader.read(content_length) if content_length > 0 else b"" |
| 232 | |
| 233 | # Parse path |
| 234 | parts = request_line.decode("utf-8").split() |
| 235 | method = parts[0] if parts else "GET" |
| 236 | path = parts[1] if len(parts) > 1 else "/" |
| 237 | |
| 238 | if method == "POST" and body: |
| 239 | request_data = json.loads(body) |
| 240 | |
| 241 | if path == "/handshake": |
| 242 | response_data = await self.handle_handshake(request_data) |
| 243 | else: |
| 244 | response_data = await self.handle_request(request_data) |
| 245 | |
| 246 | response_body = json.dumps(response_data).encode("utf-8") |
| 247 | status = "200 OK" |
| 248 | else: |
| 249 | response_body = b'{"status": "PQC MCP Server running"}' |
| 250 | status = "200 OK" |
| 251 | |
| 252 | # Write HTTP response |
| 253 | response = ( |
| 254 | f"HTTP/1.1 {status}\r\n" |
| 255 | f"Content-Type: application/json\r\n" |
| 256 | f"Content-Length: {len(response_body)}\r\n" |
| 257 | f"\r\n" |
| 258 | ).encode("utf-8") + response_body |
| 259 | |
| 260 | writer.write(response) |
| 261 | await writer.drain() |
| 262 | except Exception: |
| 263 | error_body = b'{"error": "Internal server error"}' |
| 264 | error_response = ( |
| 265 | f"HTTP/1.1 500 Internal Server Error\r\n" |
| 266 | f"Content-Type: application/json\r\n" |
| 267 | f"Content-Length: {len(error_body)}\r\n" |
| 268 | f"\r\n" |
| 269 | ).encode("utf-8") + error_body |
| 270 | writer.write(error_response) |
| 271 | await writer.drain() |
| 272 | finally: |
| 273 | writer.close() |
| 274 | |