mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add preliminary support for llm agents
This commit is contained in:
parent
ffb8566c36
commit
8104eac596
17 changed files with 1508 additions and 79 deletions
|
|
@ -3,6 +3,8 @@ from openai import OpenAI
|
|||
from src.commons.utils import get_model_server_logger
|
||||
from src.core.guardrails import get_guardrail_handler
|
||||
from src.core.function_calling import (
|
||||
ArchAgentConfig,
|
||||
ArchAgentHandler,
|
||||
ArchIntentConfig,
|
||||
ArchIntentHandler,
|
||||
ArchFunctionConfig,
|
||||
|
|
@ -18,10 +20,14 @@ logger = get_model_server_logger()
|
|||
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
|
||||
ARCH_API_KEY = "EMPTY"
|
||||
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
|
||||
ARCH_AGENT_CLIENT = ARCH_CLIENT
|
||||
# ARCH_AGENT_CLIENT = OpenAI(api_key=os.getenv("OPENAI_API_KEY", "EMPTY"))
|
||||
|
||||
# Define model names
|
||||
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
|
||||
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
|
||||
ARCH_AGENT_MODEL_ALIAS = ARCH_FUNCTION_MODEL_ALIAS
|
||||
# ARCH_AGENT_MODEL_ALIAS = "gpt-4o-mini"
|
||||
ARCH_GUARD_MODEL_ALIAS = "katanemo/Arch-Guard"
|
||||
|
||||
# Define model handlers
|
||||
|
|
@ -32,5 +38,8 @@ handler_map = {
|
|||
"Arch-Function": ArchFunctionHandler(
|
||||
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
|
||||
),
|
||||
"Arch-Agent": ArchAgentHandler(
|
||||
ARCH_AGENT_CLIENT, ARCH_AGENT_MODEL_ALIAS, ArchAgentConfig
|
||||
),
|
||||
"Arch-Guard": get_guardrail_handler(ARCH_GUARD_MODEL_ALIAS),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import ast
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import builtins
|
||||
|
|
@ -221,6 +222,41 @@ class ArchFunctionConfig:
|
|||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||
|
||||
|
||||
class ArchAgentConfig(ArchFunctionConfig):
|
||||
TASK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You will be given a list of tools and a user request. Your task is to match the user request with the most appropriate tool(s) based on the tool descriptions. Do not explain your reasoning, just provide the tool(s) that best match the user request.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
|
||||
"""
|
||||
You will be presented with a list of tools and their descriptions:
|
||||
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
FORMAT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
Provide your answer in the following format:
|
||||
For each function call, return a json object with function name <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>}
|
||||
</tool_call>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.01,
|
||||
"stop_token_ids": [151645],
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
|
||||
class ArchFunctionHandler(ArchBaseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -358,16 +394,17 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
is_valid, error_message = False, e
|
||||
break
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_content["name"],
|
||||
"arguments": tool_content["arguments"],
|
||||
},
|
||||
}
|
||||
)
|
||||
tool = {
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_content["name"],
|
||||
},
|
||||
}
|
||||
if "arguments" in tool_content:
|
||||
tool["function"]["arguments"] = tool_content["arguments"]
|
||||
|
||||
tool_calls.append(tool)
|
||||
|
||||
flag = False
|
||||
|
||||
|
|
@ -415,7 +452,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
break
|
||||
|
||||
func_name = tool_call["function"]["name"]
|
||||
func_args = tool_call["function"]["arguments"]
|
||||
func_args = tool_call["function"].get("arguments")
|
||||
if not func_args:
|
||||
func_args = {}
|
||||
|
||||
# Check whether the function is available or not
|
||||
if func_name not in functions:
|
||||
|
|
@ -430,12 +469,14 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if required_param not in func_args:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
|
||||
error_message = f"`{required_param}` is required by the function `{func_name}` but not found in the tool call!"
|
||||
break
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
function_properties = functions[func_name]["properties"]
|
||||
|
||||
logger.info("== func_args ==")
|
||||
logger.info(func_args)
|
||||
for param_name in func_args:
|
||||
if param_name not in function_properties:
|
||||
is_valid = False
|
||||
|
|
@ -523,52 +564,60 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
req.messages, req.tools, metadata=req.metadata
|
||||
)
|
||||
|
||||
logger.info(f"[request to arch-fc]: {json.dumps(messages)}")
|
||||
logger.info(
|
||||
f"[request to arch-fc]: model: {self.model_name}, extra_body: {self.generation_params}, body: {json.dumps(messages)}"
|
||||
)
|
||||
|
||||
# always enable `stream=True` to collect model responses
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=True,
|
||||
extra_body=self.generation_params,
|
||||
)
|
||||
|
||||
# initialize the hallucination handler, which is an iterator
|
||||
self.hallucination_state = HallucinationState(
|
||||
response_iterator=response, function=req.tools
|
||||
extra_body={"temperature": 0.01, "logprobs": True},
|
||||
)
|
||||
|
||||
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
|
||||
model_response = ""
|
||||
if use_agent_orchestrator:
|
||||
for chunk in response:
|
||||
if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
|
||||
model_response += chunk.choices[0].delta.content
|
||||
logger.info(f"[Agent Orchestrator]: response received: {model_response}")
|
||||
else:
|
||||
# initialize the hallucination handler, which is an iterator
|
||||
self.hallucination_state = HallucinationState(
|
||||
response_iterator=response, function=req.tools
|
||||
)
|
||||
|
||||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
|
||||
if self.hallucination_state.tokens[0] == "<tool_call>":
|
||||
has_tool_calls = True
|
||||
else:
|
||||
has_tool_calls = False
|
||||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
|
||||
if self.hallucination_state.tokens[0] == "<tool_call>":
|
||||
has_tool_calls = True
|
||||
else:
|
||||
has_tool_calls = False
|
||||
break
|
||||
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
if self.hallucination_state.hallucination is True:
|
||||
has_hallucination = True
|
||||
break
|
||||
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
if self.hallucination_state.hallucination is True:
|
||||
has_hallucination = True
|
||||
break
|
||||
|
||||
if has_tool_calls:
|
||||
if has_hallucination:
|
||||
# start prompt prefilling if hallcuination is found in tool calls
|
||||
logger.info(
|
||||
f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
)
|
||||
if has_tool_calls:
|
||||
if has_hallucination:
|
||||
# start prompt prefilling if hallcuination is found in tool calls
|
||||
logger.info(
|
||||
f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
)
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
else:
|
||||
model_response = "".join(self.hallucination_state.tokens)
|
||||
else:
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
else:
|
||||
model_response = "".join(self.hallucination_state.tokens)
|
||||
else:
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
|
||||
# Extract tool calls from model response
|
||||
extracted = self._extract_tool_calls(model_response)
|
||||
|
|
@ -601,3 +650,36 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}")
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
|
||||
# override ArchFunctionHandler
|
||||
class ArchAgentHandler(ArchFunctionHandler):
|
||||
def __init__(self, client: OpenAI, model_name: str, config: ArchAgentConfig):
|
||||
super().__init__(client, model_name, config)
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Converts a list of tools into JSON format.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||
|
||||
Returns:
|
||||
str: A string representation of converted tools.
|
||||
"""
|
||||
|
||||
converted = []
|
||||
# delete parameters key if its empty in tool
|
||||
for tool in tools:
|
||||
if (
|
||||
"parameters" in tool["function"]
|
||||
and "properties" in tool["function"]["parameters"]
|
||||
and not tool["function"]["parameters"]["properties"]
|
||||
):
|
||||
tool_copy = copy.deepcopy(tool)
|
||||
del tool_copy["function"]["parameters"]
|
||||
converted.append(json.dumps(tool_copy))
|
||||
else:
|
||||
converted.append(json.dumps(tool))
|
||||
return "\n".join(converted)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import logging
|
|||
import src.commons.utils as utils
|
||||
|
||||
from src.commons.globals import ARCH_ENDPOINT, handler_map
|
||||
from src.core.function_calling import ArchFunctionHandler
|
||||
from src.core.utils.model_utils import (
|
||||
ChatMessage,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -76,33 +77,51 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
error_messages = None
|
||||
|
||||
try:
|
||||
intent_start_time = time.perf_counter()
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
intent_latency = time.perf_counter() - intent_start_time
|
||||
intent_detected = False
|
||||
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
|
||||
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
|
||||
if not use_agent_orchestrator:
|
||||
intent_start_time = time.perf_counter()
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
intent_latency = time.perf_counter() - intent_start_time
|
||||
intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response)
|
||||
|
||||
if handler_map["Arch-Intent"].detect_intent(intent_response):
|
||||
if use_agent_orchestrator or intent_detected:
|
||||
# TODO: measure agreement between intent detection and function calling
|
||||
try:
|
||||
function_start_time = time.perf_counter()
|
||||
final_response = await handler_map["Arch-Function"].chat_completion(req)
|
||||
handler_name = (
|
||||
"Arch-Agent" if use_agent_orchestrator else "Arch-Function"
|
||||
)
|
||||
function_calling_handler: ArchFunctionHandler = handler_map[
|
||||
handler_name
|
||||
]
|
||||
final_response = await function_calling_handler.chat_completion(req)
|
||||
function_latency = time.perf_counter() - function_start_time
|
||||
|
||||
final_response.metadata = {
|
||||
"intent_latency": str(round(intent_latency * 1000, 3)),
|
||||
"function_latency": str(round(function_latency * 1000, 3)),
|
||||
"hallucination": str(
|
||||
handler_map["Arch-Function"].hallucination_state.hallucination
|
||||
),
|
||||
}
|
||||
|
||||
if not use_agent_orchestrator:
|
||||
final_response.metadata["intent_latency"] = str(
|
||||
round(intent_latency * 1000, 3)
|
||||
)
|
||||
final_response.metadata["hallucination"] = str(
|
||||
function_calling_handler.hallucination_state.hallucination
|
||||
)
|
||||
except ValueError as e:
|
||||
res.statuscode = 503
|
||||
error_messages = f"[Arch-Function] - Error in tool call extraction: {e}"
|
||||
error_messages = (
|
||||
f"[{handler_name}] - Error in tool call extraction: {e}"
|
||||
)
|
||||
except StopIteration as e:
|
||||
res.statuscode = 500
|
||||
error_messages = f"[Arch-Function] - Error in hallucination check: {e}"
|
||||
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
|
||||
except Exception as e:
|
||||
res.status_code = 500
|
||||
error_messages = f"[Arch-Function] - Error in ChatCompletion: {e}"
|
||||
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
|
||||
raise
|
||||
else:
|
||||
# no intent matched
|
||||
intent_response.metadata = {
|
||||
|
|
@ -113,6 +132,7 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
except Exception as e:
|
||||
res.status_code = 500
|
||||
error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}"
|
||||
raise
|
||||
|
||||
if error_messages is not None:
|
||||
logger.error(error_messages)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue