From e0d4ee73571ae2976d0e2115e4aee64dc51a7161 Mon Sep 17 00:00:00 2001 From: Shuguang Chen <54548843+nehcgs@users.noreply.github.com> Date: Sun, 8 Dec 2024 16:41:45 -0800 Subject: [PATCH] Update `ArchFunctionHandler` --- model_server/src/core/function_calling.py | 140 ++++++++++------------ 1 file changed, 62 insertions(+), 78 deletions(-) diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 0352ab47..f652fa22 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -4,7 +4,7 @@ import builtins import textwrap from openai import OpenAI -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List from overrides import override from src.core.model_utils import ( Message, @@ -299,9 +299,7 @@ class ArchFunctionHandler(ArchBaseHandler): # Attempt to parse the corrected string to ensure it’s valid JSON return fixed_str.replace("'", '"') - def _extract_tool_calls( - self, content: str - ) -> Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]: + def _extract_tool_calls(self, content: str) -> Dict[str, any]: """ Extracts tool call information from a given string. @@ -309,10 +307,10 @@ class ArchFunctionHandler(ArchBaseHandler): content (str): The content string containing potential tool call information. Returns: - Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]: - - A list of tool call dictionaries. - - A boolean indicating if the extraction was valid. - - An error message or exception if extraction failed. + Dict: A dictionary of extraction, including: + - "result": A list of tool call dictionaries. + - "status": A boolean indicating if the extraction was valid. + - "message": An error message or exception if extraction failed. """ tool_calls, is_valid, error_message = [], True, "" @@ -348,11 +346,11 @@ class ArchFunctionHandler(ArchBaseHandler): flag = False - return tool_calls, is_valid, error_message + return {"result": tool_calls, "status": is_valid, "message": "error_message"} def _verify_tool_calls( self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]] - ) -> Tuple[bool, Union[Dict[str, Any], None], str]: + ) -> Dict[str, any]: """ Verifies the validity of extracted tool calls against the provided tools. @@ -361,13 +359,13 @@ class ArchFunctionHandler(ArchBaseHandler): tool_calls (List[Dict[str, Any]]): A list of tool calls to verify. Returns: - Tuple[bool, Union[Dict[str, Any], None], str]: - - A boolean indicating if the tool calls are valid. - - The invalid tool call dictionary if any. - - An error message. + Dict: A dictionary of verification, including: + - "status": A boolean indicating if the tool calls are valid. + - "invalid_tool_call": A dictionary of the invalid tool call if any. + - "message": An error message. """ - is_valid, error_tool_call, error_message = True, None, "" + is_valid, invalid_tool_call, error_message = True, None, "" functions = {} for tool in tools: @@ -390,9 +388,9 @@ class ArchFunctionHandler(ArchBaseHandler): for required_param in functions[func_name].get("required", []): if required_param not in func_args: is_valid = False - error_tool_call = tool_call + invalid_tool_call = tool_call error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!" - return is_valid, error_tool_call, error_message + return is_valid, invalid_tool_call, error_message # Verify the data type of each parameter in the tool calls for param_name, param_value in func_args: @@ -403,46 +401,36 @@ class ArchFunctionHandler(ArchBaseHandler): param_value, self.support_data_types[data_type] ): is_valid = False - error_tool_call = tool_call + invalid_tool_call = tool_call error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`." - return is_valid, error_tool_call, error_message + return is_valid, invalid_tool_call, error_message - return is_valid, error_tool_call, error_message + return { + "status": is_valid, + "invalid_tool_call": invalid_tool_call, + "message": error_message, + } - def _prefill_response(self, messages: List[Dict[str, str]]): + def _add_prefill_message(self, messages: List[Dict[str, str]]): """ - Prefills the response with the tool call prefix. + Update messages and generation params for prompt prefilling Args: messages (List[Dict[str, str]]): A list of messages. - tools (List[Dict[str, Any]]): A list of tools. Returns: - List[Dict[str, str]]: A list of messages with the prefill prefix. + prefill_messages (List[Dict[str, str]]): A list of messages. """ - messages.append( + return messages + [ { "role": "assistant", "content": random.choice(self.prefill_prefix), } - ) - prefill_response = self.client.chat.completions.create( - messages=messages, - model=self.model_name, - stream=False, - extra_body={ - **self.generation_params, - **self.prefill_params, - }, - ) - - return prefill_response + ] @override - async def chat_completion( - self, req: ChatMessage, enable_prefilling=True - ) -> ChatCompletionResponse: + async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse: """ Generates a chat completion response for a given request. @@ -459,60 +447,56 @@ class ArchFunctionHandler(ArchBaseHandler): messages = self._process_messages(req.messages, req.tools) - # Retrieve the first token, handling the Stream object carefully + # always enable `stream=True` to collect model responses response = self.client.chat.completions.create( messages=messages, model=self.model_name, - stream=enable_prefilling, + stream=True, extra_body=self.generation_params, ) - model_response = "" + model_response, has_tool_call = "", None - if enable_prefilling: - has_tool_call = None + for token in response: + token_content = token.choices[0].delta.content.strip() - model_response = "" - for token in response: - token_content = token.choices[0].delta.content.strip() + if token_content: + if has_tool_call is None and token_content != "": + has_tool_call = False + response.close() + break + else: + has_tool_call = True - if token_content: - if has_tool_call is None and token_content != "": - has_tool_call = False - response.close() - break - else: - has_tool_call = True + if has_tool_call is True: + model_response += token_content - if has_tool_call is True: - model_response += token_content + # start parameter gathering if the model is not generating tool calls + if has_tool_call is False: + prefill_response = self.client.chat.completions.create( + messages=self._add_prefill_message(messages), + model=self.model_name, + extra_body={ + **self.generation_params, + **self.prefill_params, + }, + ) + model_response = prefill_response.choices[0].message.content - # start parameter gathering if the model is not generating a tool call - if has_tool_call is False: - prefill_response = self._prefill_response(messages) - model_response = prefill_response.choices[0].message.content - else: - model_response = response.choices[0].message.content + # Extract tool calls from model response + extracted = self._extract_tool_calls(model_response) - ( - tool_calls, - extraction_status, - extraction_error_message, - ) = self._extract_tool_calls(model_response) - - if tool_calls: + if extracted["tool_calls"]: # [TODO] Review: define the behavior in the case that tool call extraction fails - # if not extraction_status: + # if not extracted["status"]: - ( - verification_status, - invalid_tool_call, - verification_error_message, - ) = self._verify_tool_calls(tools=req.tools, tool_calls=tool_calls) + verified = self._verify_tool_calls( + tools=req.tools, tool_calls=extracted["tool_calls"] + ) # [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately - if verification_status: - model_response = Message(content="", tool_calls=tool_calls) + if verified["status"]: + model_response = Message(content="", tool_calls=extracted["tool_calls"]) # else: else: