mirror of
https://github.com/katanemo/plano.git
synced 2026-06-20 15:28:07 +02:00
Merge branch 'shuguang/main' of https://github.com/katanemo/arch into shuguang/main
This commit is contained in:
commit
2405fb36e3
17 changed files with 311 additions and 1243 deletions
|
|
@ -43,7 +43,11 @@ class ArchIntentConfig:
|
|||
|
||||
EXTRA_INSTRUCTION = "Are there any tools can help?"
|
||||
|
||||
GENERATION_PARAMS = {"max_tokens": 1, "stop_token_ids": [151645]}
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.01,
|
||||
"max_tokens": 1,
|
||||
"stop_token_ids": [151645],
|
||||
}
|
||||
|
||||
|
||||
class ArchIntentHandler(ArchBaseHandler):
|
||||
|
|
@ -318,6 +322,9 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
flag = False
|
||||
for line in content.split("\n"):
|
||||
if not is_valid:
|
||||
break
|
||||
|
||||
if "<tool_call>" == line:
|
||||
flag = True
|
||||
elif "</tool_call>" == line:
|
||||
|
|
@ -332,7 +339,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
tool_content = json.loads(fixed_content)
|
||||
except Exception:
|
||||
tool_calls, is_valid, error_message = [], False, e
|
||||
return tool_calls, is_valid, error_message
|
||||
break
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
|
|
@ -347,7 +354,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
flag = False
|
||||
|
||||
return {"result": tool_calls, "status": is_valid, "message": "error_message"}
|
||||
return {"result": tool_calls, "status": is_valid, "message": error_message}
|
||||
|
||||
def _verify_tool_calls(
|
||||
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
|
||||
|
|
@ -374,16 +381,19 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
functions[tool["function"]["name"]] = tool["function"]["parameters"]
|
||||
|
||||
for tool_call in tool_calls:
|
||||
func_name, func_args = (
|
||||
tool_call["function"]["name"],
|
||||
tool_call["function"]["arguments"],
|
||||
)
|
||||
if not is_valid:
|
||||
break
|
||||
|
||||
func_name = tool_call["function"]["name"]
|
||||
func_args = tool_call["function"]["arguments"]
|
||||
|
||||
# Check whether the function is available or not
|
||||
if func_name not in functions:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"{func_name} is not defined!"
|
||||
return is_valid, error_message
|
||||
break
|
||||
|
||||
else:
|
||||
# Check if all the requried parameters can be found in the tool calls
|
||||
for required_param in functions[func_name].get("required", []):
|
||||
|
|
@ -391,7 +401,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
|
||||
return is_valid, invalid_tool_call, error_message
|
||||
break
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
for param_name in func_args:
|
||||
|
|
@ -405,7 +415,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
|
||||
return is_valid, invalid_tool_call, error_message
|
||||
break
|
||||
|
||||
return {
|
||||
"status": is_valid,
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ class ChatCompletionResponse(BaseModel):
|
|||
created: Optional[str] = ""
|
||||
choices: List[Choice]
|
||||
model: str
|
||||
metadata: Optional[Dict[str, str]] = {}
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
|
|
|
|||
|
|
@ -67,11 +67,12 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
"Arch-Function"
|
||||
].chat_completion(req)
|
||||
function_latency = time.perf_counter() - function_start_time
|
||||
return {
|
||||
"response": function_calling_response,
|
||||
"intent_latency": round(intent_latency * 1000, 3),
|
||||
"function_latency": round(function_latency * 1000, 3),
|
||||
function_calling_response.metadata = {
|
||||
"intent_latency": str(round(intent_latency * 1000, 3)),
|
||||
"function_latency": str(round(function_latency * 1000, 3)),
|
||||
}
|
||||
|
||||
return function_calling_response
|
||||
except Exception as e:
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
# logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue