mirror of
https://github.com/katanemo/plano.git
synced 2026-04-28 18:36:34 +02:00
Integrate Arch-Function-Calling-1.5B model (#85)
* add arch support * add missing file * e2e tests * delete old files and fix response * fmt
This commit is contained in:
parent
9ea6bb0d73
commit
3511798fa8
12 changed files with 203 additions and 427 deletions
|
|
@ -33,7 +33,7 @@ class ArchHandler:
|
|||
|
||||
def _format_system(self, tools: List[Dict[str, Any]]):
|
||||
def convert_tools(tools):
|
||||
return "\n".join([json.dumps(tool) for tool in tools])
|
||||
return "\n".join([json.dumps(tool["function"]) for tool in tools])
|
||||
|
||||
tool_text = convert_tools(tools)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,225 +0,0 @@
|
|||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
[BEGIN OF TASK INSTRUCTION]
|
||||
You are a function calling assistant with access to the following tools. You task is to assist users as best as you can.
|
||||
For each user query, you may need to call one or more functions to to better generate responses.
|
||||
If none of the functions are relevant, you should point it out.
|
||||
If the given query lacks the parameters required by the function, you should ask users for clarification.
|
||||
The users may execute functions and return results as `Observation` to you. In the case, you MUST generate responses by summarizing it.
|
||||
[END OF TASK INSTRUCTION]
|
||||
""".strip()
|
||||
|
||||
TOOL_PROMPT = """
|
||||
[BEGIN OF AVAILABLE TOOLS]
|
||||
{tool_text}
|
||||
[END OF AVAILABLE TOOLS]
|
||||
""".strip()
|
||||
|
||||
FORMAT_PROMPT = """
|
||||
[BEGIN OF FORMAT INSTRUCTION]
|
||||
You MUST use the following JSON format if using tools.
|
||||
The example format is as follows. DO NOT use this format if no function call is needed.
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
[END OF FORMAT INSTRUCTION]
|
||||
""".strip()
|
||||
|
||||
|
||||
class BoltHandler:
|
||||
def _format_system(self, tools: List[Dict[str, Any]]):
|
||||
tool_text = self._format_tools(tools=tools)
|
||||
return (
|
||||
SYSTEM_PROMPT
|
||||
+ "\n\n"
|
||||
+ TOOL_PROMPT.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ FORMAT_PROMPT
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
def _format_tools(self, tools: List[Dict[str, Any]]):
|
||||
TOOL_DESC = "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}"
|
||||
|
||||
tool_text = []
|
||||
for tool in tools:
|
||||
param_text = self.get_param_text(tool.parameters)
|
||||
tool_text.append(
|
||||
TOOL_DESC.format(
|
||||
name=tool.name, desc=tool.description, args=param_text
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(tool_text)
|
||||
|
||||
def extract_tools(self, content, executable=False):
|
||||
# retrieve `tool_calls` from model responses
|
||||
try:
|
||||
content_json = json.loads(content)
|
||||
except Exception:
|
||||
fixed_content = self.fix_json_string(content)
|
||||
try:
|
||||
content_json = json.loads(fixed_content)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
if isinstance(content_json, list):
|
||||
tool_calls = content_json
|
||||
elif isinstance(content_json, dict):
|
||||
tool_calls = content_json.get("tool_calls", [])
|
||||
else:
|
||||
tool_calls = []
|
||||
|
||||
if not isinstance(tool_calls, list):
|
||||
return content
|
||||
|
||||
# process and extract tools from `tool_calls`
|
||||
extracted = []
|
||||
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
try:
|
||||
if not executable:
|
||||
extracted.append({tool_call["name"]: tool_call["arguments"]})
|
||||
else:
|
||||
name, arguments = (
|
||||
tool_call.get("name", ""),
|
||||
tool_call.get("arguments", {}),
|
||||
)
|
||||
|
||||
for key, value in arguments.items():
|
||||
if value == "False" or value == "false":
|
||||
arguments[key] = False
|
||||
elif value == "True" or value == "true":
|
||||
arguments[key] = True
|
||||
|
||||
args_str = ", ".join(
|
||||
[f"{key}={repr(value)}" for key, value in arguments.items()]
|
||||
)
|
||||
|
||||
extracted.append(f"{name}({args_str})")
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return extracted
|
||||
|
||||
def get_param_text(self, parameter_dict, prefix=""):
|
||||
param_text = ""
|
||||
|
||||
for name, param in parameter_dict["properties"].items():
|
||||
param_type = param.get("type", "")
|
||||
|
||||
required, default, param_format, properties, enum, items = (
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
if name in parameter_dict.get("required", []):
|
||||
required = ", required"
|
||||
|
||||
required_param = parameter_dict.get("required", [])
|
||||
|
||||
if isinstance(required_param, bool):
|
||||
required = ", required" if required_param else ""
|
||||
elif isinstance(required_param, list) and name in required_param:
|
||||
required = ", required"
|
||||
else:
|
||||
required = ", optional"
|
||||
|
||||
default_param = param.get("default", None)
|
||||
if default_param:
|
||||
default = f", default: {default_param}"
|
||||
|
||||
format_in = param.get("format", None)
|
||||
if format_in:
|
||||
param_format = f", format: {format_in}"
|
||||
|
||||
desc = param.get("description", "")
|
||||
|
||||
if "properties" in param:
|
||||
arg_properties = self.get_param_text(param, prefix + " ")
|
||||
properties += "with the properties:\n{}".format(arg_properties)
|
||||
|
||||
enum_param = param.get("enum", None)
|
||||
if enum_param:
|
||||
enum = "should be one of [{}]".format(", ".join(enum_param))
|
||||
|
||||
item_param = param.get("items", None)
|
||||
if item_param:
|
||||
item_type = item_param.get("type", None)
|
||||
if item_type:
|
||||
items += "each item should be the {} type ".format(item_type)
|
||||
|
||||
item_properties = item_param.get("properties", None)
|
||||
if item_properties:
|
||||
item_properties = self.get_param_text(item_param, prefix + " ")
|
||||
items += "with the properties:\n{}".format(item_properties)
|
||||
|
||||
illustration = ", ".join(
|
||||
[x for x in [desc, properties, enum, items] if len(x)]
|
||||
)
|
||||
|
||||
param_text += (
|
||||
prefix
|
||||
+ "- {name} ({param_type}{required}{param_format}{default}): {illustration}\n".format(
|
||||
name=name,
|
||||
param_type=param_type,
|
||||
required=required,
|
||||
param_format=param_format,
|
||||
default=default,
|
||||
illustration=illustration,
|
||||
)
|
||||
)
|
||||
|
||||
return param_text
|
||||
|
||||
def fix_json_string(self, json_str):
|
||||
# Remove any leading or trailing whitespace or newline characters
|
||||
json_str = json_str.strip()
|
||||
|
||||
# Stack to keep track of brackets
|
||||
stack = []
|
||||
|
||||
# Clean string to collect valid characters
|
||||
fixed_str = ""
|
||||
|
||||
# Dictionary for matching brackets
|
||||
matching_bracket = {")": "(", "}": "{", "]": "["}
|
||||
|
||||
# Dictionary for the opposite of matching_bracket
|
||||
opening_bracket = {v: k for k, v in matching_bracket.items()}
|
||||
|
||||
for char in json_str:
|
||||
if char in "{[(":
|
||||
stack.append(char)
|
||||
fixed_str += char
|
||||
elif char in "}])":
|
||||
if stack and stack[-1] == matching_bracket[char]:
|
||||
stack.pop()
|
||||
fixed_str += char
|
||||
else:
|
||||
# Ignore the unmatched closing brackets
|
||||
continue
|
||||
else:
|
||||
fixed_str += char
|
||||
|
||||
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
|
||||
while stack:
|
||||
unmatched_opening = stack.pop()
|
||||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str
|
||||
|
|
@ -1,14 +1,10 @@
|
|||
from typing import Any, Dict, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Tool(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: list[Tool]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import random
|
||||
from fastapi import FastAPI, Response
|
||||
from bolt_handler import BoltHandler
|
||||
from arch_handler import ArchHandler
|
||||
from common import ChatMessage
|
||||
import logging
|
||||
|
|
@ -8,15 +8,15 @@ from openai import OpenAI
|
|||
import os
|
||||
|
||||
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
|
||||
ollama_model = os.getenv("OLLAMA_MODEL", "Bolt-Function-Calling-1B:Q4_K_M")
|
||||
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
|
||||
logger = logging.getLogger('uvicorn.error')
|
||||
|
||||
logger.info(f"using model: {ollama_model}")
|
||||
logger.info(f"using ollama endpoint: {ollama_endpoint}")
|
||||
|
||||
app = FastAPI()
|
||||
bolt_handler = BoltHandler()
|
||||
arch_handler = ArchHandler()
|
||||
|
||||
handler = ArchHandler()
|
||||
|
||||
client = OpenAI(
|
||||
base_url='http://{}:11434/v1/'.format(ollama_endpoint),
|
||||
|
|
@ -35,10 +35,6 @@ async def healthz():
|
|||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(req: ChatMessage, res: Response):
|
||||
logger.info("starting request")
|
||||
if ollama_model.startswith("Bolt"):
|
||||
handler = bolt_handler
|
||||
else:
|
||||
handler = arch_handler
|
||||
tools_encoded = handler._format_system(req.tools)
|
||||
# append system prompt with tools to messages
|
||||
messages = [{"role": "system", "content": tools_encoded}]
|
||||
|
|
@ -46,5 +42,21 @@ async def chat_completion(req: ChatMessage, res: Response):
|
|||
messages.append({"role": message.role, "content": message.content})
|
||||
logger.info(f"request model: {ollama_model}, messages: {json.dumps(messages)}")
|
||||
resp = client.chat.completions.create(messages=messages, model=ollama_model, stream=False)
|
||||
logger.info(f"response: {resp.to_json()}")
|
||||
tools = handler.extract_tools(resp.choices[0].message.content)
|
||||
tool_calls = []
|
||||
for tool in tools:
|
||||
for tool_name, tool_args in tool.items():
|
||||
tool_calls.append({
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": tool_args
|
||||
}
|
||||
})
|
||||
if tools:
|
||||
resp.choices[0].message.tool_calls = tool_calls
|
||||
resp.choices[0].message.content = None
|
||||
logger.info(f"response (tools): {json.dumps(tools)}")
|
||||
logger.info(f"response: {json.dumps(resp.to_dict())}")
|
||||
return resp
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue