qwen3coder_tool_parser.py
| 1 | # SPDX-License-Identifier: Apache-2.0 |
| 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | import ast |
| 4 | import json |
| 5 | import uuid |
| 6 | from collections.abc import Sequence |
| 7 | from typing import Any, List, Optional, Union |
| 8 | |
| 9 | import regex as re |
| 10 | |
| 11 | from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, |
| 12 | ChatCompletionToolsParam, |
| 13 | DeltaFunctionCall, DeltaMessage, |
| 14 | DeltaToolCall, |
| 15 | ExtractedToolCallInformation, |
| 16 | FunctionCall, ToolCall) |
| 17 | from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
| 18 | ToolParser, ToolParserManager) |
| 19 | from vllm.logger import init_logger |
| 20 | from vllm.transformers_utils.tokenizer import AnyTokenizer |
| 21 | |
| 22 | logger = init_logger(__name__) |
| 23 | |
| 24 | |
| 25 | @ToolParserManager.register_module("qwen3_coder") |
| 26 | class Qwen3CoderToolParser(ToolParser): |
| 27 | |
| 28 | def __init__(self, tokenizer: AnyTokenizer): |
| 29 | super().__init__(tokenizer) |
| 30 | |
| 31 | self.current_tool_name_sent: bool = False |
| 32 | self.prev_tool_call_arr: list[dict] = [] |
| 33 | self.current_tool_id: int = -1 |
| 34 | self.streamed_args_for_tool: list[str] = [] |
| 35 | |
| 36 | # Sentinel tokens for streaming mode |
| 37 | self.tool_call_start_token: str = "<tool_call>" |
| 38 | self.tool_call_end_token: str = "</tool_call>" |
| 39 | self.tool_call_prefix: str = "<function=" |
| 40 | self.function_end_token: str = "</function>" |
| 41 | self.parameter_prefix: str = "<parameter=" |
| 42 | self.parameter_end_token: str = "</parameter>" |
| 43 | self.is_tool_call_started: bool = False |
| 44 | self.failed_count: int = 0 |
| 45 | |
| 46 | # Enhanced streaming state - reset for each new message |
| 47 | self._reset_streaming_state() |
| 48 | |
| 49 | # Regex patterns |
| 50 | self.tool_call_complete_regex = re.compile( |
| 51 | r"<tool_call>(.*?)</tool_call>", re.DOTALL) |
| 52 | self.tool_call_regex = re.compile( |
| 53 | r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL) |
| 54 | self.tool_call_function_regex = re.compile( |
| 55 | r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) |
| 56 | self.tool_call_parameter_regex = re.compile( |
| 57 | r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)", |
| 58 | re.DOTALL) |
| 59 | |
| 60 | if not self.model_tokenizer: |
| 61 | raise ValueError( |
| 62 | "The model tokenizer must be passed to the ToolParser " |
| 63 | "constructor during construction.") |
| 64 | |
| 65 | self.tool_call_start_token_id = self.vocab.get( |
| 66 | self.tool_call_start_token) |
| 67 | self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) |
| 68 | |
| 69 | if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: |
| 70 | raise RuntimeError( |
| 71 | "Qwen3 XML Tool parser could not locate tool call start/end " |
| 72 | "tokens in the tokenizer!") |
| 73 | |
| 74 | logger.info( |
| 75 | f"vLLM Successfully import tool parser {self.__class__.__name__} !" |
| 76 | ) |
| 77 | |
| 78 | def _generate_tool_call_id(self) -> str: |
| 79 | """Generate a unique tool call ID.""" |
| 80 | return f"call_{uuid.uuid4().hex[:24]}" |
| 81 | |
| 82 | def _reset_streaming_state(self): |
| 83 | """Reset all streaming state.""" |
| 84 | self.current_tool_index = 0 |
| 85 | self.is_tool_call_started = False |
| 86 | self.header_sent = False |
| 87 | self.current_tool_id = None |
| 88 | self.current_function_name = None |
| 89 | self.current_param_name = None |
| 90 | self.current_param_value = "" |
| 91 | self.param_count = 0 |
| 92 | self.in_param = False |
| 93 | self.in_function = False |
| 94 | self.accumulated_text = "" |
| 95 | self.json_started = False |
| 96 | self.json_closed = False |
| 97 | # Store accumulated parameters for type conversion |
| 98 | self.accumulated_params = {} |
| 99 | self.streaming_request = None |
| 100 | |
| 101 | def _get_arguments_config( |
| 102 | self, func_name: str, |
| 103 | tools: Optional[list[ChatCompletionToolsParam]]) -> dict: |
| 104 | """Extract argument configuration for a function.""" |
| 105 | if tools is None: |
| 106 | return {} |
| 107 | for config in tools: |
| 108 | if not hasattr(config, "type") or not (hasattr( |
| 109 | config, "function") and hasattr(config.function, "name")): |
| 110 | continue |
| 111 | if config.type == "function" and config.function.name == func_name: |
| 112 | if not hasattr(config.function, "parameters"): |
| 113 | return {} |
| 114 | params = config.function.parameters |
| 115 | if isinstance(params, dict) and "properties" in params: |
| 116 | return params["properties"] |
| 117 | elif isinstance(params, dict): |
| 118 | return params |
| 119 | else: |
| 120 | return {} |
| 121 | logger.warning(f"Tool '{func_name}' is not defined in the tools list.") |
| 122 | return {} |
| 123 | |
| 124 | def _convert_param_value(self, param_value: str, param_name: str, |
| 125 | param_config: dict, func_name: str) -> Any: |
| 126 | """Convert parameter value based on its type in the schema.""" |
| 127 | # Handle null value for any type |
| 128 | if param_value.lower() == "null": |
| 129 | return None |
| 130 | |
| 131 | if param_name not in param_config: |
| 132 | if param_config != {}: |
| 133 | logger.warning( |
| 134 | f"Parsed parameter '{param_name}' is not defined in the tool " |
| 135 | f"parameters for tool '{func_name}', directly returning the string value." |
| 136 | ) |
| 137 | return param_value |
| 138 | |
| 139 | if isinstance(param_config[param_name], |
| 140 | dict) and "type" in param_config[param_name]: |
| 141 | param_type = str(param_config[param_name]["type"]).strip().lower() |
| 142 | else: |
| 143 | param_type = "string" |
| 144 | if param_type in ["string", "str", "text", "varchar", "char", "enum"]: |
| 145 | return param_value |
| 146 | elif param_type.startswith("int") or param_type.startswith( |
| 147 | "uint") or param_type.startswith( |
| 148 | "long") or param_type.startswith( |
| 149 | "short") or param_type.startswith("unsigned"): |
| 150 | try: |
| 151 | param_value = int(param_value) |
| 152 | except: |
| 153 | logger.warning( |
| 154 | f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool " |
| 155 | f"'{func_name}', degenerating to string.") |
| 156 | return param_value |
| 157 | elif param_type.startswith("num") or param_type.startswith("float"): |
| 158 | try: |
| 159 | float_param_value = float(param_value) |
| 160 | param_value = float_param_value if float_param_value - int( |
| 161 | float_param_value) != 0 else int(float_param_value) |
| 162 | except: |
| 163 | logger.warning( |
| 164 | f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool " |
| 165 | f"'{func_name}', degenerating to string.") |
| 166 | return param_value |
| 167 | elif param_type in ["boolean", "bool", "binary"]: |
| 168 | param_value = param_value.lower() |
| 169 | if param_value not in ["true", "false"]: |
| 170 | logger.warning( |
| 171 | f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false." |
| 172 | ) |
| 173 | return param_value == "true" |
| 174 | else: |
| 175 | if param_type in ["object", "array", "arr" |
| 176 | ] or param_type.startswith( |
| 177 | "dict") or param_type.startswith("list"): |
| 178 | try: |
| 179 | param_value = json.loads(param_value) |
| 180 | return param_value |
| 181 | except: |
| 182 | logger.warning( |
| 183 | f"Parsed value '{param_value}' of parameter '{param_name}' cannot be parsed with json.loads in tool " |
| 184 | f"'{func_name}', will try other methods to parse it.") |
| 185 | try: |
| 186 | param_value = ast.literal_eval(param_value) # safer |
| 187 | except: |
| 188 | logger.warning( |
| 189 | f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `ast.literal_eval()` in tool '{func_name}', degenerating to string." |
| 190 | ) |
| 191 | return param_value |
| 192 | |
| 193 | def _parse_xml_function_call( |
| 194 | self, function_call_str: str, |
| 195 | tools: Optional[list[ChatCompletionToolsParam]] |
| 196 | ) -> Optional[ToolCall]: |
| 197 | |
| 198 | # Extract function name |
| 199 | end_index = function_call_str.index(">") |
| 200 | function_name = function_call_str[:end_index] |
| 201 | param_config = self._get_arguments_config(function_name, tools) |
| 202 | parameters = function_call_str[end_index + 1:] |
| 203 | param_dict = {} |
| 204 | for match_text in self.tool_call_parameter_regex.findall(parameters): |
| 205 | idx = match_text.index(">") |
| 206 | param_name = match_text[:idx] |
| 207 | param_value = str(match_text[idx + 1:]) |
| 208 | # Remove prefix and trailing \n |
| 209 | if param_value.startswith("\n"): |
| 210 | param_value = param_value[1:] |
| 211 | if param_value.endswith("\n"): |
| 212 | param_value = param_value[:-1] |
| 213 | |
| 214 | param_dict[param_name] = self._convert_param_value( |
| 215 | param_value, param_name, param_config, function_name) |
| 216 | return ToolCall( |
| 217 | type="function", |
| 218 | function=FunctionCall(name=function_name, |
| 219 | arguments=json.dumps(param_dict, |
| 220 | ensure_ascii=False)), |
| 221 | ) |
| 222 | |
| 223 | def _get_function_calls(self, model_output: str) -> List[str]: |
| 224 | # Find all tool calls |
| 225 | matched_ranges = self.tool_call_regex.findall(model_output) |
| 226 | raw_tool_calls = [ |
| 227 | match[0] if match[0] else match[1] for match in matched_ranges |
| 228 | ] |
| 229 | |
| 230 | # Back-off strategy if no tool_call tags found |
| 231 | if len(raw_tool_calls) == 0: |
| 232 | raw_tool_calls = [model_output] |
| 233 | |
| 234 | raw_function_calls = [] |
| 235 | for tool_call in raw_tool_calls: |
| 236 | raw_function_calls.extend( |
| 237 | self.tool_call_function_regex.findall(tool_call)) |
| 238 | |
| 239 | function_calls = [ |
| 240 | match[0] if match[0] else match[1] for match in raw_function_calls |
| 241 | ] |
| 242 | return function_calls |
| 243 | |
| 244 | def extract_tool_calls( |
| 245 | self, |
| 246 | model_output: str, |
| 247 | request: ChatCompletionRequest, |
| 248 | ) -> ExtractedToolCallInformation: |
| 249 | # Quick check to avoid unnecessary processing |
| 250 | if self.tool_call_prefix not in model_output: |
| 251 | return ExtractedToolCallInformation(tools_called=False, |
| 252 | tool_calls=[], |
| 253 | content=model_output) |
| 254 | |
| 255 | try: |
| 256 | function_calls = self._get_function_calls(model_output) |
| 257 | if len(function_calls) == 0: |
| 258 | return ExtractedToolCallInformation(tools_called=False, |
| 259 | tool_calls=[], |
| 260 | content=model_output) |
| 261 | |
| 262 | tool_calls = [ |
| 263 | self._parse_xml_function_call(function_call_str, request.tools) |
| 264 | for function_call_str in function_calls |
| 265 | ] |
| 266 | |
| 267 | # Populate prev_tool_call_arr for serving layer to set finish_reason |
| 268 | self.prev_tool_call_arr.clear() # Clear previous calls |
| 269 | for tool_call in tool_calls: |
| 270 | if tool_call: |
| 271 | self.prev_tool_call_arr.append({ |
| 272 | "name": |
| 273 | tool_call.function.name, |
| 274 | "arguments": |
| 275 | tool_call.function.arguments, |
| 276 | }) |
| 277 | |
| 278 | # Extract content before tool calls |
| 279 | content_index = model_output.find(self.tool_call_start_token) |
| 280 | content_index = content_index if content_index >= 0 else model_output.find( |
| 281 | self.tool_call_prefix) |
| 282 | content = model_output[:content_index] # .rstrip() |
| 283 | |
| 284 | return ExtractedToolCallInformation( |
| 285 | tools_called=(len(tool_calls) > 0), |
| 286 | tool_calls=tool_calls, |
| 287 | content=content if content else None, |
| 288 | ) |
| 289 | |
| 290 | except Exception: |
| 291 | logger.exception("Error in extracting tool call from response.") |
| 292 | return ExtractedToolCallInformation(tools_called=False, |
| 293 | tool_calls=[], |
| 294 | content=model_output) |
| 295 | |
| 296 | def extract_tool_calls_streaming( |
| 297 | self, |
| 298 | previous_text: str, |
| 299 | current_text: str, |
| 300 | delta_text: str, |
| 301 | previous_token_ids: Sequence[int], |
| 302 | current_token_ids: Sequence[int], |
| 303 | delta_token_ids: Sequence[int], |
| 304 | request: ChatCompletionRequest, |
| 305 | ) -> Union[DeltaMessage, None]: |
| 306 | # Store request for type conversion |
| 307 | if not previous_text: |
| 308 | self._reset_streaming_state() |
| 309 | self.streaming_request = request |
| 310 | |
| 311 | # If no delta text, return None unless it's an EOS token after tool calls |
| 312 | if not delta_text: |
| 313 | # Check if this is an EOS token after all tool calls are complete |
| 314 | # We check for tool calls in the text even if is_tool_call_started is False |
| 315 | # because it might have been reset after processing all tools |
| 316 | if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids: |
| 317 | # Count complete tool calls |
| 318 | complete_calls = len( |
| 319 | self.tool_call_complete_regex.findall(current_text)) |
| 320 | |
| 321 | # If we have completed tool calls and populated prev_tool_call_arr |
| 322 | if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: |
| 323 | # Check if all tool calls are closed |
| 324 | open_calls = current_text.count( |
| 325 | self.tool_call_start_token) - current_text.count( |
| 326 | self.tool_call_end_token) |
| 327 | if open_calls == 0: |
| 328 | # Return empty delta message to allow finish_reason processing |
| 329 | return DeltaMessage(content="") |
| 330 | elif not self.is_tool_call_started and current_text: |
| 331 | # This is a regular content response that's now complete |
| 332 | return DeltaMessage(content="") |
| 333 | return None |
| 334 | |
| 335 | # Update accumulated text |
| 336 | self.accumulated_text = current_text |
| 337 | |
| 338 | # Check if we need to advance to next tool |
| 339 | if self.json_closed and not self.in_function: |
| 340 | # Check if this tool call has ended |
| 341 | tool_ends = current_text.count(self.tool_call_end_token) |
| 342 | if tool_ends > self.current_tool_index: |
| 343 | # This tool has ended, advance to next |
| 344 | self.current_tool_index += 1 |
| 345 | self.header_sent = False |
| 346 | self.param_count = 0 |
| 347 | self.json_started = False |
| 348 | self.json_closed = False |
| 349 | self.accumulated_params = {} |
| 350 | |
| 351 | # Check if there are more tool calls |
| 352 | tool_starts = current_text.count(self.tool_call_start_token) |
| 353 | if self.current_tool_index >= tool_starts: |
| 354 | # No more tool calls |
| 355 | self.is_tool_call_started = False |
| 356 | # Continue processing next tool |
| 357 | return None |
| 358 | |
| 359 | # Handle normal content before tool calls |
| 360 | if not self.is_tool_call_started: |
| 361 | # Check if tool call is starting |
| 362 | if self.tool_call_start_token_id in delta_token_ids or self.tool_call_start_token in delta_text: |
| 363 | self.is_tool_call_started = True |
| 364 | # Return any content before the tool call |
| 365 | if self.tool_call_start_token in delta_text: |
| 366 | content_before = delta_text[:delta_text.index( |
| 367 | self.tool_call_start_token)] |
| 368 | if content_before: |
| 369 | return DeltaMessage(content=content_before) |
| 370 | return None |
| 371 | else: |
| 372 | # Check if we're between tool calls - skip whitespace |
| 373 | if current_text.rstrip().endswith(self.tool_call_end_token): |
| 374 | # We just ended a tool call, skip whitespace |
| 375 | if delta_text.strip() == "": |
| 376 | return None |
| 377 | # Normal content, no tool call |
| 378 | return DeltaMessage(content=delta_text) |
| 379 | |
| 380 | # Check if we're between tool calls (waiting for next one) |
| 381 | # Count tool calls we've seen vs processed |
| 382 | tool_starts_count = current_text.count(self.tool_call_start_token) |
| 383 | if self.current_tool_index >= tool_starts_count: |
| 384 | # We're past all tool calls, shouldn't be here |
| 385 | return None |
| 386 | |
| 387 | # We're in a tool call, find the current tool call portion |
| 388 | # Need to find the correct tool call based on current_tool_index |
| 389 | tool_starts = [] |
| 390 | idx = 0 |
| 391 | while True: |
| 392 | idx = current_text.find(self.tool_call_start_token, idx) |
| 393 | if idx == -1: |
| 394 | break |
| 395 | tool_starts.append(idx) |
| 396 | idx += len(self.tool_call_start_token) |
| 397 | |
| 398 | if self.current_tool_index >= len(tool_starts): |
| 399 | # No more tool calls to process yet |
| 400 | return None |
| 401 | |
| 402 | tool_start_idx = tool_starts[self.current_tool_index] |
| 403 | # Find where this tool call ends (or current position if not ended yet) |
| 404 | tool_end_idx = current_text.find(self.tool_call_end_token, |
| 405 | tool_start_idx) |
| 406 | if tool_end_idx == -1: |
| 407 | tool_text = current_text[tool_start_idx:] |
| 408 | else: |
| 409 | tool_text = current_text[tool_start_idx:tool_end_idx + |
| 410 | len(self.tool_call_end_token)] |
| 411 | |
| 412 | # Looking for function header |
| 413 | if not self.header_sent: |
| 414 | if self.tool_call_prefix in tool_text: |
| 415 | func_start = tool_text.find(self.tool_call_prefix) + len( |
| 416 | self.tool_call_prefix) |
| 417 | func_end = tool_text.find(">", func_start) |
| 418 | |
| 419 | if func_end != -1: |
| 420 | # Found complete function name |
| 421 | self.current_function_name = tool_text[func_start:func_end] |
| 422 | self.current_tool_id = self._generate_tool_call_id() |
| 423 | self.header_sent = True |
| 424 | self.in_function = True |
| 425 | |
| 426 | # IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call |
| 427 | # This ensures finish_reason="tool_calls" even if parsing isn't complete |
| 428 | already_added = any( |
| 429 | tool.get("name") == self.current_function_name |
| 430 | for tool in self.prev_tool_call_arr) |
| 431 | if not already_added: |
| 432 | self.prev_tool_call_arr.append({ |
| 433 | "name": self.current_function_name, |
| 434 | "arguments": |
| 435 | "{}", # Placeholder, will be updated later |
| 436 | }) |
| 437 | |
| 438 | # Send header with function info |
| 439 | return DeltaMessage(tool_calls=[ |
| 440 | DeltaToolCall( |
| 441 | index=self.current_tool_index, |
| 442 | id=self.current_tool_id, |
| 443 | function=DeltaFunctionCall( |
| 444 | name=self.current_function_name, arguments=""), |
| 445 | type="function", |
| 446 | ) |
| 447 | ]) |
| 448 | return None |
| 449 | |
| 450 | # We've sent header, now handle function body |
| 451 | if self.in_function: |
| 452 | # Send opening brace if not sent yet |
| 453 | if not self.json_started and self.parameter_prefix not in delta_text: |
| 454 | self.json_started = True |
| 455 | return DeltaMessage(tool_calls=[ |
| 456 | DeltaToolCall( |
| 457 | index=self.current_tool_index, |
| 458 | function=DeltaFunctionCall(arguments="{"), |
| 459 | ) |
| 460 | ]) |
| 461 | |
| 462 | # Make sure json_started is set if we're processing parameters |
| 463 | if not self.json_started: |
| 464 | self.json_started = True |
| 465 | |
| 466 | # Check for function end in accumulated text |
| 467 | if not self.json_closed and self.function_end_token in tool_text: |
| 468 | # Close JSON |
| 469 | self.json_closed = True |
| 470 | |
| 471 | # Extract the complete tool call to update prev_tool_call_arr with final arguments |
| 472 | # Find the function content |
| 473 | func_start = tool_text.find(self.tool_call_prefix) + len( |
| 474 | self.tool_call_prefix) |
| 475 | func_content_end = tool_text.find(self.function_end_token, |
| 476 | func_start) |
| 477 | if func_content_end != -1: |
| 478 | func_content = tool_text[func_start:func_content_end] |
| 479 | # Parse to get the complete arguments |
| 480 | try: |
| 481 | parsed_tool = self._parse_xml_function_call( |
| 482 | func_content, self.streaming_request.tools |
| 483 | if self.streaming_request else None) |
| 484 | if parsed_tool: |
| 485 | # Update existing entry in prev_tool_call_arr with complete arguments |
| 486 | for i, tool in enumerate(self.prev_tool_call_arr): |
| 487 | if tool.get( |
| 488 | "name") == parsed_tool.function.name: |
| 489 | self.prev_tool_call_arr[i][ |
| 490 | "arguments"] = parsed_tool.function.arguments |
| 491 | break |
| 492 | except Exception: |
| 493 | pass # Ignore parsing errors during streaming |
| 494 | |
| 495 | result = DeltaMessage(tool_calls=[ |
| 496 | DeltaToolCall( |
| 497 | index=self.current_tool_index, |
| 498 | function=DeltaFunctionCall(arguments="}"), |
| 499 | ) |
| 500 | ]) |
| 501 | |
| 502 | # Reset state for next tool |
| 503 | self.in_function = False |
| 504 | self.json_closed = True |
| 505 | self.accumulated_params = {} |
| 506 | |
| 507 | return result |
| 508 | |
| 509 | # Look for parameters |
| 510 | # Find all parameter starts |
| 511 | param_starts = [] |
| 512 | idx = 0 |
| 513 | while True: |
| 514 | idx = tool_text.find(self.parameter_prefix, idx) |
| 515 | if idx == -1: |
| 516 | break |
| 517 | param_starts.append(idx) |
| 518 | idx += len(self.parameter_prefix) |
| 519 | |
| 520 | # Check if we should start a new parameter |
| 521 | if not self.in_param and self.param_count < len(param_starts): |
| 522 | |
| 523 | if len(param_starts) > self.param_count: |
| 524 | # Process the next parameter |
| 525 | param_idx = param_starts[self.param_count] |
| 526 | param_start = param_idx + len(self.parameter_prefix) |
| 527 | remaining = tool_text[param_start:] |
| 528 | |
| 529 | if ">" in remaining: |
| 530 | # We have the complete parameter name |
| 531 | name_end = remaining.find(">") |
| 532 | self.current_param_name = remaining[:name_end] |
| 533 | |
| 534 | # Find the parameter value |
| 535 | value_start = param_start + name_end + 1 |
| 536 | value_text = tool_text[value_start:] |
| 537 | if value_text.startswith("\n"): |
| 538 | value_text = value_text[1:] |
| 539 | |
| 540 | # Find where this parameter ends |
| 541 | param_end_idx = value_text.find( |
| 542 | self.parameter_end_token) |
| 543 | if param_end_idx == -1: |
| 544 | # No closing tag, look for next parameter or function end |
| 545 | next_param_idx = value_text.find( |
| 546 | self.parameter_prefix) |
| 547 | func_end_idx = value_text.find( |
| 548 | self.function_end_token) |
| 549 | |
| 550 | if next_param_idx != -1 and (func_end_idx == -1 |
| 551 | or next_param_idx |
| 552 | < func_end_idx): |
| 553 | param_end_idx = next_param_idx |
| 554 | elif func_end_idx != -1: |
| 555 | param_end_idx = func_end_idx |
| 556 | else: |
| 557 | # Neither found, check if tool call is complete |
| 558 | if self.tool_call_end_token in tool_text: |
| 559 | # Tool call is complete, so parameter must be complete too |
| 560 | # Use all remaining text before function end as value |
| 561 | param_end_idx = len(value_text) |
| 562 | else: |
| 563 | # Still streaming, wait for more content |
| 564 | return None |
| 565 | |
| 566 | if param_end_idx != -1: |
| 567 | # Complete parameter found |
| 568 | param_value = value_text[:param_end_idx] |
| 569 | if param_value.endswith("\n"): |
| 570 | param_value = param_value[:-1] |
| 571 | |
| 572 | # Store raw value for later processing |
| 573 | self.accumulated_params[ |
| 574 | self.current_param_name] = param_value |
| 575 | |
| 576 | # Get parameter configuration for type conversion |
| 577 | param_config = self._get_arguments_config( |
| 578 | self.current_function_name, |
| 579 | self.streaming_request.tools |
| 580 | if self.streaming_request else None) |
| 581 | |
| 582 | # Convert the parameter value to the appropriate type |
| 583 | converted_value = self._convert_param_value( |
| 584 | param_value, self.current_param_name, |
| 585 | param_config, self.current_function_name) |
| 586 | |
| 587 | # Build JSON fragment based on the converted type |
| 588 | # Use json.dumps to properly serialize the value |
| 589 | serialized_value = json.dumps(converted_value, |
| 590 | ensure_ascii=False) |
| 591 | |
| 592 | if self.param_count == 0: |
| 593 | json_fragment = f'"{self.current_param_name}": {serialized_value}' |
| 594 | else: |
| 595 | json_fragment = f', "{self.current_param_name}": {serialized_value}' |
| 596 | |
| 597 | self.param_count += 1 |
| 598 | |
| 599 | return DeltaMessage(tool_calls=[ |
| 600 | DeltaToolCall( |
| 601 | index=self.current_tool_index, |
| 602 | function=DeltaFunctionCall( |
| 603 | arguments=json_fragment), |
| 604 | ) |
| 605 | ]) |
| 606 | |
| 607 | # Continue parameter value - Not used in the current implementation |
| 608 | # since we process complete parameters above |
| 609 | if self.in_param: |
| 610 | if self.parameter_end_token in delta_text: |
| 611 | # End of parameter |
| 612 | end_idx = delta_text.find(self.parameter_end_token) |
| 613 | value_chunk = delta_text[:end_idx] |
| 614 | |
| 615 | # Skip past > if at start |
| 616 | if not self.current_param_value and ">" in value_chunk: |
| 617 | gt_idx = value_chunk.find(">") |
| 618 | value_chunk = value_chunk[gt_idx + 1:] |
| 619 | |
| 620 | if not self.current_param_value and value_chunk.startswith( |
| 621 | "\n"): |
| 622 | value_chunk = value_chunk[1:] |
| 623 | |
| 624 | # Store complete value |
| 625 | full_value = self.current_param_value + value_chunk |
| 626 | self.accumulated_params[ |
| 627 | self.current_param_name] = full_value |
| 628 | |
| 629 | # Get parameter configuration for type conversion |
| 630 | param_config = self._get_arguments_config( |
| 631 | self.current_function_name, |
| 632 | self.streaming_request.tools |
| 633 | if self.streaming_request else None) |
| 634 | |
| 635 | # Convert the parameter value to the appropriate type |
| 636 | converted_value = self._convert_param_value( |
| 637 | full_value, self.current_param_name, param_config, |
| 638 | self.current_function_name) |
| 639 | |
| 640 | # Serialize the converted value |
| 641 | serialized_value = json.dumps(converted_value, |
| 642 | ensure_ascii=False) |
| 643 | |
| 644 | # Since we've been streaming the quoted version, we need to close it properly |
| 645 | # This is complex - for now just complete the value |
| 646 | self.in_param = False |
| 647 | self.current_param_value = "" |
| 648 | |
| 649 | # Just close the current parameter string |
| 650 | return DeltaMessage(tool_calls=[ |
| 651 | DeltaToolCall( |
| 652 | index=self.current_tool_index, |
| 653 | function=DeltaFunctionCall( |
| 654 | arguments='"'), # Close the string quote |
| 655 | ) |
| 656 | ]) |
| 657 | else: |
| 658 | # Continue accumulating value |
| 659 | value_chunk = delta_text |
| 660 | |
| 661 | # Handle first chunk after param name |
| 662 | if not self.current_param_value and ">" in value_chunk: |
| 663 | gt_idx = value_chunk.find(">") |
| 664 | value_chunk = value_chunk[gt_idx + 1:] |
| 665 | |
| 666 | if not self.current_param_value and value_chunk.startswith( |
| 667 | "\n"): |
| 668 | value_chunk = value_chunk[1:] |
| 669 | |
| 670 | if value_chunk: |
| 671 | # Stream the escaped delta |
| 672 | prev_escaped = json.dumps( |
| 673 | self.current_param_value, ensure_ascii=False |
| 674 | )[1:-1] if self.current_param_value else "" |
| 675 | self.current_param_value += value_chunk |
| 676 | full_escaped = json.dumps(self.current_param_value, |
| 677 | ensure_ascii=False)[1:-1] |
| 678 | delta_escaped = full_escaped[len(prev_escaped):] |
| 679 | |
| 680 | if delta_escaped: |
| 681 | return DeltaMessage(tool_calls=[ |
| 682 | DeltaToolCall( |
| 683 | index=self.current_tool_index, |
| 684 | function=DeltaFunctionCall( |
| 685 | arguments=delta_escaped), |
| 686 | ) |
| 687 | ]) |
| 688 | |
| 689 | return None |
| 690 | |