mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Init update
This commit is contained in:
parent
9f59943041
commit
cf30e94415
5 changed files with 508 additions and 632 deletions
|
|
@ -5,8 +5,6 @@ from src.core.guardrails import get_guardrail_handler
|
|||
from src.core.function_calling import (
|
||||
ArchAgentConfig,
|
||||
ArchAgentHandler,
|
||||
ArchIntentConfig,
|
||||
ArchIntentHandler,
|
||||
ArchFunctionConfig,
|
||||
ArchFunctionHandler,
|
||||
)
|
||||
|
|
@ -30,9 +28,6 @@ ARCH_GUARD_MODEL_ALIAS = "katanemo/Arch-Guard"
|
|||
|
||||
# Define model handlers
|
||||
handler_map = {
|
||||
"Arch-Intent": ArchIntentHandler(
|
||||
ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig
|
||||
),
|
||||
"Arch-Function": ArchFunctionHandler(
|
||||
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
|
||||
),
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import copy
|
|||
import json
|
||||
import random
|
||||
import builtins
|
||||
import textwrap
|
||||
import src.commons.utils as utils
|
||||
|
||||
from openai import OpenAI
|
||||
|
|
@ -22,176 +21,23 @@ from src.core.utils.model_utils import (
|
|||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
class ArchIntentConfig:
|
||||
TASK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
|
||||
"""
|
||||
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
|
||||
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
FORMAT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
|
||||
- First line must read 'Yes' or 'No'.
|
||||
- If yes, a second line must include a comma-separated list of tool indexes.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
EXTRA_INSTRUCTION = "Are there any tools can help?"
|
||||
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.01,
|
||||
"max_tokens": 1,
|
||||
"stop_token_ids": [151645],
|
||||
}
|
||||
|
||||
|
||||
class ArchIntentHandler(ArchBaseHandler):
|
||||
def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig):
|
||||
"""
|
||||
Initializes the intent handler.
|
||||
|
||||
Args:
|
||||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
config (ArchIntentConfig): The configuration for Arch-Intent.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
client,
|
||||
model_name,
|
||||
config.TASK_PROMPT,
|
||||
config.TOOL_PROMPT_TEMPLATE,
|
||||
config.FORMAT_PROMPT,
|
||||
config.GENERATION_PARAMS,
|
||||
)
|
||||
|
||||
self.extra_instruction = config.EXTRA_INSTRUCTION
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Converts a list of tools into a JSON-like format with indexed keys.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||
|
||||
Returns:
|
||||
str: A string representation of converted tools.
|
||||
"""
|
||||
|
||||
converted = [
|
||||
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
|
||||
]
|
||||
return "\n".join(converted)
|
||||
|
||||
def detect_intent(self, content: str) -> bool:
|
||||
"""
|
||||
Detect if any intent match with prompts
|
||||
|
||||
Args:
|
||||
content: str: Model response that contains intent detection results
|
||||
|
||||
Returns:
|
||||
bool: A boolean value to indicate if any intent match with prompts or not
|
||||
"""
|
||||
if hasattr(content.choices[0].message, "content"):
|
||||
return content.choices[0].message.content == "Yes"
|
||||
else:
|
||||
return False
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
Generates a chat completion for a given request.
|
||||
|
||||
Args:
|
||||
req (ChatMessage): A chat message request object.
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse: The model's response to the chat request.
|
||||
|
||||
Note:
|
||||
Currently only support vllm inference
|
||||
"""
|
||||
logger.info("[Arch-Intent] - ChatCompletion")
|
||||
|
||||
# In the case that no tools are available, simply return `No` to avoid making a call
|
||||
if len(req.tools) == 0:
|
||||
model_response = Message(content="No", tool_calls=[])
|
||||
logger.info("No tools found, return `No` as the model response.")
|
||||
else:
|
||||
messages = self._process_messages(
|
||||
req.messages, req.tools, self.extra_instruction
|
||||
)
|
||||
|
||||
logger.info(f"[request to arch-fc (intent)]: {json.dumps(messages)}")
|
||||
|
||||
model_response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=False,
|
||||
extra_body=self.generation_params,
|
||||
)
|
||||
|
||||
logger.info(f"[response]: {json.dumps(model_response.model_dump())}")
|
||||
|
||||
model_response = Message(
|
||||
content=model_response.choices[0].message.content, tool_calls=[]
|
||||
)
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_response)], model=self.model_name
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
|
||||
# =============================================================================================================
|
||||
# ==============================================================================================================================================
|
||||
|
||||
|
||||
class ArchFunctionConfig:
|
||||
TASK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant.
|
||||
TASK_PROMPT = (
|
||||
"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed."
|
||||
"\nToday's date: {today_date}"
|
||||
"\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}\n</tools>"
|
||||
"\n\nYour task is to decide which functions are needed and collect missing parameters if necessary.\n\n"
|
||||
)
|
||||
|
||||
Today's date: {}
|
||||
""".format(
|
||||
utils.get_today_date()
|
||||
)
|
||||
).strip()
|
||||
|
||||
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
|
||||
"""
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
FORMAT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
"""
|
||||
).strip()
|
||||
FORMAT_PROMPT = (
|
||||
"Based on your analysis, provide your response in one of the following JSON formats:"
|
||||
'\n1. If no functions are needed:\n```\n{"response": "Your response text here"}\n```'
|
||||
'\n2. If functions are needed but some required parameters are missing:\n```\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```'
|
||||
'\n3. If functions are needed and all required parameters are available:\n```\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```'
|
||||
)
|
||||
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.6,
|
||||
|
|
@ -222,15 +68,6 @@ class ArchFunctionConfig:
|
|||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||
|
||||
|
||||
class ArchAgentConfig(ArchFunctionConfig):
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.01,
|
||||
"stop_token_ids": [151645],
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
|
||||
class ArchFunctionHandler(ArchBaseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -251,7 +88,6 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
client,
|
||||
model_name,
|
||||
config.TASK_PROMPT,
|
||||
config.TOOL_PROMPT_TEMPLATE,
|
||||
config.FORMAT_PROMPT,
|
||||
config.GENERATION_PARAMS,
|
||||
)
|
||||
|
|
@ -280,7 +116,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
str: A string representation of converted tools.
|
||||
"""
|
||||
|
||||
converted = [json.dumps(tool) for tool in tools]
|
||||
converted = [json.dumps(tool["function"], ensure_ascii=False) for tool in tools]
|
||||
return "\n".join(converted)
|
||||
|
||||
def _fix_json_string(self, json_str: str) -> str:
|
||||
|
|
@ -331,7 +167,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str.replace("'", '"')
|
||||
|
||||
def _extract_tool_calls(self, content: str) -> Dict[str, any]:
|
||||
def _parse_model_resonse(self, content: str) -> Dict[str, any]:
|
||||
"""
|
||||
Extracts tool call information from a given string.
|
||||
|
||||
|
|
@ -340,49 +176,47 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
Returns:
|
||||
Dict: A dictionary of extraction, including:
|
||||
- "result": A list of tool call dictionaries.
|
||||
- "status": A boolean indicating if the extraction was valid.
|
||||
- "message": An error message or exception if extraction failed.
|
||||
- "required_functions": A list of detected intents.
|
||||
- "clarification": Text to collect missing parameters
|
||||
- "tool_calls": A list of tool call dictionaries.
|
||||
- "is_valid": A boolean indicating if the extraction was valid.
|
||||
- "error_message": An error message or exception if parsing failed.
|
||||
"""
|
||||
|
||||
tool_calls, is_valid, error_message = [], True, ""
|
||||
response_dict = {
|
||||
"response": [],
|
||||
"required_functions": [],
|
||||
"clarification": "",
|
||||
"tool_calls": [],
|
||||
"is_valid": True,
|
||||
"error_message": "",
|
||||
}
|
||||
|
||||
flag = False
|
||||
for line in content.split("\n"):
|
||||
if not is_valid:
|
||||
break
|
||||
try:
|
||||
model_response = json.loads(self._fix_json_string(content))
|
||||
|
||||
if "<tool_call>" == line:
|
||||
flag = True
|
||||
elif "</tool_call>" == line:
|
||||
flag = False
|
||||
else:
|
||||
if flag:
|
||||
try:
|
||||
tool_content = json.loads(line)
|
||||
except Exception as e:
|
||||
fixed_content = self._fix_json_string(line)
|
||||
try:
|
||||
tool_content = json.loads(fixed_content)
|
||||
except Exception:
|
||||
is_valid, error_message = False, e
|
||||
break
|
||||
response_dict["response"] = model_response.get("response", "")
|
||||
response_dict["required_functions"] = model_response.get(
|
||||
"required_functions", ""
|
||||
)
|
||||
response_dict["clarification"] = model_response.get("clarification", "")
|
||||
|
||||
tool = {
|
||||
for tool_call in model_response.get("tool_calls", []):
|
||||
response_dict["tool_call"].append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_content["name"],
|
||||
"name": tool_call.get("name", ""),
|
||||
"arguments": tool_call.get("arguments", {}),
|
||||
},
|
||||
}
|
||||
if "arguments" in tool_content:
|
||||
tool["function"]["arguments"] = tool_content["arguments"]
|
||||
)
|
||||
except Exception as e:
|
||||
response_dict["is_valid"] = False
|
||||
response_dict["error_message"] = f"Fail to parse model responses: {e}"
|
||||
|
||||
tool_calls.append(tool)
|
||||
|
||||
flag = False
|
||||
|
||||
return {"result": tool_calls, "status": is_valid, "message": error_message}
|
||||
return response_dict
|
||||
|
||||
def _convert_data_type(self, value: str, target_type: str):
|
||||
# TODO: Add more conversion rules as needed
|
||||
|
|
@ -414,7 +248,11 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
- "message": An error message.
|
||||
"""
|
||||
|
||||
is_valid, invalid_tool_call, error_message = True, None, ""
|
||||
verification_dict = {
|
||||
"is_valid": True,
|
||||
"invalid_tool_call": {},
|
||||
"error_message": "",
|
||||
}
|
||||
|
||||
functions = {}
|
||||
for tool in tools:
|
||||
|
|
@ -422,28 +260,26 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
functions[tool["function"]["name"]] = tool["function"]["parameters"]
|
||||
|
||||
for tool_call in tool_calls:
|
||||
if not is_valid:
|
||||
if not verification_dict["is_valid"]:
|
||||
break
|
||||
|
||||
func_name = tool_call["function"]["name"]
|
||||
func_args = tool_call["function"].get("arguments")
|
||||
if not func_args:
|
||||
func_args = {}
|
||||
func_args = tool_call["function"]["arguments"]
|
||||
|
||||
# Check whether the function is available or not
|
||||
if func_name not in functions:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"{func_name} is not defined!"
|
||||
break
|
||||
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict["error_message"] = f"{func_name} is not available!"
|
||||
else:
|
||||
# Check if all the requried parameters can be found in the tool calls
|
||||
for required_param in functions[func_name].get("required", []):
|
||||
if required_param not in func_args:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
||||
break
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
|
|
@ -453,9 +289,11 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
logger.info(func_args)
|
||||
for param_name in func_args:
|
||||
if param_name not in function_properties:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
||||
break
|
||||
else:
|
||||
param_value = func_args[param_name]
|
||||
|
|
@ -469,20 +307,20 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
param_value, data_type
|
||||
)
|
||||
if not isinstance(param_value, data_type):
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
|
||||
break
|
||||
else:
|
||||
error_message = (
|
||||
f"Data type `{target_type}` is not supported."
|
||||
)
|
||||
verification_dict["is_valid"] = False
|
||||
verification_dict["invalid_tool_call"] = tool_call
|
||||
verification_dict[
|
||||
"error_message"
|
||||
] = f"Data type `{target_type}` is not supported."
|
||||
|
||||
return {
|
||||
"status": is_valid,
|
||||
"invalid_tool_call": invalid_tool_call,
|
||||
"message": error_message,
|
||||
}
|
||||
return verification_dict
|
||||
|
||||
def _add_prefill_message(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
|
|
@ -558,72 +396,110 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
model_response += chunk.choices[0].delta.content
|
||||
logger.info(f"[Agent Orchestrator]: response received: {model_response}")
|
||||
else:
|
||||
# *********************************************************************************************\
|
||||
# Update the following logic for hallucination check
|
||||
# 1. If the model response starts wtth `tool_calls`, continue halluciantion check:
|
||||
# - If hallucination detected, start prompt prefilling
|
||||
# - Otherwise, continue until the end
|
||||
# 2. Otherwise, stop it
|
||||
# *********************************************************************************************
|
||||
|
||||
# initialize the hallucination handler, which is an iterator
|
||||
self.hallucination_state = HallucinationState(
|
||||
response_iterator=response, function=req.tools
|
||||
)
|
||||
|
||||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
|
||||
if self.hallucination_state.tokens[0] == "<tool_call>":
|
||||
has_tool_calls = True
|
||||
else:
|
||||
has_tool_calls = False
|
||||
break
|
||||
# self.hallucination_state = HallucinationState(
|
||||
# response_iterator=response, function=req.tools
|
||||
# )
|
||||
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
if self.hallucination_state.hallucination is True:
|
||||
has_hallucination = True
|
||||
break
|
||||
# has_tool_calls, has_hallucination = None, False
|
||||
# for _ in self.hallucination_state:
|
||||
# # check if the first token is <tool_call>
|
||||
# if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
|
||||
# if self.hallucination_state.tokens[0] == "<tool_call>":
|
||||
# has_tool_calls = True
|
||||
# else:
|
||||
# has_tool_calls = False
|
||||
# break
|
||||
|
||||
if has_tool_calls:
|
||||
if has_hallucination:
|
||||
# start prompt prefilling if hallcuination is found in tool calls
|
||||
logger.info(
|
||||
f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
)
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
else:
|
||||
model_response = "".join(self.hallucination_state.tokens)
|
||||
else:
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
# # if the model is hallucinating, start parameter gathering
|
||||
# if self.hallucination_state.hallucination is True:
|
||||
# has_hallucination = True
|
||||
# break
|
||||
|
||||
# if has_tool_calls:
|
||||
# if has_hallucination:
|
||||
# # start prompt prefilling if hallcuination is found in tool calls
|
||||
# logger.info(
|
||||
# f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
# )
|
||||
# prefill_response = self._engage_parameter_gathering(messages)
|
||||
# model_response = prefill_response.choices[0].message.content
|
||||
# else:
|
||||
# model_response = "".join(self.hallucination_state.tokens)
|
||||
# else:
|
||||
# # start parameter gathering if the model is not generating tool calls
|
||||
# prefill_response = self._engage_parameter_gathering(messages)
|
||||
# model_response = prefill_response.choices[0].message.content
|
||||
|
||||
# *********************************************************************************************\
|
||||
# Remove the following for loop after updating hallucination check
|
||||
# *********************************************************************************************
|
||||
for chunk in response:
|
||||
if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
|
||||
model_response += chunk.choices[0].delta.content
|
||||
|
||||
# Extract tool calls from model response
|
||||
extracted = self._extract_tool_calls(model_response)
|
||||
response_dict = self._parse_model_resonse(model_response)
|
||||
|
||||
if extracted["status"]:
|
||||
# Response with tool calls
|
||||
if len(extracted["result"]):
|
||||
verified = {}
|
||||
if use_agent_orchestrator:
|
||||
# skip tool call verification if using agent orchestrator
|
||||
verified = {"status": True, "message": ""}
|
||||
else:
|
||||
verified = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=extracted["result"]
|
||||
)
|
||||
|
||||
if verified["status"]:
|
||||
logger.info(
|
||||
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in extracted['result']])}"
|
||||
)
|
||||
model_response = Message(content="", tool_calls=extracted["result"])
|
||||
else:
|
||||
logger.error(f"Invalid tool call - {verified['message']}")
|
||||
# Response without tool calls
|
||||
if response_dict.get("response", ""):
|
||||
# General model response
|
||||
model_message = Message(content="", tool_calls=[])
|
||||
elif response_dict.get("required_functions", []):
|
||||
# Model response for parameter gathering
|
||||
if not use_agent_orchestrator:
|
||||
clarification = response_dict.get("clarification", "")
|
||||
model_message = Message(content=clarification, tool_calls=[])
|
||||
else:
|
||||
model_response = Message(content=model_response, tool_calls=[])
|
||||
# Response with tool calls but contain errors
|
||||
model_message = Message(content="", tool_calls=[])
|
||||
elif response_dict.get("tool_calls", []):
|
||||
# Response with tool calls
|
||||
if response_dict["is_valid"]:
|
||||
if not use_agent_orchestrator:
|
||||
verification_dict = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=response_dict["tool_calls"]
|
||||
)
|
||||
|
||||
if verification_dict["is_valid"]:
|
||||
logger.info(
|
||||
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in response_dict['tool_calls']])}"
|
||||
)
|
||||
model_message = Message(
|
||||
content="", tool_calls=response_dict["tool_calls"]
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Invalid tool call - {verification_dict['error_message']}"
|
||||
)
|
||||
model_message = Message(content="", tool_calls=[])
|
||||
else:
|
||||
# skip tool call verification if using agent orchestrator
|
||||
logger.info(
|
||||
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in response_dict['tool_calls']])}"
|
||||
)
|
||||
model_message = Message(
|
||||
content="", tool_calls=response_dict["tool_calls"]
|
||||
)
|
||||
|
||||
else:
|
||||
# Response with tool calls but contain errors
|
||||
model_message = Message(content="", tool_calls=[])
|
||||
else:
|
||||
logger.error(f"Tool call extraction error - {extracted['message']}")
|
||||
logger.error(f"Invalid model response - {model_response}")
|
||||
|
||||
# Response with tool calls but contain errors
|
||||
model_message = Message(content="", tool_calls=[])
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_response)], model=self.model_name
|
||||
choices=[Choice(message=model_message)], model=self.model_name
|
||||
)
|
||||
|
||||
logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}")
|
||||
|
|
@ -631,6 +507,21 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
return chat_completion_response
|
||||
|
||||
|
||||
# ==============================================================================================================================================
|
||||
|
||||
|
||||
class ArchAgentConfig(ArchFunctionConfig):
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.01,
|
||||
"top_p": 1.0,
|
||||
"top_k": 10,
|
||||
"max_tokens": 1024,
|
||||
"stop_token_ids": [151645],
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
|
||||
class ArchAgentHandler(ArchFunctionHandler):
|
||||
def __init__(self, client: OpenAI, model_name: str, config: ArchAgentConfig):
|
||||
super().__init__(client, model_name, config)
|
||||
|
|
@ -657,7 +548,7 @@ class ArchAgentHandler(ArchFunctionHandler):
|
|||
):
|
||||
tool_copy = copy.deepcopy(tool)
|
||||
del tool_copy["function"]["parameters"]
|
||||
converted.append(json.dumps(tool_copy))
|
||||
converted.append(json.dumps(tool_copy["function"], ensure_ascii=False))
|
||||
else:
|
||||
converted.append(json.dumps(tool))
|
||||
converted.append(json.dumps(tool["function"], ensure_ascii=False))
|
||||
return "\n".join(converted)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import src.commons.utils as utils
|
||||
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -56,7 +57,6 @@ class ArchBaseHandler:
|
|||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt_template: str,
|
||||
format_prompt: str,
|
||||
generation_params: Dict,
|
||||
):
|
||||
|
|
@ -67,7 +67,6 @@ class ArchBaseHandler:
|
|||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
"""
|
||||
|
|
@ -75,7 +74,6 @@ class ArchBaseHandler:
|
|||
self.model_name = model_name
|
||||
|
||||
self.task_prompt = task_prompt
|
||||
self.tool_prompt_template = tool_prompt_template
|
||||
self.format_prompt = format_prompt
|
||||
|
||||
self.generation_params = generation_params
|
||||
|
|
@ -105,13 +103,11 @@ class ArchBaseHandler:
|
|||
str: A formatted system prompt.
|
||||
"""
|
||||
|
||||
today_date = utils.get_today_date()
|
||||
tool_text = self._convert_tools(tools)
|
||||
|
||||
system_prompt = (
|
||||
self.task_prompt
|
||||
+ "\n\n"
|
||||
+ self.tool_prompt_template.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
self.task_prompt.format(today_date=today_date, tool_text=tool_text)
|
||||
+ self.format_prompt
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -76,62 +76,57 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
final_response: ChatCompletionResponse = None
|
||||
error_messages = None
|
||||
|
||||
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
|
||||
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
|
||||
# if not use_agent_orchestrator:
|
||||
# intent_start_time = time.perf_counter()
|
||||
# intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
# intent_latency = time.perf_counter() - intent_start_time
|
||||
# intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response)
|
||||
|
||||
try:
|
||||
intent_detected = False
|
||||
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
|
||||
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
|
||||
if not use_agent_orchestrator:
|
||||
intent_start_time = time.perf_counter()
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
intent_latency = time.perf_counter() - intent_start_time
|
||||
intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response)
|
||||
handler_name = "Arch-Agent" if use_agent_orchestrator else "Arch-Function"
|
||||
model_handler: ArchFunctionHandler = handler_map[handler_name]
|
||||
|
||||
if use_agent_orchestrator or intent_detected:
|
||||
# TODO: measure agreement between intent detection and function calling
|
||||
try:
|
||||
function_start_time = time.perf_counter()
|
||||
handler_name = (
|
||||
"Arch-Agent" if use_agent_orchestrator else "Arch-Function"
|
||||
)
|
||||
function_calling_handler: ArchFunctionHandler = handler_map[
|
||||
handler_name
|
||||
]
|
||||
final_response = await function_calling_handler.chat_completion(req)
|
||||
function_latency = time.perf_counter() - function_start_time
|
||||
start_time = time.perf_counter()
|
||||
final_response = await model_handler.chat_completion(req)
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
final_response.metadata = {
|
||||
"function_latency": str(round(function_latency * 1000, 3)),
|
||||
}
|
||||
|
||||
if not use_agent_orchestrator:
|
||||
final_response.metadata["intent_latency"] = str(
|
||||
round(intent_latency * 1000, 3)
|
||||
)
|
||||
final_response.metadata["hallucination"] = str(
|
||||
function_calling_handler.hallucination_state.hallucination
|
||||
)
|
||||
except ValueError as e:
|
||||
res.statuscode = 503
|
||||
error_messages = (
|
||||
f"[{handler_name}] - Error in tool call extraction: {e}"
|
||||
)
|
||||
except StopIteration as e:
|
||||
res.statuscode = 500
|
||||
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
|
||||
except Exception as e:
|
||||
res.status_code = 500
|
||||
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
|
||||
raise
|
||||
else:
|
||||
# no intent matched
|
||||
intent_response.metadata = {
|
||||
"intent_latency": str(round(intent_latency * 1000, 3)),
|
||||
# Parameter gathering for detected intents
|
||||
if final_response.choices[0].message.content:
|
||||
final_response.metadata = {
|
||||
"function_latency": str(round(latency * 1000, 3)),
|
||||
}
|
||||
# Function Calling
|
||||
elif final_response.choices[0].message.tool_calls:
|
||||
final_response.metadata = {
|
||||
"function_latency": str(round(latency * 1000, 3)),
|
||||
}
|
||||
if not use_agent_orchestrator:
|
||||
final_response.metadata["hallucination"] = str(
|
||||
model_handler.hallucination_state.hallucination
|
||||
)
|
||||
# No intent detected
|
||||
else:
|
||||
final_response.metadata = {
|
||||
"intent_latency": str(round(latency * 1000, 3)),
|
||||
}
|
||||
final_response = intent_response
|
||||
|
||||
if not use_agent_orchestrator:
|
||||
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
|
||||
final_response.metadata["hallucination"] = str(
|
||||
model_handler.hallucination_state.hallucination
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
res.statuscode = 503
|
||||
error_messages = f"[{handler_name}] - Error in tool call extraction: {e}"
|
||||
except StopIteration as e:
|
||||
res.statuscode = 500
|
||||
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
|
||||
except Exception as e:
|
||||
res.status_code = 500
|
||||
error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}"
|
||||
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
|
||||
raise
|
||||
|
||||
if error_messages is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue