src/pqc_mcp_transport/handshake.py
| 1 | """PQC mutual authentication handshake for MCP peers.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | import os |
| 7 | import uuid |
| 8 | from dataclasses import dataclass |
| 9 | from datetime import datetime, timezone |
| 10 | |
| 11 | from quantumshield.core.algorithms import SignatureAlgorithm |
| 12 | from quantumshield.core.signatures import sign, verify |
| 13 | from quantumshield.identity.agent import AgentIdentity |
| 14 | |
| 15 | from pqc_mcp_transport.errors import HandshakeError |
| 16 | from pqc_mcp_transport.session import PQCSession |
| 17 | |
| 18 | |
| 19 | @dataclass |
| 20 | class HandshakeRequest: |
| 21 | """Client's handshake initiation.""" |
| 22 | |
| 23 | client_did: str |
| 24 | client_public_key: str # hex |
| 25 | algorithm: str |
| 26 | timestamp: str |
| 27 | nonce: str |
| 28 | signature: str # hex |
| 29 | |
| 30 | def to_dict(self) -> dict: |
| 31 | return { |
| 32 | "type": "pqc_handshake_request", |
| 33 | "client_did": self.client_did, |
| 34 | "client_public_key": self.client_public_key, |
| 35 | "algorithm": self.algorithm, |
| 36 | "timestamp": self.timestamp, |
| 37 | "nonce": self.nonce, |
| 38 | "signature": self.signature, |
| 39 | } |
| 40 | |
| 41 | @classmethod |
| 42 | def from_dict(cls, data: dict) -> HandshakeRequest: |
| 43 | return cls( |
| 44 | client_did=data["client_did"], |
| 45 | client_public_key=data["client_public_key"], |
| 46 | algorithm=data["algorithm"], |
| 47 | timestamp=data["timestamp"], |
| 48 | nonce=data["nonce"], |
| 49 | signature=data["signature"], |
| 50 | ) |
| 51 | |
| 52 | |
| 53 | @dataclass |
| 54 | class HandshakeResponse: |
| 55 | """Server's handshake response.""" |
| 56 | |
| 57 | server_did: str |
| 58 | server_public_key: str # hex |
| 59 | algorithm: str |
| 60 | client_nonce: str # echo back |
| 61 | server_nonce: str |
| 62 | signature: str # hex |
| 63 | session_id: str |
| 64 | |
| 65 | def to_dict(self) -> dict: |
| 66 | return { |
| 67 | "type": "pqc_handshake_response", |
| 68 | "server_did": self.server_did, |
| 69 | "server_public_key": self.server_public_key, |
| 70 | "algorithm": self.algorithm, |
| 71 | "client_nonce": self.client_nonce, |
| 72 | "server_nonce": self.server_nonce, |
| 73 | "signature": self.signature, |
| 74 | "session_id": self.session_id, |
| 75 | } |
| 76 | |
| 77 | @classmethod |
| 78 | def from_dict(cls, data: dict) -> HandshakeResponse: |
| 79 | return cls( |
| 80 | server_did=data["server_did"], |
| 81 | server_public_key=data["server_public_key"], |
| 82 | algorithm=data["algorithm"], |
| 83 | client_nonce=data["client_nonce"], |
| 84 | server_nonce=data["server_nonce"], |
| 85 | signature=data["signature"], |
| 86 | session_id=data["session_id"], |
| 87 | ) |
| 88 | |
| 89 | |
| 90 | class PQCHandshake: |
| 91 | """Mutual PQC authentication handshake between MCP client and server.""" |
| 92 | |
| 93 | @staticmethod |
| 94 | def _sign_payload(payload: bytes, identity: AgentIdentity) -> bytes: |
| 95 | """Hash and sign a payload.""" |
| 96 | msg_hash = hashlib.sha3_256(payload).digest() |
| 97 | return sign(msg_hash, identity.signing_keypair) |
| 98 | |
| 99 | @staticmethod |
| 100 | def _verify_payload( |
| 101 | payload: bytes, |
| 102 | signature: bytes, |
| 103 | public_key: bytes, |
| 104 | algorithm: SignatureAlgorithm, |
| 105 | ) -> bool: |
| 106 | """Hash and verify a payload signature.""" |
| 107 | msg_hash = hashlib.sha3_256(payload).digest() |
| 108 | return verify(msg_hash, signature, public_key, algorithm) |
| 109 | |
| 110 | @staticmethod |
| 111 | def initiate(identity: AgentIdentity) -> tuple[HandshakeRequest, str]: |
| 112 | """Create a handshake request. |
| 113 | |
| 114 | Returns the request and the nonce (needed later to complete the handshake). |
| 115 | """ |
| 116 | nonce = os.urandom(16).hex() |
| 117 | timestamp = datetime.now(timezone.utc).isoformat() |
| 118 | |
| 119 | payload = f"{identity.did}:{nonce}:{timestamp}".encode("utf-8") |
| 120 | sig = PQCHandshake._sign_payload(payload, identity) |
| 121 | |
| 122 | request = HandshakeRequest( |
| 123 | client_did=identity.did, |
| 124 | client_public_key=identity.signing_keypair.public_key.hex(), |
| 125 | algorithm=identity.signing_keypair.algorithm.value, |
| 126 | timestamp=timestamp, |
| 127 | nonce=nonce, |
| 128 | signature=sig.hex(), |
| 129 | ) |
| 130 | return request, nonce |
| 131 | |
| 132 | @staticmethod |
| 133 | def respond( |
| 134 | request: HandshakeRequest, server_identity: AgentIdentity |
| 135 | ) -> HandshakeResponse: |
| 136 | """Verify the client's request and create a signed response. |
| 137 | |
| 138 | Raises :class:`HandshakeError` if the client's signature is invalid. |
| 139 | """ |
| 140 | # Verify client signature |
| 141 | payload = f"{request.client_did}:{request.nonce}:{request.timestamp}".encode( |
| 142 | "utf-8" |
| 143 | ) |
| 144 | client_pub = bytes.fromhex(request.client_public_key) |
| 145 | algorithm = SignatureAlgorithm(request.algorithm) |
| 146 | client_sig = bytes.fromhex(request.signature) |
| 147 | |
| 148 | if not PQCHandshake._verify_payload(payload, client_sig, client_pub, algorithm): |
| 149 | raise HandshakeError("Client handshake signature verification failed") |
| 150 | |
| 151 | # Create response |
| 152 | session_id = uuid.uuid4().hex |
| 153 | server_nonce = os.urandom(16).hex() |
| 154 | |
| 155 | resp_payload = ( |
| 156 | f"{server_identity.did}:{request.nonce}:{server_nonce}:{session_id}" |
| 157 | ).encode("utf-8") |
| 158 | sig = PQCHandshake._sign_payload(resp_payload, server_identity) |
| 159 | |
| 160 | return HandshakeResponse( |
| 161 | server_did=server_identity.did, |
| 162 | server_public_key=server_identity.signing_keypair.public_key.hex(), |
| 163 | algorithm=server_identity.signing_keypair.algorithm.value, |
| 164 | client_nonce=request.nonce, |
| 165 | server_nonce=server_nonce, |
| 166 | signature=sig.hex(), |
| 167 | session_id=session_id, |
| 168 | ) |
| 169 | |
| 170 | @staticmethod |
| 171 | def complete( |
| 172 | response: HandshakeResponse, |
| 173 | client_identity: AgentIdentity, |
| 174 | original_nonce: str, |
| 175 | ) -> PQCSession: |
| 176 | """Verify the server's response and create a session. |
| 177 | |
| 178 | Raises :class:`HandshakeError` on verification failure. |
| 179 | """ |
| 180 | # Verify nonce echo |
| 181 | if response.client_nonce != original_nonce: |
| 182 | raise HandshakeError("Server did not echo back the correct client nonce") |
| 183 | |
| 184 | # Verify server signature |
| 185 | resp_payload = ( |
| 186 | f"{response.server_did}:{response.client_nonce}:{response.server_nonce}:{response.session_id}" |
| 187 | ).encode("utf-8") |
| 188 | server_pub = bytes.fromhex(response.server_public_key) |
| 189 | algorithm = SignatureAlgorithm(response.algorithm) |
| 190 | server_sig = bytes.fromhex(response.signature) |
| 191 | |
| 192 | if not PQCHandshake._verify_payload( |
| 193 | resp_payload, server_sig, server_pub, algorithm |
| 194 | ): |
| 195 | raise HandshakeError("Server handshake signature verification failed") |
| 196 | |
| 197 | return PQCSession( |
| 198 | session_id=response.session_id, |
| 199 | local_identity=client_identity, |
| 200 | peer_did=response.server_did, |
| 201 | peer_public_key=server_pub, |
| 202 | peer_algorithm=algorithm, |
| 203 | ) |
| 204 | |