src/pqc_training_data/commitment.py
6.0 KB · 173 lines · python Raw
1 """TrainingCommitment - a signed Merkle-root commitment to a training dataset."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import json
7 import uuid
8 from dataclasses import asdict, dataclass, field
9 from datetime import datetime, timezone
10 from typing import Any
11
12 from quantumshield.core.algorithms import SignatureAlgorithm
13 from quantumshield.core.signatures import sign, verify
14 from quantumshield.identity.agent import AgentIdentity
15
16 from pqc_training_data.merkle import MerkleTree
17 from pqc_training_data.record import DataRecord, RecordHash
18
19
20 @dataclass
21 class TrainingCommitment:
22 """Signed commitment to a training dataset's Merkle root.
23
24 Does NOT contain the records themselves - only the root. Records stay
25 private; proofs can be issued selectively on demand.
26 """
27
28 commitment_id: str
29 dataset_name: str
30 dataset_version: str
31 description: str
32 record_count: int
33 root: str # hex Merkle root
34 created_at: str
35 licenses: list[str] = field(default_factory=list)
36 tags: list[str] = field(default_factory=list)
37 extra: dict = field(default_factory=dict)
38
39 # Filled by CommitmentSigner
40 signer_did: str = ""
41 algorithm: str = ""
42 signature: str = "" # hex
43 public_key: str = "" # hex
44 signed_at: str = ""
45
46 def canonical_bytes(self) -> bytes:
47 payload = {
48 "commitment_id": self.commitment_id,
49 "dataset_name": self.dataset_name,
50 "dataset_version": self.dataset_version,
51 "description": self.description,
52 "record_count": self.record_count,
53 "root": self.root,
54 "created_at": self.created_at,
55 "licenses": sorted(self.licenses),
56 "tags": sorted(self.tags),
57 "extra": self.extra,
58 }
59 return json.dumps(
60 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
61 ).encode("utf-8")
62
63 def to_dict(self) -> dict[str, Any]:
64 return asdict(self)
65
66 def to_json(self) -> str:
67 return json.dumps(self.to_dict(), indent=2)
68
69 @classmethod
70 def from_dict(cls, data: dict[str, Any]) -> TrainingCommitment:
71 return cls(
72 commitment_id=data["commitment_id"],
73 dataset_name=data["dataset_name"],
74 dataset_version=data["dataset_version"],
75 description=data.get("description", ""),
76 record_count=int(data["record_count"]),
77 root=data["root"],
78 created_at=data.get("created_at", ""),
79 licenses=list(data.get("licenses", [])),
80 tags=list(data.get("tags", [])),
81 extra=dict(data.get("extra", {})),
82 signer_did=data.get("signer_did", ""),
83 algorithm=data.get("algorithm", ""),
84 signature=data.get("signature", ""),
85 public_key=data.get("public_key", ""),
86 signed_at=data.get("signed_at", ""),
87 )
88
89 @classmethod
90 def from_json(cls, blob: str) -> TrainingCommitment:
91 return cls.from_dict(json.loads(blob))
92
93
94 class CommitmentBuilder:
95 """Collect records, compute the Merkle root, produce an unsigned commitment.
96
97 Usage:
98 builder = CommitmentBuilder("company-train-v1", "1.0.0")
99 for doc in corpus:
100 builder.add_record(DataRecord(content=doc.bytes, metadata={...}))
101 commitment = builder.build(description="Company training corpus")
102 """
103
104 def __init__(self, dataset_name: str, dataset_version: str):
105 self.dataset_name = dataset_name
106 self.dataset_version = dataset_version
107 self.tree = MerkleTree()
108 self.licenses: list[str] = []
109 self.tags: list[str] = []
110 self.extra: dict = {}
111
112 def add_record(self, record: DataRecord) -> None:
113 self.tree.add(record.leaf_hash())
114
115 def add_records(self, records: list[DataRecord]) -> None:
116 for r in records:
117 self.add_record(r)
118
119 def add_leaf_hash_hex(self, hex_hash: str) -> None:
120 """Direct-add by leaf hash (when the caller already hashed the data)."""
121 self.tree.add(RecordHash(hex=hex_hash))
122
123 def build(self, description: str = "") -> TrainingCommitment:
124 return TrainingCommitment(
125 commitment_id=f"urn:pqc-td:{uuid.uuid4().hex}",
126 dataset_name=self.dataset_name,
127 dataset_version=self.dataset_version,
128 description=description,
129 record_count=self.tree.size,
130 root=self.tree.root(),
131 created_at=datetime.now(timezone.utc).isoformat(),
132 licenses=list(self.licenses),
133 tags=list(self.tags),
134 extra=dict(self.extra),
135 )
136
137
138 class CommitmentSigner:
139 """Sign and verify TrainingCommitments with ML-DSA."""
140
141 def __init__(self, identity: AgentIdentity):
142 self.identity = identity
143
144 def sign(self, commitment: TrainingCommitment) -> TrainingCommitment:
145 canonical = commitment.canonical_bytes()
146 digest = hashlib.sha3_256(canonical).digest()
147 sig = sign(digest, self.identity.signing_keypair)
148 commitment.signer_did = self.identity.did
149 commitment.algorithm = self.identity.signing_keypair.algorithm.value
150 commitment.signature = sig.hex()
151 commitment.public_key = self.identity.signing_keypair.public_key.hex()
152 commitment.signed_at = datetime.now(timezone.utc).isoformat()
153 return commitment
154
155 @staticmethod
156 def verify(commitment: TrainingCommitment) -> bool:
157 if not commitment.signature or not commitment.algorithm:
158 return False
159 try:
160 algorithm = SignatureAlgorithm(commitment.algorithm)
161 except ValueError:
162 return False
163 digest = hashlib.sha3_256(commitment.canonical_bytes()).digest()
164 try:
165 return verify(
166 digest,
167 bytes.fromhex(commitment.signature),
168 bytes.fromhex(commitment.public_key),
169 algorithm,
170 )
171 except Exception:
172 return False
173