Fold function_resolver into model_server (#103)

This commit is contained in:
Adil Hafeez 2024-10-01 09:13:50 -07:00 committed by GitHub
parent b0ce5eca93
commit f4395d39f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 31 additions and 197 deletions

View file

@ -5,12 +5,11 @@
"version": "0.2.0",
"configurations": [
{
"name": "embedding server",
"cwd": "${workspaceFolder}/app",
"name": "model server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"args": ["main:app","--reload", "--port", "8000"],
"args": ["app.main:app","--reload", "--port", "8000"],
}
]
}

View file

@ -31,7 +31,7 @@ ENV NER_MODELS="urchade/gliner_large-v2.1"
COPY --from=builder /runtime /usr/local
COPY /app /app
COPY ./ /app
WORKDIR /app
RUN apt-get update && apt-get install -y \
@ -45,4 +45,4 @@ RUN apt-get update && apt-get install -y \
# RUN python install.py && \
# find /root/.cache/torch/sentence_transformers/ -name onnx -exec rm -rf {} +
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]

View file

@ -0,0 +1,60 @@
import json
import random
from fastapi import FastAPI, Response
from app.arch_fc.arch_handler import ArchHandler
from app.arch_fc.bolt_handler import BoltHandler
from app.arch_fc.common import ChatMessage
import logging
from openai import OpenAI
import os
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
logger = logging.getLogger('uvicorn.error')
handler = None
if ollama_model.startswith("Arch"):
handler = ArchHandler()
else:
handler = BoltHandler()
logger.info(f"using model: {ollama_model}")
logger.info(f"using ollama endpoint: {ollama_endpoint}")
# app = FastAPI()
client = OpenAI(
base_url='http://{}:11434/v1/'.format(ollama_endpoint),
# required but ignored
api_key='ollama',
)
async def chat_completion(req: ChatMessage, res: Response):
logger.info("starting request")
tools_encoded = handler._format_system(req.tools)
# append system prompt with tools to messages
messages = [{"role": "system", "content": tools_encoded}]
for message in req.messages:
messages.append({"role": message.role, "content": message.content})
logger.info(f"request model: {ollama_model}, messages: {json.dumps(messages)}")
resp = client.chat.completions.create(messages=messages, model=ollama_model, stream=False)
tools = handler.extract_tools(resp.choices[0].message.content)
tool_calls = []
for tool in tools:
for tool_name, tool_args in tool.items():
tool_calls.append({
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_name,
"arguments": tool_args
}
})
if tools:
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
logger.info(f"response (tools): {json.dumps(tools)}")
logger.info(f"response: {json.dumps(resp.to_dict())}")
return resp

View file

@ -0,0 +1,124 @@
import json
from typing import Any, Dict, List
ARCH_FUNCTION_CALLING_TASK_PROMPT = """
You are a helpful assistant.
""".strip()
ARCH_FUNCTION_CALLING_TOOL_PROMPT = """
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
""".strip()
ARCH_FUNCTION_CALLING_FORMAT_PROMPT = """
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
""".strip()
class ArchHandler:
def __init__(self) -> None:
super().__init__()
def _format_system(self, tools: List[Dict[str, Any]]):
def convert_tools(tools):
return "\n".join([json.dumps(tool) for tool in tools])
tool_text = convert_tools(tools)
system_prompt = (
ARCH_FUNCTION_CALLING_TASK_PROMPT
+ "\n\n"
+ ARCH_FUNCTION_CALLING_TOOL_PROMPT.format(tool_text=tool_text)
+ "\n\n"
+ ARCH_FUNCTION_CALLING_FORMAT_PROMPT
)
return system_prompt
def _add_execution_results_prompting(
self,
messages: list[dict],
execution_results: list,
) -> dict:
content = []
for result in execution_results:
content.append(f"<tool_response>\n{json.dumps(result)}\n</tool_response>")
content = "\n".join(content)
messages.append({"role": "user", "content": content})
return messages
def extract_tools(self, result: str):
lines = result.split("\n")
flag = False
func_call = []
for line in lines:
if "<tool_call>" == line:
flag = True
elif "</tool_call>" == line:
flag = False
else:
if flag:
try:
tool_result = json.loads(line)
except Exception:
fixed_content = self.fix_json_string(line)
try:
tool_result = json.loads(fixed_content)
except json.JSONDecodeError:
return result
func_call.append({tool_result["name"]: tool_result["arguments"]})
flag = False
return func_call
def fix_json_string(self, json_str: str):
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
# Stack to keep track of brackets
stack = []
# Clean string to collect valid characters
fixed_str = ""
# Dictionary for matching brackets
matching_bracket = {")": "(", "}": "{", "]": "["}
# Dictionary for the opposite of matching_bracket
opening_bracket = {v: k for k, v in matching_bracket.items()}
for char in json_str:
if char in "{[(":
stack.append(char)
fixed_str += char
elif char in "}])":
if stack and stack[-1] == matching_bracket[char]:
stack.pop()
fixed_str += char
else:
# Ignore the unmatched closing brackets
continue
else:
fixed_str += char
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
while stack:
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str

