src/pqc_federated_learning/aggregator.py
6.7 KB · 201 lines · python Raw
1 """FederatedAggregator - verify client updates and produce a signed aggregation proof."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import json
7 from dataclasses import asdict, dataclass, field
8 from datetime import datetime, timezone
9 from typing import Any
10
11 from quantumshield.core.algorithms import SignatureAlgorithm
12 from quantumshield.core.signatures import sign, verify
13 from quantumshield.identity.agent import AgentIdentity
14
15 from pqc_federated_learning.aggregators.base import Aggregator
16 from pqc_federated_learning.errors import (
17 AggregationError,
18 InsufficientUpdatesError,
19 )
20 from pqc_federated_learning.signer import UpdateSigner
21 from pqc_federated_learning.update import ClientUpdate, GradientTensor
22
23
24 @dataclass
25 class AggregationProof:
26 """Signed proof of which updates were aggregated and what the result hash is."""
27
28 round_id: str
29 model_id: str
30 aggregator_name: str
31 included_client_dids: list[str]
32 included_update_hashes: list[str] # content_hash of each included update
33 excluded_reasons: dict[str, str] # {client_did: reason} for excluded updates
34 result_hash: str # SHA3-256 of canonical aggregated tensors
35 num_tensors: int
36 aggregated_at: str
37 signer_did: str = ""
38 algorithm: str = ""
39 signature: str = ""
40 public_key: str = ""
41
42 def canonical_bytes(self) -> bytes:
43 payload = {
44 "round_id": self.round_id,
45 "model_id": self.model_id,
46 "aggregator_name": self.aggregator_name,
47 "included_client_dids": sorted(self.included_client_dids),
48 "included_update_hashes": sorted(self.included_update_hashes),
49 "excluded_reasons": self.excluded_reasons,
50 "result_hash": self.result_hash,
51 "num_tensors": self.num_tensors,
52 "aggregated_at": self.aggregated_at,
53 }
54 return json.dumps(
55 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
56 ).encode("utf-8")
57
58 def to_dict(self) -> dict[str, Any]:
59 return asdict(self)
60
61 def to_json(self) -> str:
62 return json.dumps(self.to_dict(), indent=2)
63
64 @classmethod
65 def from_dict(cls, data: dict[str, Any]) -> AggregationProof:
66 return cls(**data)
67
68
69 @dataclass
70 class AggregationResult:
71 """Outcome of one aggregation: tensors + signed proof."""
72
73 aggregated: list[GradientTensor]
74 proof: AggregationProof
75
76
77 @dataclass
78 class AggregationRound:
79 """One federated round: N client updates, configured aggregator."""
80
81 round_id: str
82 model_id: str
83 updates: list[ClientUpdate] = field(default_factory=list)
84
85 def add(self, update: ClientUpdate) -> None:
86 if update.metadata.round_id != self.round_id:
87 raise AggregationError(
88 f"update round_id {update.metadata.round_id} != round {self.round_id}"
89 )
90 if update.metadata.model_id != self.model_id:
91 raise AggregationError(
92 f"update model_id {update.metadata.model_id} != round {self.model_id}"
93 )
94 self.updates.append(update)
95
96
97 class FederatedAggregator:
98 """Verify signed client updates and produce a signed aggregation proof.
99
100 Usage:
101 identity = AgentIdentity.create("aggregator")
102 aggregator = FederatedAggregator(
103 identity=identity,
104 strategy=FedAvgAggregator(),
105 trusted_clients={"did:pqaid:..."},
106 )
107 result = aggregator.aggregate(round)
108 # result.aggregated: list[GradientTensor]
109 # result.proof: AggregationProof (signed with ML-DSA)
110 """
111
112 def __init__(
113 self,
114 identity: AgentIdentity,
115 strategy: Aggregator,
116 trusted_clients: set[str] | None = None,
117 min_updates: int = 1,
118 ):
119 self.identity = identity
120 self.strategy = strategy
121 self.trusted_clients = trusted_clients
122 self.min_updates = min_updates
123
124 def aggregate(self, round_: AggregationRound) -> AggregationResult:
125 accepted: list[ClientUpdate] = []
126 excluded: dict[str, str] = {}
127
128 for update in round_.updates:
129 # Verify signature
130 result = UpdateSigner.verify(update)
131 if not result.valid:
132 excluded[update.metadata.client_did] = (
133 result.error or "signature invalid"
134 )
135 continue
136
137 # Allow-list check
138 if (
139 self.trusted_clients is not None
140 and update.metadata.client_did not in self.trusted_clients
141 ):
142 excluded[update.metadata.client_did] = "client not in trusted set"
143 continue
144
145 accepted.append(update)
146
147 if len(accepted) < self.min_updates:
148 raise InsufficientUpdatesError(
149 f"only {len(accepted)} valid updates, need {self.min_updates}"
150 )
151
152 aggregated = self.strategy.aggregate(accepted)
153
154 # Compute result hash: canonical bytes over aggregated tensors
155 payload = [t.to_dict() for t in aggregated]
156 canonical = json.dumps(
157 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
158 ).encode("utf-8")
159 result_hash = hashlib.sha3_256(canonical).hexdigest()
160
161 proof = AggregationProof(
162 round_id=round_.round_id,
163 model_id=round_.model_id,
164 aggregator_name=self.strategy.name,
165 included_client_dids=[u.metadata.client_did for u in accepted],
166 included_update_hashes=[u.content_hash for u in accepted],
167 excluded_reasons=excluded,
168 result_hash=result_hash,
169 num_tensors=len(aggregated),
170 aggregated_at=datetime.now(timezone.utc).isoformat(),
171 )
172
173 # Sign the proof
174 digest = hashlib.sha3_256(proof.canonical_bytes()).digest()
175 sig = sign(digest, self.identity.signing_keypair)
176 proof.signer_did = self.identity.did
177 proof.algorithm = self.identity.signing_keypair.algorithm.value
178 proof.signature = sig.hex()
179 proof.public_key = self.identity.signing_keypair.public_key.hex()
180
181 return AggregationResult(aggregated=aggregated, proof=proof)
182
183 @staticmethod
184 def verify_proof(proof: AggregationProof) -> bool:
185 if not proof.signature:
186 return False
187 try:
188 algorithm = SignatureAlgorithm(proof.algorithm)
189 except ValueError:
190 return False
191 digest = hashlib.sha3_256(proof.canonical_bytes()).digest()
192 try:
193 return verify(
194 digest,
195 bytes.fromhex(proof.signature),
196 bytes.fromhex(proof.public_key),
197 algorithm,
198 )
199 except Exception:
200 return False
201