examples/mutual_auth.py
| 1 | """ |
| 2 | Mutual Authentication Example |
| 3 | |
| 4 | Demonstrates both client and server verifying each other |
| 5 | in a single script using in-memory transport (no network). |
| 6 | """ |
| 7 | |
| 8 | import asyncio |
| 9 | |
| 10 | from quantumshield.identity.agent import AgentIdentity |
| 11 | |
| 12 | from pqc_mcp_transport import PQCMCPServer, PQCHandshake, MessageSigner |
| 13 | from pqc_mcp_transport.handshake import HandshakeResponse |
| 14 | |
| 15 | |
| 16 | async def main() -> None: |
| 17 | # Create identities for both sides |
| 18 | client_id = AgentIdentity.create("mutual-auth-client", capabilities=["tools:call"]) |
| 19 | server_id = AgentIdentity.create("mutual-auth-server", capabilities=["tools:serve"]) |
| 20 | |
| 21 | print(f"Client DID: {client_id.did}") |
| 22 | print(f"Server DID: {server_id.did}") |
| 23 | print(f"Algorithm: {client_id.signing_keypair.algorithm.value}") |
| 24 | print() |
| 25 | |
| 26 | # Set up server with a tool |
| 27 | server = PQCMCPServer(identity=server_id, require_auth=True) |
| 28 | |
| 29 | @server.tool("multiply", description="Multiply two numbers") |
| 30 | async def multiply(a: float, b: float) -> float: |
| 31 | return a * b |
| 32 | |
| 33 | # --- Step 1: Mutual Handshake --- |
| 34 | print("=== Step 1: PQC Handshake ===") |
| 35 | hs_request, nonce = PQCHandshake.initiate(client_id) |
| 36 | print(f" Client sent handshake request (nonce: {nonce[:16]}...)") |
| 37 | |
| 38 | hs_response_dict = await server.handle_handshake(hs_request.to_dict()) |
| 39 | print(f" Server verified client and responded") |
| 40 | |
| 41 | hs_response = HandshakeResponse.from_dict(hs_response_dict) |
| 42 | session = PQCHandshake.complete(hs_response, client_id, nonce) |
| 43 | print(f" Client verified server") |
| 44 | print(f" Session established: {session.session_id[:16]}...") |
| 45 | print(f" Mutual authentication: COMPLETE") |
| 46 | print() |
| 47 | |
| 48 | # --- Step 2: Signed Tool Call --- |
| 49 | print("=== Step 2: Signed Tool Call ===") |
| 50 | client_signer = MessageSigner(client_id) |
| 51 | |
| 52 | call_msg = { |
| 53 | "jsonrpc": "2.0", |
| 54 | "method": "tools/call", |
| 55 | "id": "demo-1", |
| 56 | "params": {"name": "multiply", "arguments": {"a": 6.0, "b": 7.0}}, |
| 57 | } |
| 58 | signed_call = client_signer.sign_message(call_msg) |
| 59 | signed_call["_pqc"]["session_id"] = session.session_id |
| 60 | print(f" Client signed request with DID: {client_id.did[:32]}...") |
| 61 | |
| 62 | # Server verifies and processes |
| 63 | response = await server.handle_request(signed_call) |
| 64 | print(f" Server verified client signature: OK") |
| 65 | |
| 66 | # --- Step 3: Verify Server Response --- |
| 67 | print() |
| 68 | print("=== Step 3: Verify Server Response ===") |
| 69 | vr = MessageSigner.verify_message(response) |
| 70 | print(f" Server signature valid: {vr.valid}") |
| 71 | print(f" Server DID confirmed: {vr.signer_did[:32]}...") |
| 72 | print(f" Algorithm: {vr.algorithm}") |
| 73 | |
| 74 | stripped = MessageSigner.strip_pqc(response) |
| 75 | result = stripped.get("result", {}).get("content") |
| 76 | print(f" Result: 6.0 * 7.0 = {result}") |
| 77 | print() |
| 78 | |
| 79 | # --- Audit Trail --- |
| 80 | print("=== Audit Trail ===") |
| 81 | for entry in session.get_audit_log(): |
| 82 | print( |
| 83 | f" [{entry.timestamp}] {entry.operation}: " |
| 84 | f"method={entry.method} verified={entry.verified}" |
| 85 | ) |
| 86 | |
| 87 | print() |
| 88 | print("All messages were PQC-signed with ML-DSA.") |
| 89 | print("Both client and server identities were mutually verified.") |
| 90 | |
| 91 | |
| 92 | if __name__ == "__main__": |
| 93 | asyncio.run(main()) |
| 94 | |