diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py
index 847861e9..ac6c2605 100644
--- a/model_server/src/core/function_calling.py
+++ b/model_server/src/core/function_calling.py
@@ -27,16 +27,15 @@ logger = utils.get_model_server_logger()
class ArchFunctionConfig:
TASK_PROMPT = (
"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed."
- "\nToday's date: {today_date}"
- "\n\nYou are provided with function signatures within XML tags:\n{tool_text}\n"
- "\n\nYour task is to decide which functions are needed and collect missing parameters if necessary.\n\n"
+ "\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n"
+ "\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
)
FORMAT_PROMPT = (
- "Based on your analysis, provide your response in one of the following JSON formats:"
- '\n1. If no functions are needed:\n```\n{"response": "Your response text here"}\n```'
- '\n2. If functions are needed but some required parameters are missing:\n```\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```'
- '\n3. If functions are needed and all required parameters are available:\n```\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```'
+ "\n\nBased on your analysis, provide your response in one of the following JSON formats:"
+ '\n1. If no functions are needed:\n```json\n{"response": "Your response text here"}\n```'
+ '\n2. If functions are needed but some required parameters are missing:\n```json\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```'
+ '\n3. If functions are needed and all required parameters are available:\n```json\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```'
)
GENERATION_PARAMS = {
@@ -193,6 +192,11 @@ class ArchFunctionHandler(ArchBaseHandler):
}
try:
+ if content.startswith("```") and content.endswith("```"):
+ content = content.strip("```").strip()
+ if content.startswith("json"):
+ content = content[4:].strip()
+
model_response = json.loads(self._fix_json_string(content))
response_dict["response"] = model_response.get("response", "")
@@ -414,7 +418,7 @@ class ArchFunctionHandler(ArchBaseHandler):
for _ in self.hallucination_state:
# check if the first token is
if len(self.hallucination_state.tokens) > 2 and has_tool_calls is None:
- content = ''.join(self.hallucination_state.tokens)
+ content = "".join(self.hallucination_state.tokens)
if "tool_calls" in content:
has_tool_calls = True
else:
diff --git a/model_server/src/core/utils/model_utils.py b/model_server/src/core/utils/model_utils.py
index 2dfd2237..0ce75333 100644
--- a/model_server/src/core/utils/model_utils.py
+++ b/model_server/src/core/utils/model_utils.py
@@ -104,10 +104,10 @@ class ArchBaseHandler:
"""
today_date = utils.get_today_date()
- tool_text = self._convert_tools(tools)
+ tools = self._convert_tools(tools)
system_prompt = (
- self.task_prompt.format(today_date=today_date, tool_text=tool_text)
+ self.task_prompt.format(today_date=today_date, tools=tools)
+ self.format_prompt
)