Integrate Arch-Function-Chat (#449)

This commit is contained in:
Shuguang Chen 2025-04-15 14:39:12 -07:00 committed by GitHub
parent f31aa59fac
commit 7d4b261a68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 558 additions and 603 deletions

View file

@ -71,67 +71,58 @@ async def models():
@app.post("/function_calling")
async def function_calling(req: ChatMessage, res: Response):
logger.info("[Endpoint: /function_calling]")
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
final_response: ChatCompletionResponse = None
error_messages = None
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
try:
intent_detected = False
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)
logger.info(f"Use agent orchestrator: {use_agent_orchestrator}")
if not use_agent_orchestrator:
intent_start_time = time.perf_counter()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
intent_latency = time.perf_counter() - intent_start_time
intent_detected = handler_map["Arch-Intent"].detect_intent(intent_response)
handler_name = "Arch-Agent" if use_agent_orchestrator else "Arch-Function"
model_handler: ArchFunctionHandler = handler_map[handler_name]
if use_agent_orchestrator or intent_detected:
# TODO: measure agreement between intent detection and function calling
try:
function_start_time = time.perf_counter()
handler_name = (
"Arch-Agent" if use_agent_orchestrator else "Arch-Function"
start_time = time.perf_counter()
final_response = await model_handler.chat_completion(req)
latency = time.perf_counter() - start_time
if not final_response.metadata:
final_response.metadata = {}
# Parameter gathering for detected intents
if final_response.choices[0].message.content:
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
# Function Calling
elif final_response.choices[0].message.tool_calls:
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
if not use_agent_orchestrator:
final_response.metadata["hallucination"] = str(
model_handler.hallucination_state.hallucination
)
function_calling_handler: ArchFunctionHandler = handler_map[
handler_name
]
final_response = await function_calling_handler.chat_completion(req)
function_latency = time.perf_counter() - function_start_time
final_response.metadata = {
"function_latency": str(round(function_latency * 1000, 3)),
}
if not use_agent_orchestrator:
final_response.metadata["intent_latency"] = str(
round(intent_latency * 1000, 3)
)
final_response.metadata["hallucination"] = str(
function_calling_handler.hallucination_state.hallucination
)
except ValueError as e:
res.statuscode = 503
error_messages = (
f"[{handler_name}] - Error in tool call extraction: {e}"
)
except StopIteration as e:
res.statuscode = 500
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
except Exception as e:
res.status_code = 500
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
raise
# No intent detected
else:
# no intent matched
intent_response.metadata = {
"intent_latency": str(round(intent_latency * 1000, 3)),
}
final_response = intent_response
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
if not use_agent_orchestrator:
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
final_response.metadata["hallucination"] = str(
model_handler.hallucination_state.hallucination
)
except ValueError as e:
res.statuscode = 503
error_messages = f"[{handler_name}] - Error in tool call extraction: {e}"
raise
except StopIteration as e:
res.statuscode = 500
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
raise
except Exception as e:
res.status_code = 500
error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}"
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
raise
if error_messages is not None:
@ -144,7 +135,7 @@ async def function_calling(req: ChatMessage, res: Response):
@app.post("/guardrails")
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
logger.info("[Endpoint: /guardrails] - Gateway")
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
final_response: GuardResponse = None
error_messages = None