Update model usage

This commit is contained in:
Shuguang Chen 2025-03-28 15:10:51 -07:00
parent 8290d1969f
commit 425f9b0dd5
2 changed files with 14 additions and 10 deletions

View file

@ -27,16 +27,15 @@ logger = utils.get_model_server_logger()
class ArchFunctionConfig:
TASK_PROMPT = (
"You are a helpful assistant designed to assist with the user query by making one or more function calls if needed."
"\nToday's date: {today_date}"
"\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}\n</tools>"
"\n\nYour task is to decide which functions are needed and collect missing parameters if necessary.\n\n"
"\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{tools}\n</tools>"
"\n\nYour task is to decide which functions are needed and collect missing parameters if necessary."
)
FORMAT_PROMPT = (
"Based on your analysis, provide your response in one of the following JSON formats:"
'\n1. If no functions are needed:\n```\n{"response": "Your response text here"}\n```'
'\n2. If functions are needed but some required parameters are missing:\n```\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```'
'\n3. If functions are needed and all required parameters are available:\n```\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```'
"\n\nBased on your analysis, provide your response in one of the following JSON formats:"
'\n1. If no functions are needed:\n```json\n{"response": "Your response text here"}\n```'
'\n2. If functions are needed but some required parameters are missing:\n```json\n{"required_functions": ["func_name1", "func_name2", ...], "clarification": "Text asking for missing parameters"}\n```'
'\n3. If functions are needed and all required parameters are available:\n```json\n{"tool_calls": [{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},... (more tool calls as required)]}\n```'
)
GENERATION_PARAMS = {
@ -193,6 +192,11 @@ class ArchFunctionHandler(ArchBaseHandler):
}
try:
if content.startswith("```") and content.endswith("```"):
content = content.strip("```").strip()
if content.startswith("json"):
content = content[4:].strip()
model_response = json.loads(self._fix_json_string(content))
response_dict["response"] = model_response.get("response", "")
@ -414,7 +418,7 @@ class ArchFunctionHandler(ArchBaseHandler):
for _ in self.hallucination_state:
# check if the first token is <tool_call>
if len(self.hallucination_state.tokens) > 2 and has_tool_calls is None:
content = ''.join(self.hallucination_state.tokens)
content = "".join(self.hallucination_state.tokens)
if "tool_calls" in content:
has_tool_calls = True
else:

View file

@ -104,10 +104,10 @@ class ArchBaseHandler:
"""
today_date = utils.get_today_date()
tool_text = self._convert_tools(tools)
tools = self._convert_tools(tools)
system_prompt = (
self.task_prompt.format(today_date=today_date, tool_text=tool_text)
self.task_prompt.format(today_date=today_date, tools=tools)
+ self.format_prompt
)