add preliminary support for llm agents

This commit is contained in:
Adil Hafeez 2025-03-12 15:45:05 -07:00
parent ffb8566c36
commit 8104eac596
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
17 changed files with 1508 additions and 79 deletions

View file

@ -3,6 +3,8 @@ from openai import OpenAI
from src.commons.utils import get_model_server_logger
from src.core.guardrails import get_guardrail_handler
from src.core.function_calling import (
ArchAgentConfig,
ArchAgentHandler,
ArchIntentConfig,
ArchIntentHandler,
ArchFunctionConfig,
@ -18,10 +20,14 @@ logger = get_model_server_logger()
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
ARCH_API_KEY = "EMPTY"
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
ARCH_AGENT_CLIENT = ARCH_CLIENT
# ARCH_AGENT_CLIENT = OpenAI(api_key=os.getenv("OPENAI_API_KEY", "EMPTY"))
# Define model names
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
ARCH_AGENT_MODEL_ALIAS = ARCH_FUNCTION_MODEL_ALIAS
# ARCH_AGENT_MODEL_ALIAS = "gpt-4o-mini"
ARCH_GUARD_MODEL_ALIAS = "katanemo/Arch-Guard"
# Define model handlers
@ -32,5 +38,8 @@ handler_map = {
"Arch-Function": ArchFunctionHandler(
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
),
"Arch-Agent": ArchAgentHandler(
ARCH_AGENT_CLIENT, ARCH_AGENT_MODEL_ALIAS, ArchAgentConfig
),
"Arch-Guard": get_guardrail_handler(ARCH_GUARD_MODEL_ALIAS),
}

View file

@ -1,4 +1,5 @@
import ast
import copy
import json
import random
import builtins
@ -221,6 +222,41 @@ class ArchFunctionConfig:
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class ArchAgentConfig(ArchFunctionConfig):
TASK_PROMPT = textwrap.dedent(
"""
You will be given a list of tools and a user request. Your task is to match the user request with the most appropriate tool(s) based on the tool descriptions. Do not explain your reasoning, just provide the tool(s) that best match the user request.
"""
).strip()
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
"""
You will be presented with a list of tools and their descriptions:
<tools>
{tool_text}
</tools>
"""
).strip()
FORMAT_PROMPT = textwrap.dedent(
"""
Provide your answer in the following format:
For each function call, return a json object with function name <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>}
</tool_call>
"""
).strip()
GENERATION_PARAMS = {
"temperature": 0.01,
"stop_token_ids": [151645],
"logprobs": True,
"top_logprobs": 10,
}
class ArchFunctionHandler(ArchBaseHandler):
def __init__(
self,
@ -358,16 +394,17 @@ class ArchFunctionHandler(ArchBaseHandler):
is_valid, error_message = False, e
break
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
"arguments": tool_content["arguments"],
},
}
)
tool = {
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
},
}
if "arguments" in tool_content:
tool["function"]["arguments"] = tool_content["arguments"]
tool_calls.append(tool)
flag = False
@ -415,7 +452,9 @@ class ArchFunctionHandler(ArchBaseHandler):
break
func_name = tool_call["function"]["name"]
func_args = tool_call["function"]["arguments"]
func_args = tool_call["function"].get("arguments")
if not func_args:
func_args = {}
# Check whether the function is available or not
if func_name not in functions:
@ -430,12 +469,14 @@ class ArchFunctionHandler(ArchBaseHandler):
if required_param not in func_args:
is_valid = False
invalid_tool_call = tool_call
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
error_message = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
break
# Verify the data type of each parameter in the tool calls
function_properties = functions[func_name]["properties"]
logger.info("== func_args ==")
logger.info(func_args)
for param_name in func_args:
if param_name not in function_properties:
is_valid = False
@ -523,52 +564,60 @@ class ArchFunctionHandler(ArchBaseHandler):
req.messages, req.tools, metadata=req.metadata
)
logger.info(f"[request to arch-fc]: {json.dumps(messages)}")
logger.info(
f"[request to arch-fc]: model: {self.model_name}, extra_body: {self.generation_params}, body: {json.dumps(messages)}"
)
# always enable `stream=True` to collect model responses
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=True,
extra_body=self.generation_params,
)
# initialize the hallucination handler, which is an iterator
self.hallucination_state = HallucinationState(
response_iterator=response, function=req.tools
extra_body={"temperature": 0.01, "logprobs": True},
)
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
model_response = ""
if use_agent_orchestrator:
for chunk in response:
if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
model_response += chunk.choices[0].delta.content
logger.info(f"[Agent Orchestrator]: response received: {model_response}")
else:
# initialize the hallucination handler, which is an iterator
self.hallucination_state = HallucinationState(
response_iterator=response, function=req.tools
)
has_tool_calls, has_hallucination = None, False
for _ in self.hallucination_state:
# check if the first token is <tool_call>
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
if self.hallucination_state.tokens[0] == "<tool_call>":
has_tool_calls = True
else:
has_tool_calls = False
has_tool_calls, has_hallucination = None, False
for _ in self.hallucination_state:
# check if the first token is <tool_call>
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
if self.hallucination_state.tokens[0] == "<tool_call>":
has_tool_calls = True
else:
has_tool_calls = False
break
# if the model is hallucinating, start parameter gathering
if self.hallucination_state.hallucination is True:
has_hallucination = True
break
# if the model is hallucinating, start parameter gathering
if self.hallucination_state.hallucination is True:
has_hallucination = True
break
if has_tool_calls:
if has_hallucination:
# start prompt prefilling if hallcuination is found in tool calls
logger.info(
f"[Hallucination]: {self.hallucination_state.error_message}"
)
if has_tool_calls:
if has_hallucination:
# start prompt prefilling if hallcuination is found in tool calls
logger.info(
f"[Hallucination]: {self.hallucination_state.error_message}"
)
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
else:
model_response = "".join(self.hallucination_state.tokens)
else:
# start parameter gathering if the model is not generating tool calls
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
else:
model_response = "".join(self.hallucination_state.tokens)
else:
# start parameter gathering if the model is not generating tool calls
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
# Extract tool calls from model response
extracted = self._extract_tool_calls(model_response)
@ -601,3 +650,36 @@ class ArchFunctionHandler(ArchBaseHandler):
logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}")
return chat_completion_response
# override ArchFunctionHandler
class ArchAgentHandler(ArchFunctionHandler):
def __init__(self, client: OpenAI, model_name: str, config: ArchAgentConfig):
super().__init__(client, model_name, config)
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into JSON format.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A string representation of converted tools.
"""
converted = []
# delete parameters key if its empty in tool
for tool in tools:
if (
"parameters" in tool["function"]
and "properties" in tool["function"]["parameters"]
and not tool["function"]["parameters"]["properties"]
):
tool_copy = copy.deepcopy(tool)
del tool_copy["function"]["parameters"]
converted.append(json.dumps(tool_copy))
else:
converted.append(json.dumps(tool))
return "\n".join(converted)

View file

@ -5,6 +5,7 @@ import logging
import src.commons.utils as utils
from src.commons.globals import ARCH_ENDPOINT, handler_map
from src.core.function_calling import ArchFunctionHandler
from src.core.utils.model_utils import (
ChatMessage,
ChatCompletionResponse,
@ -76,33 +77,51 @@ async def function_calling(req: ChatMessage, res: Response):
error_messages = None
try:
intent_start_time = time.perf_counter()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
intent_latency = time.perf_counter() - intent_start_time
intent_detected = False
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
if not use_agent_orchestrator:
intent_start_time = time.perf_counter()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
intent_latency = time.perf_counter() - intent_start_time
intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response)
if handler_map["Arch-Intent"].detect_intent(intent_response):
if use_agent_orchestrator or intent_detected:
# TODO: measure agreement between intent detection and function calling
try:
function_start_time = time.perf_counter()
final_response = await handler_map["Arch-Function"].chat_completion(req)
handler_name = (
"Arch-Agent" if use_agent_orchestrator else "Arch-Function"
)
function_calling_handler: ArchFunctionHandler = handler_map[
handler_name
]
final_response = await function_calling_handler.chat_completion(req)
function_latency = time.perf_counter() - function_start_time
final_response.metadata = {
"intent_latency": str(round(intent_latency * 1000, 3)),
"function_latency": str(round(function_latency * 1000, 3)),
"hallucination": str(
handler_map["Arch-Function"].hallucination_state.hallucination
),
}
if not use_agent_orchestrator:
final_response.metadata["intent_latency"] = str(
round(intent_latency * 1000, 3)
)
final_response.metadata["hallucination"] = str(
function_calling_handler.hallucination_state.hallucination
)
except ValueError as e:
res.statuscode = 503
error_messages = f"[Arch-Function] - Error in tool call extraction: {e}"
error_messages = (
f"[{handler_name}] - Error in tool call extraction: {e}"
)
except StopIteration as e:
res.statuscode = 500
error_messages = f"[Arch-Function] - Error in hallucination check: {e}"
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
except Exception as e:
res.status_code = 500
error_messages = f"[Arch-Function] - Error in ChatCompletion: {e}"
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
raise
else:
# no intent matched
intent_response.metadata = {
@ -113,6 +132,7 @@ async def function_calling(req: ChatMessage, res: Response):
except Exception as e:
res.status_code = 500
error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}"
raise
if error_messages is not None:
logger.error(error_messages)