tests/test_session.py
2.9 KB · 86 lines · python Raw
1 """Tests for PQCSession — expiry, replay protection, audit logging."""
2
3 from __future__ import annotations
4
5 from datetime import datetime, timedelta, timezone
6
7 import pytest
8 from quantumshield.core.algorithms import SignatureAlgorithm
9 from quantumshield.identity.agent import AgentIdentity
10
11 from pqc_mcp_transport.errors import ReplayAttackError
12 from pqc_mcp_transport.session import PQCSession
13
14
15 def _make_session(
16 client_identity: AgentIdentity,
17 server_identity: AgentIdentity,
18 expires_delta: timedelta | None = None,
19 ) -> PQCSession:
20 """Helper to create a session without going through the handshake."""
21 now = datetime.now(timezone.utc)
22 return PQCSession(
23 session_id="test-session-001",
24 local_identity=client_identity,
25 peer_did=server_identity.did,
26 peer_public_key=server_identity.signing_keypair.public_key,
27 peer_algorithm=server_identity.signing_keypair.algorithm,
28 created_at=now,
29 expires_at=now + (expires_delta or timedelta(hours=1)),
30 )
31
32
33 class TestSessionValidity:
34 def test_session_valid_initially(
35 self, client_identity: AgentIdentity, server_identity: AgentIdentity
36 ) -> None:
37 session = _make_session(client_identity, server_identity)
38 assert session.is_valid() is True
39
40 def test_session_expires(
41 self, client_identity: AgentIdentity, server_identity: AgentIdentity
42 ) -> None:
43 session = _make_session(
44 client_identity, server_identity, expires_delta=timedelta(seconds=-1)
45 )
46 assert session.is_valid() is False
47
48
49 class TestReplayProtection:
50 def test_nonce_replay_protection(
51 self, client_identity: AgentIdentity, server_identity: AgentIdentity
52 ) -> None:
53 session = _make_session(client_identity, server_identity)
54 assert session.check_nonce("nonce-1") is True
55 with pytest.raises(ReplayAttackError):
56 session.check_nonce("nonce-1")
57
58 def test_different_nonces_ok(
59 self, client_identity: AgentIdentity, server_identity: AgentIdentity
60 ) -> None:
61 session = _make_session(client_identity, server_identity)
62 assert session.check_nonce("nonce-1") is True
63 assert session.check_nonce("nonce-2") is True
64
65
66 class TestAuditLog:
67 def test_audit_log_records_operations(
68 self, client_identity: AgentIdentity, server_identity: AgentIdentity
69 ) -> None:
70 session = _make_session(client_identity, server_identity)
71 assert len(session.get_audit_log()) == 0
72
73 session.log_operation(
74 op_type="tool_call",
75 method="greet",
76 signer_did=client_identity.did,
77 verified=True,
78 signature_hex="aabbccdd",
79 algorithm="ML-DSA-65",
80 )
81 log = session.get_audit_log()
82 assert len(log) == 1
83 assert log[0].operation == "tool_call"
84 assert log[0].method == "greet"
85 assert log[0].verified is True
86