src/pqc_mcp_transport/session.py
| 1 | """PQC session management with replay protection and audit logging.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from dataclasses import dataclass, field |
| 6 | from datetime import datetime, timedelta, timezone |
| 7 | |
| 8 | from quantumshield.core.algorithms import SignatureAlgorithm |
| 9 | from quantumshield.identity.agent import AgentIdentity |
| 10 | |
| 11 | from pqc_mcp_transport.audit import AuditEntry |
| 12 | from pqc_mcp_transport.errors import ReplayAttackError, SessionExpiredError |
| 13 | |
| 14 | |
| 15 | @dataclass |
| 16 | class PQCSession: |
| 17 | """An authenticated PQC session between two MCP peers.""" |
| 18 | |
| 19 | session_id: str |
| 20 | local_identity: AgentIdentity |
| 21 | peer_did: str |
| 22 | peer_public_key: bytes |
| 23 | peer_algorithm: SignatureAlgorithm |
| 24 | created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) |
| 25 | expires_at: datetime = field( |
| 26 | default_factory=lambda: datetime.now(timezone.utc) + timedelta(hours=1) |
| 27 | ) |
| 28 | _used_nonces: set[str] = field(default_factory=set) |
| 29 | _audit_log: list[AuditEntry] = field(default_factory=list) |
| 30 | last_response_verified: bool = False |
| 31 | |
| 32 | def is_valid(self) -> bool: |
| 33 | """Return True if the session has not expired.""" |
| 34 | return datetime.now(timezone.utc) < self.expires_at |
| 35 | |
| 36 | def check_nonce(self, nonce: str) -> bool: |
| 37 | """Check and register a nonce for replay protection. |
| 38 | |
| 39 | Returns True if the nonce is fresh. Raises :class:`ReplayAttackError` |
| 40 | if the nonce has already been seen. |
| 41 | """ |
| 42 | if nonce in self._used_nonces: |
| 43 | raise ReplayAttackError(f"Nonce already used: {nonce}") |
| 44 | self._used_nonces.add(nonce) |
| 45 | return True |
| 46 | |
| 47 | def log_operation( |
| 48 | self, |
| 49 | op_type: str, |
| 50 | method: str | None, |
| 51 | signer_did: str, |
| 52 | verified: bool, |
| 53 | signature_hex: str = "", |
| 54 | algorithm: str = "", |
| 55 | details: str | None = None, |
| 56 | ) -> None: |
| 57 | """Record an operation in the session audit log.""" |
| 58 | entry = AuditEntry( |
| 59 | timestamp=datetime.now(timezone.utc).isoformat(), |
| 60 | session_id=self.session_id, |
| 61 | operation=op_type, |
| 62 | method=method, |
| 63 | signer_did=signer_did, |
| 64 | peer_did=self.peer_did, |
| 65 | algorithm=algorithm or self.peer_algorithm.value, |
| 66 | signature_truncated=signature_hex[:32] if signature_hex else "", |
| 67 | verified=verified, |
| 68 | details=details, |
| 69 | ) |
| 70 | self._audit_log.append(entry) |
| 71 | |
| 72 | def get_audit_log(self) -> list[AuditEntry]: |
| 73 | """Return the full audit trail for this session.""" |
| 74 | return list(self._audit_log) |
| 75 | |