Update ArchFunctionHandler

This commit is contained in:
Shuguang Chen 2024-12-08 16:41:45 -08:00
parent 95e167c2f6
commit e0d4ee7357

View file

@ -4,7 +4,7 @@ import builtins
import textwrap
from openai import OpenAI
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List
from overrides import override
from src.core.model_utils import (
Message,
@ -299,9 +299,7 @@ class ArchFunctionHandler(ArchBaseHandler):
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str.replace("'", '"')
def _extract_tool_calls(
self, content: str
) -> Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
def _extract_tool_calls(self, content: str) -> Dict[str, any]:
"""
Extracts tool call information from a given string.
@ -309,10 +307,10 @@ class ArchFunctionHandler(ArchBaseHandler):
content (str): The content string containing potential tool call information.
Returns:
Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
- A list of tool call dictionaries.
- A boolean indicating if the extraction was valid.
- An error message or exception if extraction failed.
Dict: A dictionary of extraction, including:
- "result": A list of tool call dictionaries.
- "status": A boolean indicating if the extraction was valid.
- "message": An error message or exception if extraction failed.
"""
tool_calls, is_valid, error_message = [], True, ""
@ -348,11 +346,11 @@ class ArchFunctionHandler(ArchBaseHandler):
flag = False
return tool_calls, is_valid, error_message
return {"result": tool_calls, "status": is_valid, "message": "error_message"}
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
) -> Tuple[bool, Union[Dict[str, Any], None], str]:
) -> Dict[str, any]:
"""
Verifies the validity of extracted tool calls against the provided tools.
@ -361,13 +359,13 @@ class ArchFunctionHandler(ArchBaseHandler):
tool_calls (List[Dict[str, Any]]): A list of tool calls to verify.
Returns:
Tuple[bool, Union[Dict[str, Any], None], str]:
- A boolean indicating if the tool calls are valid.
- The invalid tool call dictionary if any.
- An error message.
Dict: A dictionary of verification, including:
- "status": A boolean indicating if the tool calls are valid.
- "invalid_tool_call": A dictionary of the invalid tool call if any.
- "message": An error message.
"""
is_valid, error_tool_call, error_message = True, None, ""
is_valid, invalid_tool_call, error_message = True, None, ""
functions = {}
for tool in tools:
@ -390,9 +388,9 @@ class ArchFunctionHandler(ArchBaseHandler):
for required_param in functions[func_name].get("required", []):
if required_param not in func_args:
is_valid = False
error_tool_call = tool_call
invalid_tool_call = tool_call
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
return is_valid, error_tool_call, error_message
return is_valid, invalid_tool_call, error_message
# Verify the data type of each parameter in the tool calls
for param_name, param_value in func_args:
@ -403,46 +401,36 @@ class ArchFunctionHandler(ArchBaseHandler):
param_value, self.support_data_types[data_type]
):
is_valid = False
error_tool_call = tool_call
invalid_tool_call = tool_call
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
return is_valid, error_tool_call, error_message
return is_valid, invalid_tool_call, error_message
return is_valid, error_tool_call, error_message
return {
"status": is_valid,
"invalid_tool_call": invalid_tool_call,
"message": error_message,
}
def _prefill_response(self, messages: List[Dict[str, str]]):
def _add_prefill_message(self, messages: List[Dict[str, str]]):
"""
Prefills the response with the tool call prefix.
Update messages and generation params for prompt prefilling
Args:
messages (List[Dict[str, str]]): A list of messages.
tools (List[Dict[str, Any]]): A list of tools.
Returns:
List[Dict[str, str]]: A list of messages with the prefill prefix.
prefill_messages (List[Dict[str, str]]): A list of messages.
"""
messages.append(
return messages + [
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
)
prefill_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
return prefill_response
]
@override
async def chat_completion(
self, req: ChatMessage, enable_prefilling=True
) -> ChatCompletionResponse:
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Generates a chat completion response for a given request.
@ -459,60 +447,56 @@ class ArchFunctionHandler(ArchBaseHandler):
messages = self._process_messages(req.messages, req.tools)
# Retrieve the first token, handling the Stream object carefully
# always enable `stream=True` to collect model responses
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=enable_prefilling,
stream=True,
extra_body=self.generation_params,
)
model_response = ""
model_response, has_tool_call = "", None
if enable_prefilling:
has_tool_call = None
for token in response:
token_content = token.choices[0].delta.content.strip()
model_response = ""
for token in response:
token_content = token.choices[0].delta.content.strip()
if token_content:
if has_tool_call is None and token_content != "<tool_call>":
has_tool_call = False
response.close()
break
else:
has_tool_call = True
if token_content:
if has_tool_call is None and token_content != "<tool_call>":
has_tool_call = False
response.close()
break
else:
has_tool_call = True
if has_tool_call is True:
model_response += token_content
if has_tool_call is True:
model_response += token_content
# start parameter gathering if the model is not generating tool calls
if has_tool_call is False:
prefill_response = self.client.chat.completions.create(
messages=self._add_prefill_message(messages),
model=self.model_name,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
model_response = prefill_response.choices[0].message.content
# start parameter gathering if the model is not generating a tool call
if has_tool_call is False:
prefill_response = self._prefill_response(messages)
model_response = prefill_response.choices[0].message.content
else:
model_response = response.choices[0].message.content
# Extract tool calls from model response
extracted = self._extract_tool_calls(model_response)
(
tool_calls,
extraction_status,
extraction_error_message,
) = self._extract_tool_calls(model_response)
if tool_calls:
if extracted["tool_calls"]:
# [TODO] Review: define the behavior in the case that tool call extraction fails
# if not extraction_status:
# if not extracted["status"]:
(
verification_status,
invalid_tool_call,
verification_error_message,
) = self._verify_tool_calls(tools=req.tools, tool_calls=tool_calls)
verified = self._verify_tool_calls(
tools=req.tools, tool_calls=extracted["tool_calls"]
)
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
if verification_status:
model_response = Message(content="", tool_calls=tool_calls)
if verified["status"]:
model_response = Message(content="", tool_calls=extracted["tool_calls"])
# else:
else: