Integrate Arch-Function-Chat (#449)

This commit is contained in:
Shuguang Chen 2025-04-15 14:39:12 -07:00 committed by GitHub
parent f31aa59fac
commit 7d4b261a68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 558 additions and 603 deletions

View file

@ -5,8 +5,6 @@ from src.core.guardrails import get_guardrail_handler
from src.core.function_calling import (
ArchAgentConfig,
ArchAgentHandler,
ArchIntentConfig,
ArchIntentHandler,
ArchFunctionConfig,
ArchFunctionHandler,
)
@ -17,7 +15,10 @@ logger = get_model_server_logger()
# Define the client
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
# ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
# use temporary endpoint until we deprecate archfc-v1.0 from archfc.katanemo.dev
# and officially release archfc-v1.1 on archfc.katanemo.dev
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "http://34.72.123.163:8000/v1")
ARCH_API_KEY = "EMPTY"
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
ARCH_AGENT_CLIENT = ARCH_CLIENT
@ -30,9 +31,6 @@ ARCH_GUARD_MODEL_ALIAS = "katanemo/Arch-Guard"
# Define model handlers
handler_map = {
"Arch-Intent": ArchIntentHandler(
ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig
),
"Arch-Function": ArchFunctionHandler(
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
),

View file

@ -3,7 +3,6 @@ import copy
import json
import random
import builtins
import textwrap
import src.commons.utils as utils
from openai import OpenAI
@ -22,179 +21,25 @@ from src.core.utils.model_utils import (
logger = utils.get_model_server_logger()
class ArchIntentConfig:
TASK_PROMPT = textwrap.dedent(
"""
You are a helpful assistant.
"""
).strip()
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
"""
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
<tools>
{tool_text}
</tools>
"""
).strip()
FORMAT_PROMPT = textwrap.dedent(
"""
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
- First line must read 'Yes' or 'No'.
- If yes, a second line must include a comma-separated list of tool indexes.
"""
).strip()
EXTRA_INSTRUCTION = "Are there any tools can help?"
GENERATION_PARAMS = {
"temperature": 0.01,
"max_tokens": 1,
"stop_token_ids": [151645],
}
class ArchIntentHandler(ArchBaseHandler):
def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig):
"""
Initializes the intent handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
config (ArchIntentConfig): The configuration for Arch-Intent.
"""
super().__init__(
client,
model_name,
config.TASK_PROMPT,
config.TOOL_PROMPT_TEMPLATE,
config.FORMAT_PROMPT,
config.GENERATION_PARAMS,
)
self.extra_instruction = config.EXTRA_INSTRUCTION
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into a JSON-like format with indexed keys.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A string representation of converted tools.
"""
converted = [
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
]
return "\n".join(converted)
def detect_intent(self, content: str) -> bool:
"""
Detect if any intent match with prompts
Args:
content: str: Model response that contains intent detection results
Returns:
bool: A boolean value to indicate if any intent match with prompts or not
"""
if hasattr(content.choices[0].message, "content"):
return content.choices[0].message.content == "Yes"
else:
return False
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Generates a chat completion for a given request.
Args:
req (ChatMessage): A chat message request object.
Returns:
ChatCompletionResponse: The model's response to the chat request.
Note:
Currently only support vllm inference
"""
logger.info("[Arch-Intent] - ChatCompletion")
# In the case that no tools are available, simply return `No` to avoid making a call
if len(req.tools) == 0:
model_response = Message(content="No", tool_calls=[])
logger.info("No tools found, return `No` as the model response.")
else:
messages = self._process_messages(
req.messages, req.tools, self.extra_instruction
)
logger.info(f"[request to arch-fc (intent)]: {json.dumps(messages)}")
model_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body=self.generation_params,
)
logger.info(f"[response]: {json.dumps(model_response.model_dump())}")
model_response = Message(
content=model_response.choices[0].message.content, tool_calls=[]
)
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
)
return chat_completion_response
# =============================================================================================================
# ==============================================================================================================================================
class ArchFunctionConfig:
TASK_PROMPT = textwrap.dedent(
"""
You are a helpful assistant.
TASK_PROMPT = (
"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed."
"\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{tools}\n</tools>"
"\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
)
Today's date: {}
""".format(
utils.get_today_date()
)
).strip()
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
"""
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
"""
).strip()
FORMAT_PROMPT = textwrap.dedent(
"""
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
"""
).strip()
FORMAT_PROMPT = (
"\n\nBased on your analysis, provide your response in one of the following JSON formats:"
'\n1. If no functions are needed:\n```json\n{"response": "Your response text here"}\n```'
'\n2. If functions are needed but some required parameters are missing:\n```json\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```'
'\n3. If functions are needed and all required parameters are available:\n```json\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```'
)
GENERATION_PARAMS = {
"temperature": 0.6,
"temperature": 0.1,
"top_p": 1.0,
"top_k": 10,
"max_tokens": 1024,
@ -203,34 +48,9 @@ class ArchFunctionConfig:
"top_logprobs": 10,
}
PREFILL_CONFIG = {
"prefill_params": {
"continue_final_message": True,
"add_generation_prompt": False,
},
"prefill_prefix": [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
],
}
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class ArchAgentConfig(ArchFunctionConfig):
GENERATION_PARAMS = {
"temperature": 0.01,
"stop_token_ids": [151645],
"logprobs": True,
"top_logprobs": 10,
}
class ArchFunctionHandler(ArchBaseHandler):
def __init__(
self,
@ -251,13 +71,17 @@ class ArchFunctionHandler(ArchBaseHandler):
client,
model_name,
config.TASK_PROMPT,
config.TOOL_PROMPT_TEMPLATE,
config.FORMAT_PROMPT,
config.GENERATION_PARAMS,
)
self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
self.generation_params = self.generation_params | {
"continue_final_message": True,
"add_generation_prompt": False,
}
self.default_prefix = '```json\n{"'
self.clarify_prefix = '```json\n{"required_functions":'
self.hallucination_state = None
@ -280,7 +104,7 @@ class ArchFunctionHandler(ArchBaseHandler):
str: A string representation of converted tools.
"""
converted = [json.dumps(tool) for tool in tools]
converted = [json.dumps(tool["function"], ensure_ascii=False) for tool in tools]
return "\n".join(converted)
def _fix_json_string(self, json_str: str) -> str:
@ -328,10 +152,14 @@ class ArchFunctionHandler(ArchBaseHandler):
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str.replace("'", '"')
try:
fixed_str = json.loads(fixed_str)
except Exception:
fixed_str = json.loads(fixed_str.replace("'", '"'))
def _extract_tool_calls(self, content: str) -> Dict[str, any]:
return json.dumps(fixed_str)
def _parse_model_response(self, content: str) -> Dict[str, any]:
"""
Extracts tool call information from a given string.
@ -340,49 +168,55 @@ class ArchFunctionHandler(ArchBaseHandler):
Returns:
Dict: A dictionary of extraction, including:
- "result": A list of tool call dictionaries.
- "status": A boolean indicating if the extraction was valid.
- "message": An error message or exception if extraction failed.
- "required_functions": A list of detected intents.
- "clarification": Text to collect missing parameters
- "tool_calls": A list of tool call dictionaries.
- "is_valid": A boolean indicating if the extraction was valid.
- "error_message": An error message or exception if parsing failed.
"""
tool_calls, is_valid, error_message = [], True, ""
response_dict = {
"raw_response": [],
"response": [],
"required_functions": [],
"clarification": "",
"tool_calls": [],
"is_valid": True,
"error_message": "",
}
flag = False
for line in content.split("\n"):
if not is_valid:
break
try:
if content.startswith("```") and content.endswith("```"):
content = content.strip("```").strip()
if content.startswith("json"):
content = content[4:].strip()
if "<tool_call>" == line:
flag = True
elif "</tool_call>" == line:
flag = False
else:
if flag:
try:
tool_content = json.loads(line)
except Exception as e:
fixed_content = self._fix_json_string(line)
try:
tool_content = json.loads(fixed_content)
except Exception:
is_valid, error_message = False, e
break
content = self._fix_json_string(content)
response_dict["raw_response"] = f"```json\n{content}\n```"
tool = {
model_response = json.loads(content)
response_dict["response"] = model_response.get("response", "")
response_dict["required_functions"] = model_response.get(
"required_functions", []
)
response_dict["clarification"] = model_response.get("clarification", "")
for tool_call in model_response.get("tool_calls", []):
response_dict["tool_calls"].append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
"name": tool_call.get("name", ""),
"arguments": tool_call.get("arguments", {}),
},
}
if "arguments" in tool_content:
tool["function"]["arguments"] = tool_content["arguments"]
)
except Exception as e:
response_dict["is_valid"] = False
response_dict["error_message"] = f"Fail to parse model responses: {e}"
tool_calls.append(tool)
flag = False
return {"result": tool_calls, "status": is_valid, "message": error_message}
return response_dict
def _convert_data_type(self, value: str, target_type: str):
# TODO: Add more conversion rules as needed
@ -414,36 +248,37 @@ class ArchFunctionHandler(ArchBaseHandler):
- "message": An error message.
"""
is_valid, invalid_tool_call, error_message = True, None, ""
verification_dict = {
"is_valid": True,
"invalid_tool_call": {},
"error_message": "",
}
functions = {}
for tool in tools:
if tool["type"] == "function":
functions[tool["function"]["name"]] = tool["function"]["parameters"]
functions[tool["function"]["name"]] = tool["function"]["parameters"]
for tool_call in tool_calls:
if not is_valid:
if not verification_dict["is_valid"]:
break
func_name = tool_call["function"]["name"]
func_args = tool_call["function"].get("arguments")
if not func_args:
func_args = {}
func_args = tool_call["function"]["arguments"]
# Check whether the function is available or not
if func_name not in functions:
is_valid = False
invalid_tool_call = tool_call
error_message = f"{func_name} is not defined!"
break
verification_dict["is_valid"] = False
verification_dict["invalid_tool_call"] = tool_call
verification_dict["error_message"] = f"{func_name} is not available!"
else:
# Check if all the requried parameters can be found in the tool calls
for required_param in functions[func_name].get("required", []):
if required_param not in func_args:
is_valid = False
invalid_tool_call = tool_call
error_message = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
verification_dict["is_valid"] = False
verification_dict["invalid_tool_call"] = tool_call
verification_dict[
"error_message"
] = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
break
# Verify the data type of each parameter in the tool calls
@ -453,9 +288,11 @@ class ArchFunctionHandler(ArchBaseHandler):
logger.info(func_args)
for param_name in func_args:
if param_name not in function_properties:
is_valid = False
invalid_tool_call = tool_call
error_message = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
verification_dict["is_valid"] = False
verification_dict["invalid_tool_call"] = tool_call
verification_dict[
"error_message"
] = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
break
else:
param_value = func_args[param_name]
@ -469,22 +306,22 @@ class ArchFunctionHandler(ArchBaseHandler):
param_value, data_type
)
if not isinstance(param_value, data_type):
is_valid = False
invalid_tool_call = tool_call
error_message = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
verification_dict["is_valid"] = False
verification_dict["invalid_tool_call"] = tool_call
verification_dict[
"error_message"
] = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
break
else:
error_message = (
f"Data type `{target_type}` is not supported."
)
verification_dict["is_valid"] = False
verification_dict["invalid_tool_call"] = tool_call
verification_dict[
"error_message"
] = f"Data type `{target_type}` is not supported."
return {
"status": is_valid,
"invalid_tool_call": invalid_tool_call,
"message": error_message,
}
return verification_dict
def _add_prefill_message(self, messages: List[Dict[str, str]]):
def _prefill_message(self, messages: List[Dict[str, str]], prefill_message):
"""
Update messages and generation params for prompt prefilling
@ -494,29 +331,7 @@ class ArchFunctionHandler(ArchBaseHandler):
Returns:
prefill_messages (List[Dict[str, str]]): A list of messages.
"""
return messages + [
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
]
def _engage_parameter_gathering(self, messages: List[Dict[str, str]]):
"""
Engage parameter gathering for tool calls
"""
# TODO: log enaging parameter gathering
prefill_response = self.client.chat.completions.create(
messages=self._add_prefill_message(messages),
model=self.model_name,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
return prefill_response
return messages + [{"role": "assistant", "content": prefill_message}]
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
@ -544,7 +359,7 @@ class ArchFunctionHandler(ArchBaseHandler):
# always enable `stream=True` to collect model responses
response = self.client.chat.completions.create(
messages=messages,
messages=self._prefill_message(messages, self.default_prefix),
model=self.model_name,
stream=True,
extra_body=self.generation_params,
@ -565,72 +380,114 @@ class ArchFunctionHandler(ArchBaseHandler):
has_tool_calls, has_hallucination = None, False
for _ in self.hallucination_state:
# check if the first token is <tool_call>
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
if self.hallucination_state.tokens[0] == "<tool_call>":
# check if moodel response starts with tool calls, we do it after 5 tokens because we only check the first part of the response.
if len(self.hallucination_state.tokens) > 5 and has_tool_calls is None:
content = "".join(self.hallucination_state.tokens)
if "tool_calls" in content:
has_tool_calls = True
else:
has_tool_calls = False
break
# if the model is hallucinating, start parameter gathering
if self.hallucination_state.hallucination is True:
has_hallucination = True
break
if has_tool_calls:
if has_hallucination:
# start prompt prefilling if hallcuination is found in tool calls
logger.info(
f"[Hallucination]: {self.hallucination_state.error_message}"
)
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
else:
model_response = "".join(self.hallucination_state.tokens)
if has_tool_calls and has_hallucination:
# start prompt prefilling if hallcuination is found in tool calls
logger.info(
f"[Hallucination]: {self.hallucination_state.error_message}"
)
response = self.client.chat.completions.create(
messages=self._prefill_message(messages, self.clarify_prefix),
model=self.model_name,
stream=False,
extra_body=self.generation_params,
)
model_response = response.choices[0].message.content
else:
# start parameter gathering if the model is not generating tool calls
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
model_response = "".join(self.hallucination_state.tokens)
# Extract tool calls from model response
extracted = self._extract_tool_calls(model_response)
response_dict = self._parse_model_response(model_response)
logger.info(f"[arch-fc]: raw model response: {response_dict['raw_response']}")
if extracted["status"]:
# Response with tool calls
if len(extracted["result"]):
verified = {}
if use_agent_orchestrator:
# skip tool call verification if using agent orchestrator
verified = {"status": True, "message": ""}
else:
verified = self._verify_tool_calls(
tools=req.tools, tool_calls=extracted["result"]
)
if verified["status"]:
logger.info(
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in extracted['result']])}"
)
model_response = Message(content="", tool_calls=extracted["result"])
else:
logger.error(f"Invalid tool call - {verified['message']}")
# Response without tool calls
# General model response
if response_dict.get("response", ""):
model_message = Message(content="", tool_calls=[])
# Parameter gathering
elif response_dict.get("required_functions", []):
if not use_agent_orchestrator:
clarification = response_dict.get("clarification", "")
model_message = Message(content=clarification, tool_calls=[])
else:
model_response = Message(content=model_response, tool_calls=[])
# Response with tool calls but contain errors
model_message = Message(content="", tool_calls=[])
# Function Calling
elif response_dict.get("tool_calls", []):
if response_dict["is_valid"]:
if not use_agent_orchestrator:
verification_dict = self._verify_tool_calls(
tools=req.tools, tool_calls=response_dict["tool_calls"]
)
if verification_dict["is_valid"]:
logger.info(
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in response_dict['tool_calls']])}"
)
model_message = Message(
content="", tool_calls=response_dict["tool_calls"]
)
else:
logger.error(
f"Invalid tool call - {verification_dict['error_message']}"
)
model_message = Message(content="", tool_calls=[])
else:
# skip tool call verification if using agent orchestrator
logger.info(
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in response_dict['tool_calls']])}"
)
model_message = Message(
content="", tool_calls=response_dict["tool_calls"]
)
else:
# Response with tool calls but invalid
model_message = Message(content="", tool_calls=[])
# Response not in the desired format
else:
logger.error(f"Tool call extraction error - {extracted['message']}")
logger.error(f"Invalid model response - {model_response}")
model_message = Message(content="", tool_calls=[])
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
choices=[Choice(message=model_message)],
model=self.model_name,
metadata={"x-arch-fc-model-response": response_dict["raw_response"]},
role="assistant",
)
logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}")
logger.info(
f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump(exclude_none=True))}"
)
return chat_completion_response
# ==============================================================================================================================================
class ArchAgentConfig(ArchFunctionConfig):
GENERATION_PARAMS = {
"temperature": 0.01,
"top_p": 1.0,
"top_k": 10,
"max_tokens": 1024,
"stop_token_ids": [151645],
"logprobs": True,
"top_logprobs": 10,
}
class ArchAgentHandler(ArchFunctionHandler):
def __init__(self, client: OpenAI, model_name: str, config: ArchAgentConfig):
super().__init__(client, model_name, config)
@ -657,7 +514,7 @@ class ArchAgentHandler(ArchFunctionHandler):
):
tool_copy = copy.deepcopy(tool)
del tool_copy["function"]["parameters"]
converted.append(json.dumps(tool_copy))
converted.append(json.dumps(tool_copy["function"], ensure_ascii=False))
else:
converted.append(json.dumps(tool))
converted.append(json.dumps(tool["function"], ensure_ascii=False))
return "\n".join(converted)

View file

@ -13,16 +13,15 @@ from src.commons.utils import get_model_server_logger
logger = get_model_server_logger()
# constants
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
FUNC_NAME_START_PATTERN = ('{"name":"', "{'name':'")
FUNC_NAME_END_TOKEN = ('",', "',")
TOOL_CALL_TOKEN = "<tool_call>"
END_TOOL_CALL_TOKEN = "</tool_call>"
END_TOOL_CALL_TOKEN = "}}"
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
PARAMETER_NAME_START_PATTERN = (',"', ",'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'", '":"', "':'")
PARAMETER_NAME_START_PATTERN = ('","', "','")
PARAMETER_VALUE_START_PATTERN = ('":', "':")
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
PARAMETER_VALUE_END_TOKEN = ('",', '"}')
BRACKETS = {"(": ")", "{": "}", "[": "]"}
@ -37,16 +36,9 @@ class MaskToken(Enum):
HALLUCINATION_THRESHOLD_DICT = {
MaskToken.TOOL_CALL.value: {
"entropy": 0.35,
"varentropy": 1.7,
"probability": 0.8,
},
MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.28,
"varentropy": 1.4,
"probability": 0.8,
},
"entropy": 0.0001,
"varentropy": 0.0001,
"probability": 0.8,
}
@ -160,6 +152,7 @@ class HallucinationState:
self._process_function(function)
self.open_bracket = False
self.bracket = None
self.function_name = ""
self.check_parameter_name = {}
self.HALLUCINATION_THRESHOLD_DICT = HALLUCINATION_THRESHOLD_DICT
@ -208,22 +201,20 @@ class HallucinationState:
r = next(self.response_iterator)
if hasattr(r.choices[0].delta, "content"):
token_content = r.choices[0].delta.content
if token_content:
if token_content != "":
try:
logprobs = [
p.logprob
for p in r.choices[0].logprobs.content[0].top_logprobs
]
except Exception as e:
raise ValueError(
f"Error extracting logprobs from response: {e}"
)
if token_content == END_TOOL_CALL_TOKEN:
self._reset_parameters()
else:
self.append_and_check_token_hallucination(
token_content, logprobs
)
except Exception as e:
self.append_and_check_token_hallucination(
token_content, [None]
)
return token_content
except StopIteration:
raise StopIteration
@ -234,12 +225,12 @@ class HallucinationState:
Detects hallucinations based on the token type and log probabilities.
"""
content = "".join(self.tokens).replace(" ", "")
if self.tokens[-1] == TOOL_CALL_TOKEN:
self.mask.append(MaskToken.TOOL_CALL)
self._check_logprob()
# Function name extraction logic
# If the state is function name and the token is not an end token, add to the mask
if content.endswith(END_TOOL_CALL_TOKEN):
self._reset_parameters()
if self.state == "function_name":
if self.tokens[-1] not in FUNC_NAME_END_TOKEN:
self.mask.append(MaskToken.FUNCTION_NAME)
@ -359,7 +350,7 @@ class HallucinationState:
if check_threshold(
entropy,
varentropy,
self.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value],
self.HALLUCINATION_THRESHOLD_DICT,
):
self.hallucination = True
self.error_message = f"token '{self.tokens[-1]}' is uncertain. Generated response:\n{''.join(self.tokens)}"

View file

@ -1,4 +1,5 @@
import json
import src.commons.utils as utils
from openai import OpenAI
from pydantic import BaseModel
@ -56,7 +57,6 @@ class ArchBaseHandler:
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
):
@ -67,7 +67,6 @@ class ArchBaseHandler:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
"""
@ -75,7 +74,6 @@ class ArchBaseHandler:
self.model_name = model_name
self.task_prompt = task_prompt
self.tool_prompt_template = tool_prompt_template
self.format_prompt = format_prompt
self.generation_params = generation_params
@ -105,13 +103,11 @@ class ArchBaseHandler:
str: A formatted system prompt.
"""
tool_text = self._convert_tools(tools)
today_date = utils.get_today_date()
tools = self._convert_tools(tools)
system_prompt = (
self.task_prompt
+ "\n\n"
+ self.tool_prompt_template.format(tool_text=tool_text)
+ "\n\n"
self.task_prompt.format(today_date=today_date, tools=tools)
+ self.format_prompt
)
@ -146,7 +142,7 @@ class ArchBaseHandler:
{"role": "system", "content": self._format_system_prompt(tools)}
)
for message in messages:
for idx, message in enumerate(messages):
role, content, tool_calls = (
message.role,
message.content,
@ -162,9 +158,24 @@ class ArchBaseHandler:
if metadata.get("optimize_context_window", "false").lower() == "true":
content = f"<tool_response>\n\n</tool_response>"
else:
content = (
f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
# sample response below
# "content": "<tool_response>\n{'name': 'get_stock_price', 'result': '$196.66'}\n</tool_response>"
# msg[idx-1] contains tool call = '{"tool_calls": [{"name": "currency_exchange", "arguments": {"currency_symbol": "NZD"}}]}'
tool_call_msg = messages[idx - 1].content
if tool_call_msg.startswith("```") and tool_call_msg.endswith(
"```"
):
tool_call_msg = tool_call_msg.strip("```").strip()
if tool_call_msg.startswith("json"):
tool_call_msg = tool_call_msg[4:].strip()
func_name = json.loads(tool_call_msg)["tool_calls"][0].get(
"name", "no_name"
)
tool_response = {
"name": func_name,
"result": content,
}
content = f"<tool_response>\n{json.dumps(tool_response)}\n</tool_response>"
processed_messages.append({"role": role, "content": content})

View file

@ -71,67 +71,58 @@ async def models():
@app.post("/function_calling")
async def function_calling(req: ChatMessage, res: Response):
logger.info("[Endpoint: /function_calling]")
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
final_response: ChatCompletionResponse = None
error_messages = None
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
try:
intent_detected = False
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
if not use_agent_orchestrator:
intent_start_time = time.perf_counter()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
intent_latency = time.perf_counter() - intent_start_time
intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response)
handler_name = "Arch-Agent" if use_agent_orchestrator else "Arch-Function"
model_handler: ArchFunctionHandler = handler_map[handler_name]
if use_agent_orchestrator or intent_detected:
# TODO: measure agreement between intent detection and function calling
try:
function_start_time = time.perf_counter()
handler_name = (
"Arch-Agent" if use_agent_orchestrator else "Arch-Function"
start_time = time.perf_counter()
final_response = await model_handler.chat_completion(req)
latency = time.perf_counter() - start_time
if not final_response.metadata:
final_response.metadata = {}
# Parameter gathering for detected intents
if final_response.choices[0].message.content:
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
# Function Calling
elif final_response.choices[0].message.tool_calls:
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
if not use_agent_orchestrator:
final_response.metadata["hallucination"] = str(
model_handler.hallucination_state.hallucination
)
function_calling_handler: ArchFunctionHandler = handler_map[
handler_name
]
final_response = await function_calling_handler.chat_completion(req)
function_latency = time.perf_counter() - function_start_time
final_response.metadata = {
"function_latency": str(round(function_latency * 1000, 3)),
}
if not use_agent_orchestrator:
final_response.metadata["intent_latency"] = str(
round(intent_latency * 1000, 3)
)
final_response.metadata["hallucination"] = str(
function_calling_handler.hallucination_state.hallucination
)
except ValueError as e:
res.statuscode = 503
error_messages = (
f"[{handler_name}] - Error in tool call extraction: {e}"
)
except StopIteration as e:
res.statuscode = 500
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
except Exception as e:
res.status_code = 500
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
raise
# No intent detected
else:
# no intent matched
intent_response.metadata = {
"intent_latency": str(round(intent_latency * 1000, 3)),
}
final_response = intent_response
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
if not use_agent_orchestrator:
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
final_response.metadata["hallucination"] = str(
model_handler.hallucination_state.hallucination
)
except ValueError as e:
res.statuscode = 503
error_messages = f"[{handler_name}] - Error in tool call extraction: {e}"
raise
except StopIteration as e:
res.statuscode = 500
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
raise
except Exception as e:
res.status_code = 500
error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}"
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
raise
if error_messages is not None:
@ -144,7 +135,7 @@ async def function_calling(req: ChatMessage, res: Response):
@app.post("/guardrails")
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
logger.info("[Endpoint: /guardrails] - Gateway")
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
final_response: GuardResponse = None
error_messages = None