View file

@ -0,0 +1,226 @@
import json
from typing import Any, Dict, List
SYSTEM_PROMPT = """
[BEGIN OF TASK INSTRUCTION]
You are a function calling assistant with access to the following tools. You task is to assist users as best as you can.
For each user query, you may need to call one or more functions to to better generate responses.
If none of the functions are relevant, you should point it out.
If the given query lacks the parameters required by the function, you should ask users for clarification.
The users may execute functions and return results as `Observation` to you. In the case, you MUST generate responses by summarizing it.
[END OF TASK INSTRUCTION]
""".strip()
TOOL_PROMPT = """
[BEGIN OF AVAILABLE TOOLS]
{tool_text}
[END OF AVAILABLE TOOLS]
""".strip()
FORMAT_PROMPT = """
[BEGIN OF FORMAT INSTRUCTION]
You MUST use the following JSON format if using tools.
The example format is as follows. DO NOT use this format if no function call is needed.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
""".strip()
class BoltHandler:
def _format_system(self, tools: List[Dict[str, Any]]):
tool_text = self._format_tools(tools=tools)
return (
SYSTEM_PROMPT
+ "\n\n"
+ TOOL_PROMPT.format(tool_text=tool_text)
+ "\n\n"
+ FORMAT_PROMPT
+ "\n"
)
def _format_tools(self, tools: List[Dict[str, Any]]):
TOOL_DESC = "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}"
tool_text = []
for fn in tools:
tool = fn["function"]
param_text = self.get_param_text(tool["parameters"])
tool_text.append(
TOOL_DESC.format(
name=tool["name"], desc=tool["description"], args=param_text
)
)
return "\n".join(tool_text)
def extract_tools(self, content, executable=False):
extracted_tools = []
# retrieve `tool_calls` from model responses
try:
content_json = json.loads(content)
except Exception:
fixed_content = self.fix_json_string(content)
try:
content_json = json.loads(fixed_content)
except json.JSONDecodeError:
return extracted_tools
if isinstance(content_json, list):
tool_calls = content_json
elif isinstance(content_json, dict):
tool_calls = content_json.get("tool_calls", [])
else:
tool_calls = []
if not isinstance(tool_calls, list):
return extracted_tools
# process and extract tools from `tool_calls`
for tool_call in tool_calls:
if isinstance(tool_call, dict):
try:
if not executable:
extracted_tools.append({tool_call["name"]: tool_call["arguments"]})
else:
name, arguments = (
tool_call.get("name", ""),
tool_call.get("arguments", {}),
)
for key, value in arguments.items():
if value == "False" or value == "false":
arguments[key] = False
elif value == "True" or value == "true":
arguments[key] = True
args_str = ", ".join(
[f"{key}={repr(value)}" for key, value in arguments.items()]
)
extracted_tools.append(f"{name}({args_str})")
except Exception:
continue
return extracted_tools
def get_param_text(self, parameter_dict, prefix=""):
param_text = ""
for name, param in parameter_dict["properties"].items():
param_type = param.get("type", "")
required, default, param_format, properties, enum, items = (
"",
"",
"",
"",
"",
"",
)
if name in parameter_dict.get("required", []):
required = ", required"
required_param = parameter_dict.get("required", [])
if isinstance(required_param, bool):
required = ", required" if required_param else ""
elif isinstance(required_param, list) and name in required_param:
required = ", required"
else:
required = ", optional"
default_param = param.get("default", None)
if default_param:
default = f", default: {default_param}"
format_in = param.get("format", None)
if format_in:
param_format = f", format: {format_in}"
desc = param.get("description", "")
if "properties" in param:
arg_properties = self.get_param_text(param, prefix + " ")
properties += "with the properties:\n{}".format(arg_properties)
enum_param = param.get("enum", None)
if enum_param:
enum = "should be one of [{}]".format(", ".join(enum_param))
item_param = param.get("items", None)
if item_param:
item_type = item_param.get("type", None)
if item_type:
items += "each item should be the {} type ".format(item_type)
item_properties = item_param.get("properties", None)
if item_properties:
item_properties = self.get_param_text(item_param, prefix + " ")
items += "with the properties:\n{}".format(item_properties)
illustration = ", ".join(
[x for x in [desc, properties, enum, items] if len(x)]
)
param_text += (
prefix
+ "- {name} ({param_type}{required}{param_format}{default}): {illustration}\n".format(
name=name,
param_type=param_type,
required=required,
param_format=param_format,
default=default,
illustration=illustration,
)
)
return param_text
def fix_json_string(self, json_str):
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
# Stack to keep track of brackets
stack = []
# Clean string to collect valid characters
fixed_str = ""
# Dictionary for matching brackets
matching_bracket = {")": "(", "}": "{", "]": "["}
# Dictionary for the opposite of matching_bracket
opening_bracket = {v: k for k, v in matching_bracket.items()}
for char in json_str:
if char in "{[(":
stack.append(char)
fixed_str += char
elif char in "}])":
if stack and stack[-1] == matching_bracket[char]:
stack.pop()
fixed_str += char
else:
# Ignore the unmatched closing brackets
continue
else:
fixed_str += char
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
while stack:
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str

