src/pqc_reasoning_ledger/recorder.py
5.3 KB · 146 lines · python Raw
1 """ReasoningRecorder - high-level API for building traces and sealing them with ML-DSA."""
2
3 from __future__ import annotations
4
5 import hashlib
6 from datetime import datetime, timezone
7 from typing import Any
8
9 from quantumshield.core.signatures import sign
10 from quantumshield.identity.agent import AgentIdentity
11
12 from pqc_reasoning_ledger.errors import ReasoningLedgerError
13 from pqc_reasoning_ledger.merkle import compute_merkle_root
14 from pqc_reasoning_ledger.step import ReasoningStep, StepKind, StepReference
15 from pqc_reasoning_ledger.trace import ReasoningTrace, SealedTrace
16
17
18 class ReasoningRecorder:
19 """Live recorder: append steps during inference, seal and sign at the end.
20
21 Usage:
22 identity = AgentIdentity.create("gpt-legal-advisor-signer")
23 rec = ReasoningRecorder(identity)
24 rec.begin_trace(
25 model_did="did:pqaid:gpt-legal",
26 model_version="2.1",
27 task="contract-review",
28 domain="legal",
29 )
30 rec.record_observation("Contract contains a liquidated damages clause.")
31 rec.record_hypothesis("Clause is likely enforceable under NY law.")
32 rec.record_deduction("Based on prior observation and hypothesis...")
33 rec.record_decision("Recommend signing with modification to cap at $50k.")
34 sealed = rec.seal()
35 """
36
37 def __init__(self, identity: AgentIdentity):
38 self.identity = identity
39 self.trace: ReasoningTrace | None = None
40
41 # -- trace lifecycle ----------------------------------------------------
42
43 def begin_trace(
44 self,
45 model_did: str,
46 model_version: str,
47 task: str = "",
48 actor_did: str = "",
49 session_id: str = "",
50 domain: str = "",
51 ) -> ReasoningTrace:
52 self.trace = ReasoningTrace.create(
53 model_did=model_did,
54 model_version=model_version,
55 task=task,
56 actor_did=actor_did,
57 session_id=session_id,
58 domain=domain,
59 )
60 return self.trace
61
62 def _require_trace(self) -> ReasoningTrace:
63 if self.trace is None:
64 raise ReasoningLedgerError("call begin_trace() first")
65 return self.trace
66
67 # -- generic step recording --------------------------------------------
68
69 def record(
70 self,
71 kind: StepKind,
72 content: str,
73 references: list[StepReference] | None = None,
74 confidence: float = 1.0,
75 metadata: dict[str, Any] | None = None,
76 ) -> ReasoningStep:
77 trace = self._require_trace()
78 step = ReasoningStep.create(
79 kind=kind,
80 content=content,
81 step_number=len(trace.steps) + 1,
82 previous_step_hash=trace.current_hash,
83 references=references,
84 confidence=confidence,
85 metadata=metadata,
86 )
87 trace.append(step)
88 return step
89
90 # -- convenience wrappers for each StepKind ----------------------------
91
92 def record_thought(self, content: str, **kw: Any) -> ReasoningStep:
93 return self.record(StepKind.THOUGHT, content, **kw)
94
95 def record_observation(self, content: str, **kw: Any) -> ReasoningStep:
96 return self.record(StepKind.OBSERVATION, content, **kw)
97
98 def record_hypothesis(self, content: str, **kw: Any) -> ReasoningStep:
99 return self.record(StepKind.HYPOTHESIS, content, **kw)
100
101 def record_deduction(self, content: str, **kw: Any) -> ReasoningStep:
102 return self.record(StepKind.DEDUCTION, content, **kw)
103
104 def record_retrieval(self, content: str, **kw: Any) -> ReasoningStep:
105 return self.record(StepKind.RETRIEVAL, content, **kw)
106
107 def record_tool_call(self, content: str, **kw: Any) -> ReasoningStep:
108 return self.record(StepKind.TOOL_CALL, content, **kw)
109
110 def record_tool_result(self, content: str, **kw: Any) -> ReasoningStep:
111 return self.record(StepKind.TOOL_RESULT, content, **kw)
112
113 def record_self_critique(self, content: str, **kw: Any) -> ReasoningStep:
114 return self.record(StepKind.SELF_CRITIQUE, content, **kw)
115
116 def record_refinement(self, content: str, **kw: Any) -> ReasoningStep:
117 return self.record(StepKind.REFINEMENT, content, **kw)
118
119 def record_decision(self, content: str, **kw: Any) -> ReasoningStep:
120 return self.record(StepKind.DECISION, content, **kw)
121
122 # -- sealing -----------------------------------------------------------
123
124 def seal(self) -> SealedTrace:
125 trace = self._require_trace()
126 if not trace.steps:
127 raise ReasoningLedgerError("cannot seal empty trace")
128 step_hashes = [s.step_hash for s in trace.steps]
129 merkle_root = compute_merkle_root(step_hashes)
130 sealed = SealedTrace(
131 metadata=trace.metadata,
132 steps=list(trace.steps),
133 final_chain_hash=trace.current_hash,
134 merkle_root=merkle_root,
135 step_count=len(trace.steps),
136 sealed_at=datetime.now(timezone.utc).isoformat(),
137 )
138 digest = hashlib.sha3_256(sealed.canonical_bytes()).digest()
139 sig = sign(digest, self.identity.signing_keypair)
140 sealed.signer_did = self.identity.did
141 sealed.algorithm = self.identity.signing_keypair.algorithm.value
142 sealed.signature = sig.hex()
143 sealed.public_key = self.identity.signing_keypair.public_key.hex()
144 trace.sealed = True
145 return sealed
146