2024-12-04 16:41:30 -08:00
|
|
|
|
import json
|
|
|
|
|
|
import random
|
|
|
|
|
|
import builtins
|
|
|
|
|
|
|
|
|
|
|
|
from openai import OpenAI
|
2024-12-05 11:00:22 -08:00
|
|
|
|
from typing import Any, Dict, List, Tuple, Union
|
|
|
|
|
|
from overrides import override
|
|
|
|
|
|
from app.model_handler.base_handler import (
|
|
|
|
|
|
Message,
|
|
|
|
|
|
ChatMessage,
|
|
|
|
|
|
Choice,
|
|
|
|
|
|
ChatCompletionResponse,
|
|
|
|
|
|
ArchBaseHandler,
|
|
|
|
|
|
)
|
2024-12-04 16:41:30 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ArchIntentHandler(ArchBaseHandler):
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
client: OpenAI,
|
|
|
|
|
|
model_name: str,
|
|
|
|
|
|
task_prompt: str,
|
|
|
|
|
|
tool_prompt: str,
|
|
|
|
|
|
format_prompt: str,
|
2024-12-05 11:30:58 -08:00
|
|
|
|
extra_instruction: str,
|
2024-12-04 16:41:30 -08:00
|
|
|
|
generation_params: Dict,
|
|
|
|
|
|
):
|
2024-12-05 11:00:22 -08:00
|
|
|
|
"""
|
|
|
|
|
|
Initializes the intent handler.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
client (OpenAI): An OpenAI client instance.
|
|
|
|
|
|
model_name (str): Name of the model to use.
|
|
|
|
|
|
task_prompt (str): The main task prompt for the system.
|
|
|
|
|
|
tool_prompt (str): A prompt to describe tools.
|
|
|
|
|
|
format_prompt (str): A prompt specifying the desired output format.
|
2024-12-05 11:30:58 -08:00
|
|
|
|
extra_instruction (str): Instructions specific to intent handling.
|
2024-12-05 11:00:22 -08:00
|
|
|
|
generation_params (Dict): Generation parameters for the model.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2024-12-04 16:41:30 -08:00
|
|
|
|
super().__init__(
|
|
|
|
|
|
client,
|
|
|
|
|
|
model_name,
|
|
|
|
|
|
task_prompt,
|
|
|
|
|
|
tool_prompt,
|
|
|
|
|
|
format_prompt,
|
|
|
|
|
|
generation_params,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2024-12-05 11:30:58 -08:00
|
|
|
|
self.extra_instruction = extra_instruction
|
2024-12-04 16:41:30 -08:00
|
|
|
|
|
|
|
|
|
|
@override
|
2024-12-05 11:00:22 -08:00
|
|
|
|
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Converts a list of tools into a JSON-like format with indexed keys.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
str: A string representation of converted tools.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2024-12-04 16:41:30 -08:00
|
|
|
|
converted = [
|
|
|
|
|
|
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
|
|
|
|
|
|
]
|
|
|
|
|
|
return "\n".join(converted)
|
|
|
|
|
|
|
|
|
|
|
|
@override
|
2024-12-05 11:00:22 -08:00
|
|
|
|
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
2024-12-04 16:41:30 -08:00
|
|
|
|
"""
|
2024-12-05 11:00:22 -08:00
|
|
|
|
Generates a chat completion for a given request.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
req (ChatMessage): A chat message request object.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
ChatCompletionResponse: The model's response to the chat request.
|
|
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
|
Currently only support vllm inference
|
2024-12-04 16:41:30 -08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
messages = self._process_messages(
|
2024-12-05 11:30:58 -08:00
|
|
|
|
req.messages, req.tools, self.extra_instruction
|
2024-12-04 16:41:30 -08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
model_response = self.client.chat.completions.create(
|
|
|
|
|
|
messages=messages,
|
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
|
stream=False,
|
|
|
|
|
|
extra_body=self.generation_params,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
model_response = Message(content=model_response, tool_calls=[])
|
|
|
|
|
|
|
|
|
|
|
|
chat_completion_response = ChatCompletionResponse(
|
|
|
|
|
|
choices=[Choice(message=model_response)], model=self.model_name
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return chat_completion_response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ArchFunctionHandler(ArchBaseHandler):
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
client: OpenAI,
|
|
|
|
|
|
model_name: str,
|
|
|
|
|
|
task_prompt: str,
|
|
|
|
|
|
tool_prompt: str,
|
|
|
|
|
|
format_prompt: str,
|
|
|
|
|
|
generation_params: Dict,
|
|
|
|
|
|
prefill_params: Dict,
|
|
|
|
|
|
prefill_prefix: List,
|
|
|
|
|
|
):
|
2024-12-05 11:00:22 -08:00
|
|
|
|
"""
|
|
|
|
|
|
Initializes the function handler.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
client (OpenAI): An OpenAI client instance.
|
|
|
|
|
|
model_name (str): Name of the model to use.
|
|
|
|
|
|
task_prompt (str): The main task prompt for the system.
|
|
|
|
|
|
tool_prompt (str): A prompt to describe tools.
|
|
|
|
|
|
format_prompt (str): A prompt specifying the desired output format.
|
|
|
|
|
|
generation_params (Dict): Generation parameters for the model.
|
|
|
|
|
|
prefill_params (Dict): Additional parameters for prefilling responses.
|
|
|
|
|
|
prefill_prefix (List[str]): List of prefixes for prefill responses.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2024-12-04 16:41:30 -08:00
|
|
|
|
super().__init__(
|
|
|
|
|
|
client,
|
|
|
|
|
|
model_name,
|
|
|
|
|
|
task_prompt,
|
|
|
|
|
|
tool_prompt,
|
|
|
|
|
|
format_prompt,
|
|
|
|
|
|
generation_params,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.prefill_params = prefill_params
|
|
|
|
|
|
self.prefill_prefix = prefill_prefix
|
|
|
|
|
|
|
|
|
|
|
|
# Predefine data types for verification. Only support Python for now.
|
|
|
|
|
|
# [TODO] Extend the list of support data types
|
|
|
|
|
|
self.support_data_types = {
|
|
|
|
|
|
type_name: getattr(builtins, type_name) for type_name in SUPPORT_DATA_TYPES
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@override
|
2024-12-05 11:00:22 -08:00
|
|
|
|
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Converts a list of tools into JSON format.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
str: A string representation of converted tools.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2024-12-04 16:41:30 -08:00
|
|
|
|
converted = [json.dumps(tool) for tool in tools]
|
|
|
|
|
|
return "\n".join(converted)
|
|
|
|
|
|
|
2024-12-05 11:00:22 -08:00
|
|
|
|
def _fix_json_string(self, json_str: str) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Fixes malformed JSON strings by ensuring proper bracket matching.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
json_str (str): A JSON string that might be malformed.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
str: A corrected JSON string.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2024-12-04 16:41:30 -08:00
|
|
|
|
# Remove any leading or trailing whitespace or newline characters
|
|
|
|
|
|
json_str = json_str.strip()
|
|
|
|
|
|
|
|
|
|
|
|
# Stack to keep track of brackets
|
|
|
|
|
|
stack = []
|
|
|
|
|
|
|
|
|
|
|
|
# Clean string to collect valid characters
|
|
|
|
|
|
fixed_str = ""
|
|
|
|
|
|
|
|
|
|
|
|
# Dictionary for matching brackets
|
|
|
|
|
|
matching_bracket = {")": "(", "}": "{", "]": "["}
|
|
|
|
|
|
|
|
|
|
|
|
# Dictionary for the opposite of matching_bracket
|
|
|
|
|
|
opening_bracket = {v: k for k, v in matching_bracket.items()}
|
|
|
|
|
|
|
|
|
|
|
|
for char in json_str:
|
|
|
|
|
|
if char in "{[(":
|
|
|
|
|
|
stack.append(char)
|
|
|
|
|
|
fixed_str += char
|
|
|
|
|
|
elif char in "}])":
|
|
|
|
|
|
if stack and stack[-1] == matching_bracket[char]:
|
|
|
|
|
|
stack.pop()
|
|
|
|
|
|
fixed_str += char
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Ignore the unmatched closing brackets
|
|
|
|
|
|
continue
|
|
|
|
|
|
else:
|
|
|
|
|
|
fixed_str += char
|
|
|
|
|
|
|
|
|
|
|
|
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
|
|
|
|
|
|
while stack:
|
|
|
|
|
|
unmatched_opening = stack.pop()
|
|
|
|
|
|
fixed_str += opening_bracket[unmatched_opening]
|
|
|
|
|
|
|
|
|
|
|
|
# Attempt to parse the corrected string to ensure it’s valid JSON
|
|
|
|
|
|
return fixed_str.replace("'", '"')
|
|
|
|
|
|
|
2024-12-05 11:00:22 -08:00
|
|
|
|
def _extract_tool_calls(
|
|
|
|
|
|
self, content: str
|
|
|
|
|
|
) -> Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Extracts tool call information from a given string.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
content (str): The content string containing potential tool call information.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
|
|
|
|
|
|
- A list of tool call dictionaries.
|
|
|
|
|
|
- A boolean indicating if the extraction was valid.
|
|
|
|
|
|
- An error message or exception if extraction failed.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2024-12-04 16:41:30 -08:00
|
|
|
|
tool_calls, is_valid, error_message = [], True, ""
|
|
|
|
|
|
|
|
|
|
|
|
flag = False
|
|
|
|
|
|
for line in content.split("\n"):
|
|
|
|
|
|
if "<tool_call>" == line:
|
|
|
|
|
|
flag = True
|
|
|
|
|
|
elif "</tool_call>" == line:
|
|
|
|
|
|
flag = False
|
|
|
|
|
|
else:
|
|
|
|
|
|
if flag:
|
|
|
|
|
|
try:
|
|
|
|
|
|
tool_content = json.loads(line)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
fixed_content = self._fix_json_string(line)
|
|
|
|
|
|
try:
|
|
|
|
|
|
tool_content = json.loads(fixed_content)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
tool_calls, is_valid, error_message = [], False, e
|
|
|
|
|
|
return tool_calls, is_valid, error_message
|
|
|
|
|
|
|
|
|
|
|
|
tool_calls.append(
|
|
|
|
|
|
{
|
|
|
|
|
|
"id": f"call_{random.randint(1000, 10000)}",
|
|
|
|
|
|
"type": "function",
|
|
|
|
|
|
"function": {
|
|
|
|
|
|
"name": tool_content["name"],
|
|
|
|
|
|
"arguments": tool_content["arguments"],
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
flag = False
|
|
|
|
|
|
|
|
|
|
|
|
return tool_calls, is_valid, error_message
|
|
|
|
|
|
|
|
|
|
|
|
def _verify_tool_calls(
|
|
|
|
|
|
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
|
2024-12-05 11:00:22 -08:00
|
|
|
|
) -> Tuple[bool, Union[Dict[str, Any], None], str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Verifies the validity of extracted tool calls against the provided tools.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
tools (List[Dict[str, Any]]): A list of available tools.
|
|
|
|
|
|
tool_calls (List[Dict[str, Any]]): A list of tool calls to verify.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Tuple[bool, Union[Dict[str, Any], None], str]:
|
|
|
|
|
|
- A boolean indicating if the tool calls are valid.
|
|
|
|
|
|
- The invalid tool call dictionary if any.
|
|
|
|
|
|
- An error message.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2024-12-04 16:41:30 -08:00
|
|
|
|
is_valid, error_tool_call, error_message = True, None, ""
|
|
|
|
|
|
|
|
|
|
|
|
functions = {}
|
|
|
|
|
|
for tool in tools:
|
|
|
|
|
|
if tool["type"] == "function":
|
|
|
|
|
|
functions[tool["function"]["name"]] = tool["function"]["parameters"]
|
|
|
|
|
|
|
|
|
|
|
|
for tool_call in tool_calls:
|
|
|
|
|
|
func_name, func_args = (
|
|
|
|
|
|
tool_call["function"]["name"],
|
|
|
|
|
|
tool_call["function"]["arguments"],
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Check whether the function is available or not
|
|
|
|
|
|
if func_name not in functions:
|
|
|
|
|
|
is_valid = False
|
|
|
|
|
|
error_message = f"{func_name} is not defined!"
|
|
|
|
|
|
return is_valid, error_message
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Check if all the requried parameters can be found in the tool calls
|
|
|
|
|
|
for required_param in functions[func_name].get("required", []):
|
|
|
|
|
|
if required_param not in func_args:
|
|
|
|
|
|
is_valid = False
|
|
|
|
|
|
error_tool_call = tool_call
|
|
|
|
|
|
error_message = f"`{required_param}` is requried by the function `{func_name}` but not found in the tool call!"
|
|
|
|
|
|
return is_valid, error_tool_call, error_message
|
|
|
|
|
|
|
|
|
|
|
|
# Verify the data type of each parameter in the tool calls
|
|
|
|
|
|
for param_name, param_value in func_args:
|
|
|
|
|
|
data_type = functions[func_name]["properties"][param_name]["type"]
|
|
|
|
|
|
|
|
|
|
|
|
if data_type in self.support_data_types:
|
|
|
|
|
|
if not isinstance(
|
|
|
|
|
|
param_value, self.support_data_types[data_type]
|
|
|
|
|
|
):
|
|
|
|
|
|
is_valid = False
|
|
|
|
|
|
error_tool_call = tool_call
|
|
|
|
|
|
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
|
|
|
|
|
|
return is_valid, error_tool_call, error_message
|
|
|
|
|
|
|
|
|
|
|
|
return is_valid, error_tool_call, error_message
|
|
|
|
|
|
|
|
|
|
|
|
@override
|
2024-12-05 11:00:22 -08:00
|
|
|
|
async def chat_completion(
|
|
|
|
|
|
self, req: ChatMessage, enable_prefilling=True
|
|
|
|
|
|
) -> ChatCompletionResponse:
|
2024-12-04 16:41:30 -08:00
|
|
|
|
"""
|
2024-12-05 11:00:22 -08:00
|
|
|
|
Generates a chat completion response for a given request.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
req (ChatMessage): A chat message request object.
|
|
|
|
|
|
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
ChatCompletionResponse: The model's response to the chat request.
|
|
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
|
Currently only support vllm inference
|
2024-12-04 16:41:30 -08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
messages = self._process_messages(req.messages, req.tools)
|
|
|
|
|
|
|
|
|
|
|
|
# Retrieve the first token, handling the Stream object carefully
|
|
|
|
|
|
response = self.client.chat.completions.create(
|
|
|
|
|
|
messages=messages,
|
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
|
stream=enable_prefilling,
|
|
|
|
|
|
extra_body=self.generation_params,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
model_response = ""
|
|
|
|
|
|
|
|
|
|
|
|
if enable_prefilling:
|
|
|
|
|
|
has_tool_call = None
|
|
|
|
|
|
|
|
|
|
|
|
model_response = ""
|
|
|
|
|
|
for token in response:
|
|
|
|
|
|
token_content = token.choices[0].delta.content.strip()
|
|
|
|
|
|
|
|
|
|
|
|
if has_tool_call is None and token_content != "<tool_call>":
|
|
|
|
|
|
has_tool_call = False
|
|
|
|
|
|
response.close()
|
|
|
|
|
|
break
|
|
|
|
|
|
else:
|
|
|
|
|
|
has_tool_call = True
|
|
|
|
|
|
|
|
|
|
|
|
if has_tool_call is True:
|
|
|
|
|
|
model_response += token_content
|
|
|
|
|
|
|
|
|
|
|
|
# start parameter gathering if the model is not generating a tool call
|
|
|
|
|
|
if has_tool_call is False:
|
|
|
|
|
|
messages.append(
|
|
|
|
|
|
{
|
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
|
"content": random.choice(self.prefill_prefix),
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
prefill_response = self.client.chat.completions.create(
|
|
|
|
|
|
messages=messages,
|
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
|
stream=False,
|
|
|
|
|
|
extra_body={
|
|
|
|
|
|
**self.generation_params,
|
|
|
|
|
|
**self.prefill_params,
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
model_response = prefill_response.choices[0].message.content
|
|
|
|
|
|
else:
|
|
|
|
|
|
model_response = response.choices[0].message.content
|
|
|
|
|
|
|
|
|
|
|
|
tool_calls, is_valid, error_message = self._extract_tool_calls(model_response)
|
|
|
|
|
|
|
|
|
|
|
|
if tool_calls:
|
|
|
|
|
|
is_valid, error_tool_call, error_message = self._verify_tool_calls(
|
|
|
|
|
|
tools=req.tools, tool_calls=tool_calls
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
|
|
|
|
|
|
if is_valid:
|
|
|
|
|
|
model_response = Message(content="", tool_calls=tool_calls)
|
|
|
|
|
|
# else:
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
model_response = Message(content=model_response, tool_calls=[])
|
|
|
|
|
|
|
|
|
|
|
|
chat_completion_response = ChatCompletionResponse(
|
|
|
|
|
|
choices=[Choice(message=model_response)], model=self.model_name
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# [TODO] Review: define the protocol to collect debugging output
|
|
|
|
|
|
# logger.info(
|
|
|
|
|
|
# f"model_server <= arch_function: (tool_calls): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
|
|
|
|
|
|
# )
|
|
|
|
|
|
# logger.info(
|
|
|
|
|
|
# f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
|
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
|
|
return chat_completion_response
|