tests/test_handshake.py
| 1 | """Tests for PQC mutual authentication handshake.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import pytest |
| 6 | from quantumshield.identity.agent import AgentIdentity |
| 7 | |
| 8 | from pqc_mcp_transport.errors import HandshakeError |
| 9 | from pqc_mcp_transport.handshake import ( |
| 10 | HandshakeRequest, |
| 11 | HandshakeResponse, |
| 12 | PQCHandshake, |
| 13 | ) |
| 14 | from pqc_mcp_transport.session import PQCSession |
| 15 | |
| 16 | |
| 17 | class TestHandshakeInitiate: |
| 18 | def test_initiate_creates_valid_request( |
| 19 | self, client_identity: AgentIdentity |
| 20 | ) -> None: |
| 21 | request, nonce = PQCHandshake.initiate(client_identity) |
| 22 | assert isinstance(request, HandshakeRequest) |
| 23 | assert request.client_did == client_identity.did |
| 24 | assert request.client_public_key == client_identity.signing_keypair.public_key.hex() |
| 25 | assert len(nonce) == 32 # 16 bytes hex |
| 26 | assert request.nonce == nonce |
| 27 | assert len(request.signature) > 0 |
| 28 | |
| 29 | |
| 30 | class TestHandshakeRespond: |
| 31 | def test_respond_verifies_client( |
| 32 | self, client_identity: AgentIdentity, server_identity: AgentIdentity |
| 33 | ) -> None: |
| 34 | request, _nonce = PQCHandshake.initiate(client_identity) |
| 35 | response = PQCHandshake.respond(request, server_identity) |
| 36 | assert isinstance(response, HandshakeResponse) |
| 37 | assert response.server_did == server_identity.did |
| 38 | assert response.client_nonce == request.nonce |
| 39 | assert len(response.session_id) > 0 |
| 40 | |
| 41 | def test_respond_rejects_invalid_signature( |
| 42 | self, client_identity: AgentIdentity, server_identity: AgentIdentity |
| 43 | ) -> None: |
| 44 | request, _nonce = PQCHandshake.initiate(client_identity) |
| 45 | # Tamper with the signature |
| 46 | request.signature = "00" * 64 |
| 47 | with pytest.raises(HandshakeError, match="Client handshake signature"): |
| 48 | PQCHandshake.respond(request, server_identity) |
| 49 | |
| 50 | |
| 51 | class TestHandshakeComplete: |
| 52 | def test_complete_creates_session( |
| 53 | self, client_identity: AgentIdentity, server_identity: AgentIdentity |
| 54 | ) -> None: |
| 55 | request, nonce = PQCHandshake.initiate(client_identity) |
| 56 | response = PQCHandshake.respond(request, server_identity) |
| 57 | session = PQCHandshake.complete(response, client_identity, nonce) |
| 58 | assert isinstance(session, PQCSession) |
| 59 | assert session.session_id == response.session_id |
| 60 | assert session.peer_did == server_identity.did |
| 61 | |
| 62 | def test_complete_rejects_invalid_server_signature( |
| 63 | self, client_identity: AgentIdentity, server_identity: AgentIdentity |
| 64 | ) -> None: |
| 65 | request, nonce = PQCHandshake.initiate(client_identity) |
| 66 | response = PQCHandshake.respond(request, server_identity) |
| 67 | # Tamper with the server signature |
| 68 | response.signature = "00" * 64 |
| 69 | with pytest.raises(HandshakeError, match="Server handshake signature"): |
| 70 | PQCHandshake.complete(response, client_identity, nonce) |
| 71 | |
| 72 | def test_complete_rejects_wrong_nonce( |
| 73 | self, client_identity: AgentIdentity, server_identity: AgentIdentity |
| 74 | ) -> None: |
| 75 | request, nonce = PQCHandshake.initiate(client_identity) |
| 76 | response = PQCHandshake.respond(request, server_identity) |
| 77 | with pytest.raises(HandshakeError, match="correct client nonce"): |
| 78 | PQCHandshake.complete(response, client_identity, "wrong_nonce") |
| 79 | |
| 80 | |
| 81 | class TestFullRoundTrip: |
| 82 | def test_full_handshake_round_trip( |
| 83 | self, client_identity: AgentIdentity, server_identity: AgentIdentity |
| 84 | ) -> None: |
| 85 | """End-to-end handshake: initiate -> respond -> complete.""" |
| 86 | request, nonce = PQCHandshake.initiate(client_identity) |
| 87 | response = PQCHandshake.respond(request, server_identity) |
| 88 | session = PQCHandshake.complete(response, client_identity, nonce) |
| 89 | |
| 90 | assert session.is_valid() |
| 91 | assert session.peer_did == server_identity.did |
| 92 | assert session.local_identity.did == client_identity.did |
| 93 | assert session.peer_public_key == server_identity.signing_keypair.public_key |
| 94 | |