conversation.py
9.0 KB · 281 lines · python Raw
1 """
2 From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3 """
4
5 import dataclasses
6 from enum import IntEnum, auto
7 from typing import Any, Dict, List
8
9
10 class SeparatorStyle(IntEnum):
11 """Separator styles."""
12
13 DeepSeek = auto()
14 DeepSeekV2 = auto()
15 PLAIN = auto()
16 ALIGNMENT = auto()
17
18
19 @dataclasses.dataclass
20 class Conversation:
21 """A class that manages prompt templates and keeps all conversation history."""
22
23 # The name of this template
24 name: str
25 # The template of the system prompt
26 system_template: str = "{system_message}"
27 # The system message
28 system_message: str = ""
29 # The names of two roles
30 roles: List[str] = (("USER", "ASSISTANT"),)
31 # All messages. Each item is (role, message).
32 messages: List[List[str]] = ()
33 # The number of few shot examples
34 offset: int = 0
35 # The separator style and configurations
36 sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
37 sep: str = "\n"
38 sep2: str = None
39 # Stop criteria (the default one is EOS token)
40 stop_str: str = None
41 # Stops generation if meeting any token in this list
42 stop_token_ids: List[int] = None
43
44 def get_prompt(self) -> str:
45 """Get the prompt for generation."""
46 system_prompt = self.system_template.format(system_message=self.system_message)
47 if self.sep_style == SeparatorStyle.DeepSeek:
48 seps = [self.sep, self.sep2]
49 if system_prompt == "" or system_prompt is None:
50 ret = ""
51 else:
52 ret = system_prompt + seps[0]
53 for i, (role, message) in enumerate(self.messages):
54 if message:
55 ret += role + ": " + message + seps[i % 2]
56 else:
57 ret += role + ":"
58 return ret
59 elif self.sep_style == SeparatorStyle.DeepSeekV2:
60 seps = [self.sep, self.sep2]
61 if system_prompt == "" or system_prompt is None:
62 ret = ""
63 else:
64 ret = system_prompt + seps[0]
65 for i, (role, message) in enumerate(self.messages):
66 if message:
67 if role == "User":
68 ret += "<|sft▁begin|>\n" + message + self.sep #<|sft▁begin|>User Input<|sft▁end|>\nResponse<|end▁of▁sentence|>
69 else:
70 ret += message + self.sep2
71 else:
72 ret = ret
73 return ret
74
75 elif self.sep_style == SeparatorStyle.PLAIN:
76 seps = [self.sep, self.sep2]
77 ret = ""
78 for i, (role, message) in enumerate(self.messages):
79 if message:
80 if type(message) is tuple:
81 message, _, _ = message
82 if i % 2 == 0:
83 ret += message + seps[i % 2]
84 else:
85 ret += message + seps[i % 2]
86 else:
87 ret += ""
88 return ret
89 elif self.sep_style == SeparatorStyle.ALIGNMENT:
90 seps = [self.sep, self.sep2]
91 ret = ""
92 for i, (role, message) in enumerate(self.messages):
93 if message:
94 if type(message) is tuple:
95 message, _, _ = message
96 if i % 2 == 0:
97 ret += '<image>\n' + seps[i % 2]
98 else:
99 ret += message + seps[i % 2]
100 else:
101 ret += ""
102 return ret
103 else:
104 raise ValueError(f"Invalid style: {self.sep_style}")
105
106 def set_system_message(self, system_message: str):
107 """Set the system message."""
108 self.system_message = system_message
109
110 def append_message(self, role: str, message: str):
111 """Append a new message."""
112 self.messages.append([role, message])
113
114 def update_last_message(self, message: str):
115 """Update the last output.
116
117 The last message is typically set to be None when constructing the prompt,
118 so we need to update it in-place after getting the response from a model.
119 """
120 self.messages[-1][1] = message
121
122 def reset_message(self):
123 """Reset a new message."""
124 self.messages = []
125
126 def to_gradio_chatbot(self):
127 """Convert the conversation to gradio chatbot format."""
128 ret = []
129 for i, (role, msg) in enumerate(self.messages[self.offset :]):
130 if i % 2 == 0:
131 ret.append([msg, None])
132 else:
133 ret[-1][-1] = msg
134 return ret
135
136 def to_openai_api_messages(self):
137 """Convert the conversation to OpenAI chat completion format."""
138 system_prompt = self.system_template.format(system_message=self.system_message)
139 ret = [{"role": "system", "content": system_prompt}]
140
141 for i, (_, msg) in enumerate(self.messages[self.offset :]):
142 if i % 2 == 0:
143 ret.append({"role": "user", "content": msg})
144 else:
145 if msg is not None:
146 ret.append({"role": "assistant", "content": msg})
147 return ret
148
149 def copy(self):
150 return Conversation(
151 name=self.name,
152 system_template=self.system_template,
153 system_message=self.system_message,
154 roles=self.roles,
155 messages=[[x, y] for x, y in self.messages],
156 offset=self.offset,
157 sep_style=self.sep_style,
158 sep=self.sep,
159 sep2=self.sep2,
160 stop_str=self.stop_str,
161 stop_token_ids=self.stop_token_ids,
162 )
163
164 def dict(self):
165 return {
166 "template_name": self.name,
167 "system_message": self.system_message,
168 "roles": self.roles,
169 "messages": self.messages,
170 "offset": self.offset,
171 }
172
173
174 # A global registry for all conversation templates
175 conv_templates: Dict[str, Conversation] = {}
176
177
178 def register_conv_template(template: Conversation, override: bool = False):
179 """Register a new conversation template."""
180 if not override:
181 assert template.name not in conv_templates, f"{template.name} has been registered."
182
183 conv_templates[template.name] = template
184
185
186 def get_conv_template(name: str) -> Conversation:
187 """Get a conversation template."""
188 return conv_templates[name].copy()
189
190
191 register_conv_template(
192 Conversation(
193 name="deepseek",
194 system_template="{system_message}",
195 # system_message="You are a helpful assistant. Please answer truthfully and write out your "
196 # "thinking step by step to be sure you get the right answer.",
197 system_message="",
198 roles=("<|User|>", "<|Assistant|>"),
199 messages=(),
200 offset=0,
201 sep_style=SeparatorStyle.DeepSeek,
202 sep="\n\n",
203 sep2="<|end▁of▁sentence|>",
204 stop_token_ids=[100001],
205 stop_str=["User:", "<|end▁of▁sentence|>"]
206 )
207 )
208 register_conv_template(
209 Conversation(
210 name="deepseekv2",
211 system_template="{system_message}",
212 # system_message="You are a helpful assistant. Please answer truthfully and write out your "
213 # "thinking step by step to be sure you get the right answer.",
214 system_message="",
215 roles=("<|User|>", "<|Assistant|>"),
216 messages=(),
217 offset=0,
218 sep_style=SeparatorStyle.DeepSeek,
219 sep="",
220 sep2="<|end▁of▁sentence|>",
221 stop_token_ids=[100001],
222 stop_str=["User:", "<|end▁of▁sentence|>"]
223 )
224 )
225
226
227 register_conv_template(
228 Conversation(
229 name="plain",
230 system_template="",
231 system_message="",
232 roles=("", ""),
233 messages=(),
234 offset=0,
235 sep_style=SeparatorStyle.PLAIN,
236 sep="",
237 sep2="",
238 stop_token_ids=[100001],
239 stop_str=['</s>'],
240 )
241 )
242
243
244 register_conv_template(
245 Conversation(
246 name="alignment",
247 system_template="",
248 system_message="",
249 roles=("", ""),
250 messages=(),
251 offset=0,
252 sep_style=SeparatorStyle.ALIGNMENT,
253 sep="",
254 sep2="",
255 stop_token_ids=[100001],
256 stop_str=['</s>'],
257 )
258 )
259
260
261 if __name__ == "__main__":
262 print("deepseek template:")
263 conv = get_conv_template("deepseek")
264 conv.append_message(conv.roles[0], "Hello!")
265 conv.append_message(conv.roles[1], "Hi! This is Tony.")
266 conv.append_message(conv.roles[0], "Who are you?")
267 conv.append_message(conv.roles[1], "I am a helpful assistant.")
268 conv.append_message(conv.roles[0], "How are you?")
269 conv.append_message(conv.roles[1], None)
270 print(conv.get_prompt())
271
272 print("deepseekv2 template:")
273 conv = get_conv_template("deepseekv2")
274 conv.append_message(conv.roles[0], "Hello!")
275 conv.append_message(conv.roles[1], "Hi! This is Tony.")
276 conv.append_message(conv.roles[0], "Who are you?")
277 conv.append_message(conv.roles[1], "I am a helpful assistant.")
278 conv.append_message(conv.roles[0], "How are you?")
279 conv.append_message(conv.roles[1], None)
280 print(conv.get_prompt())
281