src/pqc_reasoning_ledger/trace.py
5.4 KB · 163 lines · python Raw
1 """ReasoningTrace - ordered list of steps; SealedTrace is a signed finalization."""
2
3 from __future__ import annotations
4
5 import json
6 import uuid
7 from dataclasses import asdict, dataclass, field
8 from datetime import datetime, timezone
9 from typing import Any
10
11 from pqc_reasoning_ledger.errors import ChainBrokenError, TraceSealedError
12 from pqc_reasoning_ledger.step import ReasoningStep
13
14
15 @dataclass
16 class TraceMetadata:
17 """Non-step metadata describing a reasoning trace."""
18
19 trace_id: str
20 model_did: str # DID of the model that produced this trace
21 model_version: str
22 task: str = "" # e.g. "contract_review" | "medical_diagnosis"
23 actor_did: str = "" # who invoked the model
24 session_id: str = ""
25 domain: str = "" # e.g. "legal" | "medical" | "finance"
26 created_at: str = ""
27
28 def to_dict(self) -> dict[str, Any]:
29 return asdict(self)
30
31
32 @dataclass
33 class ReasoningTrace:
34 """A live, mutable reasoning trace accumulated during inference."""
35
36 metadata: TraceMetadata
37 steps: list[ReasoningStep] = field(default_factory=list)
38 sealed: bool = False
39
40 @classmethod
41 def create(
42 cls,
43 model_did: str,
44 model_version: str,
45 task: str = "",
46 actor_did: str = "",
47 session_id: str = "",
48 domain: str = "",
49 ) -> ReasoningTrace:
50 return cls(
51 metadata=TraceMetadata(
52 trace_id=f"urn:pqc-trace:{uuid.uuid4().hex}",
53 model_did=model_did,
54 model_version=model_version,
55 task=task,
56 actor_did=actor_did,
57 session_id=session_id,
58 domain=domain,
59 created_at=datetime.now(timezone.utc).isoformat(),
60 ),
61 )
62
63 @property
64 def current_hash(self) -> str:
65 """The chain-tip hash - what the next step should reference as its previous."""
66 if not self.steps:
67 return "0" * 64
68 return self.steps[-1].step_hash
69
70 def append(self, step: ReasoningStep) -> None:
71 if self.sealed:
72 raise TraceSealedError(f"trace {self.metadata.trace_id} is sealed")
73 if step.previous_step_hash != self.current_hash:
74 raise ChainBrokenError(
75 f"step previous_step_hash {step.previous_step_hash[:16]}... does not match "
76 f"current chain tip {self.current_hash[:16]}..."
77 )
78 if step.step_number != len(self.steps) + 1:
79 raise ChainBrokenError(
80 f"step_number {step.step_number} != expected {len(self.steps) + 1}"
81 )
82 # Verify step_hash was correctly computed
83 expected = step.compute_step_hash()
84 if step.step_hash != expected:
85 raise ChainBrokenError(
86 f"step_hash mismatch: declared={step.step_hash[:16]}..., "
87 f"expected={expected[:16]}..."
88 )
89 self.steps.append(step)
90
91 def to_dict(self) -> dict[str, Any]:
92 return {
93 "metadata": self.metadata.to_dict(),
94 "steps": [s.to_dict() for s in self.steps],
95 "sealed": self.sealed,
96 }
97
98
99 @dataclass
100 class SealedTrace:
101 """A sealed, ML-DSA signed ReasoningTrace with Merkle root over step hashes."""
102
103 metadata: TraceMetadata
104 steps: list[ReasoningStep]
105 final_chain_hash: str # step_hash of last step
106 merkle_root: str # Merkle root over step_hash values
107 step_count: int
108 sealed_at: str
109 signer_did: str = ""
110 algorithm: str = ""
111 signature: str = "" # hex
112 public_key: str = "" # hex
113
114 def canonical_bytes(self) -> bytes:
115 payload = {
116 "metadata": self.metadata.to_dict(),
117 "step_hashes": [s.step_hash for s in self.steps],
118 "final_chain_hash": self.final_chain_hash,
119 "merkle_root": self.merkle_root,
120 "step_count": self.step_count,
121 "sealed_at": self.sealed_at,
122 }
123 return json.dumps(
124 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
125 ).encode("utf-8")
126
127 def to_dict(self) -> dict[str, Any]:
128 return {
129 "metadata": self.metadata.to_dict(),
130 "steps": [s.to_dict() for s in self.steps],
131 "final_chain_hash": self.final_chain_hash,
132 "merkle_root": self.merkle_root,
133 "step_count": self.step_count,
134 "sealed_at": self.sealed_at,
135 "signer_did": self.signer_did,
136 "algorithm": self.algorithm,
137 "signature": self.signature,
138 "public_key": self.public_key,
139 }
140
141 def to_json(self) -> str:
142 return json.dumps(self.to_dict(), indent=2)
143
144 @classmethod
145 def from_dict(cls, data: dict[str, Any]) -> SealedTrace:
146 meta = data["metadata"]
147 return cls(
148 metadata=TraceMetadata(**meta),
149 steps=[ReasoningStep.from_dict(s) for s in data.get("steps", [])],
150 final_chain_hash=data["final_chain_hash"],
151 merkle_root=data["merkle_root"],
152 step_count=int(data["step_count"]),
153 sealed_at=data["sealed_at"],
154 signer_did=data.get("signer_did", ""),
155 algorithm=data.get("algorithm", ""),
156 signature=data.get("signature", ""),
157 public_key=data.get("public_key", ""),
158 )
159
160 @classmethod
161 def from_json(cls, blob: str) -> SealedTrace:
162 return cls.from_dict(json.loads(blob))
163