mirror of
https://github.com/katanemo/plano.git
synced 2026-04-28 10:26:36 +02:00
add preliminary support for llm agents (#432)
This commit is contained in:
parent
8d66fefded
commit
84cd1df7bf
29 changed files with 1388 additions and 121 deletions
|
|
@ -3,6 +3,8 @@ from openai import OpenAI
|
|||
from src.commons.utils import get_model_server_logger
|
||||
from src.core.guardrails import get_guardrail_handler
|
||||
from src.core.function_calling import (
|
||||
ArchAgentConfig,
|
||||
ArchAgentHandler,
|
||||
ArchIntentConfig,
|
||||
ArchIntentHandler,
|
||||
ArchFunctionConfig,
|
||||
|
|
@ -18,10 +20,12 @@ logger = get_model_server_logger()
|
|||
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
|
||||
ARCH_API_KEY = "EMPTY"
|
||||
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
|
||||
ARCH_AGENT_CLIENT = ARCH_CLIENT
|
||||
|
||||
# Define model names
|
||||
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
|
||||
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
|
||||
ARCH_AGENT_MODEL_ALIAS = ARCH_FUNCTION_MODEL_ALIAS
|
||||
ARCH_GUARD_MODEL_ALIAS = "katanemo/Arch-Guard"
|
||||
|
||||
# Define model handlers
|
||||
|
|
@ -32,5 +36,8 @@ handler_map = {
|
|||
"Arch-Function": ArchFunctionHandler(
|
||||
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
|
||||
),
|
||||
"Arch-Agent": ArchAgentHandler(
|
||||
ARCH_AGENT_CLIENT, ARCH_AGENT_MODEL_ALIAS, ArchAgentConfig
|
||||
),
|
||||
"Arch-Guard": get_guardrail_handler(ARCH_GUARD_MODEL_ALIAS),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import ast
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import builtins
|
||||
|
|
@ -221,6 +222,15 @@ class ArchFunctionConfig:
|
|||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||
|
||||
|
||||
class ArchAgentConfig(ArchFunctionConfig):
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.01,
|
||||
"stop_token_ids": [151645],
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
|
||||
class ArchFunctionHandler(ArchBaseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -358,16 +368,17 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
is_valid, error_message = False, e
|
||||
break
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_content["name"],
|
||||
"arguments": tool_content["arguments"],
|
||||
},
|
||||
}
|
||||
)
|
||||
tool = {
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_content["name"],
|
||||
},
|
||||
}
|
||||
if "arguments" in tool_content:
|
||||
tool["function"]["arguments"] = tool_content["arguments"]
|
||||
|
||||
tool_calls.append(tool)
|
||||
|
||||
flag = False
|
||||
|
||||
|
|
@ -415,7 +426,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
break
|
||||
|
||||
func_name = tool_call["function"]["name"]
|
||||
func_args = tool_call["function"]["arguments"]
|
||||
func_args = tool_call["function"].get("arguments")
|
||||
if not func_args:
|
||||
func_args = {}
|
||||
|
||||
# Check whether the function is available or not
|
||||
if func_name not in functions:
|
||||
|
|
@ -430,12 +443,14 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if required_param not in func_args:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
|
||||
error_message = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
||||
break
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
function_properties = functions[func_name]["properties"]
|
||||
|
||||
logger.info("== func_args ==")
|
||||
logger.info(func_args)
|
||||
for param_name in func_args:
|
||||
if param_name not in function_properties:
|
||||
is_valid = False
|
||||
|
|
@ -523,7 +538,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
req.messages, req.tools, metadata=req.metadata
|
||||
)
|
||||
|
||||
logger.info(f"[request to arch-fc]: {json.dumps(messages)}")
|
||||
logger.info(
|
||||
f"[request to arch-fc]: model: {self.model_name}, extra_body: {self.generation_params}, body: {json.dumps(messages)}"
|
||||
)
|
||||
|
||||
# always enable `stream=True` to collect model responses
|
||||
response = self.client.chat.completions.create(
|
||||
|
|
@ -533,42 +550,48 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
extra_body=self.generation_params,
|
||||
)
|
||||
|
||||
# initialize the hallucination handler, which is an iterator
|
||||
self.hallucination_state = HallucinationState(
|
||||
response_iterator=response, function=req.tools
|
||||
)
|
||||
|
||||
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
|
||||
model_response = ""
|
||||
if use_agent_orchestrator:
|
||||
for chunk in response:
|
||||
if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
|
||||
model_response += chunk.choices[0].delta.content
|
||||
logger.info(f"[Agent Orchestrator]: response received: {model_response}")
|
||||
else:
|
||||
# initialize the hallucination handler, which is an iterator
|
||||
self.hallucination_state = HallucinationState(
|
||||
response_iterator=response, function=req.tools
|
||||
)
|
||||
|
||||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
|
||||
if self.hallucination_state.tokens[0] == "<tool_call>":
|
||||
has_tool_calls = True
|
||||
else:
|
||||
has_tool_calls = False
|
||||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
|
||||
if self.hallucination_state.tokens[0] == "<tool_call>":
|
||||
has_tool_calls = True
|
||||
else:
|
||||
has_tool_calls = False
|
||||
break
|
||||
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
if self.hallucination_state.hallucination is True:
|
||||
has_hallucination = True
|
||||
break
|
||||
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
if self.hallucination_state.hallucination is True:
|
||||
has_hallucination = True
|
||||
break
|
||||
|
||||
if has_tool_calls:
|
||||
if has_hallucination:
|
||||
# start prompt prefilling if hallcuination is found in tool calls
|
||||
logger.info(
|
||||
f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
)
|
||||
if has_tool_calls:
|
||||
if has_hallucination:
|
||||
# start prompt prefilling if hallcuination is found in tool calls
|
||||
logger.info(
|
||||
f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
)
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
else:
|
||||
model_response = "".join(self.hallucination_state.tokens)
|
||||
else:
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
else:
|
||||
model_response = "".join(self.hallucination_state.tokens)
|
||||
else:
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
|
||||
# Extract tool calls from model response
|
||||
extracted = self._extract_tool_calls(model_response)
|
||||
|
|
@ -576,9 +599,14 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if extracted["status"]:
|
||||
# Response with tool calls
|
||||
if len(extracted["result"]):
|
||||
verified = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=extracted["result"]
|
||||
)
|
||||
verified = {}
|
||||
if use_agent_orchestrator:
|
||||
# skip tool call verification if using agent orchestrator
|
||||
verified = {"status": True, "message": ""}
|
||||
else:
|
||||
verified = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=extracted["result"]
|
||||
)
|
||||
|
||||
if verified["status"]:
|
||||
logger.info(
|
||||
|
|
@ -601,3 +629,35 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}")
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
|
||||
class ArchAgentHandler(ArchFunctionHandler):
|
||||
def __init__(self, client: OpenAI, model_name: str, config: ArchAgentConfig):
|
||||
super().__init__(client, model_name, config)
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Converts a list of tools into JSON format.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||
|
||||
Returns:
|
||||
str: A string representation of converted tools.
|
||||
"""
|
||||
|
||||
converted = []
|
||||
# delete parameters key if its empty in tool
|
||||
for tool in tools:
|
||||
if (
|
||||
"parameters" in tool["function"]
|
||||
and "properties" in tool["function"]["parameters"]
|
||||
and not tool["function"]["parameters"]["properties"]
|
||||
):
|
||||
tool_copy = copy.deepcopy(tool)
|
||||
del tool_copy["function"]["parameters"]
|
||||
converted.append(json.dumps(tool_copy))
|
||||
else:
|
||||
converted.append(json.dumps(tool))
|
||||
return "\n".join(converted)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import logging
|
|||
import src.commons.utils as utils
|
||||
|
||||
from src.commons.globals import ARCH_ENDPOINT, handler_map
|
||||
from src.core.function_calling import ArchFunctionHandler
|
||||
from src.core.utils.model_utils import (
|
||||
ChatMessage,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -76,33 +77,51 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
error_messages = None
|
||||
|
||||
try:
|
||||
intent_start_time = time.perf_counter()
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
intent_latency = time.perf_counter() - intent_start_time
|
||||
intent_detected = False
|
||||
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
|
||||
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
|
||||
if not use_agent_orchestrator:
|
||||
intent_start_time = time.perf_counter()
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
intent_latency = time.perf_counter() - intent_start_time
|
||||
intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response)
|
||||
|
||||
if handler_map["Arch-Intent"].detect_intent(intent_response):
|
||||
if use_agent_orchestrator or intent_detected:
|
||||
# TODO: measure agreement between intent detection and function calling
|
||||
try:
|
||||
function_start_time = time.perf_counter()
|
||||
final_response = await handler_map["Arch-Function"].chat_completion(req)
|
||||
handler_name = (
|
||||
"Arch-Agent" if use_agent_orchestrator else "Arch-Function"
|
||||
)
|
||||
function_calling_handler: ArchFunctionHandler = handler_map[
|
||||
handler_name
|
||||
]
|
||||
final_response = await function_calling_handler.chat_completion(req)
|
||||
function_latency = time.perf_counter() - function_start_time
|
||||
|
||||
final_response.metadata = {
|
||||
"intent_latency": str(round(intent_latency * 1000, 3)),
|
||||
"function_latency": str(round(function_latency * 1000, 3)),
|
||||
"hallucination": str(
|
||||
handler_map["Arch-Function"].hallucination_state.hallucination
|
||||
),
|
||||
}
|
||||
|
||||
if not use_agent_orchestrator:
|
||||
final_response.metadata["intent_latency"] = str(
|
||||
round(intent_latency * 1000, 3)
|
||||
)
|
||||
final_response.metadata["hallucination"] = str(
|
||||
function_calling_handler.hallucination_state.hallucination
|
||||
)
|
||||
except ValueError as e:
|
||||
res.statuscode = 503
|
||||
error_messages = f"[Arch-Function] - Error in tool call extraction: {e}"
|
||||
error_messages = (
|
||||
f"[{handler_name}] - Error in tool call extraction: {e}"
|
||||
)
|
||||
except StopIteration as e:
|
||||
res.statuscode = 500
|
||||
error_messages = f"[Arch-Function] - Error in hallucination check: {e}"
|
||||
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
|
||||
except Exception as e:
|
||||
res.status_code = 500
|
||||
error_messages = f"[Arch-Function] - Error in ChatCompletion: {e}"
|
||||
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
|
||||
raise
|
||||
else:
|
||||
# no intent matched
|
||||
intent_response.metadata = {
|
||||
|
|
@ -113,6 +132,7 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
except Exception as e:
|
||||
res.status_code = 500
|
||||
error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}"
|
||||
raise
|
||||
|
||||
if error_messages is not None:
|
||||
logger.error(error_messages)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue