README.md
11.7 KB · 274 lines · markdown Raw
1 # PQC Federated Learning
2
3 ![PQC Native](https://img.shields.io/badge/PQC-Native-blue)
4 ![ML-DSA-65](https://img.shields.io/badge/ML--DSA--65-FIPS%20204-green)
5 ![License](https://img.shields.io/badge/License-Apache%202.0-orange)
6 ![Version](https://img.shields.io/badge/version-0.1.0-lightgrey)
7
8 **Post-quantum secure federated learning.** Every client signs its gradient update with **ML-DSA** (FIPS 204). The aggregator verifies every signature, rejects anything it cannot verify, then emits a **signed aggregation proof** binding the set of included clients to a hash of the aggregated tensors. Pluggable strategies: FedAvg, FedSum, FedMedian, FedTrimmedMean.
9
10 ## The Problem
11
12 Federated learning is sold as a privacy story: training data never leaves the client. But the gradient updates that *do* leave are rarely authenticated end-to-end with quantum-safe crypto. A compromised TLS connection, a malicious coordinator, a "harvest now, decrypt later" adversary sitting on the wire — all of them can substitute, forge, or selectively drop updates to bias the global model. In regulated domains (medical imaging, fraud detection, loan underwriting), this is a 10+ year liability: model provenance must be auditable long after the underlying crypto has fallen.
13
14 ## The Solution
15
16 - **Per-update ML-DSA signature.** Each `ClientUpdate` carries a SHA3-256 content hash and a post-quantum signature over that hash, bound to the client's DID.
17 - **Signature-gated aggregation.** The `FederatedAggregator` refuses to include any update whose signature does not verify.
18 - **Optional allow-list.** Provide a `trusted_clients` set to hard-enforce which DIDs may contribute to a round.
19 - **Signed aggregation proof.** The aggregator emits an `AggregationProof` listing the included client DIDs, update hashes, excluded reasons, and a hash of the final aggregated tensors — signed with ML-DSA by the aggregator's own identity. Auditors can verify the entire round forever.
20 - **Robust aggregators.** FedMedian and FedTrimmedMean shrug off a bounded number of Byzantine clients (cryptographically valid but value-malicious).
21
22 ## Installation
23
24 ```bash
25 pip install pqc-federated-learning
26 ```
27
28 Development:
29
30 ```bash
31 pip install -e ".[dev]"
32 pytest
33 ```
34
35 ## Quick Start
36
37 ```python
38 from quantumshield.identity.agent import AgentIdentity
39 from pqc_federated_learning import (
40 AggregationRound,
41 ClientUpdate,
42 ClientUpdateMetadata,
43 FedAvgAggregator,
44 FederatedAggregator,
45 GradientTensor,
46 UpdateSigner,
47 )
48
49 # 1. Each client has a PQ-AID identity and signs its own update.
50 client = AgentIdentity.create("hospital-a")
51
52 update = ClientUpdate.create(
53 metadata=ClientUpdateMetadata(
54 client_did=client.did,
55 round_id="round-1",
56 model_id="pneumonia-detector-v2",
57 num_samples=1024,
58 epochs=3,
59 local_loss=0.18,
60 ),
61 tensors=[
62 GradientTensor(name="conv1.weights", shape=(2, 2), values=(0.1, 0.2, 0.3, 0.4)),
63 GradientTensor(name="conv1.bias", shape=(2,), values=(0.01, 0.02)),
64 ],
65 )
66 signed = UpdateSigner(client).sign(update)
67
68 # 2. Coordinator collects N client updates into a round.
69 round_ = AggregationRound(round_id="round-1", model_id="pneumonia-detector-v2")
70 round_.add(signed)
71 # ... add updates from other clients ...
72
73 # 3. Aggregator verifies every signature, then aggregates, then signs the result.
74 aggregator_id = AgentIdentity.create("central-aggregator")
75 aggregator = FederatedAggregator(
76 identity=aggregator_id,
77 strategy=FedAvgAggregator(),
78 trusted_clients={client.did}, # optional allow-list
79 min_updates=1,
80 )
81 result = aggregator.aggregate(round_)
82
83 # result.aggregated -> list[GradientTensor] ready to apply to the global model
84 # result.proof -> signed AggregationProof with included/excluded DIDs + result hash
85
86 assert FederatedAggregator.verify_proof(result.proof)
87 ```
88
89 ## Architecture
90
91 ```
92 Client A Client B Client C
93 -------- -------- --------
94 | | |
95 | local training | local training | local training
96 | | |
97 | compute gradient | compute gradient | compute gradient
98 | | |
99 | sign(update) | sign(update) | sign(update)
100 | ML-DSA / SHA3-256 | |
101 | | |
102 +----------+-----------+----------+-----------+
103 | |
104 v v
105 +----------------------------------+
106 | FederatedAggregator |
107 | |
108 | 1. UpdateSigner.verify() per |
109 | update (ML-DSA signature, |
110 | content hash, allow-list) |
111 | |
112 | 2. Strategy.aggregate() |
113 | FedAvg | FedSum | |
114 | FedMedian | FedTrimmedMean |
115 | |
116 | 3. SHA3-256(aggregated) |
117 | |
118 | 4. Sign AggregationProof |
119 | with aggregator's ML-DSA |
120 | identity |
121 +----------------+-----------------+
122 |
123 v
124 +-------------------------------+
125 | AggregationResult |
126 | - aggregated: [GradientTensor]|
127 | - proof: AggregationProof |
128 | included_client_dids |
129 | included_update_hashes |
130 | excluded_reasons |
131 | result_hash |
132 | ML-DSA signature |
133 +-------------------------------+
134 ```
135
136 ## Threat Model
137
138 | Threat | Mitigation |
139 |---|---|
140 | **Update forgery** (attacker fabricates an update and claims it came from client A) | Only client A's private key can produce a valid ML-DSA signature over A's DID-bound content hash. |
141 | **Update tampering in transit** (flip a gradient value) | Recomputed SHA3-256 content hash no longer matches the signed hash; update excluded. |
142 | **Malicious coordinator** (silently drops honest updates, keeps poisoned ones) | `AggregationProof` is signed by the aggregator and lists every included DID + content hash. Auditors detect missing clients. |
143 | **Untrusted client joins** (rogue node submits signed updates) | `trusted_clients` allow-list rejects any DID not on the roster. |
144 | **Byzantine value attack** (valid signature, extreme values) | Use `FedMedianAggregator` or `FedTrimmedMeanAggregator` — they are robust to a bounded fraction of bad clients. |
145 | **Replay of an old round** | Each update is bound to a `round_id` + `model_id`; `AggregationRound.add()` refuses mismatches. |
146 | **Harvest-now-decrypt-later** (adversary records traffic today, breaks RSA/ECDSA with a future quantum computer) | ML-DSA is a FIPS 204 post-quantum signature scheme; signatures stay valid against known quantum attacks. |
147 | **Proof tampering** (auditor is handed a modified proof) | `FederatedAggregator.verify_proof()` recomputes the canonical bytes and checks the aggregator's ML-DSA signature. |
148
149 ## Why PQC for Federated Learning?
150
151 Federated models trained on medical images, financial transactions, or legal corpora have a shelf life measured in **decades**. A forged gradient injected in 2026 still corrupts the downstream model in 2040. Once a classical signature scheme falls to a cryptographically relevant quantum computer, every federated training round ever conducted over that scheme becomes retroactively unverifiable. PQC signatures are the only way to make an FL audit trail that still means something after Q-day.
152
153 ## API Reference
154
155 ### `GradientTensor`
156
157 Frozen dataclass. A single named tensor.
158
159 | Field | Description |
160 |---|---|
161 | `name` | Layer name, e.g. `"dense_1.weights"` |
162 | `shape` | Tuple of ints |
163 | `values` | Flat tuple of floats (row-major) |
164
165 | Method | Description |
166 |---|---|
167 | `to_dict()` / `from_dict()` | JSON-safe round-trip |
168
169 ### `ClientUpdateMetadata`
170
171 Frozen dataclass with client DID, round/model ids, `num_samples`, `epochs`, `local_loss`.
172
173 ### `ClientUpdate`
174
175 | Field | Description |
176 |---|---|
177 | `metadata` | `ClientUpdateMetadata` |
178 | `tensors` | `list[GradientTensor]` |
179 | `created_at` | ISO-8601 timestamp |
180 | `content_hash` | SHA3-256 over canonical `(metadata, tensors, created_at)` |
181 | `signer_did`, `public_key`, `algorithm`, `signature`, `signed_at` | Signature envelope |
182
183 | Method | Description |
184 |---|---|
185 | `ClientUpdate.create(metadata, tensors)` | Build unsigned update with `content_hash` populated |
186 | `compute_content_hash(metadata, tensors, created_at)` | Static canonical hash |
187 | `to_dict()` / `from_dict()` | JSON-safe round-trip |
188
189 ### `UpdateSigner`
190
191 | Method | Description |
192 |---|---|
193 | `UpdateSigner(identity).sign(update)` | Populate signature envelope (mutates + returns) |
194 | `UpdateSigner.verify(update)` | Static - returns `UpdateVerificationResult` |
195
196 ### `AggregationRound`
197
198 | Method | Description |
199 |---|---|
200 | `AggregationRound(round_id, model_id)` | New empty round |
201 | `add(update)` | Append update; raises `AggregationError` on round/model mismatch |
202
203 ### `FederatedAggregator`
204
205 | Method | Description |
206 |---|---|
207 | `FederatedAggregator(identity, strategy, trusted_clients=None, min_updates=1)` | Construct |
208 | `aggregate(round_)` | Returns `AggregationResult(aggregated, proof)` |
209 | `FederatedAggregator.verify_proof(proof)` | Static - verify the aggregator's ML-DSA signature |
210
211 ### `AggregationProof`
212
213 Fields: `round_id`, `model_id`, `aggregator_name`, `included_client_dids`, `included_update_hashes`, `excluded_reasons`, `result_hash`, `num_tensors`, `aggregated_at`, `signer_did`, `algorithm`, `signature`, `public_key`.
214
215 Methods: `canonical_bytes()`, `to_dict()`, `to_json()`, `from_dict()`.
216
217 ### Aggregator strategies
218
219 | Strategy | Behavior |
220 |---|---|
221 | `FedAvgAggregator()` | Weighted mean by `num_samples`. Default choice. |
222 | `FedSumAggregator()` | Unweighted element-wise sum. Building block for secure aggregation. |
223 | `FedMedianAggregator()` | Element-wise median. Robust to a minority of Byzantine clients. |
224 | `FedTrimmedMeanAggregator(trim_ratio=0.1)` | Drops top/bottom `trim_ratio` fraction before averaging. |
225
226 ### Exceptions
227
228 | Exception | When |
229 |---|---|
230 | `FLError` | Base class |
231 | `InvalidUpdateError` | Structural problems with an update |
232 | `SignatureVerificationError` | Signature failed to verify |
233 | `AggregationError` | Round-level error (round/model mismatch) |
234 | `UntrustedClientError` | DID not in allow-list |
235 | `ShapeMismatchError` | Tensor names or shapes disagree across updates |
236 | `InsufficientUpdatesError` | Fewer valid updates than `min_updates` |
237
238 ## Examples
239
240 See the `examples/` directory:
241
242 - **`simple_fedavg.py`** - 3 clients, FedAvg, signed aggregation proof.
243 - **`byzantine_client_rejected.py`** - attacker with a forged signature is excluded.
244 - **`robust_median.py`** - FedMedian absorbs one malicious client sending extreme values.
245
246 Run them:
247
248 ```bash
249 python examples/simple_fedavg.py
250 python examples/byzantine_client_rejected.py
251 python examples/robust_median.py
252 ```
253
254 ## Development
255
256 ```bash
257 pip install -e ".[dev]"
258 pytest
259 ruff check src/ tests/ examples/
260 ```
261
262 ## Related
263
264 Part of the [QuantaMrkt](https://quantamrkt.com) post-quantum tooling registry. See also:
265
266 - **QuantumShield** - the PQC toolkit (`AgentIdentity`, `SignatureAlgorithm`, `sign/verify`).
267 - **PQC RAG Signing** - sign retrieval chunks with ML-DSA.
268 - **PQC Training Data Transparency** - sign training datasets and commitments.
269 - **PQC Content Provenance** - sign manifests for generated content.
270
271 ## License
272
273 Apache License 2.0. See [LICENSE](LICENSE).
274