src/pqc_training_data/commitment.py
| 1 | """TrainingCommitment - a signed Merkle-root commitment to a training dataset.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | import json |
| 7 | import uuid |
| 8 | from dataclasses import asdict, dataclass, field |
| 9 | from datetime import datetime, timezone |
| 10 | from typing import Any |
| 11 | |
| 12 | from quantumshield.core.algorithms import SignatureAlgorithm |
| 13 | from quantumshield.core.signatures import sign, verify |
| 14 | from quantumshield.identity.agent import AgentIdentity |
| 15 | |
| 16 | from pqc_training_data.merkle import MerkleTree |
| 17 | from pqc_training_data.record import DataRecord, RecordHash |
| 18 | |
| 19 | |
| 20 | @dataclass |
| 21 | class TrainingCommitment: |
| 22 | """Signed commitment to a training dataset's Merkle root. |
| 23 | |
| 24 | Does NOT contain the records themselves - only the root. Records stay |
| 25 | private; proofs can be issued selectively on demand. |
| 26 | """ |
| 27 | |
| 28 | commitment_id: str |
| 29 | dataset_name: str |
| 30 | dataset_version: str |
| 31 | description: str |
| 32 | record_count: int |
| 33 | root: str # hex Merkle root |
| 34 | created_at: str |
| 35 | licenses: list[str] = field(default_factory=list) |
| 36 | tags: list[str] = field(default_factory=list) |
| 37 | extra: dict = field(default_factory=dict) |
| 38 | |
| 39 | # Filled by CommitmentSigner |
| 40 | signer_did: str = "" |
| 41 | algorithm: str = "" |
| 42 | signature: str = "" # hex |
| 43 | public_key: str = "" # hex |
| 44 | signed_at: str = "" |
| 45 | |
| 46 | def canonical_bytes(self) -> bytes: |
| 47 | payload = { |
| 48 | "commitment_id": self.commitment_id, |
| 49 | "dataset_name": self.dataset_name, |
| 50 | "dataset_version": self.dataset_version, |
| 51 | "description": self.description, |
| 52 | "record_count": self.record_count, |
| 53 | "root": self.root, |
| 54 | "created_at": self.created_at, |
| 55 | "licenses": sorted(self.licenses), |
| 56 | "tags": sorted(self.tags), |
| 57 | "extra": self.extra, |
| 58 | } |
| 59 | return json.dumps( |
| 60 | payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False |
| 61 | ).encode("utf-8") |
| 62 | |
| 63 | def to_dict(self) -> dict[str, Any]: |
| 64 | return asdict(self) |
| 65 | |
| 66 | def to_json(self) -> str: |
| 67 | return json.dumps(self.to_dict(), indent=2) |
| 68 | |
| 69 | @classmethod |
| 70 | def from_dict(cls, data: dict[str, Any]) -> TrainingCommitment: |
| 71 | return cls( |
| 72 | commitment_id=data["commitment_id"], |
| 73 | dataset_name=data["dataset_name"], |
| 74 | dataset_version=data["dataset_version"], |
| 75 | description=data.get("description", ""), |
| 76 | record_count=int(data["record_count"]), |
| 77 | root=data["root"], |
| 78 | created_at=data.get("created_at", ""), |
| 79 | licenses=list(data.get("licenses", [])), |
| 80 | tags=list(data.get("tags", [])), |
| 81 | extra=dict(data.get("extra", {})), |
| 82 | signer_did=data.get("signer_did", ""), |
| 83 | algorithm=data.get("algorithm", ""), |
| 84 | signature=data.get("signature", ""), |
| 85 | public_key=data.get("public_key", ""), |
| 86 | signed_at=data.get("signed_at", ""), |
| 87 | ) |
| 88 | |
| 89 | @classmethod |
| 90 | def from_json(cls, blob: str) -> TrainingCommitment: |
| 91 | return cls.from_dict(json.loads(blob)) |
| 92 | |
| 93 | |
| 94 | class CommitmentBuilder: |
| 95 | """Collect records, compute the Merkle root, produce an unsigned commitment. |
| 96 | |
| 97 | Usage: |
| 98 | builder = CommitmentBuilder("company-train-v1", "1.0.0") |
| 99 | for doc in corpus: |
| 100 | builder.add_record(DataRecord(content=doc.bytes, metadata={...})) |
| 101 | commitment = builder.build(description="Company training corpus") |
| 102 | """ |
| 103 | |
| 104 | def __init__(self, dataset_name: str, dataset_version: str): |
| 105 | self.dataset_name = dataset_name |
| 106 | self.dataset_version = dataset_version |
| 107 | self.tree = MerkleTree() |
| 108 | self.licenses: list[str] = [] |
| 109 | self.tags: list[str] = [] |
| 110 | self.extra: dict = {} |
| 111 | |
| 112 | def add_record(self, record: DataRecord) -> None: |
| 113 | self.tree.add(record.leaf_hash()) |
| 114 | |
| 115 | def add_records(self, records: list[DataRecord]) -> None: |
| 116 | for r in records: |
| 117 | self.add_record(r) |
| 118 | |
| 119 | def add_leaf_hash_hex(self, hex_hash: str) -> None: |
| 120 | """Direct-add by leaf hash (when the caller already hashed the data).""" |
| 121 | self.tree.add(RecordHash(hex=hex_hash)) |
| 122 | |
| 123 | def build(self, description: str = "") -> TrainingCommitment: |
| 124 | return TrainingCommitment( |
| 125 | commitment_id=f"urn:pqc-td:{uuid.uuid4().hex}", |
| 126 | dataset_name=self.dataset_name, |
| 127 | dataset_version=self.dataset_version, |
| 128 | description=description, |
| 129 | record_count=self.tree.size, |
| 130 | root=self.tree.root(), |
| 131 | created_at=datetime.now(timezone.utc).isoformat(), |
| 132 | licenses=list(self.licenses), |
| 133 | tags=list(self.tags), |
| 134 | extra=dict(self.extra), |
| 135 | ) |
| 136 | |
| 137 | |
| 138 | class CommitmentSigner: |
| 139 | """Sign and verify TrainingCommitments with ML-DSA.""" |
| 140 | |
| 141 | def __init__(self, identity: AgentIdentity): |
| 142 | self.identity = identity |
| 143 | |
| 144 | def sign(self, commitment: TrainingCommitment) -> TrainingCommitment: |
| 145 | canonical = commitment.canonical_bytes() |
| 146 | digest = hashlib.sha3_256(canonical).digest() |
| 147 | sig = sign(digest, self.identity.signing_keypair) |
| 148 | commitment.signer_did = self.identity.did |
| 149 | commitment.algorithm = self.identity.signing_keypair.algorithm.value |
| 150 | commitment.signature = sig.hex() |
| 151 | commitment.public_key = self.identity.signing_keypair.public_key.hex() |
| 152 | commitment.signed_at = datetime.now(timezone.utc).isoformat() |
| 153 | return commitment |
| 154 | |
| 155 | @staticmethod |
| 156 | def verify(commitment: TrainingCommitment) -> bool: |
| 157 | if not commitment.signature or not commitment.algorithm: |
| 158 | return False |
| 159 | try: |
| 160 | algorithm = SignatureAlgorithm(commitment.algorithm) |
| 161 | except ValueError: |
| 162 | return False |
| 163 | digest = hashlib.sha3_256(commitment.canonical_bytes()).digest() |
| 164 | try: |
| 165 | return verify( |
| 166 | digest, |
| 167 | bytes.fromhex(commitment.signature), |
| 168 | bytes.fromhex(commitment.public_key), |
| 169 | algorithm, |
| 170 | ) |
| 171 | except Exception: |
| 172 | return False |
| 173 | |