src/pqc_federated_learning/aggregator.py
| 1 | """FederatedAggregator - verify client updates and produce a signed aggregation proof.""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import hashlib |
| 6 | import json |
| 7 | from dataclasses import asdict, dataclass, field |
| 8 | from datetime import datetime, timezone |
| 9 | from typing import Any |
| 10 | |
| 11 | from quantumshield.core.algorithms import SignatureAlgorithm |
| 12 | from quantumshield.core.signatures import sign, verify |
| 13 | from quantumshield.identity.agent import AgentIdentity |
| 14 | |
| 15 | from pqc_federated_learning.aggregators.base import Aggregator |
| 16 | from pqc_federated_learning.errors import ( |
| 17 | AggregationError, |
| 18 | InsufficientUpdatesError, |
| 19 | ) |
| 20 | from pqc_federated_learning.signer import UpdateSigner |
| 21 | from pqc_federated_learning.update import ClientUpdate, GradientTensor |
| 22 | |
| 23 | |
| 24 | @dataclass |
| 25 | class AggregationProof: |
| 26 | """Signed proof of which updates were aggregated and what the result hash is.""" |
| 27 | |
| 28 | round_id: str |
| 29 | model_id: str |
| 30 | aggregator_name: str |
| 31 | included_client_dids: list[str] |
| 32 | included_update_hashes: list[str] # content_hash of each included update |
| 33 | excluded_reasons: dict[str, str] # {client_did: reason} for excluded updates |
| 34 | result_hash: str # SHA3-256 of canonical aggregated tensors |
| 35 | num_tensors: int |
| 36 | aggregated_at: str |
| 37 | signer_did: str = "" |
| 38 | algorithm: str = "" |
| 39 | signature: str = "" |
| 40 | public_key: str = "" |
| 41 | |
| 42 | def canonical_bytes(self) -> bytes: |
| 43 | payload = { |
| 44 | "round_id": self.round_id, |
| 45 | "model_id": self.model_id, |
| 46 | "aggregator_name": self.aggregator_name, |
| 47 | "included_client_dids": sorted(self.included_client_dids), |
| 48 | "included_update_hashes": sorted(self.included_update_hashes), |
| 49 | "excluded_reasons": self.excluded_reasons, |
| 50 | "result_hash": self.result_hash, |
| 51 | "num_tensors": self.num_tensors, |
| 52 | "aggregated_at": self.aggregated_at, |
| 53 | } |
| 54 | return json.dumps( |
| 55 | payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False |
| 56 | ).encode("utf-8") |
| 57 | |
| 58 | def to_dict(self) -> dict[str, Any]: |
| 59 | return asdict(self) |
| 60 | |
| 61 | def to_json(self) -> str: |
| 62 | return json.dumps(self.to_dict(), indent=2) |
| 63 | |
| 64 | @classmethod |
| 65 | def from_dict(cls, data: dict[str, Any]) -> AggregationProof: |
| 66 | return cls(**data) |
| 67 | |
| 68 | |
| 69 | @dataclass |
| 70 | class AggregationResult: |
| 71 | """Outcome of one aggregation: tensors + signed proof.""" |
| 72 | |
| 73 | aggregated: list[GradientTensor] |
| 74 | proof: AggregationProof |
| 75 | |
| 76 | |
| 77 | @dataclass |
| 78 | class AggregationRound: |
| 79 | """One federated round: N client updates, configured aggregator.""" |
| 80 | |
| 81 | round_id: str |
| 82 | model_id: str |
| 83 | updates: list[ClientUpdate] = field(default_factory=list) |
| 84 | |
| 85 | def add(self, update: ClientUpdate) -> None: |
| 86 | if update.metadata.round_id != self.round_id: |
| 87 | raise AggregationError( |
| 88 | f"update round_id {update.metadata.round_id} != round {self.round_id}" |
| 89 | ) |
| 90 | if update.metadata.model_id != self.model_id: |
| 91 | raise AggregationError( |
| 92 | f"update model_id {update.metadata.model_id} != round {self.model_id}" |
| 93 | ) |
| 94 | self.updates.append(update) |
| 95 | |
| 96 | |
| 97 | class FederatedAggregator: |
| 98 | """Verify signed client updates and produce a signed aggregation proof. |
| 99 | |
| 100 | Usage: |
| 101 | identity = AgentIdentity.create("aggregator") |
| 102 | aggregator = FederatedAggregator( |
| 103 | identity=identity, |
| 104 | strategy=FedAvgAggregator(), |
| 105 | trusted_clients={"did:pqaid:..."}, |
| 106 | ) |
| 107 | result = aggregator.aggregate(round) |
| 108 | # result.aggregated: list[GradientTensor] |
| 109 | # result.proof: AggregationProof (signed with ML-DSA) |
| 110 | """ |
| 111 | |
| 112 | def __init__( |
| 113 | self, |
| 114 | identity: AgentIdentity, |
| 115 | strategy: Aggregator, |
| 116 | trusted_clients: set[str] | None = None, |
| 117 | min_updates: int = 1, |
| 118 | ): |
| 119 | self.identity = identity |
| 120 | self.strategy = strategy |
| 121 | self.trusted_clients = trusted_clients |
| 122 | self.min_updates = min_updates |
| 123 | |
| 124 | def aggregate(self, round_: AggregationRound) -> AggregationResult: |
| 125 | accepted: list[ClientUpdate] = [] |
| 126 | excluded: dict[str, str] = {} |
| 127 | |
| 128 | for update in round_.updates: |
| 129 | # Verify signature |
| 130 | result = UpdateSigner.verify(update) |
| 131 | if not result.valid: |
| 132 | excluded[update.metadata.client_did] = ( |
| 133 | result.error or "signature invalid" |
| 134 | ) |
| 135 | continue |
| 136 | |
| 137 | # Allow-list check |
| 138 | if ( |
| 139 | self.trusted_clients is not None |
| 140 | and update.metadata.client_did not in self.trusted_clients |
| 141 | ): |
| 142 | excluded[update.metadata.client_did] = "client not in trusted set" |
| 143 | continue |
| 144 | |
| 145 | accepted.append(update) |
| 146 | |
| 147 | if len(accepted) < self.min_updates: |
| 148 | raise InsufficientUpdatesError( |
| 149 | f"only {len(accepted)} valid updates, need {self.min_updates}" |
| 150 | ) |
| 151 | |
| 152 | aggregated = self.strategy.aggregate(accepted) |
| 153 | |
| 154 | # Compute result hash: canonical bytes over aggregated tensors |
| 155 | payload = [t.to_dict() for t in aggregated] |
| 156 | canonical = json.dumps( |
| 157 | payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False |
| 158 | ).encode("utf-8") |
| 159 | result_hash = hashlib.sha3_256(canonical).hexdigest() |
| 160 | |
| 161 | proof = AggregationProof( |
| 162 | round_id=round_.round_id, |
| 163 | model_id=round_.model_id, |
| 164 | aggregator_name=self.strategy.name, |
| 165 | included_client_dids=[u.metadata.client_did for u in accepted], |
| 166 | included_update_hashes=[u.content_hash for u in accepted], |
| 167 | excluded_reasons=excluded, |
| 168 | result_hash=result_hash, |
| 169 | num_tensors=len(aggregated), |
| 170 | aggregated_at=datetime.now(timezone.utc).isoformat(), |
| 171 | ) |
| 172 | |
| 173 | # Sign the proof |
| 174 | digest = hashlib.sha3_256(proof.canonical_bytes()).digest() |
| 175 | sig = sign(digest, self.identity.signing_keypair) |
| 176 | proof.signer_did = self.identity.did |
| 177 | proof.algorithm = self.identity.signing_keypair.algorithm.value |
| 178 | proof.signature = sig.hex() |
| 179 | proof.public_key = self.identity.signing_keypair.public_key.hex() |
| 180 | |
| 181 | return AggregationResult(aggregated=aggregated, proof=proof) |
| 182 | |
| 183 | @staticmethod |
| 184 | def verify_proof(proof: AggregationProof) -> bool: |
| 185 | if not proof.signature: |
| 186 | return False |
| 187 | try: |
| 188 | algorithm = SignatureAlgorithm(proof.algorithm) |
| 189 | except ValueError: |
| 190 | return False |
| 191 | digest = hashlib.sha3_256(proof.canonical_bytes()).digest() |
| 192 | try: |
| 193 | return verify( |
| 194 | digest, |
| 195 | bytes.fromhex(proof.signature), |
| 196 | bytes.fromhex(proof.public_key), |
| 197 | algorithm, |
| 198 | ) |
| 199 | except Exception: |
| 200 | return False |
| 201 | |