qwen3coder_tool_parser.py
30.9 KB · 690 lines · python Raw
1 # SPDX-License-Identifier: Apache-2.0
2 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3 import ast
4 import json
5 import uuid
6 from collections.abc import Sequence
7 from typing import Any, List, Optional, Union
8
9 import regex as re
10
11 from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
12 ChatCompletionToolsParam,
13 DeltaFunctionCall, DeltaMessage,
14 DeltaToolCall,
15 ExtractedToolCallInformation,
16 FunctionCall, ToolCall)
17 from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
18 ToolParser, ToolParserManager)
19 from vllm.logger import init_logger
20 from vllm.transformers_utils.tokenizer import AnyTokenizer
21
22 logger = init_logger(__name__)
23
24
25 @ToolParserManager.register_module("qwen3_coder")
26 class Qwen3CoderToolParser(ToolParser):
27
28 def __init__(self, tokenizer: AnyTokenizer):
29 super().__init__(tokenizer)
30
31 self.current_tool_name_sent: bool = False
32 self.prev_tool_call_arr: list[dict] = []
33 self.current_tool_id: int = -1
34 self.streamed_args_for_tool: list[str] = []
35
36 # Sentinel tokens for streaming mode
37 self.tool_call_start_token: str = "<tool_call>"
38 self.tool_call_end_token: str = "</tool_call>"
39 self.tool_call_prefix: str = "<function="
40 self.function_end_token: str = "</function>"
41 self.parameter_prefix: str = "<parameter="
42 self.parameter_end_token: str = "</parameter>"
43 self.is_tool_call_started: bool = False
44 self.failed_count: int = 0
45
46 # Enhanced streaming state - reset for each new message
47 self._reset_streaming_state()
48
49 # Regex patterns
50 self.tool_call_complete_regex = re.compile(
51 r"<tool_call>(.*?)</tool_call>", re.DOTALL)
52 self.tool_call_regex = re.compile(
53 r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL)
54 self.tool_call_function_regex = re.compile(
55 r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
56 self.tool_call_parameter_regex = re.compile(
57 r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
58 re.DOTALL)
59
60 if not self.model_tokenizer:
61 raise ValueError(
62 "The model tokenizer must be passed to the ToolParser "
63 "constructor during construction.")
64
65 self.tool_call_start_token_id = self.vocab.get(
66 self.tool_call_start_token)
67 self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
68
69 if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
70 raise RuntimeError(
71 "Qwen3 XML Tool parser could not locate tool call start/end "
72 "tokens in the tokenizer!")
73
74 logger.info(
75 f"vLLM Successfully import tool parser {self.__class__.__name__} !"
76 )
77
78 def _generate_tool_call_id(self) -> str:
79 """Generate a unique tool call ID."""
80 return f"call_{uuid.uuid4().hex[:24]}"
81
82 def _reset_streaming_state(self):
83 """Reset all streaming state."""
84 self.current_tool_index = 0
85 self.is_tool_call_started = False
86 self.header_sent = False
87 self.current_tool_id = None
88 self.current_function_name = None
89 self.current_param_name = None
90 self.current_param_value = ""
91 self.param_count = 0
92 self.in_param = False
93 self.in_function = False
94 self.accumulated_text = ""
95 self.json_started = False
96 self.json_closed = False
97 # Store accumulated parameters for type conversion
98 self.accumulated_params = {}
99 self.streaming_request = None
100
101 def _get_arguments_config(
102 self, func_name: str,
103 tools: Optional[list[ChatCompletionToolsParam]]) -> dict:
104 """Extract argument configuration for a function."""
105 if tools is None:
106 return {}
107 for config in tools:
108 if not hasattr(config, "type") or not (hasattr(
109 config, "function") and hasattr(config.function, "name")):
110 continue
111 if config.type == "function" and config.function.name == func_name:
112 if not hasattr(config.function, "parameters"):
113 return {}
114 params = config.function.parameters
115 if isinstance(params, dict) and "properties" in params:
116 return params["properties"]
117 elif isinstance(params, dict):
118 return params
119 else:
120 return {}
121 logger.warning(f"Tool '{func_name}' is not defined in the tools list.")
122 return {}
123
124 def _convert_param_value(self, param_value: str, param_name: str,
125 param_config: dict, func_name: str) -> Any:
126 """Convert parameter value based on its type in the schema."""
127 # Handle null value for any type
128 if param_value.lower() == "null":
129 return None
130
131 if param_name not in param_config:
132 if param_config != {}:
133 logger.warning(
134 f"Parsed parameter '{param_name}' is not defined in the tool "
135 f"parameters for tool '{func_name}', directly returning the string value."
136 )
137 return param_value
138
139 if isinstance(param_config[param_name],
140 dict) and "type" in param_config[param_name]:
141 param_type = str(param_config[param_name]["type"]).strip().lower()
142 else:
143 param_type = "string"
144 if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
145 return param_value
146 elif param_type.startswith("int") or param_type.startswith(
147 "uint") or param_type.startswith(
148 "long") or param_type.startswith(
149 "short") or param_type.startswith("unsigned"):
150 try:
151 param_value = int(param_value)
152 except:
153 logger.warning(
154 f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool "
155 f"'{func_name}', degenerating to string.")
156 return param_value
157 elif param_type.startswith("num") or param_type.startswith("float"):
158 try:
159 float_param_value = float(param_value)
160 param_value = float_param_value if float_param_value - int(
161 float_param_value) != 0 else int(float_param_value)
162 except:
163 logger.warning(
164 f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool "
165 f"'{func_name}', degenerating to string.")
166 return param_value
167 elif param_type in ["boolean", "bool", "binary"]:
168 param_value = param_value.lower()
169 if param_value not in ["true", "false"]:
170 logger.warning(
171 f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false."
172 )
173 return param_value == "true"
174 else:
175 if param_type in ["object", "array", "arr"
176 ] or param_type.startswith(
177 "dict") or param_type.startswith("list"):
178 try:
179 param_value = json.loads(param_value)
180 return param_value
181 except:
182 logger.warning(
183 f"Parsed value '{param_value}' of parameter '{param_name}' cannot be parsed with json.loads in tool "
184 f"'{func_name}', will try other methods to parse it.")
185 try:
186 param_value = ast.literal_eval(param_value) # safer
187 except:
188 logger.warning(
189 f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `ast.literal_eval()` in tool '{func_name}', degenerating to string."
190 )
191 return param_value
192
193 def _parse_xml_function_call(
194 self, function_call_str: str,
195 tools: Optional[list[ChatCompletionToolsParam]]
196 ) -> Optional[ToolCall]:
197
198 # Extract function name
199 end_index = function_call_str.index(">")
200 function_name = function_call_str[:end_index]
201 param_config = self._get_arguments_config(function_name, tools)
202 parameters = function_call_str[end_index + 1:]
203 param_dict = {}
204 for match_text in self.tool_call_parameter_regex.findall(parameters):
205 idx = match_text.index(">")
206 param_name = match_text[:idx]
207 param_value = str(match_text[idx + 1:])
208 # Remove prefix and trailing \n
209 if param_value.startswith("\n"):
210 param_value = param_value[1:]
211 if param_value.endswith("\n"):
212 param_value = param_value[:-1]
213
214 param_dict[param_name] = self._convert_param_value(
215 param_value, param_name, param_config, function_name)
216 return ToolCall(
217 type="function",
218 function=FunctionCall(name=function_name,
219 arguments=json.dumps(param_dict,
220 ensure_ascii=False)),
221 )
222
223 def _get_function_calls(self, model_output: str) -> List[str]:
224 # Find all tool calls
225 matched_ranges = self.tool_call_regex.findall(model_output)
226 raw_tool_calls = [
227 match[0] if match[0] else match[1] for match in matched_ranges
228 ]
229
230 # Back-off strategy if no tool_call tags found
231 if len(raw_tool_calls) == 0:
232 raw_tool_calls = [model_output]
233
234 raw_function_calls = []
235 for tool_call in raw_tool_calls:
236 raw_function_calls.extend(
237 self.tool_call_function_regex.findall(tool_call))
238
239 function_calls = [
240 match[0] if match[0] else match[1] for match in raw_function_calls
241 ]
242 return function_calls
243
244 def extract_tool_calls(
245 self,
246 model_output: str,
247 request: ChatCompletionRequest,
248 ) -> ExtractedToolCallInformation:
249 # Quick check to avoid unnecessary processing
250 if self.tool_call_prefix not in model_output:
251 return ExtractedToolCallInformation(tools_called=False,
252 tool_calls=[],
253 content=model_output)
254
255 try:
256 function_calls = self._get_function_calls(model_output)
257 if len(function_calls) == 0:
258 return ExtractedToolCallInformation(tools_called=False,
259 tool_calls=[],
260 content=model_output)
261
262 tool_calls = [
263 self._parse_xml_function_call(function_call_str, request.tools)
264 for function_call_str in function_calls
265 ]
266
267 # Populate prev_tool_call_arr for serving layer to set finish_reason
268 self.prev_tool_call_arr.clear() # Clear previous calls
269 for tool_call in tool_calls:
270 if tool_call:
271 self.prev_tool_call_arr.append({
272 "name":
273 tool_call.function.name,
274 "arguments":
275 tool_call.function.arguments,
276 })
277
278 # Extract content before tool calls
279 content_index = model_output.find(self.tool_call_start_token)
280 content_index = content_index if content_index >= 0 else model_output.find(
281 self.tool_call_prefix)
282 content = model_output[:content_index] # .rstrip()
283
284 return ExtractedToolCallInformation(
285 tools_called=(len(tool_calls) > 0),
286 tool_calls=tool_calls,
287 content=content if content else None,
288 )
289
290 except Exception:
291 logger.exception("Error in extracting tool call from response.")
292 return ExtractedToolCallInformation(tools_called=False,
293 tool_calls=[],
294 content=model_output)
295
296 def extract_tool_calls_streaming(
297 self,
298 previous_text: str,
299 current_text: str,
300 delta_text: str,
301 previous_token_ids: Sequence[int],
302 current_token_ids: Sequence[int],
303 delta_token_ids: Sequence[int],
304 request: ChatCompletionRequest,
305 ) -> Union[DeltaMessage, None]:
306 # Store request for type conversion
307 if not previous_text:
308 self._reset_streaming_state()
309 self.streaming_request = request
310
311 # If no delta text, return None unless it's an EOS token after tool calls
312 if not delta_text:
313 # Check if this is an EOS token after all tool calls are complete
314 # We check for tool calls in the text even if is_tool_call_started is False
315 # because it might have been reset after processing all tools
316 if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
317 # Count complete tool calls
318 complete_calls = len(
319 self.tool_call_complete_regex.findall(current_text))
320
321 # If we have completed tool calls and populated prev_tool_call_arr
322 if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
323 # Check if all tool calls are closed
324 open_calls = current_text.count(
325 self.tool_call_start_token) - current_text.count(
326 self.tool_call_end_token)
327 if open_calls == 0:
328 # Return empty delta message to allow finish_reason processing
329 return DeltaMessage(content="")
330 elif not self.is_tool_call_started and current_text:
331 # This is a regular content response that's now complete
332 return DeltaMessage(content="")
333 return None
334
335 # Update accumulated text
336 self.accumulated_text = current_text
337
338 # Check if we need to advance to next tool
339 if self.json_closed and not self.in_function:
340 # Check if this tool call has ended
341 tool_ends = current_text.count(self.tool_call_end_token)
342 if tool_ends > self.current_tool_index:
343 # This tool has ended, advance to next
344 self.current_tool_index += 1
345 self.header_sent = False
346 self.param_count = 0
347 self.json_started = False
348 self.json_closed = False
349 self.accumulated_params = {}
350
351 # Check if there are more tool calls
352 tool_starts = current_text.count(self.tool_call_start_token)
353 if self.current_tool_index >= tool_starts:
354 # No more tool calls
355 self.is_tool_call_started = False
356 # Continue processing next tool
357 return None
358
359 # Handle normal content before tool calls
360 if not self.is_tool_call_started:
361 # Check if tool call is starting
362 if self.tool_call_start_token_id in delta_token_ids or self.tool_call_start_token in delta_text:
363 self.is_tool_call_started = True
364 # Return any content before the tool call
365 if self.tool_call_start_token in delta_text:
366 content_before = delta_text[:delta_text.index(
367 self.tool_call_start_token)]
368 if content_before:
369 return DeltaMessage(content=content_before)
370 return None
371 else:
372 # Check if we're between tool calls - skip whitespace
373 if current_text.rstrip().endswith(self.tool_call_end_token):
374 # We just ended a tool call, skip whitespace
375 if delta_text.strip() == "":
376 return None
377 # Normal content, no tool call
378 return DeltaMessage(content=delta_text)
379
380 # Check if we're between tool calls (waiting for next one)
381 # Count tool calls we've seen vs processed
382 tool_starts_count = current_text.count(self.tool_call_start_token)
383 if self.current_tool_index >= tool_starts_count:
384 # We're past all tool calls, shouldn't be here
385 return None
386
387 # We're in a tool call, find the current tool call portion
388 # Need to find the correct tool call based on current_tool_index
389 tool_starts = []
390 idx = 0
391 while True:
392 idx = current_text.find(self.tool_call_start_token, idx)
393 if idx == -1:
394 break
395 tool_starts.append(idx)
396 idx += len(self.tool_call_start_token)
397
398 if self.current_tool_index >= len(tool_starts):
399 # No more tool calls to process yet
400 return None
401
402 tool_start_idx = tool_starts[self.current_tool_index]
403 # Find where this tool call ends (or current position if not ended yet)
404 tool_end_idx = current_text.find(self.tool_call_end_token,
405 tool_start_idx)
406 if tool_end_idx == -1:
407 tool_text = current_text[tool_start_idx:]
408 else:
409 tool_text = current_text[tool_start_idx:tool_end_idx +
410 len(self.tool_call_end_token)]
411
412 # Looking for function header
413 if not self.header_sent:
414 if self.tool_call_prefix in tool_text:
415 func_start = tool_text.find(self.tool_call_prefix) + len(
416 self.tool_call_prefix)
417 func_end = tool_text.find(">", func_start)
418
419 if func_end != -1:
420 # Found complete function name
421 self.current_function_name = tool_text[func_start:func_end]
422 self.current_tool_id = self._generate_tool_call_id()
423 self.header_sent = True
424 self.in_function = True
425
426 # IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
427 # This ensures finish_reason="tool_calls" even if parsing isn't complete
428 already_added = any(
429 tool.get("name") == self.current_function_name
430 for tool in self.prev_tool_call_arr)
431 if not already_added:
432 self.prev_tool_call_arr.append({
433 "name": self.current_function_name,
434 "arguments":
435 "{}", # Placeholder, will be updated later
436 })
437
438 # Send header with function info
439 return DeltaMessage(tool_calls=[
440 DeltaToolCall(
441 index=self.current_tool_index,
442 id=self.current_tool_id,
443 function=DeltaFunctionCall(
444 name=self.current_function_name, arguments=""),
445 type="function",
446 )
447 ])
448 return None
449
450 # We've sent header, now handle function body
451 if self.in_function:
452 # Send opening brace if not sent yet
453 if not self.json_started and self.parameter_prefix not in delta_text:
454 self.json_started = True
455 return DeltaMessage(tool_calls=[
456 DeltaToolCall(
457 index=self.current_tool_index,
458 function=DeltaFunctionCall(arguments="{"),
459 )
460 ])
461
462 # Make sure json_started is set if we're processing parameters
463 if not self.json_started:
464 self.json_started = True
465
466 # Check for function end in accumulated text
467 if not self.json_closed and self.function_end_token in tool_text:
468 # Close JSON
469 self.json_closed = True
470
471 # Extract the complete tool call to update prev_tool_call_arr with final arguments
472 # Find the function content
473 func_start = tool_text.find(self.tool_call_prefix) + len(
474 self.tool_call_prefix)
475 func_content_end = tool_text.find(self.function_end_token,
476 func_start)
477 if func_content_end != -1:
478 func_content = tool_text[func_start:func_content_end]
479 # Parse to get the complete arguments
480 try:
481 parsed_tool = self._parse_xml_function_call(
482 func_content, self.streaming_request.tools
483 if self.streaming_request else None)
484 if parsed_tool:
485 # Update existing entry in prev_tool_call_arr with complete arguments
486 for i, tool in enumerate(self.prev_tool_call_arr):
487 if tool.get(
488 "name") == parsed_tool.function.name:
489 self.prev_tool_call_arr[i][
490 "arguments"] = parsed_tool.function.arguments
491 break
492 except Exception:
493 pass # Ignore parsing errors during streaming
494
495 result = DeltaMessage(tool_calls=[
496 DeltaToolCall(
497 index=self.current_tool_index,
498 function=DeltaFunctionCall(arguments="}"),
499 )
500 ])
501
502 # Reset state for next tool
503 self.in_function = False
504 self.json_closed = True
505 self.accumulated_params = {}
506
507 return result
508
509 # Look for parameters
510 # Find all parameter starts
511 param_starts = []
512 idx = 0
513 while True:
514 idx = tool_text.find(self.parameter_prefix, idx)
515 if idx == -1:
516 break
517 param_starts.append(idx)
518 idx += len(self.parameter_prefix)
519
520 # Check if we should start a new parameter
521 if not self.in_param and self.param_count < len(param_starts):
522
523 if len(param_starts) > self.param_count:
524 # Process the next parameter
525 param_idx = param_starts[self.param_count]
526 param_start = param_idx + len(self.parameter_prefix)
527 remaining = tool_text[param_start:]
528
529 if ">" in remaining:
530 # We have the complete parameter name
531 name_end = remaining.find(">")
532 self.current_param_name = remaining[:name_end]
533
534 # Find the parameter value
535 value_start = param_start + name_end + 1
536 value_text = tool_text[value_start:]
537 if value_text.startswith("\n"):
538 value_text = value_text[1:]
539
540 # Find where this parameter ends
541 param_end_idx = value_text.find(
542 self.parameter_end_token)
543 if param_end_idx == -1:
544 # No closing tag, look for next parameter or function end
545 next_param_idx = value_text.find(
546 self.parameter_prefix)
547 func_end_idx = value_text.find(
548 self.function_end_token)
549
550 if next_param_idx != -1 and (func_end_idx == -1
551 or next_param_idx
552 < func_end_idx):
553 param_end_idx = next_param_idx
554 elif func_end_idx != -1:
555 param_end_idx = func_end_idx
556 else:
557 # Neither found, check if tool call is complete
558 if self.tool_call_end_token in tool_text:
559 # Tool call is complete, so parameter must be complete too
560 # Use all remaining text before function end as value
561 param_end_idx = len(value_text)
562 else:
563 # Still streaming, wait for more content
564 return None
565
566 if param_end_idx != -1:
567 # Complete parameter found
568 param_value = value_text[:param_end_idx]
569 if param_value.endswith("\n"):
570 param_value = param_value[:-1]
571
572 # Store raw value for later processing
573 self.accumulated_params[
574 self.current_param_name] = param_value
575
576 # Get parameter configuration for type conversion
577 param_config = self._get_arguments_config(
578 self.current_function_name,
579 self.streaming_request.tools
580 if self.streaming_request else None)
581
582 # Convert the parameter value to the appropriate type
583 converted_value = self._convert_param_value(
584 param_value, self.current_param_name,
585 param_config, self.current_function_name)
586
587 # Build JSON fragment based on the converted type
588 # Use json.dumps to properly serialize the value
589 serialized_value = json.dumps(converted_value,
590 ensure_ascii=False)
591
592 if self.param_count == 0:
593 json_fragment = f'"{self.current_param_name}": {serialized_value}'
594 else:
595 json_fragment = f', "{self.current_param_name}": {serialized_value}'
596
597 self.param_count += 1
598
599 return DeltaMessage(tool_calls=[
600 DeltaToolCall(
601 index=self.current_tool_index,
602 function=DeltaFunctionCall(
603 arguments=json_fragment),
604 )
605 ])
606
607 # Continue parameter value - Not used in the current implementation
608 # since we process complete parameters above
609 if self.in_param:
610 if self.parameter_end_token in delta_text:
611 # End of parameter
612 end_idx = delta_text.find(self.parameter_end_token)
613 value_chunk = delta_text[:end_idx]
614
615 # Skip past > if at start
616 if not self.current_param_value and ">" in value_chunk:
617 gt_idx = value_chunk.find(">")
618 value_chunk = value_chunk[gt_idx + 1:]
619
620 if not self.current_param_value and value_chunk.startswith(
621 "\n"):
622 value_chunk = value_chunk[1:]
623
624 # Store complete value
625 full_value = self.current_param_value + value_chunk
626 self.accumulated_params[
627 self.current_param_name] = full_value
628
629 # Get parameter configuration for type conversion
630 param_config = self._get_arguments_config(
631 self.current_function_name,
632 self.streaming_request.tools
633 if self.streaming_request else None)
634
635 # Convert the parameter value to the appropriate type
636 converted_value = self._convert_param_value(
637 full_value, self.current_param_name, param_config,
638 self.current_function_name)
639
640 # Serialize the converted value
641 serialized_value = json.dumps(converted_value,
642 ensure_ascii=False)
643
644 # Since we've been streaming the quoted version, we need to close it properly
645 # This is complex - for now just complete the value
646 self.in_param = False
647 self.current_param_value = ""
648
649 # Just close the current parameter string
650 return DeltaMessage(tool_calls=[
651 DeltaToolCall(
652 index=self.current_tool_index,
653 function=DeltaFunctionCall(
654 arguments='"'), # Close the string quote
655 )
656 ])
657 else:
658 # Continue accumulating value
659 value_chunk = delta_text
660
661 # Handle first chunk after param name
662 if not self.current_param_value and ">" in value_chunk:
663 gt_idx = value_chunk.find(">")
664 value_chunk = value_chunk[gt_idx + 1:]
665
666 if not self.current_param_value and value_chunk.startswith(
667 "\n"):
668 value_chunk = value_chunk[1:]
669
670 if value_chunk:
671 # Stream the escaped delta
672 prev_escaped = json.dumps(
673 self.current_param_value, ensure_ascii=False
674 )[1:-1] if self.current_param_value else ""
675 self.current_param_value += value_chunk
676 full_escaped = json.dumps(self.current_param_value,
677 ensure_ascii=False)[1:-1]
678 delta_escaped = full_escaped[len(prev_escaped):]
679
680 if delta_escaped:
681 return DeltaMessage(tool_calls=[
682 DeltaToolCall(
683 index=self.current_tool_index,
684 function=DeltaFunctionCall(
685 arguments=delta_escaped),
686 )
687 ])
688
689 return None
690