src/pqc_mcp_transport/session.py
2.5 KB · 75 lines · python Raw
1 """PQC session management with replay protection and audit logging."""
2
3 from __future__ import annotations
4
5 from dataclasses import dataclass, field
6 from datetime import datetime, timedelta, timezone
7
8 from quantumshield.core.algorithms import SignatureAlgorithm
9 from quantumshield.identity.agent import AgentIdentity
10
11 from pqc_mcp_transport.audit import AuditEntry
12 from pqc_mcp_transport.errors import ReplayAttackError, SessionExpiredError
13
14
15 @dataclass
16 class PQCSession:
17 """An authenticated PQC session between two MCP peers."""
18
19 session_id: str
20 local_identity: AgentIdentity
21 peer_did: str
22 peer_public_key: bytes
23 peer_algorithm: SignatureAlgorithm
24 created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
25 expires_at: datetime = field(
26 default_factory=lambda: datetime.now(timezone.utc) + timedelta(hours=1)
27 )
28 _used_nonces: set[str] = field(default_factory=set)
29 _audit_log: list[AuditEntry] = field(default_factory=list)
30 last_response_verified: bool = False
31
32 def is_valid(self) -> bool:
33 """Return True if the session has not expired."""
34 return datetime.now(timezone.utc) < self.expires_at
35
36 def check_nonce(self, nonce: str) -> bool:
37 """Check and register a nonce for replay protection.
38
39 Returns True if the nonce is fresh. Raises :class:`ReplayAttackError`
40 if the nonce has already been seen.
41 """
42 if nonce in self._used_nonces:
43 raise ReplayAttackError(f"Nonce already used: {nonce}")
44 self._used_nonces.add(nonce)
45 return True
46
47 def log_operation(
48 self,
49 op_type: str,
50 method: str | None,
51 signer_did: str,
52 verified: bool,
53 signature_hex: str = "",
54 algorithm: str = "",
55 details: str | None = None,
56 ) -> None:
57 """Record an operation in the session audit log."""
58 entry = AuditEntry(
59 timestamp=datetime.now(timezone.utc).isoformat(),
60 session_id=self.session_id,
61 operation=op_type,
62 method=method,
63 signer_did=signer_did,
64 peer_did=self.peer_did,
65 algorithm=algorithm or self.peer_algorithm.value,
66 signature_truncated=signature_hex[:32] if signature_hex else "",
67 verified=verified,
68 details=details,
69 )
70 self._audit_log.append(entry)
71
72 def get_audit_log(self) -> list[AuditEntry]:
73 """Return the full audit trail for this session."""
74 return list(self._audit_log)
75