src/pqc_federated_learning/update.py
4.6 KB · 149 lines · python Raw
1 """Gradient update data structures."""
2
3 from __future__ import annotations
4
5 import hashlib
6 import json
7 from dataclasses import asdict, dataclass
8 from datetime import datetime, timezone
9 from typing import Any
10
11
12 @dataclass(frozen=True)
13 class GradientTensor:
14 """A named tensor in a gradient update.
15
16 We keep values as a flat list of floats so this library has NO dependency
17 on numpy/torch. Users can trivially convert with `np.array(t.values).reshape(t.shape)`.
18 """
19
20 name: str # layer name, e.g. "dense_1.weights"
21 shape: tuple[int, ...]
22 values: tuple[float, ...] # flat, row-major
23
24 def __post_init__(self) -> None:
25 expected = 1
26 for d in self.shape:
27 expected *= d
28 if expected != len(self.values):
29 raise ValueError(
30 f"shape {self.shape} implies {expected} values, got {len(self.values)}"
31 )
32
33 def to_dict(self) -> dict[str, Any]:
34 return {
35 "name": self.name,
36 "shape": list(self.shape),
37 "values": list(self.values),
38 }
39
40 @classmethod
41 def from_dict(cls, data: dict[str, Any]) -> GradientTensor:
42 return cls(
43 name=data["name"],
44 shape=tuple(data["shape"]),
45 values=tuple(float(v) for v in data["values"]),
46 )
47
48
49 @dataclass(frozen=True)
50 class ClientUpdateMetadata:
51 """Non-secret metadata describing a client's update."""
52
53 client_did: str
54 round_id: str
55 model_id: str # which model is being trained
56 num_samples: int # size of local training set (used as weight in FedAvg)
57 epochs: int = 1
58 local_loss: float = 0.0
59
60 def to_dict(self) -> dict[str, Any]:
61 return asdict(self)
62
63
64 @dataclass
65 class ClientUpdate:
66 """Signed gradient update from a client."""
67
68 metadata: ClientUpdateMetadata
69 tensors: list[GradientTensor]
70 created_at: str = ""
71 content_hash: str = "" # SHA3-256 over canonical serialization
72 signer_did: str = ""
73 algorithm: str = ""
74 signature: str = "" # hex
75 public_key: str = "" # hex
76 signed_at: str = ""
77
78 def canonical_bytes(self) -> bytes:
79 payload = {
80 "metadata": self.metadata.to_dict(),
81 "tensors": [t.to_dict() for t in self.tensors],
82 "created_at": self.created_at,
83 }
84 return json.dumps(
85 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
86 ).encode("utf-8")
87
88 @staticmethod
89 def compute_content_hash(
90 metadata: ClientUpdateMetadata,
91 tensors: list[GradientTensor],
92 created_at: str,
93 ) -> str:
94 payload = {
95 "metadata": metadata.to_dict(),
96 "tensors": [t.to_dict() for t in tensors],
97 "created_at": created_at,
98 }
99 canonical = json.dumps(
100 payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False
101 ).encode("utf-8")
102 return hashlib.sha3_256(canonical).hexdigest()
103
104 @classmethod
105 def create(
106 cls,
107 metadata: ClientUpdateMetadata,
108 tensors: list[GradientTensor],
109 ) -> ClientUpdate:
110 now = datetime.now(timezone.utc).isoformat()
111 u = cls(metadata=metadata, tensors=list(tensors), created_at=now)
112 u.content_hash = cls.compute_content_hash(metadata, tensors, now)
113 return u
114
115 def to_dict(self) -> dict[str, Any]:
116 return {
117 "metadata": self.metadata.to_dict(),
118 "tensors": [t.to_dict() for t in self.tensors],
119 "created_at": self.created_at,
120 "content_hash": self.content_hash,
121 "signer_did": self.signer_did,
122 "algorithm": self.algorithm,
123 "signature": self.signature,
124 "public_key": self.public_key,
125 "signed_at": self.signed_at,
126 }
127
128 @classmethod
129 def from_dict(cls, data: dict[str, Any]) -> ClientUpdate:
130 meta = data["metadata"]
131 return cls(
132 metadata=ClientUpdateMetadata(
133 client_did=meta["client_did"],
134 round_id=meta["round_id"],
135 model_id=meta["model_id"],
136 num_samples=int(meta["num_samples"]),
137 epochs=int(meta.get("epochs", 1)),
138 local_loss=float(meta.get("local_loss", 0.0)),
139 ),
140 tensors=[GradientTensor.from_dict(t) for t in data["tensors"]],
141 created_at=data.get("created_at", ""),
142 content_hash=data.get("content_hash", ""),
143 signer_did=data.get("signer_did", ""),
144 algorithm=data.get("algorithm", ""),
145 signature=data.get("signature", ""),
146 public_key=data.get("public_key", ""),
147 signed_at=data.get("signed_at", ""),
148 )
149