src/pqc_mcp_transport/client.py
5.4 KB · 158 lines · python Raw
1 """PQC-secured MCP client that wraps tool calls with ML-DSA signatures."""
2
3 from __future__ import annotations
4
5 import uuid
6 from typing import Any
7
8 import httpx
9
10 from quantumshield.identity.agent import AgentIdentity
11
12 from pqc_mcp_transport.errors import (
13 HandshakeError,
14 PeerNotAuthenticatedError,
15 PQCTransportError,
16 SessionExpiredError,
17 SignatureVerificationError,
18 )
19 from pqc_mcp_transport.handshake import HandshakeRequest, HandshakeResponse, PQCHandshake
20 from pqc_mcp_transport.session import PQCSession
21 from pqc_mcp_transport.signer import MessageSigner
22
23
24 class PQCMCPClient:
25 """Async MCP client that signs every request with ML-DSA and verifies responses."""
26
27 def __init__(
28 self,
29 identity: AgentIdentity,
30 server_url: str,
31 verify_responses: bool = True,
32 ) -> None:
33 self.identity = identity
34 self.server_url = server_url.rstrip("/")
35 self.signer = MessageSigner(identity)
36 self.session: PQCSession | None = None
37 self._verify_responses = verify_responses
38 self._http = httpx.AsyncClient()
39
40 async def connect(self) -> PQCSession:
41 """Perform a PQC handshake with the server and establish a session."""
42 request, nonce = PQCHandshake.initiate(self.identity)
43
44 resp = await self._http.post(
45 f"{self.server_url}/handshake",
46 json=request.to_dict(),
47 )
48 if resp.status_code != 200:
49 raise HandshakeError(f"Handshake request failed: HTTP {resp.status_code}")
50
51 response = HandshakeResponse.from_dict(resp.json())
52 self.session = PQCHandshake.complete(response, self.identity, nonce)
53
54 self.session.log_operation(
55 op_type="handshake",
56 method=None,
57 signer_did=response.server_did,
58 verified=True,
59 signature_hex=response.signature,
60 algorithm=response.algorithm,
61 details="Handshake completed successfully",
62 )
63 return self.session
64
65 def _require_session(self) -> PQCSession:
66 if self.session is None:
67 raise PeerNotAuthenticatedError("No active session. Call connect() first.")
68 if not self.session.is_valid():
69 raise SessionExpiredError("Session has expired. Reconnect.")
70 return self.session
71
72 async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> dict:
73 """Call an MCP tool with a PQC-signed request.
74
75 Returns the JSON-RPC result (unwrapped from the response envelope).
76 """
77 session = self._require_session()
78
79 message: dict[str, Any] = {
80 "jsonrpc": "2.0",
81 "method": "tools/call",
82 "id": uuid.uuid4().hex,
83 "params": {
84 "name": name,
85 "arguments": arguments or {},
86 },
87 }
88 signed = self.signer.sign_message(message)
89 signed["_pqc"]["session_id"] = session.session_id
90
91 resp = await self._http.post(f"{self.server_url}/mcp", json=signed)
92 if resp.status_code != 200:
93 raise PQCTransportError(f"Tool call failed: HTTP {resp.status_code}")
94
95 resp_data = resp.json()
96
97 if self._verify_responses and "_pqc" in resp_data:
98 vr = MessageSigner.verify_message(resp_data)
99 session.last_response_verified = vr.valid
100 session.log_operation(
101 op_type="tool_response",
102 method=name,
103 signer_did=vr.signer_did or "unknown",
104 verified=vr.valid,
105 signature_hex=resp_data.get("_pqc", {}).get("signature", ""),
106 algorithm=vr.algorithm or "",
107 )
108 if not vr.valid:
109 raise SignatureVerificationError(
110 f"Server response signature invalid: {vr.error}"
111 )
112 else:
113 session.last_response_verified = False
114
115 session.log_operation(
116 op_type="tool_call",
117 method=name,
118 signer_did=self.identity.did,
119 verified=True,
120 signature_hex=signed.get("_pqc", {}).get("signature", ""),
121 algorithm=self.identity.signing_keypair.algorithm.value,
122 )
123
124 return MessageSigner.strip_pqc(resp_data)
125
126 async def list_tools(self) -> list[dict]:
127 """List available tools via a PQC-signed request."""
128 session = self._require_session()
129
130 message: dict[str, Any] = {
131 "jsonrpc": "2.0",
132 "method": "tools/list",
133 "id": uuid.uuid4().hex,
134 }
135 signed = self.signer.sign_message(message)
136 signed["_pqc"]["session_id"] = session.session_id
137
138 resp = await self._http.post(f"{self.server_url}/mcp", json=signed)
139 if resp.status_code != 200:
140 raise PQCTransportError(f"List tools failed: HTTP {resp.status_code}")
141
142 resp_data = resp.json()
143
144 if self._verify_responses and "_pqc" in resp_data:
145 vr = MessageSigner.verify_message(resp_data)
146 session.last_response_verified = vr.valid
147 if not vr.valid:
148 raise SignatureVerificationError(
149 f"Server response signature invalid: {vr.error}"
150 )
151
152 return MessageSigner.strip_pqc(resp_data).get("result", {}).get("tools", [])
153
154 async def close(self) -> None:
155 """Close the HTTP client and invalidate the session."""
156 await self._http.aclose()
157 self.session = None
158