tests/test_chain.py
4.1 KB · 148 lines · python Raw
1 """Tests for ProvenanceChain."""
2
3 from __future__ import annotations
4
5 import pytest
6
7 from pqc_content_provenance import (
8 AIGeneratedAssertion,
9 ContentManifest,
10 GenerationContext,
11 ManifestSigner,
12 ModelAttribution,
13 ProvenanceChain,
14 )
15 from pqc_content_provenance.errors import ChainBrokenError
16
17
18 def _build_manifest(
19 content: bytes,
20 attribution: ModelAttribution,
21 context: GenerationContext,
22 previous_manifest_id: str | None = None,
23 ) -> ContentManifest:
24 return ContentManifest.create(
25 content=content,
26 content_type="text/plain",
27 model_attribution=attribution,
28 generation_context=context,
29 assertions=[AIGeneratedAssertion(model_name=attribution.model_name)],
30 previous_manifest_id=previous_manifest_id,
31 )
32
33
34 def test_chain_single_link_verifies(
35 signer, sample_manifest: ContentManifest
36 ) -> None:
37 signed = signer.sign(sample_manifest)
38 chain = ProvenanceChain()
39 chain.add(signed)
40 ok, errors = chain.verify_chain()
41 assert ok is True
42 assert errors == []
43
44
45 def test_chain_multiple_links_verifies(
46 signer, sample_attribution, sample_context
47 ) -> None:
48 m1 = _build_manifest(b"original draft", sample_attribution, sample_context)
49 s1 = signer.sign(m1)
50
51 m2 = _build_manifest(
52 b"first edit", sample_attribution, sample_context, previous_manifest_id=s1.manifest_id
53 )
54 s2 = signer.sign(m2)
55
56 m3 = _build_manifest(
57 b"second edit", sample_attribution, sample_context, previous_manifest_id=s2.manifest_id
58 )
59 s3 = signer.sign(m3)
60
61 chain = ProvenanceChain()
62 chain.add(s1)
63 chain.add(s2)
64 chain.add(s3)
65
66 ok, errors = chain.verify_chain()
67 assert ok is True
68 assert errors == []
69 assert len(chain.links) == 3
70
71
72 def test_chain_broken_when_previous_id_mismatch(
73 signer, sample_attribution, sample_context
74 ) -> None:
75 m1 = _build_manifest(b"original", sample_attribution, sample_context)
76 s1 = signer.sign(m1)
77
78 # Deliberately wrong previous_manifest_id
79 m2 = _build_manifest(
80 b"bogus edit",
81 sample_attribution,
82 sample_context,
83 previous_manifest_id="urn:pqc-prov:not-a-real-id",
84 )
85 s2 = signer.sign(m2)
86
87 chain = ProvenanceChain()
88 chain.add(s1)
89 with pytest.raises(ChainBrokenError):
90 chain.add(s2)
91
92
93 def test_chain_roundtrip_to_dicts(
94 signer, sample_attribution, sample_context
95 ) -> None:
96 m1 = _build_manifest(b"original", sample_attribution, sample_context)
97 s1 = signer.sign(m1)
98 m2 = _build_manifest(
99 b"edit", sample_attribution, sample_context, previous_manifest_id=s1.manifest_id
100 )
101 s2 = signer.sign(m2)
102
103 chain = ProvenanceChain()
104 chain.add(s1)
105 chain.add(s2)
106
107 dicts = chain.to_dicts()
108 assert len(dicts) == 2
109 restored = ProvenanceChain.from_dicts(dicts)
110 assert len(restored.links) == 2
111 ok, errors = restored.verify_chain()
112 assert ok is True
113 assert errors == []
114
115
116 def test_chain_detects_tampered_signature_on_verify(
117 signer, sample_attribution, sample_context
118 ) -> None:
119 m1 = _build_manifest(b"original", sample_attribution, sample_context)
120 s1 = signer.sign(m1)
121 chain = ProvenanceChain()
122 chain.add(s1)
123 # Tamper after adding
124 s1.content_type = "application/malicious"
125 ok, errors = chain.verify_chain()
126 assert ok is False
127 assert any("signature invalid" in e for e in errors)
128
129
130 def test_chain_verify_reports_link_break_from_raw_dicts(
131 signer, sample_attribution, sample_context
132 ) -> None:
133 # from_dicts bypasses add(), so verify_chain must also report link breaks
134 m1 = _build_manifest(b"a", sample_attribution, sample_context)
135 s1 = signer.sign(m1)
136 m2 = _build_manifest(
137 b"b",
138 sample_attribution,
139 sample_context,
140 previous_manifest_id="urn:pqc-prov:unrelated",
141 )
142 s2 = signer.sign(m2)
143 chain = ProvenanceChain.from_dicts([s1.to_dict(), s2.to_dict()])
144 _ = ManifestSigner # keep imported name referenced for linter clarity
145 ok, errors = chain.verify_chain()
146 assert ok is False
147 assert any("link break" in e for e in errors)
148