Last active
January 10, 2025 22:26
-
-
Save dranger003/5f5daa9c80b4193f180b93f71399f817 to your computer and use it in GitHub Desktop.
A vLLM tool parser plugin for Command-R7B that handles function calling through action blocks.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # ** HOW TO USE ** | |
| # python -m vllm.entrypoints.openai.api_server \ | |
| # --pipeline-parallel-size "$GPU_COUNT" \ | |
| # --api-key "$API_KEY" \ | |
| # --model CohereForAI/c4ai-command-r7b-12-2024 \ | |
| # --chat-template c4ai-command-r7b-12-2024-tool_use.jinja \ | |
| # --chat-template-content-format string \ | |
| # --enable-auto-tool-choice \ | |
| # --tool-parser-plugin vllm-tool-parser-plugin-command-r7b.py \ | |
| # --tool-call-parser command-r7b | |
| import json | |
| import re | |
| from typing import Dict, List, Optional, Sequence, Union, Any | |
| import partial_json_parser | |
| from partial_json_parser.core.options import Allow | |
| from vllm.entrypoints.openai.tool_parsers.utils import ( | |
| partial_json_loads, | |
| is_complete_json, | |
| ) | |
| from vllm.entrypoints.openai.protocol import ( | |
| ChatCompletionRequest, | |
| ExtractedToolCallInformation, | |
| DeltaMessage, | |
| DeltaToolCall, | |
| DeltaFunctionCall, | |
| ToolCall, | |
| FunctionCall, | |
| ) | |
| from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( | |
| ToolParser, | |
| ToolParserManager, | |
| ) | |
| from vllm.transformers_utils.tokenizer import AnyTokenizer | |
| from vllm.logger import init_logger | |
| from vllm.utils import random_uuid | |
| logger = init_logger(__name__) | |
| @ToolParserManager.register_module(["command-r7b"]) | |
| class CommandR7BToolParser(ToolParser): | |
| """ | |
| A tool parser for Command-R7B that handles function calling through action blocks. | |
| This parser: | |
| - Removes response wrapper tokens (START_RESPONSE/END_RESPONSE) | |
| - Optionally removes thinking plans (START_THINKING/END_THINKING) | |
| - Looks for tool calls between START_ACTION and END_ACTION tokens | |
| - Handles both streaming and non-streaming parsing modes | |
| - Supports both single tool calls and arrays of tool calls | |
| Expected JSON format within action blocks: | |
| { | |
| "tool_call_id": "string", | |
| "tool_name": "string", | |
| "parameters": { | |
| // Tool-specific parameters | |
| } | |
| } | |
| """ | |
| START_ACTION_TOKEN = "<|START_ACTION|>" | |
| END_ACTION_TOKEN = "<|END_ACTION|>" | |
| START_RESPONSE_TOKEN = "<|START_RESPONSE|>" | |
| END_RESPONSE_TOKEN = "<|END_RESPONSE|>" | |
| START_THINKING_TOKEN = "<|START_THINKING|>" | |
| END_THINKING_TOKEN = "<|END_THINKING|>" | |
| def __init__(self, tokenizer: AnyTokenizer, remove_thinking: bool = True): | |
| """ | |
| Initialize the parser with necessary state tracking and token validation. | |
| Args: | |
| tokenizer: The tokenizer to use | |
| remove_thinking: Whether to remove thinking plans before tool calls | |
| """ | |
| super().__init__(tokenizer) | |
| self.remove_thinking = remove_thinking | |
| # Initialize state tracking | |
| self.reset_state() | |
| # Validate tokens exist in vocabulary | |
| token_pairs = [ | |
| (self.START_ACTION_TOKEN, self.END_ACTION_TOKEN), | |
| (self.START_RESPONSE_TOKEN, self.END_RESPONSE_TOKEN), | |
| (self.START_THINKING_TOKEN, self.END_THINKING_TOKEN), | |
| ] | |
| self.token_ids = {} | |
| for start_token, end_token in token_pairs: | |
| start_id = self.vocab.get(start_token) | |
| end_id = self.vocab.get(end_token) | |
| if None in (start_id, end_id): | |
| raise RuntimeError( | |
| f"Command-R7B parser could not locate {start_token} " | |
| f"or {end_token} in tokenizer vocabulary" | |
| ) | |
| self.token_ids[start_token] = start_id | |
| self.token_ids[end_token] = end_id | |
| # Buffer for partial JSON parsing | |
| self.partial_json_buffer = "" | |
| self.last_complete_json = None | |
| def reset_state(self) -> None: | |
| """Reset parser state between requests.""" | |
| self.in_action_block = False | |
| self.in_thinking_block = False | |
| self.current_block_token_ids: List[int] = [] | |
| self.tool_call_index = 0 | |
| self.current_block_content = "" | |
| self.available_tools: Dict[str, Dict] = {} | |
| self.partial_json_buffer = "" | |
| self.last_complete_json = None | |
| self.current_tool_name_sent = False | |
| self.streamed_args_for_tool: List[str] = [] | |
| def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: | |
| """Prepare for parsing by storing available tools and resetting state.""" | |
| self.reset_state() | |
| if request.tools: | |
| self.available_tools = { | |
| tool.function.name: tool.function.model_dump() for tool in request.tools | |
| } | |
| return request | |
| def clean_response_text(self, text: str) -> str: | |
| """Remove response wrapper tokens from text.""" | |
| if self.START_RESPONSE_TOKEN in text and self.END_RESPONSE_TOKEN in text: | |
| parts = text.split(self.START_RESPONSE_TOKEN, 1) | |
| if len(parts) > 1: | |
| text = parts[1] | |
| if self.END_RESPONSE_TOKEN in text: | |
| text = text.split(self.END_RESPONSE_TOKEN, 1)[0] | |
| return text.strip() | |
| def clean_thinking_text(self, text: str) -> str: | |
| """Remove thinking plan sections if configured to do so.""" | |
| if not self.remove_thinking: | |
| return text | |
| while self.START_THINKING_TOKEN in text and self.END_THINKING_TOKEN in text: | |
| start_idx = text.find(self.START_THINKING_TOKEN) | |
| end_idx = text.find(self.END_THINKING_TOKEN, start_idx) + len( | |
| self.END_THINKING_TOKEN | |
| ) | |
| text = text[:start_idx] + text[end_idx:] | |
| return text.strip() | |
| def validate_tool_call(self, tool_name: str, parameters: Dict) -> bool: | |
| """Validate a tool call against available tools.""" | |
| if not self.available_tools: | |
| return True # No tools specified in request | |
| if tool_name not in self.available_tools: | |
| logger.warning(f"Tool '{tool_name}' not found in available tools") | |
| return False | |
| # Check parameters format | |
| if not isinstance(parameters, dict): | |
| logger.warning(f"Parameters must be a dictionary for tool '{tool_name}'") | |
| return False | |
| # Handle both "parameters" and "arguments" fields like InternLM2 | |
| if "arguments" in parameters and "parameters" not in parameters: | |
| parameters["parameters"] = parameters.pop("arguments") | |
| # Validate unicode content | |
| try: | |
| json.dumps(parameters, ensure_ascii=False) | |
| except UnicodeEncodeError: | |
| logger.warning(f"Invalid Unicode in parameters for tool '{tool_name}'") | |
| return False | |
| return True | |
| def parse_action_block( | |
| self, block_text: str, streaming: bool = False | |
| ) -> List[Union[ToolCall, DeltaToolCall]]: | |
| """ | |
| Parse a complete action block into tool calls. | |
| Args: | |
| block_text: The text to parse | |
| streaming: Whether this is being called in streaming mode | |
| """ | |
| try: | |
| parsed = json.loads(block_text) | |
| if not isinstance(parsed, list): | |
| parsed = [parsed] | |
| tool_calls = [] | |
| for call_dict in parsed: | |
| # Extract and validate required fields | |
| tool_call_id = call_dict.get("tool_call_id") | |
| tool_name = call_dict.get("tool_name") | |
| parameters = call_dict.get("parameters", {}) | |
| if not tool_name: | |
| logger.warning("Tool call missing required 'tool_name' field") | |
| continue | |
| if not self.validate_tool_call(tool_name, parameters): | |
| continue | |
| # Create the tool call | |
| function_call = FunctionCall( | |
| name=tool_name, arguments=json.dumps(parameters) | |
| ) | |
| tool_call = ToolCall( | |
| id=f"chatcmpl-tool-{random_uuid()}", | |
| type="function", | |
| function=function_call, | |
| ) | |
| tool_calls.append(tool_call) | |
| return tool_calls | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Failed to parse action block JSON: {e}") | |
| return [] | |
| def extract_tool_calls( | |
| self, | |
| model_output: str, | |
| request: ChatCompletionRequest, | |
| ) -> ExtractedToolCallInformation: | |
| """ | |
| Extract tool calls from complete model output. | |
| Handles complete responses by: | |
| 1. Removing response wrapper tokens | |
| 2. Optionally removing thinking plans | |
| 3. Finding all action blocks and parsing their content | |
| """ | |
| # Clean the output text | |
| model_output = self.clean_response_text(model_output) | |
| model_output = self.clean_thinking_text(model_output) | |
| # Regex for capturing everything between action markers | |
| action_pattern = f"{self.START_ACTION_TOKEN}(.*?){self.END_ACTION_TOKEN}" | |
| action_blocks = re.findall(action_pattern, model_output, re.DOTALL) | |
| all_tool_calls = [] | |
| for block in action_blocks: | |
| block = block.strip() | |
| if not block: | |
| continue | |
| tool_calls = self.parse_action_block(block) | |
| all_tool_calls.extend(tool_calls) | |
| # Extract content before the first action block | |
| content = None | |
| if self.START_ACTION_TOKEN in model_output: | |
| content = model_output.split(self.START_ACTION_TOKEN)[0] | |
| if content.strip(): | |
| content = content.strip() | |
| else: | |
| content = None | |
| else: | |
| content = model_output | |
| return ExtractedToolCallInformation( | |
| tools_called=bool(all_tool_calls), | |
| tool_calls=all_tool_calls, | |
| content=content, | |
| ) | |
| def extract_tool_calls_streaming( | |
| self, | |
| previous_text: str, | |
| current_text: str, | |
| delta_text: str, | |
| previous_token_ids: Sequence[int], | |
| current_token_ids: Sequence[int], | |
| delta_token_ids: Sequence[int], | |
| request: ChatCompletionRequest, | |
| ) -> Optional[DeltaMessage]: | |
| """ | |
| Handle streaming extraction of tool calls. | |
| Accumulates tokens within action blocks and parses them when complete. | |
| Returns appropriate streaming deltas for tool calls and content. | |
| """ | |
| # Check for response tokens to remove | |
| if self.token_ids[self.START_RESPONSE_TOKEN] in delta_token_ids: | |
| return None | |
| if self.token_ids[self.END_RESPONSE_TOKEN] in delta_token_ids: | |
| return None | |
| # Check for thinking tokens to potentially remove | |
| if self.remove_thinking: | |
| if self.token_ids[self.START_THINKING_TOKEN] in delta_token_ids: | |
| self.in_thinking_block = True | |
| return None | |
| if self.token_ids[self.END_THINKING_TOKEN] in delta_token_ids: | |
| self.in_thinking_block = False | |
| return None | |
| if self.in_thinking_block: | |
| return None | |
| new_tool_calls = [] | |
| try: | |
| for tid in delta_token_ids: | |
| if tid == self.token_ids[self.START_ACTION_TOKEN]: | |
| # Entering an action block | |
| self.in_action_block = True | |
| self.current_block_token_ids = [] | |
| self.current_block_content = "" | |
| elif tid == self.token_ids[self.END_ACTION_TOKEN]: | |
| # Process complete action block | |
| if self.current_block_token_ids: | |
| block_text = self.model_tokenizer.decode( | |
| self.current_block_token_ids | |
| ) | |
| block_text = block_text.strip() | |
| self.partial_json_buffer += block_text | |
| if block_text: | |
| try: | |
| # Try partial JSON parsing first | |
| flags = ( | |
| Allow.ALL | |
| if self.current_tool_name_sent | |
| else Allow.ALL & ~Allow.STR | |
| ) | |
| try: | |
| parsed_obj, end_idx = partial_json_loads( | |
| self.partial_json_buffer, flags | |
| ) | |
| is_complete = is_complete_json( | |
| self.partial_json_buffer[:end_idx] | |
| ) | |
| if isinstance(parsed_obj, (dict, list)): | |
| parsed = ( | |
| [parsed_obj] | |
| if isinstance(parsed_obj, dict) | |
| else parsed_obj | |
| ) | |
| # If we have a complete JSON object, update our last known good state | |
| if is_complete: | |
| self.last_complete_json = parsed | |
| else: | |
| # Fall back to regular JSON parsing if partial parse gave unexpected type | |
| parsed = json.loads(block_text) | |
| if not isinstance(parsed, list): | |
| parsed = [parsed] | |
| except ( | |
| partial_json_parser.core.exceptions.MalformedJSON | |
| ): | |
| # If partial parsing fails, try regular JSON parse | |
| parsed = json.loads(block_text) | |
| if not isinstance(parsed, list): | |
| parsed = [parsed] | |
| for call_dict in parsed: | |
| tool_name = call_dict.get("tool_name") | |
| parameters = call_dict.get("parameters", {}) | |
| if not tool_name or not self.validate_tool_call( | |
| tool_name, parameters | |
| ): | |
| continue | |
| delta_fc = DeltaFunctionCall( | |
| name=tool_name, arguments=json.dumps(parameters) | |
| ).model_dump(exclude_none=True) | |
| dtc = DeltaToolCall( | |
| index=self.tool_call_index, | |
| type="function", | |
| id=f"chatcmpl-tool-{random_uuid()}", | |
| function=delta_fc, | |
| ) | |
| new_tool_calls.append(dtc) | |
| self.tool_call_index += 1 | |
| except json.JSONDecodeError as e: | |
| logger.error(f"JSON parsing error in streaming: {e}") | |
| # Reset block state | |
| self.in_action_block = False | |
| self.current_block_token_ids = [] | |
| self.current_block_content = "" | |
| else: | |
| # Accumulate tokens if in action block | |
| if self.in_action_block: | |
| self.current_block_token_ids.append(tid) | |
| else: | |
| # Regular content outside action block | |
| return DeltaMessage(content=delta_text) | |
| # Return tool calls if any were found | |
| if new_tool_calls: | |
| return DeltaMessage(tool_calls=new_tool_calls) | |
| # Return None to skip this chunk if we're accumulating an action block | |
| if self.in_action_block: | |
| return None | |
| # Otherwise return regular content | |
| return DeltaMessage(content=delta_text) | |
| except Exception as e: | |
| logger.exception("Error in streaming tool call extraction") | |
| # Reset state and skip chunk on error | |
| self.reset_state() | |
| return None |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
C4AI Command R7B emits wrapper response tokens which are removed by this plugin. Also, when performing tool calls the model emits its plan ahead of the calls using wrapper thinking tokens. To keep the thinking plan in the model's response, change
remove_thinking: bool = True.