Merge branch 'shuguang/main' of https://github.com/katanemo/arch into shuguang/main

This commit is contained in:
cotran 2024-12-10 18:04:56 -08:00
commit 2405fb36e3
17 changed files with 311 additions and 1243 deletions

View file

@ -43,7 +43,11 @@ class ArchIntentConfig:
EXTRA_INSTRUCTION = "Are there any tools can help?"
GENERATION_PARAMS = {"max_tokens": 1, "stop_token_ids": [151645]}
GENERATION_PARAMS = {
"temperature": 0.01,
"max_tokens": 1,
"stop_token_ids": [151645],
}
class ArchIntentHandler(ArchBaseHandler):
@ -318,6 +322,9 @@ class ArchFunctionHandler(ArchBaseHandler):
flag = False
for line in content.split("\n"):
if not is_valid:
break
if "<tool_call>" == line:
flag = True
elif "</tool_call>" == line:
@ -332,7 +339,7 @@ class ArchFunctionHandler(ArchBaseHandler):
tool_content = json.loads(fixed_content)
except Exception:
tool_calls, is_valid, error_message = [], False, e
return tool_calls, is_valid, error_message
break
tool_calls.append(
{
@ -347,7 +354,7 @@ class ArchFunctionHandler(ArchBaseHandler):
flag = False
return {"result": tool_calls, "status": is_valid, "message": "error_message"}
return {"result": tool_calls, "status": is_valid, "message": error_message}
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
@ -374,16 +381,19 @@ class ArchFunctionHandler(ArchBaseHandler):
functions[tool["function"]["name"]] = tool["function"]["parameters"]
for tool_call in tool_calls:
func_name, func_args = (
tool_call["function"]["name"],
tool_call["function"]["arguments"],
)
if not is_valid:
break
func_name = tool_call["function"]["name"]
func_args = tool_call["function"]["arguments"]
# Check whether the function is available or not
if func_name not in functions:
is_valid = False
invalid_tool_call = tool_call
error_message = f"{func_name} is not defined!"
return is_valid, error_message
break
else:
# Check if all the requried parameters can be found in the tool calls
for required_param in functions[func_name].get("required", []):
@ -391,7 +401,7 @@ class ArchFunctionHandler(ArchBaseHandler):
is_valid = False
invalid_tool_call = tool_call
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
return is_valid, invalid_tool_call, error_message
break
# Verify the data type of each parameter in the tool calls
for param_name in func_args:
@ -405,7 +415,7 @@ class ArchFunctionHandler(ArchBaseHandler):
is_valid = False
invalid_tool_call = tool_call
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
return is_valid, invalid_tool_call, error_message
break
return {
"status": is_valid,

View file

@ -30,6 +30,7 @@ class ChatCompletionResponse(BaseModel):
created: Optional[str] = ""
choices: List[Choice]
model: str
metadata: Optional[Dict[str, str]] = {}
class GuardRequest(BaseModel):

View file

@ -67,11 +67,12 @@ async def function_calling(req: ChatMessage, res: Response):
"Arch-Function"
].chat_completion(req)
function_latency = time.perf_counter() - function_start_time
return {
"response": function_calling_response,
"intent_latency": round(intent_latency * 1000, 3),
"function_latency": round(function_latency * 1000, 3),
function_calling_response.metadata = {
"intent_latency": str(round(intent_latency * 1000, 3)),
"function_latency": str(round(function_latency * 1000, 3)),
}
return function_calling_response
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
# logger.error(f"Error in chat_completion from `Arch-Function`: {e}")