diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 847861e9..ac6c2605 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -27,16 +27,15 @@ logger = utils.get_model_server_logger() class ArchFunctionConfig: TASK_PROMPT = ( "You are a helpful assistant designed to assist with the user query by making one or more function calls if needed." - "\nToday's date: {today_date}" - "\n\nYou are provided with function signatures within XML tags:\n{tool_text}\n" - "\n\nYour task is to decide which functions are needed and collect missing parameters if necessary.\n\n" + "\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n" + "\n\nYour task is to decide which functions are needed and collect missing parameters if necessary." ) FORMAT_PROMPT = ( - "Based on your analysis, provide your response in one of the following JSON formats:" - '\n1. If no functions are needed:\n```\n{"response": "Your response text here"}\n```' - '\n2. If functions are needed but some required parameters are missing:\n```\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```' - '\n3. If functions are needed and all required parameters are available:\n```\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```' + "\n\nBased on your analysis, provide your response in one of the following JSON formats:" + '\n1. If no functions are needed:\n```json\n{"response": "Your response text here"}\n```' + '\n2. If functions are needed but some required parameters are missing:\n```json\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```' + '\n3. If functions are needed and all required parameters are available:\n```json\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```' ) GENERATION_PARAMS = { @@ -193,6 +192,11 @@ class ArchFunctionHandler(ArchBaseHandler): } try: + if content.startswith("```") and content.endswith("```"): + content = content.strip("```").strip() + if content.startswith("json"): + content = content[4:].strip() + model_response = json.loads(self._fix_json_string(content)) response_dict["response"] = model_response.get("response", "") @@ -414,7 +418,7 @@ class ArchFunctionHandler(ArchBaseHandler): for _ in self.hallucination_state: # check if the first token is if len(self.hallucination_state.tokens) > 2 and has_tool_calls is None: - content = ''.join(self.hallucination_state.tokens) + content = "".join(self.hallucination_state.tokens) if "tool_calls" in content: has_tool_calls = True else: diff --git a/model_server/src/core/utils/model_utils.py b/model_server/src/core/utils/model_utils.py index 2dfd2237..0ce75333 100644 --- a/model_server/src/core/utils/model_utils.py +++ b/model_server/src/core/utils/model_utils.py @@ -104,10 +104,10 @@ class ArchBaseHandler: """ today_date = utils.get_today_date() - tool_text = self._convert_tools(tools) + tools = self._convert_tools(tools) system_prompt = ( - self.task_prompt.format(today_date=today_date, tool_text=tool_text) + self.task_prompt.format(today_date=today_date, tools=tools) + self.format_prompt )