src/pqc_federated_learning/signer.py
| 1 | """Client-side signing and server-side verification of ClientUpdates.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | from dataclasses import dataclass |
| 7 | from datetime import datetime, timezone |
| 8 | |
| 9 | from quantumshield.core.algorithms import SignatureAlgorithm |
| 10 | from quantumshield.core.signatures import sign, verify |
| 11 | from quantumshield.identity.agent import AgentIdentity |
| 12 | |
| 13 | from pqc_federated_learning.update import ClientUpdate |
| 14 | |
| 15 | |
| 16 | @dataclass(frozen=True) |
| 17 | class UpdateVerificationResult: |
| 18 | valid: bool |
| 19 | client_did: str | None |
| 20 | round_id: str | None |
| 21 | content_hash_ok: bool |
| 22 | signature_ok: bool |
| 23 | error: str | None = None |
| 24 | |
| 25 | |
| 26 | class UpdateSigner: |
| 27 | """Signs ClientUpdates with an AgentIdentity (the client's identity).""" |
| 28 | |
| 29 | def __init__(self, identity: AgentIdentity): |
| 30 | self.identity = identity |
| 31 | |
| 32 | def sign(self, update: ClientUpdate) -> ClientUpdate: |
| 33 | canonical = update.canonical_bytes() |
| 34 | digest = hashlib.sha3_256(canonical).digest() |
| 35 | sig = sign(digest, self.identity.signing_keypair) |
| 36 | update.signer_did = self.identity.did |
| 37 | update.algorithm = self.identity.signing_keypair.algorithm.value |
| 38 | update.signature = sig.hex() |
| 39 | update.public_key = self.identity.signing_keypair.public_key.hex() |
| 40 | update.signed_at = datetime.now(timezone.utc).isoformat() |
| 41 | return update |
| 42 | |
| 43 | @staticmethod |
| 44 | def verify(update: ClientUpdate) -> UpdateVerificationResult: |
| 45 | # Recompute content hash |
| 46 | expected_hash = ClientUpdate.compute_content_hash( |
| 47 | update.metadata, update.tensors, update.created_at |
| 48 | ) |
| 49 | content_hash_ok = expected_hash == update.content_hash |
| 50 | |
| 51 | if not update.signature or not update.algorithm or not update.public_key: |
| 52 | return UpdateVerificationResult( |
| 53 | valid=False, |
| 54 | client_did=update.signer_did or None, |
| 55 | round_id=update.metadata.round_id, |
| 56 | content_hash_ok=content_hash_ok, |
| 57 | signature_ok=False, |
| 58 | error="missing signature fields", |
| 59 | ) |
| 60 | |
| 61 | try: |
| 62 | algorithm = SignatureAlgorithm(update.algorithm) |
| 63 | except ValueError: |
| 64 | return UpdateVerificationResult( |
| 65 | valid=False, |
| 66 | client_did=update.signer_did, |
| 67 | round_id=update.metadata.round_id, |
| 68 | content_hash_ok=content_hash_ok, |
| 69 | signature_ok=False, |
| 70 | error=f"unknown algorithm {update.algorithm}", |
| 71 | ) |
| 72 | |
| 73 | digest = hashlib.sha3_256(update.canonical_bytes()).digest() |
| 74 | try: |
| 75 | sig_ok = verify( |
| 76 | digest, |
| 77 | bytes.fromhex(update.signature), |
| 78 | bytes.fromhex(update.public_key), |
| 79 | algorithm, |
| 80 | ) |
| 81 | except Exception as exc: |
| 82 | return UpdateVerificationResult( |
| 83 | valid=False, |
| 84 | client_did=update.signer_did, |
| 85 | round_id=update.metadata.round_id, |
| 86 | content_hash_ok=content_hash_ok, |
| 87 | signature_ok=False, |
| 88 | error=f"verify failed: {exc}", |
| 89 | ) |
| 90 | |
| 91 | err = None |
| 92 | valid = sig_ok and content_hash_ok |
| 93 | if not sig_ok: |
| 94 | err = "invalid ML-DSA signature" |
| 95 | elif not content_hash_ok: |
| 96 | err = "content hash mismatch" |
| 97 | |
| 98 | return UpdateVerificationResult( |
| 99 | valid=valid, |
| 100 | client_did=update.signer_did, |
| 101 | round_id=update.metadata.round_id, |
| 102 | content_hash_ok=content_hash_ok, |
| 103 | signature_ok=sig_ok, |
| 104 | error=err, |
| 105 | ) |
| 106 | |