add preliminary support for llm agents (#432)

This commit is contained in:
Adil Hafeez 2025-03-19 15:21:34 -07:00 committed by GitHub
parent 8d66fefded
commit 84cd1df7bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 1388 additions and 121 deletions

View file

@ -3,6 +3,8 @@ from openai import OpenAI
from src.commons.utils import get_model_server_logger
from src.core.guardrails import get_guardrail_handler
from src.core.function_calling import (
ArchAgentConfig,
ArchAgentHandler,
ArchIntentConfig,
ArchIntentHandler,
ArchFunctionConfig,
@ -18,10 +20,12 @@ logger = get_model_server_logger()
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
ARCH_API_KEY = "EMPTY"
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
ARCH_AGENT_CLIENT = ARCH_CLIENT
# Define model names
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
ARCH_AGENT_MODEL_ALIAS = ARCH_FUNCTION_MODEL_ALIAS
ARCH_GUARD_MODEL_ALIAS = "katanemo/Arch-Guard"
# Define model handlers
@ -32,5 +36,8 @@ handler_map = {
"Arch-Function": ArchFunctionHandler(
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
),
"Arch-Agent": ArchAgentHandler(
ARCH_AGENT_CLIENT, ARCH_AGENT_MODEL_ALIAS, ArchAgentConfig
),
"Arch-Guard": get_guardrail_handler(ARCH_GUARD_MODEL_ALIAS),
}

View file

@ -1,4 +1,5 @@
import ast
import copy
import json
import random
import builtins
@ -221,6 +222,15 @@ class ArchFunctionConfig:
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class ArchAgentConfig(ArchFunctionConfig):
GENERATION_PARAMS = {
"temperature": 0.01,
"stop_token_ids": [151645],
"logprobs": True,
"top_logprobs": 10,
}
class ArchFunctionHandler(ArchBaseHandler):
def __init__(
self,
@ -358,16 +368,17 @@ class ArchFunctionHandler(ArchBaseHandler):
is_valid, error_message = False, e
break
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
"arguments": tool_content["arguments"],
},
}
)
tool = {
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
},
}
if "arguments" in tool_content:
tool["function"]["arguments"] = tool_content["arguments"]
tool_calls.append(tool)
flag = False
@ -415,7 +426,9 @@ class ArchFunctionHandler(ArchBaseHandler):
break
func_name = tool_call["function"]["name"]
func_args = tool_call["function"]["arguments"]
func_args = tool_call["function"].get("arguments")
if not func_args:
func_args = {}
# Check whether the function is available or not
if func_name not in functions:
@ -430,12 +443,14 @@ class ArchFunctionHandler(ArchBaseHandler):
if required_param not in func_args:
is_valid = False
invalid_tool_call = tool_call
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
error_message = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
break
# Verify the data type of each parameter in the tool calls
function_properties = functions[func_name]["properties"]
logger.info("== func_args ==")
logger.info(func_args)
for param_name in func_args:
if param_name not in function_properties:
is_valid = False
@ -523,7 +538,9 @@ class ArchFunctionHandler(ArchBaseHandler):
req.messages, req.tools, metadata=req.metadata
)
logger.info(f"[request to arch-fc]: {json.dumps(messages)}")
logger.info(
f"[request to arch-fc]: model: {self.model_name}, extra_body: {self.generation_params}, body: {json.dumps(messages)}"
)
# always enable `stream=True` to collect model responses
response = self.client.chat.completions.create(
@ -533,42 +550,48 @@ class ArchFunctionHandler(ArchBaseHandler):
extra_body=self.generation_params,
)
# initialize the hallucination handler, which is an iterator
self.hallucination_state = HallucinationState(
response_iterator=response, function=req.tools
)
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
model_response = ""
if use_agent_orchestrator:
for chunk in response:
if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
model_response += chunk.choices[0].delta.content
logger.info(f"[Agent Orchestrator]: response received: {model_response}")
else:
# initialize the hallucination handler, which is an iterator
self.hallucination_state = HallucinationState(
response_iterator=response, function=req.tools
)
has_tool_calls, has_hallucination = None, False
for _ in self.hallucination_state:
# check if the first token is <tool_call>
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
if self.hallucination_state.tokens[0] == "<tool_call>":
has_tool_calls = True
else:
has_tool_calls = False
has_tool_calls, has_hallucination = None, False
for _ in self.hallucination_state:
# check if the first token is <tool_call>
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
if self.hallucination_state.tokens[0] == "<tool_call>":
has_tool_calls = True
else:
has_tool_calls = False
break
# if the model is hallucinating, start parameter gathering
if self.hallucination_state.hallucination is True:
has_hallucination = True
break
# if the model is hallucinating, start parameter gathering
if self.hallucination_state.hallucination is True:
has_hallucination = True
break
if has_tool_calls:
if has_hallucination:
# start prompt prefilling if hallcuination is found in tool calls
logger.info(
f"[Hallucination]: {self.hallucination_state.error_message}"
)
if has_tool_calls:
if has_hallucination:
# start prompt prefilling if hallcuination is found in tool calls
logger.info(
f"[Hallucination]: {self.hallucination_state.error_message}"
)
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
else:
model_response = "".join(self.hallucination_state.tokens)
else:
# start parameter gathering if the model is not generating tool calls
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
else:
model_response = "".join(self.hallucination_state.tokens)
else:
# start parameter gathering if the model is not generating tool calls
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
# Extract tool calls from model response
extracted = self._extract_tool_calls(model_response)
@ -576,9 +599,14 @@ class ArchFunctionHandler(ArchBaseHandler):
if extracted["status"]:
# Response with tool calls
if len(extracted["result"]):
verified = self._verify_tool_calls(
tools=req.tools, tool_calls=extracted["result"]
)
verified = {}
if use_agent_orchestrator:
# skip tool call verification if using agent orchestrator
verified = {"status": True, "message": ""}
else:
verified = self._verify_tool_calls(
tools=req.tools, tool_calls=extracted["result"]
)
if verified["status"]:
logger.info(
@ -601,3 +629,35 @@ class ArchFunctionHandler(ArchBaseHandler):
logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}")
return chat_completion_response
class ArchAgentHandler(ArchFunctionHandler):
def __init__(self, client: OpenAI, model_name: str, config: ArchAgentConfig):
super().__init__(client, model_name, config)
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into JSON format.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A string representation of converted tools.
"""
converted = []
# delete parameters key if its empty in tool
for tool in tools:
if (
"parameters" in tool["function"]
and "properties" in tool["function"]["parameters"]
and not tool["function"]["parameters"]["properties"]
):
tool_copy = copy.deepcopy(tool)
del tool_copy["function"]["parameters"]
converted.append(json.dumps(tool_copy))
else:
converted.append(json.dumps(tool))
return "\n".join(converted)

View file

@ -5,6 +5,7 @@ import logging
import src.commons.utils as utils
from src.commons.globals import ARCH_ENDPOINT, handler_map
from src.core.function_calling import ArchFunctionHandler
from src.core.utils.model_utils import (
ChatMessage,
ChatCompletionResponse,
@ -76,33 +77,51 @@ async def function_calling(req: ChatMessage, res: Response):
error_messages = None
try:
intent_start_time = time.perf_counter()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
intent_latency = time.perf_counter() - intent_start_time
intent_detected = False
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
if not use_agent_orchestrator:
intent_start_time = time.perf_counter()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
intent_latency = time.perf_counter() - intent_start_time
intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response)
if handler_map["Arch-Intent"].detect_intent(intent_response):
if use_agent_orchestrator or intent_detected:
# TODO: measure agreement between intent detection and function calling
try:
function_start_time = time.perf_counter()
final_response = await handler_map["Arch-Function"].chat_completion(req)
handler_name = (
"Arch-Agent" if use_agent_orchestrator else "Arch-Function"
)
function_calling_handler: ArchFunctionHandler = handler_map[
handler_name
]
final_response = await function_calling_handler.chat_completion(req)
function_latency = time.perf_counter() - function_start_time
final_response.metadata = {
"intent_latency": str(round(intent_latency * 1000, 3)),
"function_latency": str(round(function_latency * 1000, 3)),
"hallucination": str(
handler_map["Arch-Function"].hallucination_state.hallucination
),
}
if not use_agent_orchestrator:
final_response.metadata["intent_latency"] = str(
round(intent_latency * 1000, 3)
)
final_response.metadata["hallucination"] = str(
function_calling_handler.hallucination_state.hallucination
)
except ValueError as e:
res.statuscode = 503
error_messages = f"[Arch-Function] - Error in tool call extraction: {e}"
error_messages = (
f"[{handler_name}] - Error in tool call extraction: {e}"
)
except StopIteration as e:
res.statuscode = 500
error_messages = f"[Arch-Function] - Error in hallucination check: {e}"
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
except Exception as e:
res.status_code = 500
error_messages = f"[Arch-Function] - Error in ChatCompletion: {e}"
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
raise
else:
# no intent matched
intent_response.metadata = {
@ -113,6 +132,7 @@ async def function_calling(req: ChatMessage, res: Response):
except Exception as e:
res.status_code = 500
error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}"
raise
if error_messages is not None:
logger.error(error_messages)