mirror of
https://github.com/katanemo/plano.git
synced 2026-04-27 09:46:28 +02:00
Fold function_resolver into model_server (#103)
This commit is contained in:
parent
b0ce5eca93
commit
f4395d39f9
24 changed files with 31 additions and 197 deletions
|
|
@ -1,124 +0,0 @@
|
|||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_TOOL_PROMPT = """
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_FORMAT_PROMPT = """
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
""".strip()
|
||||
|
||||
|
||||
class ArchHandler:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _format_system(self, tools: List[Dict[str, Any]]):
|
||||
def convert_tools(tools):
|
||||
return "\n".join([json.dumps(tool) for tool in tools])
|
||||
|
||||
tool_text = convert_tools(tools)
|
||||
|
||||
system_prompt = (
|
||||
ARCH_FUNCTION_CALLING_TASK_PROMPT
|
||||
+ "\n\n"
|
||||
+ ARCH_FUNCTION_CALLING_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ ARCH_FUNCTION_CALLING_FORMAT_PROMPT
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _add_execution_results_prompting(
|
||||
self,
|
||||
messages: list[dict],
|
||||
execution_results: list,
|
||||
) -> dict:
|
||||
|
||||
content = []
|
||||
for result in execution_results:
|
||||
content.append(f"<tool_response>\n{json.dumps(result)}\n</tool_response>")
|
||||
|
||||
content = "\n".join(content)
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
return messages
|
||||
|
||||
def extract_tools(self, result: str):
|
||||
lines = result.split("\n")
|
||||
flag = False
|
||||
func_call = []
|
||||
for line in lines:
|
||||
if "<tool_call>" == line:
|
||||
flag = True
|
||||
elif "</tool_call>" == line:
|
||||
flag = False
|
||||
else:
|
||||
if flag:
|
||||
try:
|
||||
tool_result = json.loads(line)
|
||||
except Exception:
|
||||
fixed_content = self.fix_json_string(line)
|
||||
try:
|
||||
tool_result = json.loads(fixed_content)
|
||||
except json.JSONDecodeError:
|
||||
return result
|
||||
func_call.append({tool_result["name"]: tool_result["arguments"]})
|
||||
flag = False
|
||||
return func_call
|
||||
|
||||
def fix_json_string(self, json_str: str):
|
||||
# Remove any leading or trailing whitespace or newline characters
|
||||
json_str = json_str.strip()
|
||||
|
||||
# Stack to keep track of brackets
|
||||
stack = []
|
||||
|
||||
# Clean string to collect valid characters
|
||||
fixed_str = ""
|
||||
|
||||
# Dictionary for matching brackets
|
||||
matching_bracket = {")": "(", "}": "{", "]": "["}
|
||||
|
||||
# Dictionary for the opposite of matching_bracket
|
||||
opening_bracket = {v: k for k, v in matching_bracket.items()}
|
||||
|
||||
for char in json_str:
|
||||
if char in "{[(":
|
||||
stack.append(char)
|
||||
fixed_str += char
|
||||
elif char in "}])":
|
||||
if stack and stack[-1] == matching_bracket[char]:
|
||||
stack.pop()
|
||||
fixed_str += char
|
||||
else:
|
||||
# Ignore the unmatched closing brackets
|
||||
continue
|
||||
else:
|
||||
fixed_str += char
|
||||
|
||||
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
|
||||
while stack:
|
||||
unmatched_opening = stack.pop()
|
||||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str
|
||||
Loading…
Add table
Add a link
Reference in a new issue