tests/test_session.py
| 1 | """Tests for PQCSession — expiry, replay protection, audit logging.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from datetime import datetime, timedelta, timezone |
| 6 | |
| 7 | import pytest |
| 8 | from quantumshield.core.algorithms import SignatureAlgorithm |
| 9 | from quantumshield.identity.agent import AgentIdentity |
| 10 | |
| 11 | from pqc_mcp_transport.errors import ReplayAttackError |
| 12 | from pqc_mcp_transport.session import PQCSession |
| 13 | |
| 14 | |
| 15 | def _make_session( |
| 16 | client_identity: AgentIdentity, |
| 17 | server_identity: AgentIdentity, |
| 18 | expires_delta: timedelta | None = None, |
| 19 | ) -> PQCSession: |
| 20 | """Helper to create a session without going through the handshake.""" |
| 21 | now = datetime.now(timezone.utc) |
| 22 | return PQCSession( |
| 23 | session_id="test-session-001", |
| 24 | local_identity=client_identity, |
| 25 | peer_did=server_identity.did, |
| 26 | peer_public_key=server_identity.signing_keypair.public_key, |
| 27 | peer_algorithm=server_identity.signing_keypair.algorithm, |
| 28 | created_at=now, |
| 29 | expires_at=now + (expires_delta or timedelta(hours=1)), |
| 30 | ) |
| 31 | |
| 32 | |
| 33 | class TestSessionValidity: |
| 34 | def test_session_valid_initially( |
| 35 | self, client_identity: AgentIdentity, server_identity: AgentIdentity |
| 36 | ) -> None: |
| 37 | session = _make_session(client_identity, server_identity) |
| 38 | assert session.is_valid() is True |
| 39 | |
| 40 | def test_session_expires( |
| 41 | self, client_identity: AgentIdentity, server_identity: AgentIdentity |
| 42 | ) -> None: |
| 43 | session = _make_session( |
| 44 | client_identity, server_identity, expires_delta=timedelta(seconds=-1) |
| 45 | ) |
| 46 | assert session.is_valid() is False |
| 47 | |
| 48 | |
| 49 | class TestReplayProtection: |
| 50 | def test_nonce_replay_protection( |
| 51 | self, client_identity: AgentIdentity, server_identity: AgentIdentity |
| 52 | ) -> None: |
| 53 | session = _make_session(client_identity, server_identity) |
| 54 | assert session.check_nonce("nonce-1") is True |
| 55 | with pytest.raises(ReplayAttackError): |
| 56 | session.check_nonce("nonce-1") |
| 57 | |
| 58 | def test_different_nonces_ok( |
| 59 | self, client_identity: AgentIdentity, server_identity: AgentIdentity |
| 60 | ) -> None: |
| 61 | session = _make_session(client_identity, server_identity) |
| 62 | assert session.check_nonce("nonce-1") is True |
| 63 | assert session.check_nonce("nonce-2") is True |
| 64 | |
| 65 | |
| 66 | class TestAuditLog: |
| 67 | def test_audit_log_records_operations( |
| 68 | self, client_identity: AgentIdentity, server_identity: AgentIdentity |
| 69 | ) -> None: |
| 70 | session = _make_session(client_identity, server_identity) |
| 71 | assert len(session.get_audit_log()) == 0 |
| 72 | |
| 73 | session.log_operation( |
| 74 | op_type="tool_call", |
| 75 | method="greet", |
| 76 | signer_did=client_identity.did, |
| 77 | verified=True, |
| 78 | signature_hex="aabbccdd", |
| 79 | algorithm="ML-DSA-65", |
| 80 | ) |
| 81 | log = session.get_audit_log() |
| 82 | assert len(log) == 1 |
| 83 | assert log[0].operation == "tool_call" |
| 84 | assert log[0].method == "greet" |
| 85 | assert log[0].verified is True |
| 86 | |