src/pqc_mcp_transport/middleware.py
| 1 | """ASGI middleware that adds PQC verification to MCP HTTP servers.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import json |
| 6 | from typing import Any, Callable, Awaitable |
| 7 | |
| 8 | from quantumshield.identity.agent import AgentIdentity |
| 9 | |
| 10 | from pqc_mcp_transport.errors import SignatureVerificationError |
| 11 | from pqc_mcp_transport.signer import MessageSigner |
| 12 | |
| 13 | |
| 14 | class PQCMiddleware: |
| 15 | """ASGI middleware that intercepts MCP JSON-RPC requests to verify PQC |
| 16 | signatures and signs outgoing responses. |
| 17 | |
| 18 | Usage with Starlette / FastAPI:: |
| 19 | |
| 20 | from starlette.applications import Starlette |
| 21 | from pqc_mcp_transport.middleware import PQCMiddleware |
| 22 | |
| 23 | app = Starlette(...) |
| 24 | app = PQCMiddleware(app, server_identity=my_identity) |
| 25 | """ |
| 26 | |
| 27 | def __init__( |
| 28 | self, |
| 29 | app: Any, |
| 30 | server_identity: AgentIdentity, |
| 31 | require_auth: bool = True, |
| 32 | ) -> None: |
| 33 | self.app = app |
| 34 | self.identity = server_identity |
| 35 | self.signer = MessageSigner(server_identity) |
| 36 | self._require_auth = require_auth |
| 37 | |
| 38 | async def __call__( |
| 39 | self, |
| 40 | scope: dict, |
| 41 | receive: Callable[..., Awaitable[dict]], |
| 42 | send: Callable[..., Awaitable[None]], |
| 43 | ) -> None: |
| 44 | if scope["type"] != "http": |
| 45 | await self.app(scope, receive, send) |
| 46 | return |
| 47 | |
| 48 | # Accumulate request body |
| 49 | body_parts: list[bytes] = [] |
| 50 | request_complete = False |
| 51 | |
| 52 | async def receive_wrapper() -> dict: |
| 53 | nonlocal request_complete |
| 54 | message = await receive() |
| 55 | if message["type"] == "http.request": |
| 56 | body_parts.append(message.get("body", b"")) |
| 57 | if not message.get("more_body", False): |
| 58 | request_complete = True |
| 59 | return message |
| 60 | |
| 61 | # Buffer the response body so we can sign it |
| 62 | response_started = False |
| 63 | response_status = 200 |
| 64 | response_headers: list[tuple[bytes, bytes]] = [] |
| 65 | response_body_parts: list[bytes] = [] |
| 66 | |
| 67 | async def send_wrapper(message: dict) -> None: |
| 68 | nonlocal response_started, response_status, response_headers |
| 69 | |
| 70 | if message["type"] == "http.response.start": |
| 71 | response_started = True |
| 72 | response_status = message.get("status", 200) |
| 73 | response_headers = list(message.get("headers", [])) |
| 74 | return # Don't send yet; wait for body |
| 75 | |
| 76 | if message["type"] == "http.response.body": |
| 77 | body = message.get("body", b"") |
| 78 | more_body = message.get("more_body", False) |
| 79 | response_body_parts.append(body) |
| 80 | |
| 81 | if not more_body: |
| 82 | # All body received: attempt to sign JSON responses |
| 83 | full_body = b"".join(response_body_parts) |
| 84 | content_type = "" |
| 85 | for k, v in response_headers: |
| 86 | if k.lower() == b"content-type": |
| 87 | content_type = v.decode("utf-8", errors="replace") |
| 88 | break |
| 89 | |
| 90 | if "json" in content_type: |
| 91 | try: |
| 92 | data = json.loads(full_body) |
| 93 | signed = self.signer.sign_message(data) |
| 94 | full_body = json.dumps(signed).encode("utf-8") |
| 95 | except (json.JSONDecodeError, Exception): |
| 96 | pass # If not valid JSON, pass through |
| 97 | |
| 98 | # Update Content-Length |
| 99 | new_headers = [ |
| 100 | (k, v) |
| 101 | for k, v in response_headers |
| 102 | if k.lower() != b"content-length" |
| 103 | ] |
| 104 | new_headers.append( |
| 105 | (b"content-length", str(len(full_body)).encode("utf-8")) |
| 106 | ) |
| 107 | |
| 108 | await send( |
| 109 | { |
| 110 | "type": "http.response.start", |
| 111 | "status": response_status, |
| 112 | "headers": new_headers, |
| 113 | } |
| 114 | ) |
| 115 | await send( |
| 116 | { |
| 117 | "type": "http.response.body", |
| 118 | "body": full_body, |
| 119 | } |
| 120 | ) |
| 121 | return |
| 122 | |
| 123 | # Pass through other message types |
| 124 | await send(message) |
| 125 | |
| 126 | # Verify incoming request (read body, check signature, then replay to app) |
| 127 | if self._require_auth and scope.get("method", "").upper() == "POST": |
| 128 | # We need to read the body first to verify, then re-feed it to the app |
| 129 | body_chunks: list[bytes] = [] |
| 130 | while True: |
| 131 | msg = await receive() |
| 132 | if msg["type"] == "http.request": |
| 133 | body_chunks.append(msg.get("body", b"")) |
| 134 | if not msg.get("more_body", False): |
| 135 | break |
| 136 | |
| 137 | full_request_body = b"".join(body_chunks) |
| 138 | |
| 139 | # Try to verify PQC signature |
| 140 | try: |
| 141 | request_data = json.loads(full_request_body) |
| 142 | if "_pqc" in request_data: |
| 143 | vr = MessageSigner.verify_message(request_data) |
| 144 | if not vr.valid: |
| 145 | # Return 403 |
| 146 | error_body = json.dumps( |
| 147 | {"error": f"PQC signature verification failed: {vr.error}"} |
| 148 | ).encode("utf-8") |
| 149 | await send( |
| 150 | { |
| 151 | "type": "http.response.start", |
| 152 | "status": 403, |
| 153 | "headers": [ |
| 154 | (b"content-type", b"application/json"), |
| 155 | ( |
| 156 | b"content-length", |
| 157 | str(len(error_body)).encode("utf-8"), |
| 158 | ), |
| 159 | ], |
| 160 | } |
| 161 | ) |
| 162 | await send( |
| 163 | {"type": "http.response.body", "body": error_body} |
| 164 | ) |
| 165 | return |
| 166 | except (json.JSONDecodeError, Exception): |
| 167 | pass # Not JSON — pass through to app |
| 168 | |
| 169 | # Replay body to inner app |
| 170 | body_sent = False |
| 171 | |
| 172 | async def replay_receive() -> dict: |
| 173 | nonlocal body_sent |
| 174 | if not body_sent: |
| 175 | body_sent = True |
| 176 | return { |
| 177 | "type": "http.request", |
| 178 | "body": full_request_body, |
| 179 | "more_body": False, |
| 180 | } |
| 181 | return await receive() |
| 182 | |
| 183 | await self.app(scope, replay_receive, send_wrapper) |
| 184 | else: |
| 185 | await self.app(scope, receive, send_wrapper) |
| 186 | |