mirror of
https://github.com/katanemo/plano.git
synced 2026-06-20 15:28:07 +02:00
Use intent model from archfc to pick prompt gateway (#328)
This commit is contained in:
parent
67b8fd635e
commit
ba7279becb
151 changed files with 8642 additions and 10932 deletions
2
model_server/.vscode/launch.json
vendored
2
model_server/.vscode/launch.json
vendored
|
|
@ -9,7 +9,7 @@
|
|||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": ["app.main:app","--reload", "--port", "51000"]
|
||||
"args": ["src.main:app","--reload", "--port", "51000"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ WORKDIR /src
|
|||
# specify list of models that will go into the image as a comma separated list
|
||||
# following models have been tested to work with this image
|
||||
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
|
||||
ENV MODELS="katanemo/bge-large-en-v1.5-onnx"
|
||||
ENV MODELS=""
|
||||
|
||||
COPY ./app ./app
|
||||
COPY ./app/guard_model_config.yaml .
|
||||
|
|
@ -28,4 +28,4 @@ COPY ./app/openai_params.yaml .
|
|||
# RUN python install.py && \
|
||||
# find /root/.cache/torch/sentence_transformers/ -name onnx -exec rm -rf {} +
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||
CMD ["uvicorn", "src.app.main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ RUN if command -v nvcc >/dev/null 2>&1; then \
|
|||
COPY . /src
|
||||
|
||||
# Specify list of models that will go into the image as a comma separated list
|
||||
ENV MODELS="katanemo/bge-large-en-v1.5-onnx"
|
||||
ENV MODELS=""
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
COPY /app /app
|
||||
|
|
|
|||
|
|
@ -1,178 +0,0 @@
|
|||
import importlib
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import psutil
|
||||
import tempfile
|
||||
import subprocess
|
||||
import logging
|
||||
|
||||
|
||||
def get_version():
|
||||
try:
|
||||
version = importlib.metadata.version("archgw_modelserver")
|
||||
return version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
return "version not found"
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
log = logging.getLogger("model_server.cli")
|
||||
log.setLevel(logging.INFO)
|
||||
log.info(f"model server version: {get_version()}")
|
||||
|
||||
|
||||
def run_server(port=51000):
|
||||
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
|
||||
if len(sys.argv) > 1:
|
||||
action = sys.argv[1]
|
||||
else:
|
||||
action = "start"
|
||||
|
||||
if action == "start":
|
||||
start_server(port)
|
||||
elif action == "stop":
|
||||
stop_server(port)
|
||||
elif action == "restart":
|
||||
restart_server(port)
|
||||
else:
|
||||
log.info(f"Unknown action: {action}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def start_server(port=51000):
|
||||
"""Start the Uvicorn server"""
|
||||
log.info(
|
||||
"starting model server - loading some awesomeness, this may take some time :)"
|
||||
)
|
||||
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"app.main:app",
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
f"{port}",
|
||||
],
|
||||
start_new_session=True,
|
||||
bufsize=1,
|
||||
universal_newlines=True,
|
||||
stdout=subprocess.PIPE, # Suppress standard output. There is a logger that model_server prints to
|
||||
stderr=subprocess.PIPE, # Suppress standard error. There is a logger that model_server prints to
|
||||
)
|
||||
|
||||
if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"):
|
||||
log.info(f"Model server started with PID {process.pid}")
|
||||
else:
|
||||
# Add model_server boot-up logs
|
||||
log.info("model server - didn't start in time, shutting down")
|
||||
process.terminate()
|
||||
|
||||
|
||||
def wait_for_health_check(url, timeout=300):
|
||||
"""Wait for the Uvicorn server to respond to health-check requests."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.ConnectionError:
|
||||
time.sleep(1)
|
||||
print("Timed out waiting for model server to respond.")
|
||||
return False
|
||||
|
||||
|
||||
def check_and_install_lsof():
|
||||
"""Check if lsof is installed, and if not, install it using apt-get."""
|
||||
try:
|
||||
# Check if lsof is installed by running "lsof -v"
|
||||
subprocess.run(
|
||||
["lsof", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
print("lsof is already installed.")
|
||||
except subprocess.CalledProcessError:
|
||||
print("lsof not found, installing...")
|
||||
try:
|
||||
# Update package list and install lsof
|
||||
subprocess.run(["sudo", "apt-get", "update"], check=True)
|
||||
subprocess.run(["sudo", "apt-get", "install", "-y", "lsof"], check=True)
|
||||
print("lsof installed successfully.")
|
||||
except subprocess.CalledProcessError as install_error:
|
||||
print(f"Failed to install lsof: {install_error}")
|
||||
|
||||
|
||||
def kill_process(port=51000, wait=True, timeout=10):
|
||||
"""Stop the running Uvicorn server."""
|
||||
log.info("Stopping model server")
|
||||
try:
|
||||
# Run the function to check and install lsof if necessary
|
||||
# Step 1: Run lsof command to get the process using the port
|
||||
lsof_command = f"lsof -n | grep {port} | grep -i LISTEN"
|
||||
result = subprocess.run(
|
||||
lsof_command, shell=True, capture_output=True, text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"No process found listening on port {port}.")
|
||||
return
|
||||
|
||||
# Step 2: Parse the process IDs from the output
|
||||
process_ids = [line.split()[1] for line in result.stdout.splitlines()]
|
||||
|
||||
if not process_ids:
|
||||
print(f"No process found listening on port {port}.")
|
||||
return
|
||||
|
||||
# Step 3: Kill each process using its PID
|
||||
for pid in process_ids:
|
||||
print(f"Killing model server process with PID {pid}")
|
||||
subprocess.run(f"kill {pid}", shell=True)
|
||||
|
||||
if wait:
|
||||
# Step 4: Wait for the process to be killed by checking if it's still running
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
check_process = subprocess.run(
|
||||
f"ps -p {pid}", shell=True, capture_output=True, text=True
|
||||
)
|
||||
if check_process.returncode != 0:
|
||||
print(f"Process {pid} has been killed.")
|
||||
break
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > timeout:
|
||||
print(
|
||||
f"Process {pid} did not terminate within {timeout} seconds."
|
||||
)
|
||||
print(f"Attempting to force kill process {pid}...")
|
||||
subprocess.run(f"kill -9 {pid}", shell=True) # SIGKILL
|
||||
break
|
||||
|
||||
print(
|
||||
f"Waiting for process {pid} to be killed... ({elapsed_time:.2f} seconds)"
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def stop_server(port=51000, wait=True, timeout=10):
|
||||
check_and_install_lsof()
|
||||
kill_process(port, wait, timeout)
|
||||
|
||||
|
||||
def restart_server(port=51000):
|
||||
"""Restart the Uvicorn server."""
|
||||
stop_server(port)
|
||||
start_server(port)
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
import app.commons.globals as glb
|
||||
import app.commons.utilities as utils
|
||||
import app.loader as loader
|
||||
|
||||
from app.function_calling.model_handler import ArchFunctionHandler
|
||||
from app.prompt_guard.model_handler import ArchGuardHanlder
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
arch_function_hanlder = ArchFunctionHandler()
|
||||
PREFILL_LIST = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"]
|
||||
PREFILL_ENABLED = True
|
||||
TOOL_CALL_TOKEN = "<tool_call>"
|
||||
arch_function_endpoint = "https://api.fc.archgw.com/v1"
|
||||
arch_function_client = utils.get_client(arch_function_endpoint)
|
||||
arch_function_generation_params = {
|
||||
"temperature": 0.2,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
"max_tokens": 512,
|
||||
"stop_token_ids": [151645],
|
||||
# "top_logprobs": 10,
|
||||
}
|
||||
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
# Model definition
|
||||
embedding_model = loader.get_embedding_model()
|
||||
zero_shot_model = loader.get_zero_shot_model()
|
||||
|
||||
prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
|
||||
# Patterns for function name and parameter parsing
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
import app.commons.utilities as utils
|
||||
|
||||
|
||||
DEVICE = utils.get_device()
|
||||
|
|
@ -1,83 +0,0 @@
|
|||
import os
|
||||
import yaml
|
||||
import torch
|
||||
import string
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
logger_instance = None
|
||||
|
||||
|
||||
def get_device():
|
||||
available_device = {
|
||||
"cpu": True,
|
||||
"cuda": torch.cuda.is_available(),
|
||||
"mps": (
|
||||
torch.backends.mps.is_available()
|
||||
if hasattr(torch.backends, "mps")
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
||||
if available_device["cuda"]:
|
||||
device = "cuda"
|
||||
elif available_device["mps"]:
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_client(endpoint):
|
||||
client = OpenAI(base_url=endpoint, api_key="EMPTY")
|
||||
return client
|
||||
|
||||
|
||||
def get_model_server_logger():
|
||||
global logger_instance
|
||||
|
||||
if logger_instance is not None:
|
||||
# If the logger is already initialized, return the existing instance
|
||||
return logger_instance
|
||||
|
||||
# Define log file path outside current directory (e.g., ~/archgw_logs)
|
||||
log_dir = os.path.expanduser("~/archgw_logs")
|
||||
log_file = "modelserver.log"
|
||||
log_file_path = os.path.join(log_dir, log_file)
|
||||
|
||||
# Ensure the log directory exists, create it if necessary, handle permissions errors
|
||||
try:
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True) # Create directory if it doesn't exist
|
||||
|
||||
# Check if the script has write permission in the log directory
|
||||
if not os.access(log_dir, os.W_OK):
|
||||
raise PermissionError(f"No write permission for the directory: {log_dir}")
|
||||
# Configure logging to file and console using basicConfig
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in file
|
||||
],
|
||||
)
|
||||
except (PermissionError, OSError):
|
||||
# Dont' fallback to console logging if there are issues writing to the log file
|
||||
raise RuntimeError(f"No write permission for the directory: {log_dir}")
|
||||
|
||||
# Initialize the logger instance after configuring handlers
|
||||
logger_instance = logging.getLogger("model_server_logger")
|
||||
return logger_instance
|
||||
|
||||
|
||||
def remove_punctuations(s):
|
||||
s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation)))
|
||||
return " ".join(s.split()).lower()
|
||||
|
||||
|
||||
def get_label_map(labels):
|
||||
return {remove_punctuations(label): label for label in labels}
|
||||
|
|
@ -1,137 +0,0 @@
|
|||
import json
|
||||
import random
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_TOOL_PROMPT = """
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_FORMAT_PROMPT = """
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
""".strip()
|
||||
|
||||
|
||||
class ArchFunctionHandler:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _format_system(self, tools: List[Dict[str, Any]]):
|
||||
def convert_tools(tools):
|
||||
return "\n".join([json.dumps(tool) for tool in tools])
|
||||
|
||||
tool_text = convert_tools(tools)
|
||||
|
||||
system_prompt = (
|
||||
ARCH_FUNCTION_CALLING_TASK_PROMPT
|
||||
+ "\n\n"
|
||||
+ ARCH_FUNCTION_CALLING_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ ARCH_FUNCTION_CALLING_FORMAT_PROMPT
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _add_execution_results_prompting(
|
||||
self,
|
||||
messages: list[dict],
|
||||
execution_results: list,
|
||||
) -> dict:
|
||||
content = []
|
||||
for result in execution_results:
|
||||
content.append(f"<tool_response>\n{json.dumps(result)}\n</tool_response>")
|
||||
|
||||
content = "\n".join(content)
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
return messages
|
||||
|
||||
def extract_tool_calls(self, content: str):
|
||||
tool_calls = []
|
||||
|
||||
flag = False
|
||||
for line in content.split("\n"):
|
||||
if "<tool_call>" == line:
|
||||
flag = True
|
||||
elif "</tool_call>" == line:
|
||||
flag = False
|
||||
else:
|
||||
if flag:
|
||||
try:
|
||||
tool_content = json.loads(line)
|
||||
except Exception:
|
||||
fixed_content = self.fix_json_string(line)
|
||||
try:
|
||||
tool_content = json.loads(fixed_content)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_content["name"],
|
||||
"arguments": tool_content["arguments"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
flag = False
|
||||
|
||||
return tool_calls
|
||||
|
||||
def fix_json_string(self, json_str: str):
|
||||
# Remove any leading or trailing whitespace or newline characters
|
||||
json_str = json_str.strip()
|
||||
|
||||
# Stack to keep track of brackets
|
||||
stack = []
|
||||
|
||||
# Clean string to collect valid characters
|
||||
fixed_str = ""
|
||||
|
||||
# Dictionary for matching brackets
|
||||
matching_bracket = {")": "(", "}": "{", "]": "["}
|
||||
|
||||
# Dictionary for the opposite of matching_bracket
|
||||
opening_bracket = {v: k for k, v in matching_bracket.items()}
|
||||
|
||||
for char in json_str:
|
||||
if char in "{[(":
|
||||
stack.append(char)
|
||||
fixed_str += char
|
||||
elif char in "}])":
|
||||
if stack and stack[-1] == matching_bracket[char]:
|
||||
stack.pop()
|
||||
fixed_str += char
|
||||
else:
|
||||
# Ignore the unmatched closing brackets
|
||||
continue
|
||||
else:
|
||||
fixed_str += char
|
||||
|
||||
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
|
||||
while stack:
|
||||
unmatched_opening = stack.pop()
|
||||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str.replace("'", '"')
|
||||
|
|
@ -1,157 +0,0 @@
|
|||
import json
|
||||
import hashlib
|
||||
import app.commons.constants as const
|
||||
import random
|
||||
from fastapi import Response
|
||||
from pydantic import BaseModel
|
||||
from app.commons.utilities import get_model_server_logger
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
logger = get_model_server_logger()
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Optional[str] = ""
|
||||
content: Optional[str] = ""
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = []
|
||||
tool_call_id: Optional[str] = ""
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
message: Message
|
||||
finish_reason: Optional[str] = "stop"
|
||||
index: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
choices: List[Choice]
|
||||
model: Optional[str] = "Arch-Function"
|
||||
created: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
object: Optional[str] = "chat_completion"
|
||||
|
||||
|
||||
def process_messages(history: list[Message]):
|
||||
updated_history = []
|
||||
for hist in history:
|
||||
if hist.tool_calls:
|
||||
if len(hist.tool_calls) > 1:
|
||||
error_msg = f"Only one tool call is supported, tools counts: {len(hist.tool_calls)}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
tool_call_str = json.dumps(hist.tool_calls[0]["function"])
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
|
||||
}
|
||||
)
|
||||
elif hist.role == "tool":
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"<tool_response>\n{hist.content}\n</tool_response>",
|
||||
}
|
||||
)
|
||||
else:
|
||||
updated_history.append({"role": hist.role, "content": hist.content})
|
||||
return updated_history
|
||||
|
||||
|
||||
async def chat_completion(req: ChatMessage, res: Response):
|
||||
logger.info("starting request")
|
||||
|
||||
tools_encoded = const.arch_function_hanlder._format_system(req.tools)
|
||||
|
||||
messages = [{"role": "system", "content": tools_encoded}]
|
||||
|
||||
updated_history = process_messages(req.messages)
|
||||
for message in updated_history:
|
||||
messages.append({"role": message["role"], "content": message["content"]})
|
||||
|
||||
client_model_name = const.arch_function_client.models.list().data[0].id
|
||||
|
||||
logger.info(
|
||||
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
|
||||
)
|
||||
|
||||
# Retrieve the first token, handling the Stream object carefully
|
||||
|
||||
try:
|
||||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=const.PREFILL_ENABLED,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"model_server <= arch_function: error: {e}")
|
||||
raise
|
||||
|
||||
if const.PREFILL_ENABLED:
|
||||
first_token_content = ""
|
||||
for token in resp:
|
||||
first_token_content = token.choices[
|
||||
0
|
||||
].delta.content.strip() # Clean up the content
|
||||
if first_token_content: # Break if it's non-empty
|
||||
break
|
||||
|
||||
# Check if the first token requires tool call handling
|
||||
if first_token_content != const.TOOL_CALL_TOKEN:
|
||||
# Engage pre-filling response if no tool call is indicated
|
||||
resp.close()
|
||||
logger.info("Tool call is not found! Engage pre filling")
|
||||
prefill_content = random.choice(const.PREFILL_LIST)
|
||||
messages.append({"role": "assistant", "content": prefill_content})
|
||||
|
||||
# Send a new completion request with the updated messages
|
||||
# the model will continue the final message in the chat instead of starting a new one
|
||||
# disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response.
|
||||
extra_body = {
|
||||
**const.arch_function_generation_params,
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
}
|
||||
pre_fill_resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
full_response = pre_fill_resp.choices[0].message.content
|
||||
else:
|
||||
# Initialize full response and iterate over tokens to gather the full response
|
||||
full_response = first_token_content
|
||||
for token in resp:
|
||||
if hasattr(token.choices[0].delta, "content"):
|
||||
full_response += token.choices[0].delta.content
|
||||
else:
|
||||
logger.info("Stream is disabled, not engaging pre-filling")
|
||||
full_response = resp.choices[0].message.content
|
||||
|
||||
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
|
||||
|
||||
if tool_calls:
|
||||
message = Message(content="", tool_calls=tool_calls)
|
||||
else:
|
||||
message = Message(content=full_response, tool_calls=[])
|
||||
choice = Choice(message=message)
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[choice], model=client_model_name
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
|
||||
)
|
||||
logger.info(
|
||||
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
import os
|
||||
import app.commons.globals as glb
|
||||
|
||||
from transformers import AutoTokenizer, AutoModel, pipeline
|
||||
from optimum.onnxruntime import (
|
||||
ORTModelForFeatureExtraction,
|
||||
ORTModelForSequenceClassification,
|
||||
)
|
||||
import app.commons.utilities as utils
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5"),
|
||||
):
|
||||
logger.info("Loading Embedding Model...")
|
||||
|
||||
if glb.DEVICE != "cuda":
|
||||
model = ORTModelForFeatureExtraction.from_pretrained(
|
||||
model_name, file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_name, device_map=glb.DEVICE)
|
||||
|
||||
embedding_model = {
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model,
|
||||
}
|
||||
|
||||
return embedding_model
|
||||
|
||||
|
||||
def get_zero_shot_model(
|
||||
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
|
||||
):
|
||||
logger.info("Loading Zero-shot Model...")
|
||||
|
||||
if glb.DEVICE != "cuda":
|
||||
model = ORTModelForSequenceClassification.from_pretrained(
|
||||
model_name, file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
model = model_name
|
||||
|
||||
zero_shot_model = {
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name),
|
||||
"model": model,
|
||||
}
|
||||
|
||||
zero_shot_model["pipeline"] = pipeline(
|
||||
"zero-shot-classification",
|
||||
model=zero_shot_model["model"],
|
||||
tokenizer=zero_shot_model["tokenizer"],
|
||||
device=glb.DEVICE,
|
||||
)
|
||||
|
||||
return zero_shot_model
|
||||
|
||||
|
||||
def get_prompt_guard(model_name):
|
||||
logger.info("Loading Guard Model...")
|
||||
|
||||
if glb.DEVICE == "cpu":
|
||||
model_class = OVModelForSequenceClassification
|
||||
else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
|
||||
prompt_guard = {
|
||||
"device": glb.DEVICE,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
return prompt_guard
|
||||
|
|
@ -1,261 +0,0 @@
|
|||
import os
|
||||
import time
|
||||
import torch
|
||||
import app.commons.utilities as utils
|
||||
import app.commons.globals as glb
|
||||
import app.prompt_guard.model_utils as guard_utils
|
||||
|
||||
from typing import List, Dict
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, Response, HTTPException, Request
|
||||
from app.function_calling.model_utils import ChatMessage
|
||||
|
||||
from app.commons.constants import embedding_model, zero_shot_model, arch_guard_handler
|
||||
from app.function_calling.model_utils import (
|
||||
chat_completion as arch_function_chat_completion,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
"service.name": "model-server",
|
||||
}
|
||||
)
|
||||
|
||||
# Initialize the tracer provider
|
||||
trace.set_tracer_provider(TracerProvider(resource=resource))
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
logger.info(f"Ready to serve traffic. available device: {glb.DEVICE}")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
FastAPIInstrumentor().instrument_app(app)
|
||||
|
||||
# DEFAULT_OTLP_HOST = "http://localhost:4317"
|
||||
DEFAULT_OTLP_HOST = "none"
|
||||
|
||||
# Configure the OTLP exporter (Jaeger, Zipkin, etc.)
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=os.getenv("OTLP_HOST", DEFAULT_OTLP_HOST) # noqa: F821
|
||||
)
|
||||
|
||||
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: str
|
||||
model: str
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
input: str
|
||||
task: str
|
||||
|
||||
|
||||
class ZeroShotRequest(BaseModel):
|
||||
input: str
|
||||
labels: List[str]
|
||||
model: str
|
||||
|
||||
|
||||
class HallucinationRequest(BaseModel):
|
||||
prompt: str
|
||||
parameters: Dict
|
||||
model: str
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [{"id": embedding_model["model_name"], "object": "model"}],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/embeddings")
|
||||
async def embedding(req: EmbeddingRequest, res: Response):
|
||||
logger.info(f"Embedding req: {req}")
|
||||
|
||||
if req.model != embedding_model["model_name"]:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
encoded_input = embedding_model["tokenizer"](
|
||||
req.input, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(glb.DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
embeddings = embedding_model["model"](**encoded_input)
|
||||
embeddings = embeddings[0][:, 0]
|
||||
embeddings = (
|
||||
torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().numpy()
|
||||
)
|
||||
|
||||
logger.info(f"Embedding Call Complete Time: {time.perf_counter()-start_time}")
|
||||
|
||||
data = [
|
||||
{"object": "embedding", "embedding": embedding, "index": index + 1}
|
||||
for index, embedding in enumerate(embeddings.tolist())
|
||||
]
|
||||
|
||||
usage = {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
|
||||
return {"data": data, "model": req.model, "object": "list", "usage": usage}
|
||||
|
||||
|
||||
@app.post("/guard")
|
||||
async def guard(req: GuardRequest, res: Response, max_num_words=300):
|
||||
"""
|
||||
Take input as text and return the prediction of toxic and jailbreak
|
||||
"""
|
||||
|
||||
if req.task in ["both", "toxic", "jailbreak"]:
|
||||
arch_guard_handler.task = req.task
|
||||
else:
|
||||
raise NotImplementedError(f"{req.task} is not supported!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if len(req.input.split()) < max_num_words:
|
||||
guard_result = arch_guard_handler.guard_predict(req.input)
|
||||
else:
|
||||
# text is long, split into chunks
|
||||
chunks = guard_utils.split_text_into_chunks(req.input)
|
||||
|
||||
guard_result = {
|
||||
"jailbreak_prob": [],
|
||||
"time": 0,
|
||||
"jailbreak_verdict": False,
|
||||
"toxic_sentence": [],
|
||||
"jailbreak_sentence": [],
|
||||
}
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_result = arch_guard_handler.guard_predict(chunk)
|
||||
guard_result["time"] += chunk_result["time"]
|
||||
if chunk_result[f"{arch_guard_handler.task}_verdict"]:
|
||||
guard_result[f"{arch_guard_handler.task}_verdict"] = True
|
||||
guard_result[f"{arch_guard_handler.task}_sentence"].append(
|
||||
chunk_result[f"{arch_guard_handler.task}_sentence"]
|
||||
)
|
||||
guard_result[f"{arch_guard_handler.task}_prob"].append(
|
||||
chunk_result[f"{arch_guard_handler.task}_prob"].item()
|
||||
)
|
||||
|
||||
logger.info(f"Time taken for Guard: {time.perf_counter() - start_time}")
|
||||
|
||||
return guard_result
|
||||
|
||||
|
||||
@app.post("/zeroshot")
|
||||
async def zeroshot(req: ZeroShotRequest, res: Response):
|
||||
logger.info(f"zero-shot request: {req}")
|
||||
|
||||
if req.model != zero_shot_model["model_name"]:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
classifier = zero_shot_model["pipeline"]
|
||||
|
||||
label_map = utils.get_label_map(req.labels)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
predictions = classifier(
|
||||
req.input, candidate_labels=list(label_map.keys()), multi_label=True
|
||||
)
|
||||
|
||||
logger.info(f"zero-shot taking {time.perf_counter() - start_time} seconds")
|
||||
|
||||
predicted_class = label_map[predictions["labels"][0]]
|
||||
predicted_score = predictions["scores"][0]
|
||||
|
||||
scores = {
|
||||
label_map[label]: score
|
||||
for label, score in zip(predictions["labels"], predictions["scores"])
|
||||
}
|
||||
|
||||
predicted_class = label_map[predictions["labels"][0]]
|
||||
|
||||
return {
|
||||
"predicted_class": predicted_class,
|
||||
"predicted_class_score": predicted_score,
|
||||
"scores": scores,
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/hallucination")
|
||||
@patch("app.loader.glb.DEVICE", "cpu") # Mock the device to 'cpu'
|
||||
async def hallucination(req: HallucinationRequest, res: Response):
|
||||
"""
|
||||
Take input as text and return the prediction of hallucination for each parameter
|
||||
"""
|
||||
logger.info(f"hallucination request: {req}")
|
||||
if req.model != zero_shot_model["model_name"]:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
classifier = zero_shot_model["pipeline"]
|
||||
|
||||
if "messages" in req.parameters:
|
||||
req.parameters.pop("messages")
|
||||
|
||||
if not req.parameters or len(req.parameters) == 0:
|
||||
return {
|
||||
"params_scores": {},
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
|
||||
|
||||
predictions = classifier(
|
||||
req.prompt,
|
||||
candidate_labels=list(candidate_labels.keys()),
|
||||
hypothesis_template="{}",
|
||||
multi_label=True,
|
||||
)
|
||||
|
||||
params_scores = {
|
||||
candidate_labels[label]: score
|
||||
for label, score in zip(predictions["labels"], predictions["scores"])
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"hallucination time cost: {params_scores}, taking {time.perf_counter() - start_time} seconds"
|
||||
)
|
||||
|
||||
return {
|
||||
"params_scores": params_scores,
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(req: ChatMessage, res: Response, request: Request):
|
||||
try:
|
||||
result = await arch_function_chat_completion(req, res)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_completion: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": "Internal server error"}
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
import time
|
||||
import torch
|
||||
import app.prompt_guard.model_utils as model_utils
|
||||
|
||||
|
||||
class ArchGuardHanlder:
|
||||
def __init__(self, model_dict, threshold=0.5):
|
||||
self.task = "jailbreak"
|
||||
self.positive_class = 2
|
||||
|
||||
self.model = model_dict["model"]
|
||||
self.tokenizer = model_dict["tokenizer"]
|
||||
self.device = model_dict["device"]
|
||||
|
||||
self.threshold = threshold
|
||||
|
||||
def guard_predict(self, input_text, max_length=512):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
inputs = self.tokenizer(
|
||||
input_text, truncation=True, max_length=max_length, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
|
||||
prob = model_utils.softmax(logits)[self.positive_class]
|
||||
|
||||
if prob > self.threshold:
|
||||
verdict = True
|
||||
sentence = input_text
|
||||
else:
|
||||
verdict = False
|
||||
sentence = None
|
||||
|
||||
result_dict = {
|
||||
f"{self.task}_prob": prob.item(),
|
||||
f"{self.task}_verdict": verdict,
|
||||
f"{self.task}_sentence": sentence,
|
||||
"time": time.perf_counter() - start_time,
|
||||
}
|
||||
|
||||
return result_dict
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def split_text_into_chunks(text, max_words=300):
|
||||
"""
|
||||
Max number of tokens for tokenizer is 512
|
||||
Split the text into chunks of 300 words (as approximation for tokens)
|
||||
"""
|
||||
words = text.split() # Split text into words
|
||||
# Estimate token count based on word count (1 word ≈ 1 token)
|
||||
chunk_size = max_words # Use the word count as an approximation for tokens
|
||||
chunks = [
|
||||
" ".join(words[i : i + chunk_size]) for i in range(0, len(words), chunk_size)
|
||||
]
|
||||
return chunks
|
||||
|
||||
|
||||
def softmax(x):
|
||||
return np.exp(x) / np.exp(x).sum(axis=0)
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
import pytest
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
from app.main import app # Assuming your FastAPI app is in main.py
|
||||
from unittest.mock import patch
|
||||
import app.commons.globals as glb
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
logger.info(f"Model will be loaded on device: {glb.DEVICE}")
|
||||
|
||||
|
||||
# Unit tests for the health check endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_healthz():
|
||||
response = client.get("/healthz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
# Unit test for the models endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_models():
|
||||
response = client.get("/models")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["object"] == "list"
|
||||
assert len(response.json()["data"]) > 0
|
||||
|
||||
|
||||
# Unit test for embeddings endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_embedding():
|
||||
request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"}
|
||||
response = client.post("/embeddings", json=request_data)
|
||||
if request_data["model"] == "katanemo/bge-large-en-v1.5":
|
||||
assert response.status_code == 200
|
||||
assert response.json()["object"] == "list"
|
||||
assert "data" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the guard endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_guard():
|
||||
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
|
||||
response = client.post("/guard", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "jailbreak_verdict" in response.json()
|
||||
|
||||
|
||||
# Unit test for the zero-shot endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_zeroshot():
|
||||
request_data = {
|
||||
"input": "Test input",
|
||||
"labels": ["label1", "label2"],
|
||||
"model": "katanemo/bart-large-mnli",
|
||||
}
|
||||
response = client.post("/zeroshot", json=request_data)
|
||||
if request_data["model"] == "katanemo/bart-large-mnli":
|
||||
assert response.status_code == 200
|
||||
assert "predicted_class" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the hallucination endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_hallucination():
|
||||
request_data = {
|
||||
"prompt": "Test hallucination",
|
||||
"parameters": {"param1": "value1"},
|
||||
"model": "katanemo/bart-large-mnli",
|
||||
}
|
||||
response = client.post("/hallucination", json=request_data)
|
||||
if request_data["model"] == "katanemo/bart-large-mnli":
|
||||
assert response.status_code == 200
|
||||
assert "params_scores" in response.json()
|
||||
else:
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# Unit test for the chat completion endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
async def test_chat_completion():
|
||||
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
|
||||
request_data = {
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"model": "Arch-Function-1.5B",
|
||||
"tools": [], # Assuming tools is part of the req as per the function
|
||||
"metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed
|
||||
}
|
||||
response = await client.post("/v1/chat/completions", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "choices" in response.json()
|
||||
|
|
@ -1,794 +0,0 @@
|
|||
[{
|
||||
"case": "tool_call_halluciation",
|
||||
"tokens" : ["<tool_call>"],
|
||||
"expect": 1,
|
||||
"logprobs": [[-0.3333307206630707,
|
||||
-1.5310522317886353,
|
||||
-3.5098977088928223,
|
||||
-3.9004578590393066,
|
||||
-5.775152683258057,
|
||||
-5.814209461212158,
|
||||
-5.9574151039123535,
|
||||
-6.0094895362854,
|
||||
-6.0094895362854,
|
||||
-6.673445224761963]]
|
||||
},
|
||||
{
|
||||
"case" : "parameter_value_hallucination",
|
||||
"expect" : 0,
|
||||
"tokens" : ["<tool_call>",
|
||||
"\n",
|
||||
"{'",
|
||||
"name",
|
||||
"':",
|
||||
" '",
|
||||
"get",
|
||||
"_current",
|
||||
"_weather",
|
||||
"',",
|
||||
" '",
|
||||
"arguments",
|
||||
"':",
|
||||
" {'",
|
||||
"location",
|
||||
"':",
|
||||
" '",
|
||||
"Sea",
|
||||
",",
|
||||
" Australia",
|
||||
"',",
|
||||
" '",
|
||||
"unit",
|
||||
"':",
|
||||
" '",
|
||||
"c",
|
||||
"elsius",
|
||||
"',",
|
||||
" '",
|
||||
"days",
|
||||
"':",
|
||||
" '",
|
||||
"1",
|
||||
"'}}\n",
|
||||
"</tool_call>"],
|
||||
"logprobs": [[-0.008103232830762863,
|
||||
-5.085402488708496,
|
||||
-6.777836799621582,
|
||||
-7.558959007263184,
|
||||
-9.850253105163574,
|
||||
-10.266852378845215,
|
||||
-10.540244102478027,
|
||||
-10.722506523132324,
|
||||
-10.800618171691895,
|
||||
-10.917786598205566],
|
||||
[0.0,
|
||||
-23.25142478942871,
|
||||
-25.139137268066406,
|
||||
-26.2847843170166,
|
||||
-28.992677688598633,
|
||||
-29.070789337158203,
|
||||
-29.55248260498047,
|
||||
-29.91700553894043,
|
||||
-30.20341682434082,
|
||||
-30.307567596435547],
|
||||
[0.0,
|
||||
-21.66313934326172,
|
||||
-23.06916046142578,
|
||||
-23.32953453063965,
|
||||
-25.65988540649414,
|
||||
-25.985353469848633,
|
||||
-26.519121170043945,
|
||||
-27.07892417907715,
|
||||
-27.977216720581055,
|
||||
-28.458908081054688],
|
||||
[0.0,
|
||||
-28.094383239746094,
|
||||
-28.56305694580078,
|
||||
-29.109844207763672,
|
||||
-29.44832992553711,
|
||||
-31.79170036315918,
|
||||
-32.0,
|
||||
-32.05207443237305,
|
||||
-32.31244659423828,
|
||||
-32.364524841308594],
|
||||
[0.0,
|
||||
-30.489830017089844,
|
||||
-31.140766143798828,
|
||||
-31.81774139404297,
|
||||
-34.525634765625,
|
||||
-35.8275032043457,
|
||||
-36.504478454589844,
|
||||
-39.05614471435547,
|
||||
-40.123680114746094,
|
||||
-40.696502685546875],
|
||||
[0.0,
|
||||
-25.646865844726562,
|
||||
-26.66232681274414,
|
||||
-27.781936645507812,
|
||||
-28.979660034179688,
|
||||
-31.140764236450195,
|
||||
-31.92188835144043,
|
||||
-31.973962783813477,
|
||||
-33.04149627685547,
|
||||
-33.58828353881836],
|
||||
[0.0,
|
||||
-23.511798858642578,
|
||||
-24.136695861816406,
|
||||
-25.230268478393555,
|
||||
-25.777053833007812,
|
||||
-25.80309295654297,
|
||||
-26.45402717590332,
|
||||
-26.636289596557617,
|
||||
-26.740440368652344,
|
||||
-26.896663665771484],
|
||||
[0.0,
|
||||
-22.366153717041016,
|
||||
-24.683483123779297,
|
||||
-26.610252380371094,
|
||||
-26.610252380371094,
|
||||
-27.313264846801758,
|
||||
-27.67778778076172,
|
||||
-28.510986328125,
|
||||
-28.615135192871094,
|
||||
-29.13588523864746],
|
||||
[0.0,
|
||||
-22.52237319946289,
|
||||
-24.292919158935547,
|
||||
-24.344993591308594,
|
||||
-24.39706802368164,
|
||||
-24.73555564880371,
|
||||
-29.943042755126953,
|
||||
-29.969079971313477,
|
||||
-30.021154403686523,
|
||||
-30.0341739654541],
|
||||
[0.0,
|
||||
-30.17738151550293,
|
||||
-30.411718368530273,
|
||||
-30.88039207458496,
|
||||
-30.984540939331055,
|
||||
-31.270952224731445,
|
||||
-31.895851135253906,
|
||||
-32.46867370605469,
|
||||
-32.624900817871094,
|
||||
-33.484134674072266],
|
||||
[0.0,
|
||||
-28.146459579467773,
|
||||
-29.396255493164062,
|
||||
-30.099267959594727,
|
||||
-31.127744674682617,
|
||||
-31.179821014404297,
|
||||
-32.807159423828125,
|
||||
-33.7445068359375,
|
||||
-33.770545959472656,
|
||||
-34.069976806640625],
|
||||
[0.0,
|
||||
-26.323841094970703,
|
||||
-26.558177947998047,
|
||||
-30.515867233276367,
|
||||
-30.932466506958008,
|
||||
-31.37510108947754,
|
||||
-31.531326293945312,
|
||||
-31.70056915283203,
|
||||
-32.065093994140625,
|
||||
-32.364524841308594],
|
||||
[0.0,
|
||||
-26.922698974609375,
|
||||
-30.28152847290039,
|
||||
-31.505287170410156,
|
||||
-33.30187225341797,
|
||||
-33.73148727416992,
|
||||
-34.27827453613281,
|
||||
-34.33034896850586,
|
||||
-34.460533142089844,
|
||||
-34.720909118652344],
|
||||
[0.0,
|
||||
-21.532955169677734,
|
||||
-26.94873809814453,
|
||||
-29.109848022460938,
|
||||
-30.80228042602539,
|
||||
-31.55736541748047,
|
||||
-33.484134674072266,
|
||||
-34.681854248046875,
|
||||
-35.384864807128906,
|
||||
-35.853538513183594],
|
||||
[0.0,
|
||||
-19.502033233642578,
|
||||
-20.46541976928711,
|
||||
-24.110658645629883,
|
||||
-24.501218795776367,
|
||||
-25.256305694580078,
|
||||
-25.82912826538086,
|
||||
-25.881202697753906,
|
||||
-26.063465118408203,
|
||||
-26.063465118408203],
|
||||
[0.0,
|
||||
-24.37103271484375,
|
||||
-25.256305694580078,
|
||||
-25.933277130126953,
|
||||
-26.714401245117188,
|
||||
-28.2506103515625,
|
||||
-31.010576248168945,
|
||||
-32.07810974121094,
|
||||
-34.62977981567383,
|
||||
-35.241661071777344],
|
||||
[-1.1920922133867862e-06,
|
||||
-14.398697853088379,
|
||||
-14.424736976623535,
|
||||
-17.158666610717773,
|
||||
-17.41904067993164,
|
||||
-18.200162887573242,
|
||||
-18.434499740600586,
|
||||
-18.66883659362793,
|
||||
-19.71033477783203,
|
||||
-19.71033477783203],
|
||||
[-0.0001445904199499637,
|
||||
-8.98305892944336,
|
||||
-11.35246467590332,
|
||||
-13.1490478515625,
|
||||
-13.669795989990234,
|
||||
-14.073375701904297,
|
||||
-14.516012191772461,
|
||||
-14.555068969726562,
|
||||
-15.622602462768555,
|
||||
-15.635622024536133],
|
||||
[-0.44747352600097656,
|
||||
-1.0202960968017578,
|
||||
-8.467000961303711,
|
||||
-10.914518356323242,
|
||||
-11.25300407409668,
|
||||
-11.435266494750977,
|
||||
-12.346576690673828,
|
||||
-13.075624465942383,
|
||||
-13.12769889831543,
|
||||
-13.231849670410156],
|
||||
[-3.123767137527466,
|
||||
-1.1188862323760986,
|
||||
-1.639634370803833,
|
||||
-2.0562336444854736,
|
||||
-2.8633930683135986,
|
||||
-2.9675419330596924,
|
||||
-3.4882919788360596,
|
||||
-3.69659161567688,
|
||||
-4.217339515686035,
|
||||
-4.243376731872559],
|
||||
[-7.199982064776123e-05,
|
||||
-9.76410961151123,
|
||||
-11.144091606140137,
|
||||
-16.507802963256836,
|
||||
-17.132701873779297,
|
||||
-17.44515037536621,
|
||||
-17.9138240814209,
|
||||
-18.33042335510254,
|
||||
-18.9162654876709,
|
||||
-19.39795684814453],
|
||||
[0.0,
|
||||
-22.991050720214844,
|
||||
-23.824249267578125,
|
||||
-24.969894409179688,
|
||||
-25.46460723876953,
|
||||
-25.829130172729492,
|
||||
-26.480066299438477,
|
||||
-26.909683227539062,
|
||||
-27.33930206298828,
|
||||
-27.391376495361328],
|
||||
[-0.21928852796554565,
|
||||
-1.625309705734253,
|
||||
-9.775025367736816,
|
||||
-12.977627754211426,
|
||||
-16.388530731201172,
|
||||
-17.091541290283203,
|
||||
-19.044347763061523,
|
||||
-19.38283348083496,
|
||||
-19.460947036743164,
|
||||
-19.59113311767578],
|
||||
[0.0,
|
||||
-24.006507873535156,
|
||||
-27.443450927734375,
|
||||
-27.729862213134766,
|
||||
-28.12042236328125,
|
||||
-28.276647567749023,
|
||||
-28.927583694458008,
|
||||
-30.099267959594727,
|
||||
-31.479251861572266,
|
||||
-32.07810974121094],
|
||||
[0.0,
|
||||
-18.17412567138672,
|
||||
-18.772987365722656,
|
||||
-21.689178466796875,
|
||||
-21.92351531982422,
|
||||
-23.7200984954834,
|
||||
-23.79821014404297,
|
||||
-23.79821014404297,
|
||||
-24.032546997070312,
|
||||
-25.308382034301758],
|
||||
[-0.12947827577590942,
|
||||
-2.1083219051361084,
|
||||
-12.419143676757812,
|
||||
-15.23118782043457,
|
||||
-15.595710754394531,
|
||||
-15.830047607421875,
|
||||
-17.001731872558594,
|
||||
-17.60059356689453,
|
||||
-18.121341705322266,
|
||||
-18.251529693603516],
|
||||
[0.0,
|
||||
-19.449962615966797,
|
||||
-24.371034622192383,
|
||||
-24.917821884155273,
|
||||
-25.529701232910156,
|
||||
-25.85516929626465,
|
||||
-26.037429809570312,
|
||||
-26.115543365478516,
|
||||
-26.623271942138672,
|
||||
-26.649309158325195],
|
||||
[-0.03332124650478363,
|
||||
-3.4181859493255615,
|
||||
-15.759925842285156,
|
||||
-15.812002182006836,
|
||||
-16.593124389648438,
|
||||
-17.894996643066406,
|
||||
-18.09027671813965,
|
||||
-18.79328727722168,
|
||||
-19.144792556762695,
|
||||
-20.147233963012695],
|
||||
[0.0,
|
||||
-21.142393112182617,
|
||||
-22.157852172851562,
|
||||
-23.511798858642578,
|
||||
-24.657445907592773,
|
||||
-25.021968841552734,
|
||||
-25.5427188873291,
|
||||
-25.59479331970215,
|
||||
-25.75101661682129,
|
||||
-25.95931625366211],
|
||||
[0.0,
|
||||
-23.04312515258789,
|
||||
-24.94385528564453,
|
||||
-26.323841094970703,
|
||||
-27.54759979248047,
|
||||
-28.563060760498047,
|
||||
-29.786819458007812,
|
||||
-30.620018005371094,
|
||||
-30.69812774658203,
|
||||
-31.08869171142578],
|
||||
[0.0,
|
||||
-26.167617797851562,
|
||||
-28.771360397338867,
|
||||
-29.55248260498047,
|
||||
-30.906429290771484,
|
||||
-31.114728927612305,
|
||||
-31.414159774780273,
|
||||
-31.622459411621094,
|
||||
-31.713590621948242,
|
||||
-31.726608276367188],
|
||||
[-0.05012698099017143,
|
||||
-3.018392562866211,
|
||||
-11.740934371948242,
|
||||
-13.146955490112305,
|
||||
-13.797887802124023,
|
||||
-14.943536758422852,
|
||||
-16.037107467651367,
|
||||
-16.375595092773438,
|
||||
-16.714080810546875,
|
||||
-17.36501693725586],
|
||||
[-0.9704352021217346,
|
||||
-0.7360983490943909,
|
||||
-2.1941938400268555,
|
||||
-4.225115776062012,
|
||||
-5.0062360763549805,
|
||||
-5.2666120529174805,
|
||||
-5.839434623718262,
|
||||
-7.2714948654174805,
|
||||
-8.33902645111084,
|
||||
-8.495253562927246],
|
||||
[-0.014467108063399792,
|
||||
-4.258565902709961,
|
||||
-8.789079666137695,
|
||||
-10.429437637329102,
|
||||
-10.793962478637695,
|
||||
-11.835458755493164,
|
||||
-11.939607620239258,
|
||||
-13.31959342956543,
|
||||
-13.866378784179688,
|
||||
-15.038063049316406],
|
||||
[0.0,
|
||||
-20.08787727355957,
|
||||
-21.350692749023438,
|
||||
-21.415786743164062,
|
||||
-21.50691795349121,
|
||||
-21.50691795349121,
|
||||
-22.7176570892334,
|
||||
-24.13669776916504,
|
||||
-24.188772201538086,
|
||||
-24.34499740600586]]
|
||||
},
|
||||
{
|
||||
"case": "fail_case",
|
||||
"expect" : 0,
|
||||
"tokens" : ["<tool_call>",
|
||||
"\n",
|
||||
"{'",
|
||||
"name",
|
||||
"':",
|
||||
" '",
|
||||
"get",
|
||||
"_current",
|
||||
"_weather",
|
||||
"',",
|
||||
" '",
|
||||
"arguments",
|
||||
"':",
|
||||
" {'",
|
||||
"location",
|
||||
"':",
|
||||
" '",
|
||||
"Seattle",
|
||||
",",
|
||||
" WA",
|
||||
"',",
|
||||
" '",
|
||||
"unit",
|
||||
"':",
|
||||
" '",
|
||||
"c",
|
||||
"elsius",
|
||||
"',",
|
||||
" '",
|
||||
"days",
|
||||
"':",
|
||||
" '",
|
||||
"7",
|
||||
"'}}\n",
|
||||
"</tool_call>"],
|
||||
"logprobs":[[-0.00013815402053296566,
|
||||
-9.113236427307129,
|
||||
-10.571331977844238,
|
||||
-14.099404335021973,
|
||||
-14.28166675567627,
|
||||
-15.583537101745605,
|
||||
-15.81787395477295,
|
||||
-16.143341064453125,
|
||||
-16.143341064453125,
|
||||
-16.260509490966797],
|
||||
[0.0,
|
||||
-26.896663665771484,
|
||||
-27.32628059387207,
|
||||
-27.41741180419922,
|
||||
-32.07810974121094,
|
||||
-32.07810974121094,
|
||||
-32.28641128540039,
|
||||
-32.29943084716797,
|
||||
-32.44263458251953,
|
||||
-32.520748138427734],
|
||||
[0.0,
|
||||
-22.444263458251953,
|
||||
-24.527257919311523,
|
||||
-27.15703773498535,
|
||||
-28.016273498535156,
|
||||
-28.2506103515625,
|
||||
-28.693246841430664,
|
||||
-29.070789337158203,
|
||||
-29.565500259399414,
|
||||
-29.812854766845703],
|
||||
[0.0,
|
||||
-27.860050201416016,
|
||||
-28.641170501708984,
|
||||
-29.448333740234375,
|
||||
-30.932466506958008,
|
||||
-31.63547706604004,
|
||||
-32.33848571777344,
|
||||
-32.85923767089844,
|
||||
-33.17168426513672,
|
||||
-33.45809555053711],
|
||||
[0.0,
|
||||
-31.81774139404297,
|
||||
-31.895854949951172,
|
||||
-32.05207824707031,
|
||||
-35.43694305419922,
|
||||
-36.3482551574707,
|
||||
-38.61351013183594,
|
||||
-39.26444625854492,
|
||||
-40.61839294433594,
|
||||
-41.71196365356445],
|
||||
[0.0,
|
||||
-27.33930206298828,
|
||||
-27.834014892578125,
|
||||
-28.849472045898438,
|
||||
-30.567943572998047,
|
||||
-32.98942565917969,
|
||||
-33.067535400390625,
|
||||
-33.067535400390625,
|
||||
-35.67127990722656,
|
||||
-35.69731903076172],
|
||||
[0.0,
|
||||
-25.33441925048828,
|
||||
-26.063465118408203,
|
||||
-26.219690322875977,
|
||||
-26.2457275390625,
|
||||
-26.53213882446289,
|
||||
-27.365337371826172,
|
||||
-28.354759216308594,
|
||||
-28.667207717895508,
|
||||
-28.74532127380371],
|
||||
[0.0,
|
||||
-24.423107147216797,
|
||||
-24.579330444335938,
|
||||
-26.81855010986328,
|
||||
-28.12042236328125,
|
||||
-28.32872200012207,
|
||||
-28.61513328552246,
|
||||
-29.16191864013672,
|
||||
-29.187957763671875,
|
||||
-29.240032196044922],
|
||||
[0.0,
|
||||
-22.027664184570312,
|
||||
-23.850284576416016,
|
||||
-23.980472564697266,
|
||||
-24.292922973632812,
|
||||
-24.787633895874023,
|
||||
-29.279088973999023,
|
||||
-29.55248260498047,
|
||||
-29.903987884521484,
|
||||
-30.190399169921875],
|
||||
[0.0,
|
||||
-31.609439849853516,
|
||||
-31.817739486694336,
|
||||
-32.54678726196289,
|
||||
-32.676971435546875,
|
||||
-32.781124114990234,
|
||||
-32.98942565917969,
|
||||
-33.106590270996094,
|
||||
-33.57526397705078,
|
||||
-34.369407653808594],
|
||||
[0.0,
|
||||
-29.34418296813965,
|
||||
-29.63059425354004,
|
||||
-30.021156311035156,
|
||||
-30.984540939331055,
|
||||
-33.21073913574219,
|
||||
-34.30431365966797,
|
||||
-34.56468963623047,
|
||||
-34.70789337158203,
|
||||
-34.79902648925781],
|
||||
[0.0,
|
||||
-25.438566207885742,
|
||||
-25.69894027709961,
|
||||
-30.190397262573242,
|
||||
-30.802276611328125,
|
||||
-31.58340072631836,
|
||||
-31.609437942504883,
|
||||
-31.64849281311035,
|
||||
-31.973960876464844,
|
||||
-32.29943084716797],
|
||||
[0.0,
|
||||
-27.157039642333984,
|
||||
-32.104148864746094,
|
||||
-32.33848571777344,
|
||||
-34.04393768310547,
|
||||
-34.12205505371094,
|
||||
-34.40846252441406,
|
||||
-34.42148208618164,
|
||||
-34.772987365722656,
|
||||
-34.87713623046875],
|
||||
[0.0,
|
||||
-24.813671112060547,
|
||||
-26.974777221679688,
|
||||
-31.010578155517578,
|
||||
-31.08869171142578,
|
||||
-32.1822624206543,
|
||||
-35.33279037475586,
|
||||
-35.489013671875,
|
||||
-36.999183654785156,
|
||||
-37.88446044921875],
|
||||
[0.0,
|
||||
-20.46541976928711,
|
||||
-20.647682189941406,
|
||||
-23.069164276123047,
|
||||
-24.136699676513672,
|
||||
-25.438570022583008,
|
||||
-25.646869659423828,
|
||||
-26.193655014038086,
|
||||
-26.297805786132812,
|
||||
-26.506103515625],
|
||||
[0.0,
|
||||
-27.18307113647461,
|
||||
-28.30268096923828,
|
||||
-28.56305694580078,
|
||||
-29.526439666748047,
|
||||
-32.416595458984375,
|
||||
-35.202598571777344,
|
||||
-36.426361083984375,
|
||||
-39.31651306152344,
|
||||
-39.38160705566406],
|
||||
[0.0,
|
||||
-18.7469482421875,
|
||||
-20.100894927978516,
|
||||
-21.402767181396484,
|
||||
-21.428804397583008,
|
||||
-22.20992660522461,
|
||||
-22.34011459350586,
|
||||
-22.730674743652344,
|
||||
-23.069162368774414,
|
||||
-23.980472564697266],
|
||||
[-3.576278118089249e-07,
|
||||
-15.2579345703125,
|
||||
-16.481693267822266,
|
||||
-17.991863250732422,
|
||||
-19.215621948242188,
|
||||
-20.25712013244629,
|
||||
-21.350692749023438,
|
||||
-22.314077377319336,
|
||||
-22.496337890625,
|
||||
-22.938974380493164],
|
||||
[-0.08506780862808228,
|
||||
-2.506549835205078,
|
||||
-14.848289489746094,
|
||||
-15.473188400268555,
|
||||
-16.33242416381836,
|
||||
-16.358461380004883,
|
||||
-16.566761016845703,
|
||||
-17.03543472290039,
|
||||
-17.686370849609375,
|
||||
-17.816556930541992],
|
||||
[-0.0194891095161438,
|
||||
-4.445854187011719,
|
||||
-5.591499328613281,
|
||||
-5.956024169921875,
|
||||
-6.685070037841797,
|
||||
-13.142353057861328,
|
||||
-13.558952331542969,
|
||||
-15.173273086547852,
|
||||
-15.303461074829102,
|
||||
-15.85024642944336],
|
||||
[-0.0005990855861455202,
|
||||
-7.4212646484375,
|
||||
-15.675132751464844,
|
||||
-15.72720718383789,
|
||||
-16.76870346069336,
|
||||
-16.76870346069336,
|
||||
-17.706050872802734,
|
||||
-18.669435501098633,
|
||||
-19.398483276367188,
|
||||
-19.658857345581055],
|
||||
[0.0,
|
||||
-24.110658645629883,
|
||||
-25.829130172729492,
|
||||
-26.011390686035156,
|
||||
-26.011390686035156,
|
||||
-26.532140731811523,
|
||||
-26.58421516418457,
|
||||
-27.651750564575195,
|
||||
-27.75589942932129,
|
||||
-28.055330276489258],
|
||||
[-1.1408883333206177,
|
||||
-0.38580334186553955,
|
||||
-7.494022369384766,
|
||||
-12.519245147705078,
|
||||
-14.576202392578125,
|
||||
-16.034297943115234,
|
||||
-16.945608139038086,
|
||||
-17.908992767333984,
|
||||
-18.664077758789062,
|
||||
-19.34105110168457],
|
||||
[0.0,
|
||||
-26.688365936279297,
|
||||
-29.83889389038086,
|
||||
-30.177383422851562,
|
||||
-30.64605712890625,
|
||||
-31.244916915893555,
|
||||
-31.270954132080078,
|
||||
-32.83319854736328,
|
||||
-34.655818939208984,
|
||||
-34.89015579223633],
|
||||
[0.0,
|
||||
-18.929210662841797,
|
||||
-19.16354751586914,
|
||||
-23.589908599853516,
|
||||
-24.683481216430664,
|
||||
-24.995929718017578,
|
||||
-25.516677856445312,
|
||||
-25.542715072631836,
|
||||
-25.77705192565918,
|
||||
-26.063465118408203],
|
||||
[-0.2519786059856415,
|
||||
-1.5017764568328857,
|
||||
-12.437495231628418,
|
||||
-15.457839012145996,
|
||||
-15.744250297546387,
|
||||
-16.837820053100586,
|
||||
-17.41064453125,
|
||||
-17.56686782836914,
|
||||
-17.61894416809082,
|
||||
-18.035541534423828],
|
||||
[0.0,
|
||||
-20.517494201660156,
|
||||
-24.683483123779297,
|
||||
-25.67290496826172,
|
||||
-26.58421516418457,
|
||||
-27.651750564575195,
|
||||
-27.781936645507812,
|
||||
-27.912124633789062,
|
||||
-28.09438705444336,
|
||||
-28.445892333984375],
|
||||
[-3.40932747349143e-05,
|
||||
-10.284820556640625,
|
||||
-18.252273559570312,
|
||||
-20.17904281616211,
|
||||
-21.663175582885742,
|
||||
-22.027700424194336,
|
||||
-22.288074493408203,
|
||||
-22.704673767089844,
|
||||
-23.12127113342285,
|
||||
-23.277496337890625],
|
||||
[0.0,
|
||||
-22.60049057006836,
|
||||
-25.46460723876953,
|
||||
-25.829130172729492,
|
||||
-26.063467025756836,
|
||||
-27.287227630615234,
|
||||
-27.391376495361328,
|
||||
-27.4694881439209,
|
||||
-27.67778778076172,
|
||||
-28.055330276489258],
|
||||
[0.0,
|
||||
-23.902362823486328,
|
||||
-28.823436737060547,
|
||||
-29.240036010742188,
|
||||
-29.31814956665039,
|
||||
-29.917007446289062,
|
||||
-30.021160125732422,
|
||||
-31.21887969970703,
|
||||
-32.416603088378906,
|
||||
-32.416603088378906],
|
||||
[0.0,
|
||||
-28.641170501708984,
|
||||
-31.947925567626953,
|
||||
-32.59886169433594,
|
||||
-33.848655700683594,
|
||||
-34.109031677246094,
|
||||
-34.73393249511719,
|
||||
-35.02033996582031,
|
||||
-35.02033996582031,
|
||||
-36.074859619140625],
|
||||
[-0.013183215633034706,
|
||||
-4.335395336151123,
|
||||
-19.619365692138672,
|
||||
-20.035964965820312,
|
||||
-20.244266510009766,
|
||||
-21.311800003051758,
|
||||
-21.441987991333008,
|
||||
-22.561595916748047,
|
||||
-23.108383178710938,
|
||||
-23.264606475830078],
|
||||
[-8.344646857949556e-07,
|
||||
-14.190400123596191,
|
||||
-15.9088716506958,
|
||||
-18.17412567138672,
|
||||
-18.46053695678711,
|
||||
-18.46053695678711,
|
||||
-18.512611389160156,
|
||||
-18.90317153930664,
|
||||
-19.059398651123047,
|
||||
-19.085433959960938],
|
||||
[0.0,
|
||||
-17.70545196533203,
|
||||
-18.903175354003906,
|
||||
-20.829944610595703,
|
||||
-22.574451446533203,
|
||||
-22.860862731933594,
|
||||
-23.069162368774414,
|
||||
-23.32953643798828,
|
||||
-23.694061279296875,
|
||||
-24.188772201538086],
|
||||
[0.0,
|
||||
-20.022781372070312,
|
||||
-21.038240432739258,
|
||||
-21.220502853393555,
|
||||
-22.496337890625,
|
||||
-22.769729614257812,
|
||||
-23.589908599853516,
|
||||
-23.65500259399414,
|
||||
-23.94141387939453,
|
||||
-24.266881942749023]]
|
||||
}
|
||||
]
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import subprocess
|
||||
import time
|
||||
from app.cli import kill_process
|
||||
|
||||
|
||||
class TestStopServer(unittest.TestCase):
|
||||
@patch("subprocess.run")
|
||||
def test_stop_server_no_process(self, mock_run):
|
||||
# Mock subprocess.run to simulate no process listening on the port
|
||||
mock_run.return_value.returncode = 1
|
||||
with patch("builtins.print") as mock_print:
|
||||
kill_process(port=51000)
|
||||
mock_print.assert_called_with("No process found listening on port 51000.")
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_stop_server_process_killed(self, mock_run):
|
||||
# Simulate lsof returning a process id
|
||||
mock_run.side_effect = [
|
||||
MagicMock(returncode=0, stdout="uvicorn 1234 user LISTEN\n"),
|
||||
MagicMock(returncode=0), # for killing the process
|
||||
MagicMock(returncode=1), # for checking the process after it is killed
|
||||
]
|
||||
with patch("builtins.print") as mock_print:
|
||||
kill_process(port=51000, wait=True, timeout=5)
|
||||
mock_print.assert_any_call("Killing model server process with PID 1234")
|
||||
mock_print.assert_any_call("Process 1234 has been killed.")
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_stop_server_multiple_pids(self, mock_run):
|
||||
# Simulate lsof returning multiple process ids (e.g., 1234 and 5678)
|
||||
mock_run.side_effect = [
|
||||
MagicMock(
|
||||
returncode=0,
|
||||
stdout="uvicorn 1234 user LISTEN\nuvicorn 5678 user LISTEN\n",
|
||||
), # lsof output
|
||||
MagicMock(returncode=0), # first kill command for PID 1234
|
||||
MagicMock(returncode=1), # PID 1234 is successfully terminated
|
||||
MagicMock(returncode=0), # second kill command for PID 5678
|
||||
MagicMock(returncode=1), # PID 5678 is successfully terminated
|
||||
]
|
||||
|
||||
with patch("builtins.print") as mock_print:
|
||||
kill_process(port=51000, wait=True, timeout=5)
|
||||
|
||||
# Assert that the function tried to kill both PIDs
|
||||
mock_print.assert_any_call("Killing model server process with PID 1234")
|
||||
mock_print.assert_any_call("Process 1234 has been killed.")
|
||||
mock_print.assert_any_call("Killing model server process with PID 5678")
|
||||
mock_print.assert_any_call("Process 5678 has been killed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import app.commons.constants as const
|
||||
from fastapi import Response
|
||||
from app.function_calling.model_utils import (
|
||||
process_messages,
|
||||
chat_completion,
|
||||
Message,
|
||||
ChatMessage,
|
||||
Choice,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
|
||||
|
||||
def sample_messages():
|
||||
# Ensure fields are explicitly set with valid data or empty values
|
||||
return [
|
||||
Message(role="user", content="Hello!", tool_calls=[], tool_call_id=""),
|
||||
Message(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[{"function": {"name": "sample_tool"}}],
|
||||
tool_call_id="sample_id",
|
||||
),
|
||||
Message(
|
||||
role="tool", content="Response from tool", tool_calls=[], tool_call_id=""
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def sample_request(sample_messages):
|
||||
return ChatMessage(
|
||||
messages=sample_messages,
|
||||
tools=[{"name": "sample_tool", "description": "A sample tool"}],
|
||||
)
|
||||
|
||||
|
||||
@patch("app.commons.constants.arch_function_hanlder")
|
||||
def test_process_messages(mock_hanlder):
|
||||
messages = sample_messages()
|
||||
processed = process_messages(messages)
|
||||
|
||||
assert len(processed) == 3
|
||||
assert processed[0] == {"role": "user", "content": "Hello!"}
|
||||
assert processed[1] == {
|
||||
"role": "assistant",
|
||||
"content": '<tool_call>\n{"name": "sample_tool"}\n</tool_call>',
|
||||
}
|
||||
assert processed[2] == {
|
||||
"role": "user",
|
||||
"content": "<tool_response>\nResponse from tool\n</tool_response>",
|
||||
}
|
||||
|
||||
|
||||
@patch("app.commons.constants.arch_function_client")
|
||||
@patch("app.commons.constants.arch_function_hanlder")
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion(mock_hanlder, mock_client):
|
||||
# Mock the model list return for client
|
||||
mock_client.models.list.return_value = MagicMock(
|
||||
data=[MagicMock(id="sample_model")]
|
||||
)
|
||||
request = sample_request(sample_messages())
|
||||
# Simulate stream response as list of tokens
|
||||
mock_response = AsyncMock()
|
||||
mock_response.__aiter__.return_value = [
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
|
||||
MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
|
||||
]
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
# Mock the tool formatter
|
||||
mock_hanlder._format_system.return_value = "<formatted_tools>"
|
||||
|
||||
response = Response()
|
||||
chat_response = await chat_completion(request, response)
|
||||
|
||||
assert isinstance(chat_response, ChatCompletionResponse)
|
||||
assert chat_response.choices[0].message.content is not None
|
||||
|
||||
first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
|
||||
assert first_call_args["stream"] == True
|
||||
assert "model" in first_call_args
|
||||
assert first_call_args["messages"][0]["content"] == "<formatted_tools>"
|
||||
|
||||
# Check that the arguments for the second call to 'create' include the pre-fill completion
|
||||
second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
|
||||
assert second_call_args["stream"] == False
|
||||
assert "model" in second_call_args
|
||||
assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST
|
||||
|
|
@ -1,148 +0,0 @@
|
|||
import json
|
||||
from app.function_calling.hallucination_handler import HallucinationStateHandler
|
||||
import pytest
|
||||
import os
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(__file__)
|
||||
|
||||
# Construct the full path to the JSON file
|
||||
json_file_path = os.path.join(current_dir, "test_cases.json")
|
||||
|
||||
with open(json_file_path) as f:
|
||||
test_cases = json.load(f)
|
||||
|
||||
get_weather_api = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get current weather at a location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "str",
|
||||
"description": "The location to get the weather for",
|
||||
"format": "City, State",
|
||||
},
|
||||
"unit": {
|
||||
"type": "str",
|
||||
"description": "The unit to return the weather in.",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
},
|
||||
"days": {
|
||||
"type": "str",
|
||||
"description": "the number of days for the request.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
function_description = get_weather_api["function"]
|
||||
if type(function_description) != list:
|
||||
function_description = [get_weather_api["function"]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", test_cases)
|
||||
def test_hallucination(case):
|
||||
state = HallucinationStateHandler(
|
||||
response_iterator=None, function=function_description
|
||||
)
|
||||
for token, logprob in zip(case["tokens"], case["logprobs"]):
|
||||
if token != "</tool_call>":
|
||||
state.append_and_check_token_hallucination(token, logprob)
|
||||
if state.hallucination:
|
||||
break
|
||||
assert state.hallucination == case["expect"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_hallucinate_sample", [True, False])
|
||||
def test_hallucination_prompt(is_hallucinate_sample):
|
||||
TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
|
||||
TOOL_PROMPT = """
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
|
||||
FORMAT_PROMPT = """
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
""".strip()
|
||||
|
||||
def convert_tools(tools):
|
||||
return "\n".join([json.dumps(tool) for tool in tools])
|
||||
|
||||
def format_prompt(tools):
|
||||
tool_text = convert_tools(tools)
|
||||
|
||||
return (
|
||||
TASK_PROMPT
|
||||
+ "\n\n"
|
||||
+ TOOL_PROMPT.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ FORMAT_PROMPT
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
openai_format_tools = [get_weather_api]
|
||||
|
||||
system_prompt = format_prompt(openai_format_tools)
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
|
||||
|
||||
# List models API
|
||||
model = client.models.list().data[0].id
|
||||
assert model == "Arch-Function"
|
||||
if not is_hallucinate_sample:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
# {"role": "user", "content": "can you help me check weather?"},
|
||||
{"role": "user", "content": "How is the weather in Seattle in 7 days?"},
|
||||
# {"role": "assistant", "content": "Of course!"},
|
||||
# {"role": "user", "content": "Seattle please"}
|
||||
]
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
# {"role": "user", "content": "can you help me check weather?"},
|
||||
{"role": "user", "content": "How is the weather in Seattle in days?"},
|
||||
# {"role": "assistant", "content": "Of course!"},
|
||||
# {"role": "user", "content": "Seattle please"}
|
||||
]
|
||||
|
||||
extra_body = {
|
||||
"temperature": 0.6,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
# "continue_final_message": True,
|
||||
# "add_generation_prompt": False,
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model="Arch-Function", messages=messages, extra_body=extra_body, stream=True
|
||||
)
|
||||
|
||||
hallu = HallucinationStateHandler(
|
||||
response_iterator=resp, function=function_description
|
||||
)
|
||||
|
||||
for token in hallu:
|
||||
assert len(hallu.tokens) >= 0
|
||||
assert hallu.hallucination == is_hallucinate_sample
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "cpu" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
|
||||
mock_automodel.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
mock_automodel.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.pipeline")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
|
||||
mock_pipeline.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
zero_shot_model = get_zero_shot_model()
|
||||
|
||||
# Assertions
|
||||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
mock_auto_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "cuda" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
|
||||
mock_automodel.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
mock_automodel.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.pipeline")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
|
||||
mock_pipeline.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
zero_shot_model = get_zero_shot_model()
|
||||
|
||||
# Assertions
|
||||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
mock_auto_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "mps" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
|
||||
mock_automodel.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
mock_automodel.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.pipeline")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
|
||||
mock_pipeline.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
zero_shot_model = get_zero_shot_model()
|
||||
|
||||
# Assertions
|
||||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
mock_auto_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
@ -1,66 +0,0 @@
|
|||
from typing import List
|
||||
import pytest
|
||||
import json
|
||||
from app.function_calling.model_utils import Message, process_messages
|
||||
|
||||
test_input_history = """
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather in chicago for next 5 days?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"model": "Arch-Function-1.5B",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3394",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": { "city": "Chicago", "days": 5 }
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "--",
|
||||
"tool_call_id": "call_3394"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "--",
|
||||
"model": "gpt-3.5-turbo-0125"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather in chicago for next 5 days?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_5306",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": { "city": "Chicago", "days": 5 }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
def test_update_fc_history():
|
||||
history = json.loads(test_input_history)
|
||||
message_history = []
|
||||
for h in history:
|
||||
message_history.append(Message(**h))
|
||||
|
||||
updated_history = process_messages(message_history)
|
||||
assert len(updated_history) == 6
|
||||
# ensure that tool role does not exist anymore
|
||||
assert all([h["role"] != "tool" for h in updated_history])
|
||||
2583
model_server/poetry.lock
generated
2583
model_server/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,34 +1,26 @@
|
|||
[tool.poetry]
|
||||
name = "archgw_modelserver"
|
||||
version = "0.1.6"
|
||||
version = "0.1.7"
|
||||
description = "A model server for serving models"
|
||||
authors = ["Katanemo Labs, Inc <archgw@katanemo.com>"]
|
||||
authors = ["Katanemo Labs, Inc <info@katanemo.com>"]
|
||||
license = "Apache 2.0"
|
||||
readme = "README.md"
|
||||
packages = [
|
||||
{ include = "app" }, # Include the 'app' package
|
||||
{ include = "app/function_calling" }, # Include the 'app' package
|
||||
{ include = "src" }
|
||||
]
|
||||
include = ["app/*.yaml"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.12"
|
||||
python = "^3.12"
|
||||
fastapi = "0.115.0"
|
||||
sentence-transformers = "3.1.1"
|
||||
torch = "2.4.1"
|
||||
uvicorn = "0.31.0"
|
||||
transformers = "*"
|
||||
pyyaml = "6.0.2"
|
||||
accelerate = "*"
|
||||
psutil = "6.0.0"
|
||||
optimum-intel = "*"
|
||||
openvino = "2024.4.0"
|
||||
pandas = "*"
|
||||
dateparser = "*"
|
||||
openai = "1.50.2"
|
||||
tf-keras = "*"
|
||||
onnx = "1.17.0"
|
||||
onnxruntime = "1.19.2"
|
||||
httpx = "0.27.2" # https://community.openai.com/t/typeerror-asyncclient-init-got-an-unexpected-keyword-argument-proxies/1040287
|
||||
pytest-asyncio = "*"
|
||||
pytest = "*"
|
||||
|
|
@ -36,10 +28,20 @@ opentelemetry-api = "^1.28.0"
|
|||
opentelemetry-sdk = "^1.28.0"
|
||||
opentelemetry-exporter-otlp = "^1.28.0"
|
||||
opentelemetry-instrumentation-fastapi = "^0.49b0"
|
||||
overrides = "^7.7.0"
|
||||
pytest-retry = "^1.6.3"
|
||||
pytest-httpserver = "^1.1.0"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
archgw_modelserver = "app.cli:run_server"
|
||||
archgw_modelserver = "src.cli:run_server"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
python_files = ["test*.py"]
|
||||
addopts = ["-v", "-s"]
|
||||
retries = 2
|
||||
retry_delay = 0.5
|
||||
cumulative_timing = false
|
||||
|
|
|
|||
214
model_server/src/cli.py
Normal file
214
model_server/src/cli.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
import importlib
|
||||
import logging
|
||||
from os import path
|
||||
import os
|
||||
from signal import SIGKILL
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_version():
|
||||
try:
|
||||
version = importlib.metadata.version("archgw_modelserver")
|
||||
return version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
return "version not found"
|
||||
|
||||
|
||||
def wait_for_health_check(url, timeout=300):
|
||||
"""Wait for the Uvicorn server to respond to health-check requests."""
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.ConnectionError:
|
||||
time.sleep(1)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Manage the Uvicorn server.")
|
||||
parser.add_argument(
|
||||
"action",
|
||||
choices=["start", "stop", "restart"],
|
||||
default="start",
|
||||
nargs="?",
|
||||
help="Action to perform on the server (default: start).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=51000,
|
||||
help="Port number for the server (default: 51000).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--foreground",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Run the server in the foreground (default: False).",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_pid_file():
|
||||
temp_dir = tempfile.gettempdir()
|
||||
return path.join(temp_dir, "model_server.pid")
|
||||
|
||||
|
||||
def stop_server():
|
||||
"""Stop the Uvicorn server."""
|
||||
pid_file = get_pid_file()
|
||||
if os.path.exists(pid_file):
|
||||
logger.info(f"PID file found, shutting down the server.")
|
||||
# read pid from file
|
||||
with open(pid_file, "r") as f:
|
||||
pid = int(f.read())
|
||||
logger.info(f"Killing model server {pid}")
|
||||
try:
|
||||
os.kill(pid, SIGKILL)
|
||||
except ProcessLookupError:
|
||||
logger.info(f"Process {pid} not found")
|
||||
os.remove(pid_file)
|
||||
else:
|
||||
logger.info("No PID file found, server is not running.")
|
||||
|
||||
|
||||
def restart_server(port=51000, foreground=False):
|
||||
"""Restart the Uvicorn server."""
|
||||
stop_server()
|
||||
start_server(port, foreground)
|
||||
|
||||
|
||||
def run_server():
|
||||
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
|
||||
|
||||
args = parse_args()
|
||||
action = args.action
|
||||
|
||||
if action == "start":
|
||||
start_server(args.port, args.foreground)
|
||||
elif action == "stop":
|
||||
stop_server()
|
||||
elif action == "restart":
|
||||
restart_server(args.port, args.foreground)
|
||||
else:
|
||||
logger.info(f"Unknown action: {action}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def ensure_killed(process):
|
||||
process.terminate()
|
||||
# if the process is not terminated, kill it
|
||||
now = time.time()
|
||||
# wait for 5 seconds
|
||||
while time.time() - now < 5:
|
||||
if process.poll() is not None:
|
||||
break
|
||||
time.sleep(1)
|
||||
if process.poll() is None:
|
||||
logger.info("Killing model server")
|
||||
process.kill()
|
||||
|
||||
|
||||
def start_server(port=51000, foreground=False):
|
||||
"""Start the Uvicorn server."""
|
||||
|
||||
logging.info("model server version: %s", get_version())
|
||||
|
||||
stop_server()
|
||||
|
||||
logger.info(
|
||||
"starting model server, port: %s, foreground: %s. Please wait ...",
|
||||
port,
|
||||
foreground,
|
||||
)
|
||||
|
||||
if foreground:
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"src.main:app",
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
str(port),
|
||||
],
|
||||
)
|
||||
else:
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"src.main:app",
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
str(port),
|
||||
],
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
)
|
||||
|
||||
try:
|
||||
if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"):
|
||||
logger.info(
|
||||
f"model server health check passed, port {port}, pid: {process.pid}"
|
||||
)
|
||||
else:
|
||||
logger.error("health check failed, shutting it down.")
|
||||
process.terminate()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("model server stopped by user during initialization.")
|
||||
ensure_killed(process)
|
||||
|
||||
# write process id to temp file in temp folder
|
||||
pid_file = get_pid_file()
|
||||
logger.info(f"writing pid {process.pid} to {pid_file}")
|
||||
with open(pid_file, "w") as f:
|
||||
f.write(str(process.pid))
|
||||
|
||||
if foreground:
|
||||
try:
|
||||
process.wait()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("model server stopped by user.")
|
||||
ensure_killed(process)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Start, stop, or restart the Uvicorn server based on command-line arguments.
|
||||
"""
|
||||
|
||||
args = parse_args()
|
||||
|
||||
if args.action == "start":
|
||||
start_server(args.port, args.foreground)
|
||||
elif args.action == "stop":
|
||||
stop_server()
|
||||
elif args.action == "restart":
|
||||
restart_server(args.port)
|
||||
else:
|
||||
logger.error(f"Unknown action: {args.action}")
|
||||
sys.exit(1)
|
||||
38
model_server/src/commons/globals.py
Normal file
38
model_server/src/commons/globals.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import os
|
||||
from openai import OpenAI
|
||||
from src.commons.utils import get_model_server_logger
|
||||
from src.core.guardrails import get_guardrail_handler
|
||||
from src.core.function_calling import (
|
||||
ArchIntentConfig,
|
||||
ArchIntentHandler,
|
||||
ArchFunctionConfig,
|
||||
ArchFunctionHandler,
|
||||
)
|
||||
|
||||
|
||||
# Define logger
|
||||
logger = get_model_server_logger()
|
||||
|
||||
|
||||
# Define the client
|
||||
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://api.fc.archgw.com/v1")
|
||||
ARCH_API_KEY = "EMPTY"
|
||||
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
|
||||
|
||||
# Define model names
|
||||
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
|
||||
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
|
||||
|
||||
logger.info("loading prompt guard model ...")
|
||||
arch_guard_model = get_guardrail_handler()
|
||||
|
||||
# Define model handlers
|
||||
handler_map = {
|
||||
"Arch-Intent": ArchIntentHandler(
|
||||
ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig
|
||||
),
|
||||
"Arch-Function": ArchFunctionHandler(
|
||||
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
|
||||
),
|
||||
"Arch-Guard": arch_guard_model,
|
||||
}
|
||||
87
model_server/src/commons/utils.py
Normal file
87
model_server/src/commons/utils.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
import requests
|
||||
import subprocess
|
||||
import importlib
|
||||
|
||||
|
||||
PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
# Default log directory and file
|
||||
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, ".logs")
|
||||
DEFAULT_LOG_FILE = "modelserver.log"
|
||||
|
||||
|
||||
def get_model_server_logger(log_dir=None, log_file=None):
|
||||
"""
|
||||
Get or initialize the logger instance for the model server.
|
||||
|
||||
Parameters:
|
||||
- log_dir (str): Custom directory to store the log file. Defaults to `./.logs`.
|
||||
- log_file (str): Custom log file name. Defaults to `modelserver.log`.
|
||||
|
||||
Returns:
|
||||
- logging.Logger: Configured logger instance.
|
||||
"""
|
||||
log_dir = log_dir or DEFAULT_LOG_DIR
|
||||
log_file = log_file or DEFAULT_LOG_FILE
|
||||
log_file_path = os.path.join(log_dir, log_file)
|
||||
|
||||
# Check if the logger is already configured
|
||||
logger = logging.getLogger("model_server_logger")
|
||||
if logger.hasHandlers():
|
||||
# Return existing logger instance if already configured
|
||||
return logger
|
||||
|
||||
# Ensure the log directory exists, create it if necessary
|
||||
try:
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# Check for write permissions
|
||||
if not os.access(log_dir, os.W_OK):
|
||||
raise PermissionError(f"No write permission for the directory: {log_dir}")
|
||||
except (PermissionError, OSError) as e:
|
||||
raise RuntimeError(f"Failed to initialize logger: {e}")
|
||||
|
||||
# Configure logging to file
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
# logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in the file
|
||||
logging.StreamHandler(), # Also log to console
|
||||
],
|
||||
)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = get_model_server_logger()
|
||||
|
||||
logging.info("initializing torch device ...")
|
||||
import torch
|
||||
|
||||
|
||||
def get_device():
|
||||
available_device = {
|
||||
"cpu": True,
|
||||
"cuda": torch.cuda.is_available(),
|
||||
"mps": (
|
||||
torch.backends.mps.is_available()
|
||||
if hasattr(torch.backends, "mps")
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
||||
if available_device["cuda"]:
|
||||
device = "cuda"
|
||||
elif available_device["mps"]:
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return device
|
||||
644
model_server/src/core/function_calling.py
Normal file
644
model_server/src/core/function_calling.py
Normal file
|
|
@ -0,0 +1,644 @@
|
|||
import json
|
||||
import random
|
||||
import builtins
|
||||
import textwrap
|
||||
|
||||
from openai import OpenAI
|
||||
from typing import Any, Dict, List
|
||||
from overrides import override
|
||||
from src.commons.utils import get_model_server_logger
|
||||
from src.core.model_utils import (
|
||||
Message,
|
||||
ChatMessage,
|
||||
Choice,
|
||||
ChatCompletionResponse,
|
||||
ArchBaseHandler,
|
||||
)
|
||||
from src.core.hallucination import HallucinationStateHandler
|
||||
|
||||
logger = get_model_server_logger()
|
||||
|
||||
|
||||
class ArchIntentConfig:
|
||||
TASK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
|
||||
"""
|
||||
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
|
||||
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
FORMAT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
|
||||
- First line must read 'Yes' or 'No'.
|
||||
- If yes, a second line must include a comma-separated list of tool indexes.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
EXTRA_INSTRUCTION = "Are there any tools can help?"
|
||||
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.01,
|
||||
"max_tokens": 1,
|
||||
"stop_token_ids": [151645],
|
||||
}
|
||||
|
||||
|
||||
class ArchIntentHandler(ArchBaseHandler):
|
||||
def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig):
|
||||
"""
|
||||
Initializes the intent handler.
|
||||
|
||||
Args:
|
||||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
config (ArchIntentConfig): The configuration for Arch-Intent.
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
client,
|
||||
model_name,
|
||||
config.TASK_PROMPT,
|
||||
config.TOOL_PROMPT_TEMPLATE,
|
||||
config.FORMAT_PROMPT,
|
||||
config.GENERATION_PARAMS,
|
||||
)
|
||||
|
||||
self.extra_instruction = config.EXTRA_INSTRUCTION
|
||||
self.prompt_prefilling = False
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Converts a list of tools into a JSON-like format with indexed keys.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||
|
||||
Returns:
|
||||
str: A string representation of converted tools.
|
||||
"""
|
||||
|
||||
converted = [
|
||||
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
|
||||
]
|
||||
return "\n".join(converted)
|
||||
|
||||
def detect_intent(self, content: str) -> bool:
|
||||
"""
|
||||
Detect if any intent match with prompts
|
||||
|
||||
Args:
|
||||
content: str: Model response that contains intent detection results
|
||||
|
||||
Returns:
|
||||
bool: A boolean value to indicate if any intent match with prompts or not
|
||||
"""
|
||||
if hasattr(content.choices[0].message, "content"):
|
||||
return content.choices[0].message.content == "Yes"
|
||||
else:
|
||||
return False
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
Generates a chat completion for a given request.
|
||||
|
||||
Args:
|
||||
req (ChatMessage): A chat message request object.
|
||||
|
||||
Returns:
|
||||
ChatCompletionResponse: The model's response to the chat request.
|
||||
|
||||
Note:
|
||||
Currently only support vllm inference
|
||||
"""
|
||||
|
||||
# In the case that no tools are available, simply return `No` to avoid making a call
|
||||
if len(req.tools) == 0:
|
||||
model_response = Message(content="No", tool_calls=[])
|
||||
else:
|
||||
messages = self._process_messages(
|
||||
req.messages, req.tools, self.extra_instruction
|
||||
)
|
||||
|
||||
model_response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=False,
|
||||
extra_body=self.generation_params,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"arch_intent response: %s", json.dumps(model_response.model_dump())
|
||||
)
|
||||
|
||||
model_response = Message(
|
||||
content=model_response.choices[0].message.content, tool_calls=[]
|
||||
)
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_response)], model=self.model_name
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
|
||||
# =============================================================================================================
|
||||
|
||||
|
||||
class ArchFunctionConfig:
|
||||
TASK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant.
|
||||
"""
|
||||
).strip()
|
||||
|
||||
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
|
||||
"""
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
FORMAT_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
"""
|
||||
).strip()
|
||||
|
||||
GENERATION_PARAMS = {
|
||||
"temperature": 0.6,
|
||||
"top_p": 1.0,
|
||||
"top_k": 10,
|
||||
"max_tokens": 512,
|
||||
"stop_token_ids": [151645],
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
}
|
||||
|
||||
PREFILL_CONFIG = {
|
||||
"prefill_params": {
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
},
|
||||
"prefill_prefix": [
|
||||
"May",
|
||||
"Could",
|
||||
"Sure",
|
||||
"Definitely",
|
||||
"Certainly",
|
||||
"Of course",
|
||||
"Can",
|
||||
],
|
||||
}
|
||||
|
||||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||
|
||||
|
||||
class ArchFunctionHandler(ArchBaseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI,
|
||||
model_name: str,
|
||||
config: ArchFunctionConfig,
|
||||
):
|
||||
"""
|
||||
Initializes the function handler.
|
||||
|
||||
Args:
|
||||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
config (ArchFunctionConfig): The configuration for Arch-Function
|
||||
"""
|
||||
|
||||
super().__init__(
|
||||
client,
|
||||
model_name,
|
||||
config.TASK_PROMPT,
|
||||
config.TOOL_PROMPT_TEMPLATE,
|
||||
config.FORMAT_PROMPT,
|
||||
config.GENERATION_PARAMS,
|
||||
)
|
||||
|
||||
self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
|
||||
self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
|
||||
self.prompt_prefilling = False
|
||||
|
||||
# Predefine data types for verification. Only support Python for now.
|
||||
# [TODO] Extend the list of support data types
|
||||
self.support_data_types = {
|
||||
type_name: getattr(builtins, type_name)
|
||||
for type_name in config.SUPPORT_DATA_TYPES
|
||||
}
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Converts a list of tools into JSON format.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||
|
||||
Returns:
|
||||
str: A string representation of converted tools.
|
||||
"""
|
||||
|
||||
converted = [json.dumps(tool) for tool in tools]
|
||||
return "\n".join(converted)
|
||||
|
||||
def _fix_json_string(self, json_str: str) -> str:
|
||||
"""
|
||||
Fixes malformed JSON strings by ensuring proper bracket matching.
|
||||
|
||||
Args:
|
||||
json_str (str): A JSON string that might be malformed.
|
||||
|
||||
Returns:
|
||||
str: A corrected JSON string.
|
||||
"""
|
||||
|
||||
# Remove any leading or trailing whitespace or newline characters
|
||||
json_str = json_str.strip()
|
||||
|
||||
# Stack to keep track of brackets
|
||||
stack = []
|
||||
|
||||
# Clean string to collect valid characters
|
||||
fixed_str = ""
|
||||
|
||||
# Dictionary for matching brackets
|
||||
matching_bracket = {")": "(", "}": "{", "]": "["}
|
||||
|
||||
# Dictionary for the opposite of matching_bracket
|
||||
opening_bracket = {v: k for k, v in matching_bracket.items()}
|
||||
|
||||
for char in json_str:
|
||||
if char in "{[(":
|
||||
stack.append(char)
|
||||
fixed_str += char
|
||||
elif char in "}])":
|
||||
if stack and stack[-1] == matching_bracket[char]:
|
||||
stack.pop()
|
||||
fixed_str += char
|
||||
else:
|
||||
# Ignore the unmatched closing brackets
|
||||
continue
|
||||
else:
|
||||
fixed_str += char
|
||||
|
||||
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
|
||||
while stack:
|
||||
unmatched_opening = stack.pop()
|
||||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str.replace("'", '"')
|
||||
|
||||
def _extract_tool_calls(self, content: str) -> Dict[str, any]:
|
||||
"""
|
||||
Extracts tool call information from a given string.
|
||||
|
||||
Args:
|
||||
content (str): The content string containing potential tool call information.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary of extraction, including:
|
||||
- "result": A list of tool call dictionaries.
|
||||
- "status": A boolean indicating if the extraction was valid.
|
||||
- "message": An error message or exception if extraction failed.
|
||||
"""
|
||||
|
||||
tool_calls, is_valid, error_message = [], True, ""
|
||||
|
||||
flag = False
|
||||
for line in content.split("\n"):
|
||||
if not is_valid:
|
||||
break
|
||||
|
||||
if "<tool_call>" == line:
|
||||
flag = True
|
||||
elif "</tool_call>" == line:
|
||||
flag = False
|
||||
else:
|
||||
if flag:
|
||||
try:
|
||||
tool_content = json.loads(line)
|
||||
except Exception as e:
|
||||
fixed_content = self._fix_json_string(line)
|
||||
try:
|
||||
tool_content = json.loads(fixed_content)
|
||||
except Exception:
|
||||
tool_calls, is_valid, error_message = [], False, e
|
||||
break
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_content["name"],
|
||||
"arguments": tool_content["arguments"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
flag = False
|
||||
|
||||
return {"result": tool_calls, "status": is_valid, "message": error_message}
|
||||
|
||||
def _correcting_type(self, value, target_type):
|
||||
try:
|
||||
if target_type == float and isinstance(value, int):
|
||||
return float(value)
|
||||
elif target_type == list and isinstance(value, str):
|
||||
return ast.literal_eval(value)
|
||||
elif target_type == str and not isinstance(value, str):
|
||||
return str(value)
|
||||
# Add more conversion rules as needed
|
||||
except (ValueError, TypeError, json.JSONDecodeError):
|
||||
pass
|
||||
return value
|
||||
|
||||
def _verify_tool_calls(
|
||||
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
|
||||
) -> Dict[str, any]:
|
||||
"""
|
||||
Verifies the validity of extracted tool calls against the provided tools.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of available tools.
|
||||
tool_calls (List[Dict[str, Any]]): A list of tool calls to verify.
|
||||
|
||||
Returns:
|
||||
Dict: A dictionary of verification, including:
|
||||
- "status": A boolean indicating if the tool calls are valid.
|
||||
- "invalid_tool_call": A dictionary of the invalid tool call if any.
|
||||
- "message": An error message.
|
||||
"""
|
||||
|
||||
is_valid, invalid_tool_call, error_message = True, None, ""
|
||||
|
||||
functions = {}
|
||||
for tool in tools:
|
||||
if tool["type"] == "function":
|
||||
functions[tool["function"]["name"]] = tool["function"]["parameters"]
|
||||
|
||||
for tool_call in tool_calls:
|
||||
if not is_valid:
|
||||
break
|
||||
|
||||
func_name = tool_call["function"]["name"]
|
||||
func_args = tool_call["function"]["arguments"]
|
||||
|
||||
# Check whether the function is available or not
|
||||
if func_name not in functions:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"{func_name} is not defined!"
|
||||
break
|
||||
|
||||
else:
|
||||
# Check if all the requried parameters can be found in the tool calls
|
||||
for required_param in functions[func_name].get("required", []):
|
||||
if required_param not in func_args:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
|
||||
break
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
for param_name in func_args:
|
||||
if param_name not in functions[func_name]["properties"]:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
||||
break
|
||||
else:
|
||||
param_value = func_args[param_name]
|
||||
data_type = functions[func_name]["properties"][param_name][
|
||||
"type"
|
||||
]
|
||||
|
||||
if data_type in self.support_data_types:
|
||||
if not isinstance(
|
||||
param_value,
|
||||
self.support_data_types[data_type],
|
||||
) and not isinstance(
|
||||
self._correcting_type(
|
||||
param_value, self.support_data_types[data_type]
|
||||
),
|
||||
self.support_data_types[data_type],
|
||||
):
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
|
||||
break
|
||||
|
||||
return {
|
||||
"status": is_valid,
|
||||
"invalid_tool_call": invalid_tool_call,
|
||||
"message": error_message,
|
||||
}
|
||||
|
||||
def _add_prefill_message(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
Update messages and generation params for prompt prefilling
|
||||
|
||||
Args:
|
||||
messages (List[Dict[str, str]]): A list of messages.
|
||||
|
||||
Returns:
|
||||
prefill_messages (List[Dict[str, str]]): A list of messages.
|
||||
"""
|
||||
|
||||
return messages + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": random.choice(self.prefill_prefix),
|
||||
}
|
||||
]
|
||||
|
||||
def _engage_parameter_gathering(self, messages: List[Dict[str, str]]):
|
||||
"""
|
||||
Engage parameter gathering for tool calls
|
||||
"""
|
||||
|
||||
# TODO: log enaging parameter gathering
|
||||
prefill_response = self.client.chat.completions.create(
|
||||
messages=self._add_prefill_message(messages),
|
||||
model=self.model_name,
|
||||
extra_body={
|
||||
**self.generation_params,
|
||||
**self.prefill_params,
|
||||
},
|
||||
)
|
||||
self.prompt_prefilling = True
|
||||
return prefill_response
|
||||
|
||||
def _check_length_and_pop_messages(self, messages, max_tokens=4096):
|
||||
"""
|
||||
Trims the `messages` list to ensure the total token count does not exceed `max_tokens`.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dictionaries.
|
||||
max_tokens (int): Maximum allowed token count.
|
||||
|
||||
Returns:
|
||||
list: Trimmed list of messages.
|
||||
"""
|
||||
|
||||
def estimate_token_length(messages):
|
||||
"""Estimate the total token length of the messages."""
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
# Approximate token length: assuming ~4 characters per token on average
|
||||
total_tokens += len(message["content"]) // 4
|
||||
return total_tokens
|
||||
|
||||
# Calculate initial token length
|
||||
total_tokens = estimate_token_length(messages)
|
||||
|
||||
# Trim messages if token count exceeds the limit
|
||||
while total_tokens > max_tokens and len(messages) >= 3:
|
||||
# Find the first non-system message pair
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] != "system":
|
||||
# Remove the 'user'/'assistant' pair
|
||||
if i + 1 < len(messages) and messages[i + 1]["role"] in [
|
||||
"user",
|
||||
"assistant",
|
||||
]:
|
||||
del messages[i : i + 2]
|
||||
else:
|
||||
del messages[i]
|
||||
break
|
||||
# Recalculate token length
|
||||
total_tokens = estimate_token_length(messages)
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
Generates a chat completion response for a given request.
|
||||
|
||||
Args:
|
||||
req (ChatMessage): A chat message request object.
|
||||
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
|
||||
Returns:
|
||||
ChatCompletionResponse: The model's response to the chat request.
|
||||
|
||||
Note:
|
||||
Currently only support vllm inference
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"model_server => arch_function: request body: {json.dumps(req.model_dump())}"
|
||||
)
|
||||
|
||||
messages = self._process_messages(req.messages, req.tools)
|
||||
messages = self._check_length_and_pop_messages(messages)
|
||||
|
||||
# always enable `stream=True` to collect model responses
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=True,
|
||||
extra_body=self.generation_params,
|
||||
)
|
||||
|
||||
# initialize the hallucination handler, which is an iterator
|
||||
self.hallu_handler = HallucinationStateHandler(
|
||||
response_iterator=response, function=req.tools
|
||||
)
|
||||
|
||||
model_response, self.has_tool_call = "", None
|
||||
self.hallucination = False
|
||||
for _ in self.hallu_handler:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallu_handler.tokens) > 0 and self.has_tool_call is None:
|
||||
if self.hallu_handler.tokens[0] == "<tool_call>":
|
||||
self.has_tool_call = True
|
||||
else:
|
||||
self.has_tool_call = False
|
||||
break
|
||||
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
if self.hallu_handler.hallucination is True:
|
||||
self.hallucination = True
|
||||
logger.info(
|
||||
f"{self.hallu_handler.error_message} - start parameter gathering"
|
||||
)
|
||||
logger.info(
|
||||
f"Hallucinated response : {''.join(self.hallu_handler.tokens)}"
|
||||
)
|
||||
# [TODO] - add break when hallucination is detected
|
||||
break
|
||||
if self.hallucination is True:
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
|
||||
if self.has_tool_call and self.hallucination is False:
|
||||
# [TODO] - Review: remove the following code
|
||||
|
||||
model_response = "".join(self.hallu_handler.tokens)
|
||||
logger.info(f"Tool call found, no hallucination detected {model_response}!")
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
if self.has_tool_call is False:
|
||||
# [TODO] - Review: remove the following code
|
||||
logger.info("No tool call found, start parameter gathering")
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
|
||||
# Extract tool calls from model response
|
||||
extracted = self._extract_tool_calls(model_response)
|
||||
|
||||
if len(extracted["result"]) and extracted["status"]:
|
||||
# [TODO] Review: define the behavior in the case that tool call extraction fails
|
||||
# if not extracted["status"]:
|
||||
verified = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=extracted["result"]
|
||||
)
|
||||
# [TODO] - Review: remvoe the following code
|
||||
# print(f"[Verified] - {verified}")
|
||||
|
||||
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
|
||||
if verified["status"]:
|
||||
model_response = Message(content="", tool_calls=extracted["result"])
|
||||
log_message = f"model_server <= arch_function: (tool_calls): {json.dumps([tool_call['function'] for tool_call in extracted['result']])}"
|
||||
logger.info(log_message)
|
||||
else:
|
||||
raise ValueError(f"Invalid tool call: {verified['message']}")
|
||||
else:
|
||||
model_response = Message(content=model_response, tool_calls=[])
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_response)], model=self.model_name
|
||||
)
|
||||
|
||||
# [TODO] Review: define the protocol to collect debugging output
|
||||
|
||||
logger.info(
|
||||
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.model_dump())}"
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
171
model_server/src/core/guardrails.py
Normal file
171
model_server/src/core/guardrails.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import src.commons.utils as utils
|
||||
from transformers import AutoTokenizer
|
||||
from src.core.model_utils import GuardRequest, GuardResponse
|
||||
|
||||
# from optimum.intel import OVModelForSequenceClassification
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
|
||||
class ArchGuardHanlder:
|
||||
def __init__(self, model_dict):
|
||||
"""
|
||||
Initializes the ArchGuardHanlder with the given model dictionary.
|
||||
|
||||
Args:
|
||||
model_dict (dict): A dictionary containing the model, tokenizer, and device information.
|
||||
"""
|
||||
|
||||
self.model = model_dict["model"]
|
||||
self.model_name = model_dict["model_name"]
|
||||
self.tokenizer = model_dict["tokenizer"]
|
||||
self.device = model_dict["device"]
|
||||
|
||||
self.support_tasks = {"jailbreak": {"positive_class": 2, "threshold": 0.5}}
|
||||
|
||||
def _split_text_into_chunks(self, text, max_num_words=300):
|
||||
"""
|
||||
Splits the input text into chunks of up to `max_num_words` words.
|
||||
|
||||
Args:
|
||||
text (str): The input text to be split.
|
||||
max_num_words (int, optional): The maximum number of words in each chunk. Defaults to 300.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of text chunks.
|
||||
"""
|
||||
|
||||
words = text.split()
|
||||
|
||||
chunks = [
|
||||
" ".join(words[i : i + max_num_words])
|
||||
for i in range(0, len(words), max_num_words)
|
||||
]
|
||||
|
||||
return chunks
|
||||
|
||||
@staticmethod
|
||||
def softmax(x):
|
||||
"""
|
||||
Computes the softmax of the input array.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): The input array.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The softmax of the input.
|
||||
"""
|
||||
return np.exp(x) / np.exp(x).sum(axis=0)
|
||||
|
||||
def _predict_text(self, task, text, max_length=512) -> GuardResponse:
|
||||
"""
|
||||
Predicts the result for the provided text for a specific task.
|
||||
|
||||
Args:
|
||||
task (str): The task to perform (e.g., "jailbreak").
|
||||
text (str): The input text to classify.
|
||||
max_length (int, optional): The maximum length for tokenization. Defaults to 512.
|
||||
|
||||
Returns:
|
||||
GuardResponse: A GuardResponse object containing the prediction.
|
||||
"""
|
||||
|
||||
inputs = self.tokenizer(
|
||||
text, truncation=True, max_length=max_length, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
|
||||
prob = ArchGuardHanlder.softmax(logits)[
|
||||
self.support_tasks[task]["positive_class"]
|
||||
]
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
if prob > self.support_tasks[task]["threshold"]:
|
||||
verdict = True
|
||||
sentence = text
|
||||
else:
|
||||
verdict = False
|
||||
sentence = None
|
||||
|
||||
return GuardResponse(
|
||||
prob=[prob.item()], verdict=verdict, sentence=[sentence], latency=latency
|
||||
)
|
||||
|
||||
def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse:
|
||||
"""
|
||||
Makes a prediction based on the GuardRequest input.
|
||||
|
||||
Args:
|
||||
req (GuardRequest): The GuardRequest object containing the input text and task.
|
||||
max_num_words (int, optional): The maximum number of words in each chunk if splitting is needed. Defaults to 300.
|
||||
|
||||
Returns:
|
||||
GuardResponse: A GuardResponse object containing the prediction.
|
||||
|
||||
Note:
|
||||
currently only support jailbreak check
|
||||
"""
|
||||
|
||||
if req.task not in self.support_tasks:
|
||||
raise NotImplementedError(f"{req.task} is not supported!")
|
||||
|
||||
if len(req.input.split()) < max_num_words:
|
||||
return self._predict_text(req.task, req.input)
|
||||
else:
|
||||
# split into chunks if text is long
|
||||
text_chunks = self._split_text_into_chunks(req.input)
|
||||
|
||||
prob, verdict, sentence, latency = [], False, [], 0
|
||||
|
||||
for chunk in text_chunks:
|
||||
chunk_result = self._predict_text(req.task, chunk)
|
||||
|
||||
if chunk_result.verdict:
|
||||
prob.append(chunk_result.prob[0])
|
||||
verdict = True
|
||||
sentence.append(chunk_result.sentence[0])
|
||||
latency += chunk_result.latency
|
||||
|
||||
return GuardResponse(
|
||||
prob=prob, verdict=verdict, sentence=sentence, latency=latency
|
||||
)
|
||||
|
||||
|
||||
def get_guardrail_handler(device: str = None):
|
||||
"""
|
||||
Initializes and returns an instance of ArchGuardHanlder based on the specified device.
|
||||
|
||||
Args:
|
||||
device (str, optional): The device to use for model inference (e.g., "cpu" or "cuda"). Defaults to None.
|
||||
|
||||
Returns:
|
||||
ArchGuardHanlder: An instance of ArchGuardHanlder configured for the specified device.
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = utils.get_device()
|
||||
|
||||
model_class, model_name = None, None
|
||||
# if device == "cpu":
|
||||
# model_class = OVModelForSequenceClassification
|
||||
# model_name = "katanemo/Arch-Guard-cpu"
|
||||
# else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
model_name = "katanemo/Arch-Guard"
|
||||
|
||||
guardrail_dict = {
|
||||
"device": device,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=device, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
return ArchGuardHanlder(model_dict=guardrail_dict)
|
||||
|
|
@ -1,15 +1,22 @@
|
|||
import json
|
||||
import math
|
||||
import torch
|
||||
import random
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import itertools
|
||||
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
from enum import Enum
|
||||
import string
|
||||
|
||||
from src.commons.utils import get_model_server_logger
|
||||
|
||||
logger = get_model_server_logger()
|
||||
|
||||
# constants
|
||||
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
|
||||
FUNC_NAME_END_TOKEN = ('",', "',")
|
||||
TOOL_CALL_TOKEN = "<tool_call>"
|
||||
END_TOOL_CALL_TOKEN = "</tool_call>"
|
||||
|
||||
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
|
||||
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
|
||||
|
|
@ -17,6 +24,8 @@ PARAMETER_NAME_START_PATTERN = (',"', ",'")
|
|||
PARAMETER_VALUE_START_PATTERN = ('":', "':")
|
||||
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
|
||||
|
||||
BRACKETS = {"(": ")", "{": "}", "[": "]"}
|
||||
|
||||
|
||||
# Thresholds
|
||||
class MaskToken(Enum):
|
||||
|
|
@ -28,10 +37,15 @@ class MaskToken(Enum):
|
|||
|
||||
|
||||
HALLUCINATION_THRESHOLD_DICT = {
|
||||
MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
|
||||
MaskToken.TOOL_CALL.value: {
|
||||
"entropy": 0.35,
|
||||
"varentropy": 1.7,
|
||||
"probability": 0.8,
|
||||
},
|
||||
MaskToken.PARAMETER_VALUE.value: {
|
||||
"entropy": 0.5,
|
||||
"varentropy": 2.5,
|
||||
"entropy": 0.28,
|
||||
"varentropy": 1.2,
|
||||
"probability": 0.8,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -48,10 +62,10 @@ def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
|
|||
Returns:
|
||||
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
|
||||
"""
|
||||
return entropy > thd["entropy"] or varentropy > thd["varentropy"]
|
||||
return entropy > thd["entropy"] and varentropy > thd["varentropy"]
|
||||
|
||||
|
||||
def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
|
||||
def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculate the entropy and variance of entropy (varentropy) from log probabilities.
|
||||
|
||||
|
|
@ -71,7 +85,26 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
|
|||
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
|
||||
dim=-1,
|
||||
)
|
||||
return entropy.item(), varentropy.item()
|
||||
return entropy.item(), varentropy.item(), token_probs[0].item()
|
||||
|
||||
|
||||
def is_parameter_required(
|
||||
function_description: Dict,
|
||||
parameter_name: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a parameter in required list
|
||||
|
||||
Args:
|
||||
function_description (dict): The API description in JSON format.
|
||||
parameter_name (str): The name of the parameter to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the parameter has the specified property, False otherwise.
|
||||
"""
|
||||
required_parameters = function_description.get("required", {})
|
||||
|
||||
return parameter_name in required_parameters
|
||||
|
||||
|
||||
def is_parameter_property(
|
||||
|
|
@ -107,7 +140,6 @@ class HallucinationStateHandler:
|
|||
hallucination (bool): Flag indicating if a hallucination is detected.
|
||||
hallucination_message (str): Message describing the hallucination.
|
||||
parameter_name (list): List of extracted parameter names.
|
||||
function_description (dict): Description of functions and their parameters.
|
||||
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
|
||||
"""
|
||||
|
||||
|
|
@ -122,23 +154,34 @@ class HallucinationStateHandler:
|
|||
self.parameter_name_done: bool = False
|
||||
self.hallucination: bool = False
|
||||
self.error_message: str = ""
|
||||
self.error_type: str = ""
|
||||
self.parameter_name: List[str] = []
|
||||
self.token_probs_map: List[Tuple[str, float, float]] = []
|
||||
self.response_iterator = response_iterator
|
||||
self._process_function(function)
|
||||
self.open_bracket = False
|
||||
self.bracket = None
|
||||
self.check_parameter_name = {}
|
||||
self.HALLUCINATION_THRESHOLD_DICT = HALLUCINATION_THRESHOLD_DICT
|
||||
|
||||
def _process_function(self, function):
|
||||
self.function = function
|
||||
if self.function is None:
|
||||
raise ValueError("API descriptions not set.")
|
||||
parameter_names = {}
|
||||
for func in self.function:
|
||||
func_name = func["name"]
|
||||
parameters = func["parameters"]["properties"]
|
||||
parameter_names[func_name] = list(parameters.keys())
|
||||
self.function_description = parameter_names
|
||||
self.function_properties = {x["name"]: x["parameters"] for x in self.function}
|
||||
self.function_properties = {
|
||||
x["function"]["name"]: x["function"]["parameters"] for x in self.function
|
||||
}
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""
|
||||
Resets all parameters in the HallucinationStateHandler to their default values.
|
||||
"""
|
||||
self.state = None
|
||||
self.parameter_name_done = False
|
||||
self.hallucination = False
|
||||
self.error_message = ""
|
||||
self.open_bracket = False
|
||||
self.bracket = None
|
||||
self.check_parameter_name = {}
|
||||
|
||||
def append_and_check_token_hallucination(self, token, logprob):
|
||||
"""
|
||||
|
|
@ -175,9 +218,12 @@ class HallucinationStateHandler:
|
|||
raise ValueError(
|
||||
f"Error extracting logprobs from response: {e}"
|
||||
)
|
||||
self.append_and_check_token_hallucination(
|
||||
token_content, logprobs
|
||||
)
|
||||
if token_content == END_TOOL_CALL_TOKEN:
|
||||
self._reset_parameters()
|
||||
else:
|
||||
self.append_and_check_token_hallucination(
|
||||
token_content, logprobs
|
||||
)
|
||||
return token_content
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
|
|
@ -199,7 +245,7 @@ class HallucinationStateHandler:
|
|||
self.mask.append(MaskToken.FUNCTION_NAME)
|
||||
else:
|
||||
self.state = None
|
||||
self._is_function_name_hallucinated()
|
||||
self._get_function_name()
|
||||
|
||||
# Check if the token is a function name start token, change the state
|
||||
if content.endswith(FUNC_NAME_START_PATTERN):
|
||||
|
|
@ -217,11 +263,13 @@ class HallucinationStateHandler:
|
|||
PARAMETER_NAME_END_TOKENS
|
||||
):
|
||||
self.state = None
|
||||
self._is_parameter_name_hallucinated()
|
||||
self.parameter_name_done = True
|
||||
self._get_parameter_name()
|
||||
# if the parameter name is done and the token is a parameter name start token, change the state
|
||||
elif self.parameter_name_done and content.endswith(
|
||||
PARAMETER_NAME_START_PATTERN
|
||||
elif (
|
||||
self.parameter_name_done
|
||||
and self.open_bracket == False
|
||||
and content.endswith(PARAMETER_NAME_START_PATTERN)
|
||||
):
|
||||
self.state = "parameter_name"
|
||||
|
||||
|
|
@ -235,24 +283,49 @@ class HallucinationStateHandler:
|
|||
PARAMETER_VALUE_END_TOKEN
|
||||
):
|
||||
# checking if the token is a value token and is not empty
|
||||
if self.tokens[-1].strip() not in ['"', ""]:
|
||||
open_brackets = [
|
||||
char for char in self.tokens[-1].strip() if char in BRACKETS
|
||||
]
|
||||
if open_brackets:
|
||||
self.open_bracket = True
|
||||
self.bracket = open_brackets[0]
|
||||
|
||||
if self.open_bracket and BRACKETS[self.bracket] in self.tokens[-1].strip():
|
||||
self.open_bracket = False
|
||||
self.bracket = None
|
||||
|
||||
if (
|
||||
not all(
|
||||
char in set(string.punctuation) for char in self.tokens[-1].strip()
|
||||
)
|
||||
and self.tokens[-1].strip() != ""
|
||||
):
|
||||
self.mask.append(MaskToken.PARAMETER_VALUE)
|
||||
# checking if the parameter doesn't have default and the token is the first parameter value token
|
||||
|
||||
# checking if the parameter doesn't have enum and the token is the first parameter value token
|
||||
if (
|
||||
len(self.mask) > 1
|
||||
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
||||
and is_parameter_required(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
)
|
||||
and not is_parameter_property(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
"default",
|
||||
"enum",
|
||||
)
|
||||
):
|
||||
self._check_logprob()
|
||||
if self.parameter_name[-1] not in self.check_parameter_name:
|
||||
self._check_logprob()
|
||||
self.check_parameter_name[self.parameter_name[-1]] = True
|
||||
else:
|
||||
self.mask.append(MaskToken.NOT_USED)
|
||||
# if the state is parameter value and the token is an end token, change the state
|
||||
elif self.state == "parameter_value" and content.endswith(
|
||||
PARAMETER_VALUE_END_TOKEN
|
||||
elif (
|
||||
self.state == "parameter_value"
|
||||
and self.open_bracket == False
|
||||
and content.endswith(PARAMETER_VALUE_END_TOKEN)
|
||||
):
|
||||
self.state = None
|
||||
# if the parameter name is done and the token is a parameter value start token, change the state
|
||||
|
|
@ -272,17 +345,16 @@ class HallucinationStateHandler:
|
|||
Detects hallucinations based on entropy and variance of entropy.
|
||||
"""
|
||||
probs = self.logprobs[-1]
|
||||
entropy, varentropy = calculate_entropy(probs)
|
||||
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
|
||||
entropy, varentropy, probability = calculate_uncertainty(probs)
|
||||
self.token_probs_map.append((self.tokens[-1], entropy, varentropy, probability))
|
||||
|
||||
if check_threshold(
|
||||
entropy, varentropy, HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
|
||||
entropy,
|
||||
varentropy,
|
||||
self.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value],
|
||||
):
|
||||
self.hallucination = True
|
||||
self.error_type = "Hallucination"
|
||||
self.error_message = (
|
||||
f"Hallucination: token '{self.tokens[-1]}' is uncertain."
|
||||
)
|
||||
self.error_message = f"Hallucination: token '{self.tokens[-1]}' is uncertain. {self.token_probs_map}"
|
||||
|
||||
def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int:
|
||||
"""
|
||||
|
|
@ -300,25 +372,23 @@ class HallucinationStateHandler:
|
|||
else 0
|
||||
)
|
||||
|
||||
def _is_function_name_hallucinated(self):
|
||||
def _get_parameter_name(self):
|
||||
"""
|
||||
Checks the extracted function name against the function descriptions.
|
||||
Detects hallucinations if the function name is not found.
|
||||
"""
|
||||
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
|
||||
self.function_name = "".join(self.tokens[:-1][-f_len:])
|
||||
if self.function_name not in self.function_description.keys():
|
||||
self.error_type = "function_name"
|
||||
self.error_message = f"Function name '{self.function_name}' not found in given function descriptions."
|
||||
Get the parameter name from the tokens.
|
||||
|
||||
def _is_parameter_name_hallucinated(self):
|
||||
"""
|
||||
Checks the extracted parameter name against the function descriptions.
|
||||
Detects hallucinations if the parameter name is not found.
|
||||
Returns:
|
||||
str: The extracted parameter name.
|
||||
"""
|
||||
p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
|
||||
parameter_name = "".join(self.tokens[:-1][-p_len:])
|
||||
self.parameter_name.append(parameter_name)
|
||||
if parameter_name not in self.function_description[self.function_name]:
|
||||
self.error_type = "parameter_name"
|
||||
self.error_message = f"Parameter name '{parameter_name}' not found in given function descriptions."
|
||||
|
||||
def _get_function_name(self):
|
||||
"""
|
||||
Get the function name from the tokens.
|
||||
|
||||
Returns:
|
||||
str: The extracted function name.
|
||||
"""
|
||||
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
|
||||
self.function_name = "".join(self.tokens[:-1][-f_len:])
|
||||
181
model_server/src/core/model_utils.py
Normal file
181
model_server/src/core/model_utils.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
import json
|
||||
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
from overrides import final
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Optional[str] = ""
|
||||
content: Optional[str] = ""
|
||||
tool_call_id: Optional[str] = ""
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = []
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
id: Optional[int] = 0
|
||||
message: Message
|
||||
finish_reason: Optional[str] = "stop"
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Optional[int] = 0
|
||||
object: Optional[str] = "chat_completion"
|
||||
created: Optional[str] = ""
|
||||
choices: List[Choice]
|
||||
model: str
|
||||
metadata: Optional[Dict[str, str]] = {}
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
input: str
|
||||
task: str
|
||||
|
||||
|
||||
class GuardResponse(BaseModel):
|
||||
prob: List
|
||||
verdict: bool
|
||||
sentence: List
|
||||
latency: float = 0
|
||||
|
||||
|
||||
# ================================================================================================
|
||||
|
||||
|
||||
class ArchBaseHandler:
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt_template: str,
|
||||
format_prompt: str,
|
||||
generation_params: Dict,
|
||||
):
|
||||
"""
|
||||
Initializes the base handler.
|
||||
|
||||
Args:
|
||||
client (OpenAI): An OpenAI client instance.
|
||||
model_name (str): Name of the model to use.
|
||||
task_prompt (str): The main task prompt for the system.
|
||||
tool_prompt (str): A prompt to describe tools.
|
||||
format_prompt (str): A prompt specifying the desired output format.
|
||||
generation_params (Dict): Generation parameters for the model.
|
||||
"""
|
||||
self.client = client
|
||||
self.model_name = model_name
|
||||
|
||||
self.task_prompt = task_prompt
|
||||
self.tool_prompt_template = tool_prompt_template
|
||||
self.format_prompt = format_prompt
|
||||
|
||||
self.generation_params = generation_params
|
||||
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Converts a list of tools into the desired internal representation.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Method should be overridden in subclasses.
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
@final
|
||||
def _format_system_prompt(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Formats the system prompt using provided tools.
|
||||
|
||||
Args:
|
||||
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
|
||||
|
||||
Returns:
|
||||
str: A formatted system prompt.
|
||||
"""
|
||||
|
||||
tool_text = self._convert_tools(tools)
|
||||
|
||||
system_prompt = (
|
||||
self.task_prompt
|
||||
+ "\n\n"
|
||||
+ self.tool_prompt_template.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ self.format_prompt
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
@final
|
||||
def _process_messages(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: List[Dict[str, Any]] = None,
|
||||
extra_instruction: str = None,
|
||||
):
|
||||
"""
|
||||
Processes a list of messages and formats them appropriately.
|
||||
|
||||
Args:
|
||||
messages (List[Message]): A list of message objects.
|
||||
tools (List[Dict[str, Any]], optional): A list of tools to include in the system prompt.
|
||||
extra_instruction (str, optional): Additional instructions to append to the last user message.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of processed message dictionaries.
|
||||
"""
|
||||
|
||||
processed_messages = []
|
||||
|
||||
if tools:
|
||||
processed_messages.append(
|
||||
{"role": "system", "content": self._format_system_prompt(tools)}
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
role, content, tool_calls = (
|
||||
message.role,
|
||||
message.content,
|
||||
message.tool_calls,
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
# [TODO] Extend to support multiple function calls
|
||||
role = "assistant"
|
||||
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
|
||||
elif message.role == "tool":
|
||||
role = "user"
|
||||
content = (
|
||||
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
|
||||
)
|
||||
|
||||
processed_messages.append({"role": role, "content": content})
|
||||
|
||||
assert processed_messages[-1]["role"] == "user"
|
||||
|
||||
if extra_instruction:
|
||||
processed_messages[-1]["content"] += extra_instruction
|
||||
|
||||
return processed_messages
|
||||
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
Abstract method for generating chat completions.
|
||||
|
||||
Args:
|
||||
req (ChatMessage): A chat message request object.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Method should be overridden in subclasses.
|
||||
"""
|
||||
|
||||
raise NotImplementedError()
|
||||
134
model_server/src/main.py
Normal file
134
model_server/src/main.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from src.commons.globals import handler_map
|
||||
from src.core.model_utils import ChatMessage, GuardRequest
|
||||
|
||||
from fastapi import FastAPI, Response
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
"service.name": "model-server",
|
||||
}
|
||||
)
|
||||
|
||||
# Initialize the tracer provider
|
||||
trace.set_tracer_provider(TracerProvider(resource=resource))
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
FastAPIInstrumentor().instrument_app(app)
|
||||
|
||||
# DEFAULT_OTLP_HOST = "http://localhost:4317"
|
||||
DEFAULT_OTLP_HOST = "none"
|
||||
|
||||
# Configure the OTLP exporter (Jaeger, Zipkin, etc.)
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=os.getenv("OTLP_HOST", DEFAULT_OTLP_HOST) # noqa: F821
|
||||
)
|
||||
|
||||
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [{"id": model_name, "object": "model"} for model_name in handler_map],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/function_calling")
|
||||
async def function_calling(req: ChatMessage, res: Response):
|
||||
try:
|
||||
intent_start_time = time.perf_counter()
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
intent_latency = time.perf_counter() - intent_start_time
|
||||
|
||||
if handler_map["Arch-Intent"].detect_intent(intent_response):
|
||||
# [TODO] measure agreement between intent detection and function calling
|
||||
try:
|
||||
function_start_time = time.perf_counter()
|
||||
function_calling_response = await handler_map[
|
||||
"Arch-Function"
|
||||
].chat_completion(req)
|
||||
function_latency = time.perf_counter() - function_start_time
|
||||
function_calling_response.metadata = {
|
||||
"intent_latency": str(round(intent_latency * 1000, 3)),
|
||||
"function_latency": str(round(function_latency * 1000, 3)),
|
||||
"hallucination": str(handler_map["Arch-Function"].hallucination),
|
||||
"tokens_uncertainty": json.dumps(
|
||||
handler_map["Arch-Function"].hallu_handler.token_probs_map
|
||||
),
|
||||
"prompt_prefilling": str(
|
||||
handler_map["Arch-Function"].prompt_prefilling
|
||||
),
|
||||
}
|
||||
|
||||
return function_calling_response
|
||||
except ValueError as e:
|
||||
res.statuscode = 503
|
||||
error_message = "Tool call extraction error"
|
||||
logger.error(f" {error_message}: {e}")
|
||||
return {"error": f"[Arch-Function] - {error_message} - {e}"}
|
||||
except StopIteration as e:
|
||||
res.statuscode = 500
|
||||
error_message = "Hallucination iterator error"
|
||||
logger.error(f" {error_message}: {e}")
|
||||
return {"error": f"[Arch-Function] - {error_message} - {e}"}
|
||||
except Exception as e:
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Function] - {e}"}
|
||||
# [TODO] Review: define the behavior if `Arch-Intent` doesn't detect an intent
|
||||
else:
|
||||
return {
|
||||
"result": "No intent matched",
|
||||
"intent_latency": round(intent_latency * 1000, 3),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
# logger.error(f"Error in chat_completion from `Arch-Intent`: {e}")
|
||||
logger.error(f"Error in chat_completion /function_calling: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Intent] - {e}"}
|
||||
|
||||
|
||||
@app.post("/guardrails")
|
||||
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
|
||||
try:
|
||||
guard_start_time = time.perf_counter()
|
||||
guard_result = handler_map["Arch-Guard"].predict(req)
|
||||
guard_latency = time.perf_counter() - guard_start_time
|
||||
return {
|
||||
"response": guard_result,
|
||||
"guard_latency": round(guard_latency * 1000, 3),
|
||||
}
|
||||
except Exception as e:
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Guard] - {e}"}
|
||||
0
model_server/tests/core/__init__.py
Normal file
0
model_server/tests/core/__init__.py
Normal file
173
model_server/tests/core/test_function_calling.py
Normal file
173
model_server/tests/core/test_function_calling.py
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
import os
|
||||
|
||||
from src.commons.globals import handler_map
|
||||
from src.core.model_utils import ChatMessage, Message
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from src.main import app
|
||||
from src.commons.globals import handler_map
|
||||
|
||||
# define function
|
||||
get_weather_api = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get current weather at a location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "str",
|
||||
"description": "The location to get the weather for",
|
||||
"format": "City, State",
|
||||
},
|
||||
"unit": {
|
||||
"type": "str",
|
||||
"description": "The unit to return the weather in.",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
},
|
||||
"days": {
|
||||
"type": "str",
|
||||
"description": "the number of days for the request.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "days"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# get_data class return request, intent, hallucination, parameter_gathering
|
||||
|
||||
|
||||
def get_hallucination_data_complex():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
message2 = Message(
|
||||
role="assistant", content="Can you specify the unit you want the weather in?"
|
||||
)
|
||||
message3 = Message(role="user", content="In celcius please!")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1, message2, message3], tools=tools)
|
||||
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_easy():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# model will hallucinate
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_medium():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# first token will not be tool call
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_complete_data_2():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(
|
||||
role="user",
|
||||
content="what is the weather forecast for seattle in the next 10 days?",
|
||||
)
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
|
||||
|
||||
def get_complete_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle in 7 days?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
|
||||
|
||||
def get_irrelevant_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="What is 1+1?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
|
||||
|
||||
def get_greeting_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="Hello how are you?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"get_data_func",
|
||||
[
|
||||
get_hallucination_data_complex,
|
||||
get_hallucination_data_easy,
|
||||
get_complete_data,
|
||||
get_irrelevant_data,
|
||||
get_complete_data_2,
|
||||
],
|
||||
)
|
||||
async def test_function_calling(get_data_func):
|
||||
req, intent, hallucination, parameter_gathering = get_data_func()
|
||||
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
|
||||
assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
|
||||
|
||||
if intent:
|
||||
function_calling_response = await handler_map["Arch-Function"].chat_completion(
|
||||
req
|
||||
)
|
||||
assert handler_map["Arch-Function"].hallu_handler.hallucination == hallucination
|
||||
response_txt = function_calling_response.choices[0].message.content
|
||||
|
||||
if parameter_gathering:
|
||||
prefill_prefix = handler_map["Arch-Function"].prefill_prefix
|
||||
assert any(
|
||||
response_txt.startswith(prefix) for prefix in prefill_prefix
|
||||
), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"
|
||||
69
model_server/tests/core/test_guardrails.py
Normal file
69
model_server/tests/core/test_guardrails.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
from unittest.mock import patch, MagicMock
|
||||
from src.core.guardrails import get_guardrail_handler
|
||||
|
||||
# Mock constants
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
|
||||
# Test for `get_guardrail_handler()` function on `cpu`
|
||||
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cpu(mock_auto_model, mock_tokenizer):
|
||||
device = "cpu"
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail.model_name,
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `cuda`
|
||||
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_cuda(mock_auto_model, mock_tokenizer):
|
||||
device = "cuda"
|
||||
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail.model_name,
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
|
||||
# Test for `get_guardrail_handler()` function on `mps`
|
||||
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
|
||||
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_guardrail_handler_on_mps(mock_auto_model, mock_tokenizer):
|
||||
device = "mps"
|
||||
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
guardrail = get_guardrail_handler(device=device)
|
||||
|
||||
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
|
||||
|
||||
mock_auto_model.assert_called_once_with(
|
||||
guardrail.model_name,
|
||||
device_map=device,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
50
model_server/tests/core/test_state.py
Normal file
50
model_server/tests/core/test_state.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
from src.commons.globals import handler_map
|
||||
from src.core.function_calling import Message
|
||||
|
||||
|
||||
test_input_history = [
|
||||
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"model": "Arch-Function",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3394",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_3394"},
|
||||
{"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"},
|
||||
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_5306",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_5306"},
|
||||
]
|
||||
|
||||
|
||||
def test_update_fc_history():
|
||||
message_history = []
|
||||
|
||||
for h in test_input_history:
|
||||
message_history.append(Message(**h))
|
||||
|
||||
updated_history = handler_map["Arch-Function"]._process_messages(message_history)
|
||||
assert len(updated_history) == 7
|
||||
# ensure that tool role does not exist anymore
|
||||
assert all([h["role"] != "tool" for h in updated_history])
|
||||
53
model_server/tests/test_app.py
Normal file
53
model_server/tests/test_app.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import pytest
|
||||
import httpx
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from src.main import app
|
||||
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# [TODO] Review: check the following code. Seems something wrong with asyncio package❗
|
||||
# Unit tests for the health check endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthz():
|
||||
response = client.get("/healthz")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
# [TODO] Review: check the following code. Seems something wrong with asyncio package❗
|
||||
# Unit test for the models endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_models():
|
||||
response = client.get("/models")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["object"] == "list"
|
||||
assert len(response.json()["data"]) > 0
|
||||
|
||||
|
||||
# [TODO] Review: check the following code. Seems something wrong with asyncio package❗
|
||||
# Unit test for the guardrail endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_guardrail_endpoint():
|
||||
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
|
||||
response = client.post("/guardrails", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "response" in response.json()
|
||||
|
||||
|
||||
# [TODO] Review: check the following code. Seems something wrong with asyncio package❗
|
||||
# Unit test for the function calling endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calling_endpoint():
|
||||
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
|
||||
request_data = {
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"model": "Arch-Function",
|
||||
"tools": [],
|
||||
"metadata": {"x-arch-state": "[]"},
|
||||
}
|
||||
response = await client.post("/function_calling", json=request_data)
|
||||
assert response.status_code == 200
|
||||
assert "result" in response.json()
|
||||
Loading…
Add table
Add a link
Reference in a new issue