View file

@ -0,0 +1,10 @@
from typing import Any, Dict, List
from pydantic import BaseModel
class Message(BaseModel):
role: str
content: str
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]

View file

@ -0,0 +1,14 @@
version: 1
disable_existing_loggers: False
formatters:
timestamped:
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
handlers:
console:
class: logging.StreamHandler
level: INFO
formatter: timestamped
stream: ext://sys.stdout
root:
level: INFO
handlers: [console]

View file

@ -3,8 +3,8 @@ import sentence_transformers
from gliner import GLiNER
from transformers import AutoTokenizer, pipeline
import sqlite3
from employee_data_generator import generate_employee_data
from network_data_generator import (
from app.employee_data_generator import generate_employee_data
from app.network_data_generator import (
generate_device_data,
generate_interface_stats_data,
generate_flow_data,

View file

@ -1,17 +1,20 @@
import os
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
from load_models import (
from app.load_models import (
load_ner_models,
load_transformers,
load_guard_model,
load_zero_shot_models,
)
from utils import GuardHandler, split_text_into_chunks
from app.utils import GuardHandler, split_text_into_chunks
import torch
import yaml
import string
import time
import logging
from app.arch_fc.arch_fc import chat_completion as arch_fc_chat_completion, ChatMessage
import os.path
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@ -22,8 +25,11 @@ transformers = load_transformers()
ner_models = load_ner_models()
zero_shot_models = load_zero_shot_models()
with open("/root/arch_config.yaml", "r") as file:
config = yaml.safe_load(file)
config = {}
if os.path.exists("/root/arch_config.yaml"):
with open("/root/arch_config.yaml", "r") as file:
config = yaml.safe_load(file)
with open("guard_model_config.yaml") as f:
guard_model_config = yaml.safe_load(f)
@ -231,6 +237,12 @@ async def zeroshot(req: ZeroShotRequest, res: Response):
}
@app.post("/v1/chat/completions")
async def chat_completion(req: ChatMessage, res: Response):
result = await arch_fc_chat_completion(req, res)
return result
'''
*****
Adding new functions to test the usecases - Sampreeth

View file

@ -13,3 +13,6 @@ openvino
psutil
pandas
dateparser
openai
pandas
tf-keras