src/pqc_ai_governance/tally.py
| 1 | """VoteTally - count signed votes, detect Byzantine behavior (double-voting).""" |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | from dataclasses import dataclass, field |
| 6 | from typing import Any |
| 7 | |
| 8 | from pqc_ai_governance.errors import ByzantineDetectedError |
| 9 | from pqc_ai_governance.node import GovernanceNode, NodeRegistry |
| 10 | from pqc_ai_governance.proposal import GovernanceProposal |
| 11 | from pqc_ai_governance.vote import SignedVote, VoteDecision |
| 12 | |
| 13 | |
| 14 | @dataclass |
| 15 | class VoteTally: |
| 16 | """Aggregate votes for a proposal with Byzantine checks. |
| 17 | |
| 18 | Detects: |
| 19 | 1. Votes with invalid signatures -> recorded as invalid. |
| 20 | 2. Votes from non-member DIDs -> recorded as invalid. |
| 21 | 3. Votes referencing the wrong proposal_hash -> recorded as invalid. |
| 22 | 4. A single DID voting twice with different decisions -> ``ByzantineDetectedError``. |
| 23 | """ |
| 24 | |
| 25 | proposal: GovernanceProposal |
| 26 | registry: NodeRegistry |
| 27 | approve_weight: int = 0 |
| 28 | reject_weight: int = 0 |
| 29 | abstain_weight: int = 0 |
| 30 | _seen_voters: dict[str, VoteDecision] = field(default_factory=dict) |
| 31 | valid_votes: list[SignedVote] = field(default_factory=list) |
| 32 | invalid_votes: list[tuple[SignedVote, str]] = field(default_factory=list) |
| 33 | |
| 34 | def add(self, signed: SignedVote) -> None: |
| 35 | vote = signed.vote |
| 36 | |
| 37 | # 1. Proposal hash check |
| 38 | if vote.proposal_hash != self.proposal.proposal_hash(): |
| 39 | self.invalid_votes.append((signed, "proposal hash mismatch")) |
| 40 | return |
| 41 | |
| 42 | # 2. Proposal id check |
| 43 | if vote.proposal_id != self.proposal.proposal_id: |
| 44 | self.invalid_votes.append((signed, "proposal id mismatch")) |
| 45 | return |
| 46 | |
| 47 | # 3. Signature check |
| 48 | if not GovernanceNode.verify_vote(signed): |
| 49 | self.invalid_votes.append((signed, "invalid signature")) |
| 50 | return |
| 51 | |
| 52 | # 4. Membership check |
| 53 | if not self.registry.is_member(vote.voter_did): |
| 54 | self.invalid_votes.append((signed, "non-member voter")) |
| 55 | return |
| 56 | |
| 57 | # 5. Byzantine check: same voter voting differently |
| 58 | if vote.voter_did in self._seen_voters: |
| 59 | prior = self._seen_voters[vote.voter_did] |
| 60 | if prior != vote.decision: |
| 61 | raise ByzantineDetectedError( |
| 62 | f"voter {vote.voter_did} cast conflicting votes: " |
| 63 | f"{prior.value} then {vote.decision.value}" |
| 64 | ) |
| 65 | # Same decision, same voter - silent idempotent skip |
| 66 | return |
| 67 | |
| 68 | self._seen_voters[vote.voter_did] = vote.decision |
| 69 | node = self.registry.get(vote.voter_did) |
| 70 | weight = node.weight |
| 71 | |
| 72 | if vote.decision == VoteDecision.APPROVE: |
| 73 | self.approve_weight += weight |
| 74 | elif vote.decision == VoteDecision.REJECT: |
| 75 | self.reject_weight += weight |
| 76 | else: |
| 77 | self.abstain_weight += weight |
| 78 | |
| 79 | self.valid_votes.append(signed) |
| 80 | |
| 81 | def total_cast_weight(self) -> int: |
| 82 | return self.approve_weight + self.reject_weight + self.abstain_weight |
| 83 | |
| 84 | def to_dict(self) -> dict[str, Any]: |
| 85 | return { |
| 86 | "proposal_id": self.proposal.proposal_id, |
| 87 | "approve_weight": self.approve_weight, |
| 88 | "reject_weight": self.reject_weight, |
| 89 | "abstain_weight": self.abstain_weight, |
| 90 | "valid_vote_count": len(self.valid_votes), |
| 91 | "invalid_vote_count": len(self.invalid_votes), |
| 92 | } |
| 93 | |