src/pqc_ai_governance/tally.py
3.3 KB · 93 lines · python Raw
1 """VoteTally - count signed votes, detect Byzantine behavior (double-voting)."""
2
3 from __future__ import annotations
4
5 from dataclasses import dataclass, field
6 from typing import Any
7
8 from pqc_ai_governance.errors import ByzantineDetectedError
9 from pqc_ai_governance.node import GovernanceNode, NodeRegistry
10 from pqc_ai_governance.proposal import GovernanceProposal
11 from pqc_ai_governance.vote import SignedVote, VoteDecision
12
13
14 @dataclass
15 class VoteTally:
16 """Aggregate votes for a proposal with Byzantine checks.
17
18 Detects:
19 1. Votes with invalid signatures -> recorded as invalid.
20 2. Votes from non-member DIDs -> recorded as invalid.
21 3. Votes referencing the wrong proposal_hash -> recorded as invalid.
22 4. A single DID voting twice with different decisions -> ``ByzantineDetectedError``.
23 """
24
25 proposal: GovernanceProposal
26 registry: NodeRegistry
27 approve_weight: int = 0
28 reject_weight: int = 0
29 abstain_weight: int = 0
30 _seen_voters: dict[str, VoteDecision] = field(default_factory=dict)
31 valid_votes: list[SignedVote] = field(default_factory=list)
32 invalid_votes: list[tuple[SignedVote, str]] = field(default_factory=list)
33
34 def add(self, signed: SignedVote) -> None:
35 vote = signed.vote
36
37 # 1. Proposal hash check
38 if vote.proposal_hash != self.proposal.proposal_hash():
39 self.invalid_votes.append((signed, "proposal hash mismatch"))
40 return
41
42 # 2. Proposal id check
43 if vote.proposal_id != self.proposal.proposal_id:
44 self.invalid_votes.append((signed, "proposal id mismatch"))
45 return
46
47 # 3. Signature check
48 if not GovernanceNode.verify_vote(signed):
49 self.invalid_votes.append((signed, "invalid signature"))
50 return
51
52 # 4. Membership check
53 if not self.registry.is_member(vote.voter_did):
54 self.invalid_votes.append((signed, "non-member voter"))
55 return
56
57 # 5. Byzantine check: same voter voting differently
58 if vote.voter_did in self._seen_voters:
59 prior = self._seen_voters[vote.voter_did]
60 if prior != vote.decision:
61 raise ByzantineDetectedError(
62 f"voter {vote.voter_did} cast conflicting votes: "
63 f"{prior.value} then {vote.decision.value}"
64 )
65 # Same decision, same voter - silent idempotent skip
66 return
67
68 self._seen_voters[vote.voter_did] = vote.decision
69 node = self.registry.get(vote.voter_did)
70 weight = node.weight
71
72 if vote.decision == VoteDecision.APPROVE:
73 self.approve_weight += weight
74 elif vote.decision == VoteDecision.REJECT:
75 self.reject_weight += weight
76 else:
77 self.abstain_weight += weight
78
79 self.valid_votes.append(signed)
80
81 def total_cast_weight(self) -> int:
82 return self.approve_weight + self.reject_weight + self.abstain_weight
83
84 def to_dict(self) -> dict[str, Any]:
85 return {
86 "proposal_id": self.proposal.proposal_id,
87 "approve_weight": self.approve_weight,
88 "reject_weight": self.reject_weight,
89 "abstain_weight": self.abstain_weight,
90 "valid_vote_count": len(self.valid_votes),
91 "invalid_vote_count": len(self.invalid_votes),
92 }
93