mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Update ArchFunctionHandler
This commit is contained in:
parent
95e167c2f6
commit
e0d4ee7357
1 changed files with 62 additions and 78 deletions
|
|
@ -4,7 +4,7 @@ import builtins
|
|||
import textwrap
|
||||
|
||||
from openai import OpenAI
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List
|
||||
from overrides import override
|
||||
from src.core.model_utils import (
|
||||
Message,
|
||||
|
|
@ -299,9 +299,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str.replace("'", '"')
|
||||
|
||||
def _extract_tool_calls(
|
||||
self, content: str
|
||||
) -> Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
|
||||
def _extract_tool_calls(self, content: str) -> Dict[str, any]:
|
||||
"""
|
||||
Extracts tool call information from a given string.
|
||||
|
||||
|
|
@ -309,10 +307,10 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
content (str): The content string containing potential tool call information.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
|
||||
- A list of tool call dictionaries.
|
||||
- A boolean indicating if the extraction was valid.
|
||||
- An error message or exception if extraction failed.
|
||||
Dict: A dictionary of extraction, including:
|
||||
- "result": A list of tool call dictionaries.
|
||||
- "status": A boolean indicating if the extraction was valid.
|
||||
- "message": An error message or exception if extraction failed.
|
||||
"""
|
||||
|
||||
tool_calls, is_valid, error_message = [], True, ""
|
||||
|
|
@ -348,11 +346,11 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
flag = False
|
||||
|
||||
return tool_calls, is_valid, error_message
|
||||
return {"result": tool_calls, "status": is_valid, "message": "error_message"}
|
||||
|
||||
def _verify_tool_calls(
|
||||
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
|
||||
) -> Tuple[bool, Union[Dict[str, Any], None], str]:
|
||||
) -> Dict[str, any]:
|
||||
"""
|
||||
Verifies the validity of extracted tool calls against the provided tools.
|
||||
|
||||
|
|
@ -361,13 +359,13 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
tool_calls (List[Dict[str, Any]]): A list of tool calls to verify.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict[str, Any], None], str]:
|
||||
- A boolean indicating if the tool calls are valid.
|
||||
- The invalid tool call dictionary if any.
|
||||
- An error message.
|
||||
Dict: A dictionary of verification, including:
|
||||
- "status": A boolean indicating if the tool calls are valid.
|
||||
- "invalid_tool_call": A dictionary of the invalid tool call if any.
|
||||
- "message": An error message.
|
||||
"""
|
||||
|
||||
is_valid, error_tool_call, error_message = True, None, ""
|
||||
is_valid, invalid_tool_call, error_message = True, None, ""
|
||||
|
||||
functions = {}
|
||||
for tool in tools:
|
||||
|
|
@ -390,9 +388,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
for required_param in functions[func_name].get("required", []):
|
||||
if required_param not in func_args:
|
||||
is_valid = False
|
||||
error_tool_call = tool_call
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
|
||||
return is_valid, error_tool_call, error_message
|
||||
return is_valid, invalid_tool_call, error_message
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
for param_name, param_value in func_args:
|
||||
|
|
@ -403,46 +401,36 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
param_value, self.support_data_types[data_type]
|
||||
):
|
||||
is_valid = False
|
||||
error_tool_call = tool_call
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
|
||||
return is_valid, error_tool_call, error_message
|
||||
return is_valid, invalid_tool_call, error_message
|
||||
|
||||
return is_valid, error_tool_call, error_message
|
||||
return {
|
||||
"status": is_valid,
|
||||
"invalid_tool_call": invalid_tool_call,
|
||||
"message": error_message,
|
||||
}
|
||||
|
||||
def _prefill_response(self, messages: List[Dict[str, str]]):
|
||||
def _add_prefill_message(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
Prefills the response with the tool call prefix.
|
||||
Update messages and generation params for prompt prefilling
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, str]]): A list of messages.
|
||||
tools (List[Dict[str, Any]]): A list of tools.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: A list of messages with the prefill prefix.
|
||||
prefill_messages (List[Dict[str, str]]): A list of messages.
|
||||
"""
|
||||
|
||||
messages.append(
|
||||
return messages + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": random.choice(self.prefill_prefix),
|
||||
}
|
||||
)
|
||||
prefill_response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=False,
|
||||
extra_body={
|
||||
**self.generation_params,
|
||||
**self.prefill_params,
|
||||
},
|
||||
)
|
||||
|
||||
return prefill_response
|
||||
]
|
||||
|
||||
@override
|
||||
async def chat_completion(
|
||||
self, req: ChatMessage, enable_prefilling=True
|
||||
) -> ChatCompletionResponse:
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
Generates a chat completion response for a given request.
|
||||
|
||||
|
|
@ -459,60 +447,56 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
messages = self._process_messages(req.messages, req.tools)
|
||||
|
||||
# Retrieve the first token, handling the Stream object carefully
|
||||
# always enable `stream=True` to collect model responses
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=enable_prefilling,
|
||||
stream=True,
|
||||
extra_body=self.generation_params,
|
||||
)
|
||||
|
||||
model_response = ""
|
||||
model_response, has_tool_call = "", None
|
||||
|
||||
if enable_prefilling:
|
||||
has_tool_call = None
|
||||
for token in response:
|
||||
token_content = token.choices[0].delta.content.strip()
|
||||
|
||||
model_response = ""
|
||||
for token in response:
|
||||
token_content = token.choices[0].delta.content.strip()
|
||||
if token_content:
|
||||
if has_tool_call is None and token_content != "<tool_call>":
|
||||
has_tool_call = False
|
||||
response.close()
|
||||
break
|
||||
else:
|
||||
has_tool_call = True
|
||||
|
||||
if token_content:
|
||||
if has_tool_call is None and token_content != "<tool_call>":
|
||||
has_tool_call = False
|
||||
response.close()
|
||||
break
|
||||
else:
|
||||
has_tool_call = True
|
||||
if has_tool_call is True:
|
||||
model_response += token_content
|
||||
|
||||
if has_tool_call is True:
|
||||
model_response += token_content
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
if has_tool_call is False:
|
||||
prefill_response = self.client.chat.completions.create(
|
||||
messages=self._add_prefill_message(messages),
|
||||
model=self.model_name,
|
||||
extra_body={
|
||||
**self.generation_params,
|
||||
**self.prefill_params,
|
||||
},
|
||||
)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
|
||||
# start parameter gathering if the model is not generating a tool call
|
||||
if has_tool_call is False:
|
||||
prefill_response = self._prefill_response(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
else:
|
||||
model_response = response.choices[0].message.content
|
||||
# Extract tool calls from model response
|
||||
extracted = self._extract_tool_calls(model_response)
|
||||
|
||||
(
|
||||
tool_calls,
|
||||
extraction_status,
|
||||
extraction_error_message,
|
||||
) = self._extract_tool_calls(model_response)
|
||||
|
||||
if tool_calls:
|
||||
if extracted["tool_calls"]:
|
||||
# [TODO] Review: define the behavior in the case that tool call extraction fails
|
||||
# if not extraction_status:
|
||||
# if not extracted["status"]:
|
||||
|
||||
(
|
||||
verification_status,
|
||||
invalid_tool_call,
|
||||
verification_error_message,
|
||||
) = self._verify_tool_calls(tools=req.tools, tool_calls=tool_calls)
|
||||
verified = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=extracted["tool_calls"]
|
||||
)
|
||||
|
||||
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
|
||||
if verification_status:
|
||||
model_response = Message(content="", tool_calls=tool_calls)
|
||||
if verified["status"]:
|
||||
model_response = Message(content="", tool_calls=extracted["tool_calls"])
|
||||
# else:
|
||||
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue