src/pqc_mcp_transport/middleware.py
6.8 KB · 186 lines · python Raw
1 """ASGI middleware that adds PQC verification to MCP HTTP servers."""
2
3 from __future__ import annotations
4
5 import json
6 from typing import Any, Callable, Awaitable
7
8 from quantumshield.identity.agent import AgentIdentity
9
10 from pqc_mcp_transport.errors import SignatureVerificationError
11 from pqc_mcp_transport.signer import MessageSigner
12
13
14 class PQCMiddleware:
15 """ASGI middleware that intercepts MCP JSON-RPC requests to verify PQC
16 signatures and signs outgoing responses.
17
18 Usage with Starlette / FastAPI::
19
20 from starlette.applications import Starlette
21 from pqc_mcp_transport.middleware import PQCMiddleware
22
23 app = Starlette(...)
24 app = PQCMiddleware(app, server_identity=my_identity)
25 """
26
27 def __init__(
28 self,
29 app: Any,
30 server_identity: AgentIdentity,
31 require_auth: bool = True,
32 ) -> None:
33 self.app = app
34 self.identity = server_identity
35 self.signer = MessageSigner(server_identity)
36 self._require_auth = require_auth
37
38 async def __call__(
39 self,
40 scope: dict,
41 receive: Callable[..., Awaitable[dict]],
42 send: Callable[..., Awaitable[None]],
43 ) -> None:
44 if scope["type"] != "http":
45 await self.app(scope, receive, send)
46 return
47
48 # Accumulate request body
49 body_parts: list[bytes] = []
50 request_complete = False
51
52 async def receive_wrapper() -> dict:
53 nonlocal request_complete
54 message = await receive()
55 if message["type"] == "http.request":
56 body_parts.append(message.get("body", b""))
57 if not message.get("more_body", False):
58 request_complete = True
59 return message
60
61 # Buffer the response body so we can sign it
62 response_started = False
63 response_status = 200
64 response_headers: list[tuple[bytes, bytes]] = []
65 response_body_parts: list[bytes] = []
66
67 async def send_wrapper(message: dict) -> None:
68 nonlocal response_started, response_status, response_headers
69
70 if message["type"] == "http.response.start":
71 response_started = True
72 response_status = message.get("status", 200)
73 response_headers = list(message.get("headers", []))
74 return # Don't send yet; wait for body
75
76 if message["type"] == "http.response.body":
77 body = message.get("body", b"")
78 more_body = message.get("more_body", False)
79 response_body_parts.append(body)
80
81 if not more_body:
82 # All body received: attempt to sign JSON responses
83 full_body = b"".join(response_body_parts)
84 content_type = ""
85 for k, v in response_headers:
86 if k.lower() == b"content-type":
87 content_type = v.decode("utf-8", errors="replace")
88 break
89
90 if "json" in content_type:
91 try:
92 data = json.loads(full_body)
93 signed = self.signer.sign_message(data)
94 full_body = json.dumps(signed).encode("utf-8")
95 except (json.JSONDecodeError, Exception):
96 pass # If not valid JSON, pass through
97
98 # Update Content-Length
99 new_headers = [
100 (k, v)
101 for k, v in response_headers
102 if k.lower() != b"content-length"
103 ]
104 new_headers.append(
105 (b"content-length", str(len(full_body)).encode("utf-8"))
106 )
107
108 await send(
109 {
110 "type": "http.response.start",
111 "status": response_status,
112 "headers": new_headers,
113 }
114 )
115 await send(
116 {
117 "type": "http.response.body",
118 "body": full_body,
119 }
120 )
121 return
122
123 # Pass through other message types
124 await send(message)
125
126 # Verify incoming request (read body, check signature, then replay to app)
127 if self._require_auth and scope.get("method", "").upper() == "POST":
128 # We need to read the body first to verify, then re-feed it to the app
129 body_chunks: list[bytes] = []
130 while True:
131 msg = await receive()
132 if msg["type"] == "http.request":
133 body_chunks.append(msg.get("body", b""))
134 if not msg.get("more_body", False):
135 break
136
137 full_request_body = b"".join(body_chunks)
138
139 # Try to verify PQC signature
140 try:
141 request_data = json.loads(full_request_body)
142 if "_pqc" in request_data:
143 vr = MessageSigner.verify_message(request_data)
144 if not vr.valid:
145 # Return 403
146 error_body = json.dumps(
147 {"error": f"PQC signature verification failed: {vr.error}"}
148 ).encode("utf-8")
149 await send(
150 {
151 "type": "http.response.start",
152 "status": 403,
153 "headers": [
154 (b"content-type", b"application/json"),
155 (
156 b"content-length",
157 str(len(error_body)).encode("utf-8"),
158 ),
159 ],
160 }
161 )
162 await send(
163 {"type": "http.response.body", "body": error_body}
164 )
165 return
166 except (json.JSONDecodeError, Exception):
167 pass # Not JSON — pass through to app
168
169 # Replay body to inner app
170 body_sent = False
171
172 async def replay_receive() -> dict:
173 nonlocal body_sent
174 if not body_sent:
175 body_sent = True
176 return {
177 "type": "http.request",
178 "body": full_request_body,
179 "more_body": False,
180 }
181 return await receive()
182
183 await self.app(scope, replay_receive, send_wrapper)
184 else:
185 await self.app(scope, receive, send_wrapper)
186