encoding/encoding_dsv32.py
14.0 KB · 377 lines · python Raw
1 from typing import Any, Dict, List, Union, Optional, Tuple
2 import copy
3 import json
4 import re
5
6 TOOLS_SYSTEM_TEMPLATE = """## Tools
7
8 You have access to a set of tools you can use to answer the user's question.
9 You can invoke functions by writing a "<{dsml_token}function_calls>" block like the following as part of your reply to the user:
10 <{dsml_token}function_calls>
11 <{dsml_token}invoke name="$FUNCTION_NAME">
12 <{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
13 ...
14 </{dsml_token}invoke>
15 <{dsml_token}invoke name="$FUNCTION_NAME2">
16 ...
17 </{dsml_token}invoke>
18 </{dsml_token}function_calls>
19
20 String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects).
21
22 If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example:
23
24 <{dsml_token}function_calls>
25 ...
26 </{dsml_token}function_calls>
27
28 <function_results>
29 ...
30 </function_results>
31
32 {thinking_start_token}...thinking about results{thinking_end_token}
33
34 Here are the functions available in JSONSchema format:
35 <functions>
36 {tool_schemas}
37 </functions>
38 """
39
40 bos_token: str = "<|begin▁of▁sentence|>"
41 eos_token: str = "<|end▁of▁sentence|>"
42 thinking_start_token: str = "<think>"
43 thinking_end_token: str = "</think>"
44 dsml_token: str = "|DSML|"
45 system_msg_template: str = "{content}"
46 user_msg_template: str = "<|User|>{content}<|Assistant|>"
47 assistant_msg_template: str = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>"
48 thinking_template = "{reasoning_content}"
49
50 response_format_template: str = (
51 "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
52 )
53 tool_call_template: str = (
54 "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
55 )
56 tool_calls_template = (
57 "<{dsml_token}function_calls>\n{tool_calls}\n</{dsml_token}function_calls>"
58 )
59
60 tool_output_template: str = (
61 "\n<result>{content}</result>"
62 )
63
64 def to_json(value: Any) -> str:
65 try:
66 return json.dumps(value, ensure_ascii=False)
67 except:
68 return json.dumps(value, ensure_ascii=True)
69
70 def tools_from_openai_format(tools):
71 return [tool["function"] for tool in tools]
72
73 def tool_calls_from_openai_format(tool_calls):
74 return [
75 {
76 "name": tool_call["function"]["name"],
77 "arguments": tool_call["function"]["arguments"],
78 }
79 for tool_call in tool_calls
80 ]
81
82 def tool_calls_to_openai_format(tool_calls):
83 return [
84 {
85 "type": "function",
86 "function": {
87 "name": tool_call["name"],
88 "arguments": tool_call["arguments"],
89 }
90 }
91 for tool_call in tool_calls
92 ]
93
94 def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
95 p_dsml_template = """<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>"""
96 P_dsml_strs = []
97
98 arguments = json.loads(tool_call["arguments"])
99
100 for k, v in arguments.items():
101 p_dsml_str = p_dsml_template.format(
102 dsml_token=dsml_token,
103 key=k,
104 is_str="true" if isinstance(v, str) else "false",
105 value=v if isinstance(v, str) else to_json(v),
106 )
107
108 P_dsml_strs.append(p_dsml_str)
109
110 return "\n".join(P_dsml_strs)
111
112
113 def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
114 def _decode_value(key: str, value: str, string: str):
115 if string == "true":
116 value = to_json(value)
117 return f"{to_json(key)}: {value}"
118
119 tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
120 return dict(name=tool_name, arguments=tool_args_json)
121
122 def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
123 tools_json = [to_json(t) for t in tools]
124
125 return TOOLS_SYSTEM_TEMPLATE.format(
126 tool_schemas="\n".join(tools_json),
127 dsml_token=dsml_token,
128 thinking_start_token=thinking_start_token,
129 thinking_end_token=thinking_end_token,
130 )
131
132 def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
133 last_user_index = -1
134 for idx in range(len(messages)-1, -1, -1):
135 if messages[idx].get("role") in ["user", "developer"]:
136 last_user_index = idx
137 break
138 return last_user_index
139
140 def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str) -> str:
141 assert 0 <= index < len(messages)
142 assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
143
144 prompt = ""
145 msg = messages[index]
146 last_user_idx = find_last_user_index(messages)
147
148 role = msg.get("role")
149 content = msg.get("content")
150 tools = msg.get("tools")
151 response_format = msg.get("response_format")
152 tool_calls = msg.get("tool_calls")
153 reasoning_content = msg.get("reasoning_content")
154
155 if tools:
156 tools = tools_from_openai_format(tools)
157 if tool_calls:
158 tool_calls = tool_calls_from_openai_format(tool_calls)
159
160 if role == "system":
161 prompt += system_msg_template.format(content=content or "")
162 if tools:
163 prompt += "\n\n" + render_tools(tools)
164
165 if response_format:
166 prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
167
168 elif role == "developer":
169 assert content, f"Invalid message for role `{role}`: {msg}"
170 content_developer = ""
171 if tools:
172 content_developer += "\n\n" + render_tools(tools)
173
174 if response_format:
175 content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
176
177 content_developer += "\n\n# The user's message is: {}".format(content)
178
179 prompt += user_msg_template.format(content=content_developer)
180 if index == last_user_idx and thinking_mode == "thinking":
181 prompt += thinking_start_token
182 else:
183 prompt += thinking_end_token
184
185 elif role == "user":
186 prompt += user_msg_template.format(content=content)
187
188 if index == last_user_idx and thinking_mode == "thinking":
189 prompt += thinking_start_token
190 else:
191 prompt += thinking_end_token
192
193 elif role == "tool":
194 prev_assistant_idx = index - 1
195 assistant_msg = messages[prev_assistant_idx]
196 while prev_assistant_idx >= 0 and assistant_msg.get("role") == "tool":
197 prev_assistant_idx -= 1
198 assistant_msg = messages[prev_assistant_idx]
199
200 assert index == 0 or prev_assistant_idx >= 0 and assistant_msg.get("role") == "assistant", f"Invalid messages at {index}:\n{assistant_msg}"
201
202 tool_call_order = index - prev_assistant_idx
203 assistant_tool_calls = assistant_msg.get("tool_calls")
204 assert assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order, "No tool calls but found tool output"
205
206 if tool_call_order == 1:
207 prompt += "\n\n<function_results>"
208
209 prompt += tool_output_template.format(content=content)
210
211 if tool_call_order == len(assistant_tool_calls):
212 prompt += "\n</function_results>"
213
214 if index >= last_user_idx and thinking_mode == "thinking":
215 prompt += "\n\n" + thinking_start_token
216 else:
217 prompt += "\n\n" + thinking_end_token
218
219 elif role == "assistant":
220 prev_assistant_idx = index
221 thinking_part = ""
222
223 tool_calls_content = ""
224 if tool_calls:
225 tool_calls = [
226 tool_call_template.format(
227 dsml_token=dsml_token,
228 name=tool_call.get("name"),
229 arguments=encode_arguments_to_dsml(tool_call)
230 )
231 for tool_call in tool_calls
232 ]
233 tool_calls_content += "\n\n" + tool_calls_template.format(
234 dsml_token=dsml_token,
235 tool_calls="\n".join(tool_calls)
236 )
237
238 summary_content = content or ""
239
240 if thinking_mode == "thinking" and index > last_user_idx:
241 assert reasoning_content or tool_calls, f"ThinkingMode: {thinking_mode}, invalid message without reasoning_content/tool_calls `{msg}` after last user message"
242 thinking_part = thinking_template.format(reasoning_content=reasoning_content or "") + thinking_end_token
243
244 prompt += assistant_msg_template.format(
245 reasoning=thinking_part,
246 content=summary_content,
247 tool_calls=tool_calls_content,
248 )
249 else:
250 raise NotImplementedError(f"Unknown role: {role}")
251
252 return prompt
253
254 def drop_thinking_messages(messages: List[Dict[str, Any]], last_user_idx: Optional[int]=None) -> List[Dict[str, Any]]:
255 messages_wo_thinking: List[Dict[str, Any]] = []
256 last_user_idx = find_last_user_index(messages) if last_user_idx is None else last_user_idx
257 for idx, msg in enumerate(messages):
258 role = msg.get("role")
259 if role in ["user", "system", "tool"] or idx >= last_user_idx:
260 messages_wo_thinking.append(msg)
261 continue
262
263 elif role == "assistant":
264 msg_wo_thinking = copy.copy(msg)
265 msg_wo_thinking.pop("reasoning_content", None)
266 messages_wo_thinking.append(msg_wo_thinking)
267
268 return messages_wo_thinking
269
270 def encode_messages(messages: List[Dict[str, Any]], thinking_mode: str, context: Optional[List[Dict[str, Any]]] = None, drop_thinking: bool = True, add_default_bos_token: bool = True) -> str:
271 context = context if context else []
272 full_messages = context + messages
273
274 prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
275
276 if thinking_mode == "thinking" and drop_thinking:
277 full_messages = drop_thinking_messages(full_messages)
278
279 for idx in range(len(messages)):
280 prompt += render_message(idx + len(context), full_messages, thinking_mode=thinking_mode)
281
282 return prompt
283
284 def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
285 min_pos = len(text)
286 matched_stop = None
287
288 for s in stop:
289 pos = text.find(s, index)
290 if pos != -1 and pos < min_pos:
291 min_pos = pos
292 matched_stop = s
293
294 if matched_stop:
295 content = text[index:min_pos]
296 return min_pos + len(matched_stop), content, matched_stop
297 else:
298 content = text[index:]
299 return len(text), content, None
300
301 def parse_tool_calls(index: int, text: str):
302 tool_calls: List[Dict[str, Any]] = []
303 stop_token = None
304 tool_calls_end_token = f"</{dsml_token}function_calls>"
305
306 while index < len(text):
307 index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
308 assert _ == ">\n", "Tool call format error"
309
310 if stop_token == tool_calls_end_token:
311 break
312
313 assert stop_token is not None, "Missing special token"
314
315 index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
316
317 p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
318 assert len(p_tool_name) == 1, "Tool name format error"
319 tool_name = p_tool_name[0]
320
321 tool_args: Dict[str, Tuple[str, str]] = {}
322 while stop_token == f"<{dsml_token}parameter":
323 index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
324
325 param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
326 assert len(param_kv) == 1, "Parameter format error"
327 param_name, string, param_value = param_kv[0]
328
329 assert param_name not in tool_args, "Duplicate parameter name"
330 tool_args[param_name] = (param_value, string)
331
332 index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
333 assert content == ">\n", "Parameter format error"
334
335 tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
336 tool_calls.append(tool_call)
337
338 return index, stop_token, tool_calls
339
340 # NOTE: This function is designed to parse only correctly formatted string and will not attempt to correct malformed output that may be generated by the model.
341 def parse_message_from_completion_text(text: str, thinking_mode: str):
342 summary_content, reasoning_content, tool_calls = "", "", []
343 index, stop_token = 0, None
344 tool_calls_start_token = f"\n\n<{dsml_token}function_calls"
345
346 is_thinking, is_tool_calling = thinking_mode == "thinking", False
347
348 if is_thinking:
349 index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
350 reasoning_content = content_delta
351 assert stop_token == thinking_end_token, "Invalid thinking format"
352
353 index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
354 summary_content = content_delta
355 if stop_token == tool_calls_start_token:
356 is_tool_calling = True
357 else:
358 assert stop_token == eos_token, "Invalid summary format"
359
360 if is_tool_calling:
361 index, stop_token, tool_calls = parse_tool_calls(index, text)
362
363 index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
364 assert not tool_ends_text, "Unexpected content after tool calls"
365
366 assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
367
368 for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
369 assert sp_token not in summary_content and sp_token not in reasoning_content, "Unexpected special token in content"
370
371 return {
372 "role": "assistant",
373 "content": summary_content,
374 "reasoning_content": reasoning_content,
375 "tool_calls": tool_calls_to_openai_format(tool_calls)
376 }
377