src/pqc_mcp_transport/handshake.py
6.6 KB · 204 lines · python Raw
1 """PQC mutual authentication handshake for MCP peers."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import os
7 import uuid
8 from dataclasses import dataclass
9 from datetime import datetime, timezone
10
11 from quantumshield.core.algorithms import SignatureAlgorithm
12 from quantumshield.core.signatures import sign, verify
13 from quantumshield.identity.agent import AgentIdentity
14
15 from pqc_mcp_transport.errors import HandshakeError
16 from pqc_mcp_transport.session import PQCSession
17
18
19 @dataclass
20 class HandshakeRequest:
21 """Client's handshake initiation."""
22
23 client_did: str
24 client_public_key: str # hex
25 algorithm: str
26 timestamp: str
27 nonce: str
28 signature: str # hex
29
30 def to_dict(self) -> dict:
31 return {
32 "type": "pqc_handshake_request",
33 "client_did": self.client_did,
34 "client_public_key": self.client_public_key,
35 "algorithm": self.algorithm,
36 "timestamp": self.timestamp,
37 "nonce": self.nonce,
38 "signature": self.signature,
39 }
40
41 @classmethod
42 def from_dict(cls, data: dict) -> HandshakeRequest:
43 return cls(
44 client_did=data["client_did"],
45 client_public_key=data["client_public_key"],
46 algorithm=data["algorithm"],
47 timestamp=data["timestamp"],
48 nonce=data["nonce"],
49 signature=data["signature"],
50 )
51
52
53 @dataclass
54 class HandshakeResponse:
55 """Server's handshake response."""
56
57 server_did: str
58 server_public_key: str # hex
59 algorithm: str
60 client_nonce: str # echo back
61 server_nonce: str
62 signature: str # hex
63 session_id: str
64
65 def to_dict(self) -> dict:
66 return {
67 "type": "pqc_handshake_response",
68 "server_did": self.server_did,
69 "server_public_key": self.server_public_key,
70 "algorithm": self.algorithm,
71 "client_nonce": self.client_nonce,
72 "server_nonce": self.server_nonce,
73 "signature": self.signature,
74 "session_id": self.session_id,
75 }
76
77 @classmethod
78 def from_dict(cls, data: dict) -> HandshakeResponse:
79 return cls(
80 server_did=data["server_did"],
81 server_public_key=data["server_public_key"],
82 algorithm=data["algorithm"],
83 client_nonce=data["client_nonce"],
84 server_nonce=data["server_nonce"],
85 signature=data["signature"],
86 session_id=data["session_id"],
87 )
88
89
90 class PQCHandshake:
91 """Mutual PQC authentication handshake between MCP client and server."""
92
93 @staticmethod
94 def _sign_payload(payload: bytes, identity: AgentIdentity) -> bytes:
95 """Hash and sign a payload."""
96 msg_hash = hashlib.sha3_256(payload).digest()
97 return sign(msg_hash, identity.signing_keypair)
98
99 @staticmethod
100 def _verify_payload(
101 payload: bytes,
102 signature: bytes,
103 public_key: bytes,
104 algorithm: SignatureAlgorithm,
105 ) -> bool:
106 """Hash and verify a payload signature."""
107 msg_hash = hashlib.sha3_256(payload).digest()
108 return verify(msg_hash, signature, public_key, algorithm)
109
110 @staticmethod
111 def initiate(identity: AgentIdentity) -> tuple[HandshakeRequest, str]:
112 """Create a handshake request.
113
114 Returns the request and the nonce (needed later to complete the handshake).
115 """
116 nonce = os.urandom(16).hex()
117 timestamp = datetime.now(timezone.utc).isoformat()
118
119 payload = f"{identity.did}:{nonce}:{timestamp}".encode("utf-8")
120 sig = PQCHandshake._sign_payload(payload, identity)
121
122 request = HandshakeRequest(
123 client_did=identity.did,
124 client_public_key=identity.signing_keypair.public_key.hex(),
125 algorithm=identity.signing_keypair.algorithm.value,
126 timestamp=timestamp,
127 nonce=nonce,
128 signature=sig.hex(),
129 )
130 return request, nonce
131
132 @staticmethod
133 def respond(
134 request: HandshakeRequest, server_identity: AgentIdentity
135 ) -> HandshakeResponse:
136 """Verify the client's request and create a signed response.
137
138 Raises :class:`HandshakeError` if the client's signature is invalid.
139 """
140 # Verify client signature
141 payload = f"{request.client_did}:{request.nonce}:{request.timestamp}".encode(
142 "utf-8"
143 )
144 client_pub = bytes.fromhex(request.client_public_key)
145 algorithm = SignatureAlgorithm(request.algorithm)
146 client_sig = bytes.fromhex(request.signature)
147
148 if not PQCHandshake._verify_payload(payload, client_sig, client_pub, algorithm):
149 raise HandshakeError("Client handshake signature verification failed")
150
151 # Create response
152 session_id = uuid.uuid4().hex
153 server_nonce = os.urandom(16).hex()
154
155 resp_payload = (
156 f"{server_identity.did}:{request.nonce}:{server_nonce}:{session_id}"
157 ).encode("utf-8")
158 sig = PQCHandshake._sign_payload(resp_payload, server_identity)
159
160 return HandshakeResponse(
161 server_did=server_identity.did,
162 server_public_key=server_identity.signing_keypair.public_key.hex(),
163 algorithm=server_identity.signing_keypair.algorithm.value,
164 client_nonce=request.nonce,
165 server_nonce=server_nonce,
166 signature=sig.hex(),
167 session_id=session_id,
168 )
169
170 @staticmethod
171 def complete(
172 response: HandshakeResponse,
173 client_identity: AgentIdentity,
174 original_nonce: str,
175 ) -> PQCSession:
176 """Verify the server's response and create a session.
177
178 Raises :class:`HandshakeError` on verification failure.
179 """
180 # Verify nonce echo
181 if response.client_nonce != original_nonce:
182 raise HandshakeError("Server did not echo back the correct client nonce")
183
184 # Verify server signature
185 resp_payload = (
186 f"{response.server_did}:{response.client_nonce}:{response.server_nonce}:{response.session_id}"
187 ).encode("utf-8")
188 server_pub = bytes.fromhex(response.server_public_key)
189 algorithm = SignatureAlgorithm(response.algorithm)
190 server_sig = bytes.fromhex(response.signature)
191
192 if not PQCHandshake._verify_payload(
193 resp_payload, server_sig, server_pub, algorithm
194 ):
195 raise HandshakeError("Server handshake signature verification failed")
196
197 return PQCSession(
198 session_id=response.session_id,
199 local_identity=client_identity,
200 peer_did=response.server_did,
201 peer_public_key=server_pub,
202 peer_algorithm=algorithm,
203 )
204