diff --git a/demos/shared/chatbot_ui/run_stream.py b/demos/shared/chatbot_ui/run_stream.py index b406e147..f06a6155 100644 --- a/demos/shared/chatbot_ui/run_stream.py +++ b/demos/shared/chatbot_ui/run_stream.py @@ -88,6 +88,18 @@ def chat( yield "", conversation, history, debug_output, model_selector + # update assistant response to have correct format + # arch-fc 1.1 expects following format: + # { + # "response": "", + # } + + if not history[-1]["model"].startswith("Arch"): + assistant_response = { + "response": history[-1]["content"], + } + history[-1]["content"] = json.dumps(assistant_response) + def main(): with gr.Blocks( diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 73cf4fd7..0fd81e27 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -154,7 +154,7 @@ class ArchFunctionHandler(ArchBaseHandler): return fixed_str - def _parse_model_resonse(self, content: str) -> Dict[str, any]: + def _parse_model_response(self, content: str) -> Dict[str, any]: """ Extracts tool call information from a given string. @@ -212,7 +212,7 @@ class ArchFunctionHandler(ArchBaseHandler): response_dict["is_valid"] = False response_dict["error_message"] = f"Fail to parse model responses: {e}" - return response_dict + return content, response_dict def _convert_data_type(self, value: str, target_type: str): # TODO: Add more conversion rules as needed @@ -408,9 +408,13 @@ class ArchFunctionHandler(ArchBaseHandler): self.hallucination_state.tokens ) - logger.info(f"[arch-fc]: raw model response: {model_response}") # Extract tool calls from model response - response_dict = self._parse_model_resonse(model_response) + raw_model_response_json_fixed, response_dict = self._parse_model_response( + model_response + ) + logger.info( + f"[arch-fc]: raw model response (json fixed): {raw_model_response_json_fixed}" + ) # General model response if response_dict.get("response", ""): @@ -462,12 +466,12 @@ class ArchFunctionHandler(ArchBaseHandler): chat_completion_response = ChatCompletionResponse( choices=[Choice(message=model_message)], model=self.model_name, - metadata={"x-arch-fc-model-response": model_response}, + metadata={"x-arch-fc-model-response": raw_model_response_json_fixed}, role="assistant", ) logger.info( - f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump())}" + f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump(exclude_none=True))}" ) return chat_completion_response diff --git a/model_server/src/core/utils/model_utils.py b/model_server/src/core/utils/model_utils.py index 0ce75333..30a7865c 100644 --- a/model_server/src/core/utils/model_utils.py +++ b/model_server/src/core/utils/model_utils.py @@ -161,9 +161,10 @@ class ArchBaseHandler: # sample response below # "content": "\n{'name': 'get_stock_price', 'result': '$196.66'}\n" # msg[idx-1] contains tool call = '{"tool_calls": [{"name": "currency_exchange", "arguments": {"currency_symbol": "NZD"}}]}' - func_name = json.loads(messages[idx - 1].content)["tool_calls"][ - 0 - ].get("name", "no_name") + tool_call_msg = messages[idx - 1].content + func_name = json.loads(tool_call_msg)["tool_calls"][0].get( + "name", "no_name" + ) tool_response = { "name": func_name, "result": content, diff --git a/model_server/src/main.py b/model_server/src/main.py index e37136cd..34856498 100644 --- a/model_server/src/main.py +++ b/model_server/src/main.py @@ -71,7 +71,7 @@ async def models(): @app.post("/function_calling") async def function_calling(req: ChatMessage, res: Response): logger.info("[Endpoint: /function_calling]") - logger.info(f"[request body]: {json.dumps(req.model_dump())}") + logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}") final_response: ChatCompletionResponse = None error_messages = None @@ -115,9 +115,11 @@ async def function_calling(req: ChatMessage, res: Response): except ValueError as e: res.statuscode = 503 error_messages = f"[{handler_name}] - Error in tool call extraction: {e}" + raise except StopIteration as e: res.statuscode = 500 error_messages = f"[{handler_name}] - Error in hallucination check: {e}" + raise except Exception as e: res.status_code = 500 error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}" @@ -133,7 +135,7 @@ async def function_calling(req: ChatMessage, res: Response): @app.post("/guardrails") async def guardrails(req: GuardRequest, res: Response, max_num_words=300): logger.info("[Endpoint: /guardrails] - Gateway") - logger.info(f"[request body]: {json.dumps(req.model_dump())}") + logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}") final_response: GuardResponse = None error_messages = None