mirror of
https://github.com/katanemo/plano.git
synced 2026-04-29 10:56:35 +02:00
Some fixes on model server (#362)
* Some fixes on model server * Remove prompt_prefilling message * Fix logging * Fix poetry issues * Improve logging and update the support for text truncation * Fix tests * Fix tests * Fix tests * Fix modelserver tests * Update modelserver tests
This commit is contained in:
parent
ebda682b30
commit
88a02dc478
25 changed files with 1090 additions and 1666 deletions
|
|
@ -1,22 +1,17 @@
|
|||
import importlib
|
||||
import logging
|
||||
from os import path
|
||||
import os
|
||||
from signal import SIGKILL
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
import signal
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
import src.commons.utils as utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
def get_version():
|
||||
|
|
@ -42,76 +37,9 @@ def wait_for_health_check(url, timeout=300):
|
|||
return False
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Manage the Uvicorn server.")
|
||||
parser.add_argument(
|
||||
"action",
|
||||
choices=["start", "stop", "restart"],
|
||||
default="start",
|
||||
nargs="?",
|
||||
help="Action to perform on the server (default: start).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=51000,
|
||||
help="Port number for the server (default: 51000).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--foreground",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Run the server in the foreground (default: False).",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_pid_file():
|
||||
temp_dir = tempfile.gettempdir()
|
||||
return path.join(temp_dir, "model_server.pid")
|
||||
|
||||
|
||||
def stop_server():
|
||||
"""Stop the Uvicorn server."""
|
||||
pid_file = get_pid_file()
|
||||
if os.path.exists(pid_file):
|
||||
logger.info(f"PID file found, shutting down the server.")
|
||||
# read pid from file
|
||||
with open(pid_file, "r") as f:
|
||||
pid = int(f.read())
|
||||
logger.info(f"Killing model server {pid}")
|
||||
try:
|
||||
os.kill(pid, SIGKILL)
|
||||
except ProcessLookupError:
|
||||
logger.info(f"Process {pid} not found")
|
||||
os.remove(pid_file)
|
||||
else:
|
||||
logger.info("No PID file found, server is not running.")
|
||||
|
||||
|
||||
def restart_server(port=51000, foreground=False):
|
||||
"""Restart the Uvicorn server."""
|
||||
stop_server()
|
||||
start_server(port, foreground)
|
||||
|
||||
|
||||
def run_server():
|
||||
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
|
||||
|
||||
args = parse_args()
|
||||
action = args.action
|
||||
|
||||
if action == "start":
|
||||
start_server(args.port, args.foreground)
|
||||
elif action == "stop":
|
||||
stop_server()
|
||||
elif action == "restart":
|
||||
restart_server(args.port, args.foreground)
|
||||
else:
|
||||
logger.info(f"Unknown action: {action}")
|
||||
sys.exit(1)
|
||||
return os.path.join(temp_dir, "model_server.pid")
|
||||
|
||||
|
||||
def ensure_killed(process):
|
||||
|
|
@ -131,7 +59,7 @@ def ensure_killed(process):
|
|||
def start_server(port=51000, foreground=False):
|
||||
"""Start the Uvicorn server."""
|
||||
|
||||
logging.info("model server version: %s", get_version())
|
||||
logger.info("model server version: %s", get_version())
|
||||
|
||||
stop_server()
|
||||
|
||||
|
|
@ -196,6 +124,57 @@ def start_server(port=51000, foreground=False):
|
|||
ensure_killed(process)
|
||||
|
||||
|
||||
def stop_server():
|
||||
"""Stop the Uvicorn server."""
|
||||
|
||||
pid_file = get_pid_file()
|
||||
if os.path.exists(pid_file):
|
||||
logger.info("PID file found, shutting down the server.")
|
||||
# read pid from file
|
||||
with open(pid_file, "r") as f:
|
||||
pid = int(f.read())
|
||||
logger.info(f"Killing model server {pid}")
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
logger.info(f"Process {pid} not found")
|
||||
os.remove(pid_file)
|
||||
else:
|
||||
logger.info("No PID file found, server is not running.")
|
||||
|
||||
|
||||
def restart_server(port=51000, foreground=False):
|
||||
"""Restart the Uvicorn server."""
|
||||
stop_server()
|
||||
start_server(port, foreground)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Manage the Uvicorn server.")
|
||||
parser.add_argument(
|
||||
"action",
|
||||
choices=["start", "stop", "restart"],
|
||||
default="start",
|
||||
nargs="?",
|
||||
help="Action to perform on the server (default: start).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=51000,
|
||||
help="Port number for the server (default: 51000).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--foreground",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Run the server in the foreground (default: False).",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Start, stop, or restart the Uvicorn server based on command-line arguments.
|
||||
|
|
@ -204,11 +183,14 @@ def main():
|
|||
args = parse_args()
|
||||
|
||||
if args.action == "start":
|
||||
logger.info("[CLI] - Starting server")
|
||||
start_server(args.port, args.foreground)
|
||||
elif args.action == "stop":
|
||||
logger.info("[CLI] - Stopping server")
|
||||
stop_server()
|
||||
elif args.action == "restart":
|
||||
logger.info("[CLI] - Restarting server")
|
||||
restart_server(args.port)
|
||||
else:
|
||||
logger.error(f"Unknown action: {args.action}")
|
||||
logger.error(f"[CLI] - Unknown action: {args.action}")
|
||||
sys.exit(1)
|
||||
|
|
|
|||
|
|
@ -22,9 +22,7 @@ ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
|
|||
# Define model names
|
||||
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
|
||||
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
|
||||
|
||||
logger.info("loading prompt guard model ...")
|
||||
arch_guard_model = get_guardrail_handler()
|
||||
ARCH_GUARD_MODEL_ALIAS = "katanemo/Arch-Guard"
|
||||
|
||||
# Define model handlers
|
||||
handler_map = {
|
||||
|
|
@ -34,5 +32,5 @@ handler_map = {
|
|||
"Arch-Function": ArchFunctionHandler(
|
||||
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
|
||||
),
|
||||
"Arch-Guard": arch_guard_model,
|
||||
"Arch-Guard": get_guardrail_handler(ARCH_GUARD_MODEL_ALIAS),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,87 +1,50 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
import logging
|
||||
import requests
|
||||
import subprocess
|
||||
import importlib
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
# Default log directory and file
|
||||
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, ".logs")
|
||||
DEFAULT_LOG_FILE = "modelserver.log"
|
||||
|
||||
|
||||
def get_model_server_logger(log_dir=None, log_file=None):
|
||||
def get_model_server_logger():
|
||||
"""
|
||||
Get or initialize the logger instance for the model server.
|
||||
|
||||
Parameters:
|
||||
- log_dir (str): Custom directory to store the log file. Defaults to `./.logs`.
|
||||
- log_file (str): Custom log file name. Defaults to `modelserver.log`.
|
||||
|
||||
Returns:
|
||||
- logging.Logger: Configured logger instance.
|
||||
"""
|
||||
log_dir = log_dir or DEFAULT_LOG_DIR
|
||||
log_file = log_file or DEFAULT_LOG_FILE
|
||||
log_file_path = os.path.join(log_dir, log_file)
|
||||
|
||||
# Check if the logger is already configured
|
||||
logger = logging.getLogger("model_server_logger")
|
||||
logger = logging.getLogger("model_server")
|
||||
|
||||
# Return existing logger instance if already configured
|
||||
if logger.hasHandlers():
|
||||
# Return existing logger instance if already configured
|
||||
return logger
|
||||
|
||||
# Ensure the log directory exists, create it if necessary
|
||||
try:
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# Check for write permissions
|
||||
if not os.access(log_dir, os.W_OK):
|
||||
raise PermissionError(f"No write permission for the directory: {log_dir}")
|
||||
except (PermissionError, OSError) as e:
|
||||
raise RuntimeError(f"Failed to initialize logger: {e}")
|
||||
|
||||
# Configure logging to file
|
||||
# Configure logging to only log to console
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
# logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in the file
|
||||
logging.StreamHandler(), # Also log to console
|
||||
],
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
logger = get_model_server_logger()
|
||||
|
||||
logging.info("initializing torch device ...")
|
||||
import torch
|
||||
|
||||
|
||||
def get_device():
|
||||
available_device = {
|
||||
"cpu": True,
|
||||
"cuda": torch.cuda.is_available(),
|
||||
"mps": (
|
||||
torch.backends.mps.is_available()
|
||||
if hasattr(torch.backends, "mps")
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
||||
if available_device["cuda"]:
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif available_device["mps"]:
|
||||
elif torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def get_today_date():
|
||||
# Get today's date
|
||||
today = datetime.now()
|
||||
|
||||
# Get full date with day of week
|
||||
full_date = today.strftime("%Y-%m-%d")
|
||||
|
||||
return full_date
|
||||
|
|
|
|||
|
|
@ -1,22 +1,24 @@
|
|||
import ast
|
||||
import json
|
||||
import random
|
||||
import builtins
|
||||
import textwrap
|
||||
import src.commons.utils as utils
|
||||
|
||||
from openai import OpenAI
|
||||
from typing import Any, Dict, List
|
||||
from overrides import override
|
||||
from src.commons.utils import get_model_server_logger
|
||||
from src.core.model_utils import (
|
||||
from src.core.utils.hallucination_utils import HallucinationState
|
||||
from src.core.utils.model_utils import (
|
||||
Message,
|
||||
ChatMessage,
|
||||
Choice,
|
||||
ChatCompletionResponse,
|
||||
ArchBaseHandler,
|
||||
)
|
||||
from src.core.hallucination import HallucinationStateHandler
|
||||
|
||||
logger = get_model_server_logger()
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
class ArchIntentConfig:
|
||||
|
|
@ -74,7 +76,6 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
)
|
||||
|
||||
self.extra_instruction = config.EXTRA_INSTRUCTION
|
||||
self.prompt_prefilling = False
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
|
||||
|
|
@ -122,15 +123,19 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
Note:
|
||||
Currently only support vllm inference
|
||||
"""
|
||||
logger.info("[Arch-Intent] - ChatCompletion")
|
||||
|
||||
# In the case that no tools are available, simply return `No` to avoid making a call
|
||||
if len(req.tools) == 0:
|
||||
model_response = Message(content="No", tool_calls=[])
|
||||
logger.info("No tools found, return `No` as the model response.")
|
||||
else:
|
||||
messages = self._process_messages(
|
||||
req.messages, req.tools, self.extra_instruction
|
||||
)
|
||||
|
||||
logger.info(f"[request]: {json.dumps(messages)}")
|
||||
|
||||
model_response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
|
|
@ -138,9 +143,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
extra_body=self.generation_params,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"arch_intent response: %s", json.dumps(model_response.model_dump())
|
||||
)
|
||||
logger.info(f"[response]: {json.dumps(model_response.model_dump())}")
|
||||
|
||||
model_response = Message(
|
||||
content=model_response.choices[0].message.content, tool_calls=[]
|
||||
|
|
@ -160,7 +163,11 @@ class ArchFunctionConfig:
|
|||
TASK_PROMPT = textwrap.dedent(
|
||||
"""
|
||||
You are a helpful assistant.
|
||||
"""
|
||||
|
||||
Today's date: {}
|
||||
""".format(
|
||||
utils.get_today_date()
|
||||
)
|
||||
).strip()
|
||||
|
||||
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
|
||||
|
|
@ -189,7 +196,7 @@ class ArchFunctionConfig:
|
|||
"temperature": 0.6,
|
||||
"top_p": 1.0,
|
||||
"top_k": 10,
|
||||
"max_tokens": 512,
|
||||
"max_tokens": 1024,
|
||||
"stop_token_ids": [151645],
|
||||
"logprobs": True,
|
||||
"top_logprobs": 10,
|
||||
|
|
@ -241,10 +248,11 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
|
||||
self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
|
||||
self.prompt_prefilling = False
|
||||
|
||||
self.hallucination_state = None
|
||||
|
||||
# Predefine data types for verification. Only support Python for now.
|
||||
# [TODO] Extend the list of support data types
|
||||
# TODO: Extend the list of support data types
|
||||
self.support_data_types = {
|
||||
type_name: getattr(builtins, type_name)
|
||||
for type_name in config.SUPPORT_DATA_TYPES
|
||||
|
|
@ -365,15 +373,15 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
return {"result": tool_calls, "status": is_valid, "message": error_message}
|
||||
|
||||
def _correcting_type(self, value, target_type):
|
||||
def _convert_data_type(self, value: str, target_type: str):
|
||||
# TODO: Add more conversion rules as needed
|
||||
try:
|
||||
if target_type == float and isinstance(value, int):
|
||||
if target_type is float and isinstance(value, int):
|
||||
return float(value)
|
||||
elif target_type == list and isinstance(value, str):
|
||||
elif target_type is list and isinstance(value, str):
|
||||
return ast.literal_eval(value)
|
||||
elif target_type == str and not isinstance(value, str):
|
||||
elif target_type is str and not isinstance(value, str):
|
||||
return str(value)
|
||||
# Add more conversion rules as needed
|
||||
except (ValueError, TypeError, json.JSONDecodeError):
|
||||
pass
|
||||
return value
|
||||
|
|
@ -426,32 +434,34 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
break
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
function_properties = functions[func_name]["properties"]
|
||||
|
||||
for param_name in func_args:
|
||||
if param_name not in functions[func_name]["properties"]:
|
||||
if param_name not in function_properties:
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
|
||||
break
|
||||
else:
|
||||
param_value = func_args[param_name]
|
||||
data_type = functions[func_name]["properties"][param_name][
|
||||
"type"
|
||||
]
|
||||
target_type = function_properties[param_name]["type"]
|
||||
|
||||
if data_type in self.support_data_types:
|
||||
if not isinstance(
|
||||
param_value,
|
||||
self.support_data_types[data_type],
|
||||
) and not isinstance(
|
||||
self._correcting_type(
|
||||
param_value, self.support_data_types[data_type]
|
||||
),
|
||||
self.support_data_types[data_type],
|
||||
):
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
|
||||
break
|
||||
if target_type in self.support_data_types:
|
||||
data_type = self.support_data_types[target_type]
|
||||
|
||||
if not isinstance(param_value, data_type):
|
||||
param_value = self._convert_data_type(
|
||||
param_value, data_type
|
||||
)
|
||||
if not isinstance(param_value, data_type):
|
||||
is_valid = False
|
||||
invalid_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is expected to have the data type `{data_type}`, got `{type(param_value)}`."
|
||||
break
|
||||
else:
|
||||
error_message = (
|
||||
f"Data type `{target_type}` is not supported."
|
||||
)
|
||||
|
||||
return {
|
||||
"status": is_valid,
|
||||
|
|
@ -491,51 +501,8 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
**self.prefill_params,
|
||||
},
|
||||
)
|
||||
self.prompt_prefilling = True
|
||||
return prefill_response
|
||||
|
||||
def _check_length_and_pop_messages(self, messages, max_tokens=4096):
|
||||
"""
|
||||
Trims the `messages` list to ensure the total token count does not exceed `max_tokens`.
|
||||
|
||||
Args:
|
||||
messages (list): List of message dictionaries.
|
||||
max_tokens (int): Maximum allowed token count.
|
||||
|
||||
Returns:
|
||||
list: Trimmed list of messages.
|
||||
"""
|
||||
|
||||
def estimate_token_length(messages):
|
||||
"""Estimate the total token length of the messages."""
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
# Approximate token length: assuming ~4 characters per token on average
|
||||
total_tokens += len(message["content"]) // 4
|
||||
return total_tokens
|
||||
|
||||
# Calculate initial token length
|
||||
total_tokens = estimate_token_length(messages)
|
||||
|
||||
# Trim messages if token count exceeds the limit
|
||||
while total_tokens > max_tokens and len(messages) >= 3:
|
||||
# Find the first non-system message pair
|
||||
for i in range(len(messages)):
|
||||
if messages[i]["role"] != "system":
|
||||
# Remove the 'user'/'assistant' pair
|
||||
if i + 1 < len(messages) and messages[i + 1]["role"] in [
|
||||
"user",
|
||||
"assistant",
|
||||
]:
|
||||
del messages[i : i + 2]
|
||||
else:
|
||||
del messages[i]
|
||||
break
|
||||
# Recalculate token length
|
||||
total_tokens = estimate_token_length(messages)
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
"""
|
||||
|
|
@ -550,13 +517,11 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
Note:
|
||||
Currently only support vllm inference
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"model_server => arch_function: request body: {json.dumps(req.model_dump())}"
|
||||
)
|
||||
logger.info("[Arch-Function] - ChatCompletion")
|
||||
|
||||
messages = self._process_messages(req.messages, req.tools)
|
||||
messages = self._check_length_and_pop_messages(messages)
|
||||
|
||||
logger.info(f"[request]: {json.dumps(messages)}")
|
||||
|
||||
# always enable `stream=True` to collect model responses
|
||||
response = self.client.chat.completions.create(
|
||||
|
|
@ -567,45 +532,39 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
)
|
||||
|
||||
# initialize the hallucination handler, which is an iterator
|
||||
self.hallu_handler = HallucinationStateHandler(
|
||||
self.hallucination_state = HallucinationState(
|
||||
response_iterator=response, function=req.tools
|
||||
)
|
||||
|
||||
model_response, self.has_tool_call = "", None
|
||||
self.hallucination = False
|
||||
for _ in self.hallu_handler:
|
||||
model_response = ""
|
||||
|
||||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallu_handler.tokens) > 0 and self.has_tool_call is None:
|
||||
if self.hallu_handler.tokens[0] == "<tool_call>":
|
||||
self.has_tool_call = True
|
||||
if len(self.hallucination_state.tokens) > 0 and has_tool_calls is None:
|
||||
if self.hallucination_state.tokens[0] == "<tool_call>":
|
||||
has_tool_calls = True
|
||||
else:
|
||||
self.has_tool_call = False
|
||||
has_tool_calls = False
|
||||
break
|
||||
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
if self.hallu_handler.hallucination is True:
|
||||
self.hallucination = True
|
||||
logger.info(
|
||||
f"{self.hallu_handler.error_message} - start parameter gathering"
|
||||
)
|
||||
logger.info(
|
||||
f"Hallucinated response : {''.join(self.hallu_handler.tokens)}"
|
||||
)
|
||||
# [TODO] - add break when hallucination is detected
|
||||
if self.hallucination_state.hallucination is True:
|
||||
has_hallucination = True
|
||||
break
|
||||
if self.hallucination is True:
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
|
||||
if self.has_tool_call and self.hallucination is False:
|
||||
# [TODO] - Review: remove the following code
|
||||
|
||||
model_response = "".join(self.hallu_handler.tokens)
|
||||
logger.info(f"Tool call found, no hallucination detected {model_response}!")
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
if self.has_tool_call is False:
|
||||
# [TODO] - Review: remove the following code
|
||||
logger.info("No tool call found, start parameter gathering")
|
||||
if has_tool_calls:
|
||||
if has_hallucination:
|
||||
# start prompt prefilling if hallcuination is found in tool calls
|
||||
logger.info(
|
||||
f"[Hallucination]: {self.hallucination_state.error_message}"
|
||||
)
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
else:
|
||||
model_response = "".join(self.hallucination_state.tokens)
|
||||
else:
|
||||
# start parameter gathering if the model is not generating tool calls
|
||||
prefill_response = self._engage_parameter_gathering(messages)
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
|
||||
|
|
@ -613,21 +572,20 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
extracted = self._extract_tool_calls(model_response)
|
||||
|
||||
if len(extracted["result"]) and extracted["status"]:
|
||||
# [TODO] Review: define the behavior in the case that tool call extraction fails
|
||||
# if not extracted["status"]:
|
||||
verified = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=extracted["result"]
|
||||
)
|
||||
# [TODO] - Review: remvoe the following code
|
||||
# print(f"[Verified] - {verified}")
|
||||
|
||||
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
|
||||
if verified["status"]:
|
||||
logger.info(
|
||||
f"[Tool calls]: {json.dumps([tool_call['function'] for tool_call in extracted['result']])}"
|
||||
)
|
||||
model_response = Message(content="", tool_calls=extracted["result"])
|
||||
log_message = f"model_server <= arch_function: (tool_calls): {json.dumps([tool_call['function'] for tool_call in extracted['result']])}"
|
||||
logger.info(log_message)
|
||||
else:
|
||||
raise ValueError(f"Invalid tool call: {verified['message']}")
|
||||
logger.error(f"Invalid tool call - {verified['message']}")
|
||||
# raise ValueError(
|
||||
# f"[Arch-Function]: Invalid tool call - {verified['message']}"
|
||||
# )
|
||||
else:
|
||||
model_response = Message(content=model_response, tool_calls=[])
|
||||
|
||||
|
|
@ -635,10 +593,6 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
choices=[Choice(message=model_response)], model=self.model_name
|
||||
)
|
||||
|
||||
# [TODO] Review: define the protocol to collect debugging output
|
||||
|
||||
logger.info(
|
||||
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.model_dump())}"
|
||||
)
|
||||
logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}")
|
||||
|
||||
return chat_completion_response
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import src.commons.utils as utils
|
||||
from transformers import AutoTokenizer
|
||||
from src.core.model_utils import GuardRequest, GuardResponse
|
||||
|
||||
# from optimum.intel import OVModelForSequenceClassification
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from src.core.utils.model_utils import GuardRequest, GuardResponse
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
class ArchGuardHanlder:
|
||||
|
|
@ -76,26 +76,15 @@ class ArchGuardHanlder:
|
|||
text, truncation=True, max_length=max_length, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
|
||||
prob = ArchGuardHanlder.softmax(logits)[
|
||||
self.support_tasks[task]["positive_class"]
|
||||
]
|
||||
].item()
|
||||
|
||||
latency = time.perf_counter() - start_time
|
||||
verdict = prob > self.support_tasks[task]["threshold"]
|
||||
|
||||
if prob > self.support_tasks[task]["threshold"]:
|
||||
verdict = True
|
||||
sentence = text
|
||||
else:
|
||||
verdict = False
|
||||
sentence = None
|
||||
|
||||
return GuardResponse(
|
||||
prob=[prob.item()], verdict=verdict, sentence=[sentence], latency=latency
|
||||
)
|
||||
return GuardResponse(task=task, input=text, prob=prob, verdict=verdict)
|
||||
|
||||
def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse:
|
||||
"""
|
||||
|
|
@ -115,29 +104,37 @@ class ArchGuardHanlder:
|
|||
if req.task not in self.support_tasks:
|
||||
raise NotImplementedError(f"{req.task} is not supported!")
|
||||
|
||||
logger.info("[Arch-Guard] - Prediction")
|
||||
logger.info(f"[request]: {req.input}")
|
||||
|
||||
if len(req.input.split()) < max_num_words:
|
||||
return self._predict_text(req.task, req.input)
|
||||
result = self._predict_text(req.task, req.input)
|
||||
else:
|
||||
prob, verdict = 0.0, False
|
||||
|
||||
# split into chunks if text is long
|
||||
text_chunks = self._split_text_into_chunks(req.input)
|
||||
|
||||
prob, verdict, sentence, latency = [], False, [], 0
|
||||
|
||||
for chunk in text_chunks:
|
||||
chunk_result = self._predict_text(req.task, chunk)
|
||||
|
||||
if chunk_result.verdict:
|
||||
prob.append(chunk_result.prob[0])
|
||||
prob = chunk_result.prob
|
||||
verdict = True
|
||||
sentence.append(chunk_result.sentence[0])
|
||||
latency += chunk_result.latency
|
||||
break
|
||||
|
||||
return GuardResponse(
|
||||
prob=prob, verdict=verdict, sentence=sentence, latency=latency
|
||||
result = GuardResponse(
|
||||
task=req.task, input=req.input, prob=prob, verdict=verdict
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[response]: {req.task}: {'True' if result.verdict else 'False'} (prob: {result.prob:.2f})"
|
||||
)
|
||||
|
||||
def get_guardrail_handler(device: str = None):
|
||||
return result
|
||||
|
||||
|
||||
def get_guardrail_handler(model_name: str = "katanemo/Arch-Guard", device: str = None):
|
||||
"""
|
||||
Initializes and returns an instance of ArchGuardHanlder based on the specified device.
|
||||
|
||||
|
|
@ -151,19 +148,11 @@ def get_guardrail_handler(device: str = None):
|
|||
if device is None:
|
||||
device = utils.get_device()
|
||||
|
||||
model_class, model_name = None, None
|
||||
# if device == "cpu":
|
||||
# model_class = OVModelForSequenceClassification
|
||||
# model_name = "katanemo/Arch-Guard-cpu"
|
||||
# else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
model_name = "katanemo/Arch-Guard"
|
||||
|
||||
guardrail_dict = {
|
||||
"device": device,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
"model": AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, device_map=device, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
|
|
|||
0
model_server/src/core/utils/__init__.py
Normal file
0
model_server/src/core/utils/__init__.py
Normal file
|
|
@ -127,7 +127,7 @@ def is_parameter_property(
|
|||
return property_name in parameter_info
|
||||
|
||||
|
||||
class HallucinationStateHandler:
|
||||
class HallucinationState:
|
||||
"""
|
||||
A class to handle the state of hallucination detection in token processing.
|
||||
|
||||
|
|
@ -145,7 +145,7 @@ class HallucinationStateHandler:
|
|||
|
||||
def __init__(self, response_iterator=None, function=None):
|
||||
"""
|
||||
Initializes the HallucinationStateHandler with default values.
|
||||
Initializes the HallucinationState with default values.
|
||||
"""
|
||||
self.tokens: List[str] = []
|
||||
self.logprobs: List[float] = []
|
||||
|
|
@ -173,7 +173,7 @@ class HallucinationStateHandler:
|
|||
|
||||
def _reset_parameters(self):
|
||||
"""
|
||||
Resets all parameters in the HallucinationStateHandler to their default values.
|
||||
Resets all parameters in the HallucinationState to their default values.
|
||||
"""
|
||||
self.state = None
|
||||
self.parameter_name_done = False
|
||||
|
|
@ -268,7 +268,7 @@ class HallucinationStateHandler:
|
|||
# if the parameter name is done and the token is a parameter name start token, change the state
|
||||
elif (
|
||||
self.parameter_name_done
|
||||
and self.open_bracket == False
|
||||
and not self.open_bracket
|
||||
and content.endswith(PARAMETER_NAME_START_PATTERN)
|
||||
):
|
||||
self.state = "parameter_name"
|
||||
|
|
@ -324,7 +324,7 @@ class HallucinationStateHandler:
|
|||
# if the state is parameter value and the token is an end token, change the state
|
||||
elif (
|
||||
self.state == "parameter_value"
|
||||
and self.open_bracket == False
|
||||
and not self.open_bracket
|
||||
and content.endswith(PARAMETER_VALUE_END_TOKEN)
|
||||
):
|
||||
self.state = None
|
||||
|
|
@ -354,7 +354,7 @@ class HallucinationStateHandler:
|
|||
self.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value],
|
||||
):
|
||||
self.hallucination = True
|
||||
self.error_message = f"Hallucination: token '{self.tokens[-1]}' is uncertain. {self.token_probs_map}"
|
||||
self.error_message = f"token '{self.tokens[-1]}' is uncertain. Generated response:\n{''.join(self.tokens)}"
|
||||
|
||||
def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int:
|
||||
"""
|
||||
|
|
@ -14,8 +14,8 @@ class Message(BaseModel):
|
|||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: List[Dict[str, Any]]
|
||||
messages: List[Message] = []
|
||||
tools: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
|
|
@ -28,8 +28,8 @@ class ChatCompletionResponse(BaseModel):
|
|||
id: Optional[int] = 0
|
||||
object: Optional[str] = "chat_completion"
|
||||
created: Optional[str] = ""
|
||||
choices: List[Choice]
|
||||
model: str
|
||||
choices: List[Choice] = []
|
||||
model: str = ""
|
||||
metadata: Optional[Dict[str, str]] = {}
|
||||
|
||||
|
||||
|
|
@ -39,10 +39,11 @@ class GuardRequest(BaseModel):
|
|||
|
||||
|
||||
class GuardResponse(BaseModel):
|
||||
prob: List
|
||||
verdict: bool
|
||||
sentence: List
|
||||
latency: float = 0
|
||||
task: str = ""
|
||||
input: str = ""
|
||||
prob: float = 0.0
|
||||
verdict: bool = False
|
||||
metadata: Optional[Dict[str, str]] = {}
|
||||
|
||||
|
||||
# ================================================================================================
|
||||
|
|
@ -121,6 +122,7 @@ class ArchBaseHandler:
|
|||
messages: List[Message],
|
||||
tools: List[Dict[str, Any]] = None,
|
||||
extra_instruction: str = None,
|
||||
max_tokens=4096,
|
||||
):
|
||||
"""
|
||||
Processes a list of messages and formats them appropriately.
|
||||
|
|
@ -129,6 +131,7 @@ class ArchBaseHandler:
|
|||
messages (List[Message]): A list of message objects.
|
||||
tools (List[Dict[str, Any]], optional): A list of tools to include in the system prompt.
|
||||
extra_instruction (str, optional): Additional instructions to append to the last user message.
|
||||
max_tokens (int): Maximum allowed token count, assuming ~4 characters per token on average.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of processed message dictionaries.
|
||||
|
|
@ -149,14 +152,12 @@ class ArchBaseHandler:
|
|||
)
|
||||
|
||||
if tool_calls:
|
||||
# [TODO] Extend to support multiple function calls
|
||||
# TODO: Extend to support multiple function calls
|
||||
role = "assistant"
|
||||
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
|
||||
elif message.role == "tool":
|
||||
elif role == "tool":
|
||||
role = "user"
|
||||
content = (
|
||||
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
|
||||
)
|
||||
content = f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
|
||||
|
||||
processed_messages.append({"role": role, "content": content})
|
||||
|
||||
|
|
@ -165,6 +166,23 @@ class ArchBaseHandler:
|
|||
if extra_instruction:
|
||||
processed_messages[-1]["content"] += extra_instruction
|
||||
|
||||
# keep the first system message and shift conversation if the total token length exceeds the limit
|
||||
def truncate_messages(messages: List[Dict[str, Any]]):
|
||||
num_tokens, conversation_idx = 0, 0
|
||||
if messages[0]["role"] == "system":
|
||||
num_tokens += len(messages[0]["content"]) // 4
|
||||
conversation_idx = 1
|
||||
|
||||
for message_idx in range(len(messages) - 1, conversation_idx - 1, -1):
|
||||
num_tokens += len(messages[message_idx]["content"]) // 4
|
||||
if num_tokens >= max_tokens:
|
||||
if messages[message_idx]["role"] == "user":
|
||||
break
|
||||
|
||||
return messages[:conversation_idx] + messages[message_idx:]
|
||||
|
||||
processed_messages = truncate_messages(processed_messages)
|
||||
|
||||
return processed_messages
|
||||
|
||||
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
|
||||
|
|
@ -1,25 +1,25 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import src.commons.utils as utils
|
||||
|
||||
from src.commons.globals import handler_map
|
||||
from src.core.model_utils import ChatMessage, GuardRequest
|
||||
from src.core.utils.model_utils import (
|
||||
ChatMessage,
|
||||
ChatCompletionResponse,
|
||||
GuardRequest,
|
||||
GuardResponse,
|
||||
)
|
||||
|
||||
from fastapi import FastAPI, Response
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
|
|
@ -31,11 +31,6 @@ resource = Resource.create(
|
|||
trace.set_tracer_provider(TracerProvider(resource=resource))
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
FastAPIInstrumentor().instrument_app(app)
|
||||
|
||||
# DEFAULT_OTLP_HOST = "http://localhost:4317"
|
||||
DEFAULT_OTLP_HOST = "none"
|
||||
|
||||
|
|
@ -47,6 +42,16 @@ otlp_exporter = OTLPSpanExporter(
|
|||
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logging.getLogger("opentelemetry.exporter.otlp.proto.grpc.exporter").setLevel(
|
||||
logging.ERROR
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
FastAPIInstrumentor().instrument_app(app)
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
return {"status": "ok"}
|
||||
|
|
@ -62,73 +67,78 @@ async def models():
|
|||
|
||||
@app.post("/function_calling")
|
||||
async def function_calling(req: ChatMessage, res: Response):
|
||||
logger.info("[Endpoint: /function_calling]")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
|
||||
|
||||
final_response: ChatCompletionResponse = None
|
||||
error_messages = None
|
||||
|
||||
try:
|
||||
intent_start_time = time.perf_counter()
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
intent_latency = time.perf_counter() - intent_start_time
|
||||
|
||||
if handler_map["Arch-Intent"].detect_intent(intent_response):
|
||||
# [TODO] measure agreement between intent detection and function calling
|
||||
# TODO: measure agreement between intent detection and function calling
|
||||
try:
|
||||
function_start_time = time.perf_counter()
|
||||
function_calling_response = await handler_map[
|
||||
"Arch-Function"
|
||||
].chat_completion(req)
|
||||
final_response = await handler_map["Arch-Function"].chat_completion(req)
|
||||
function_latency = time.perf_counter() - function_start_time
|
||||
function_calling_response.metadata = {
|
||||
|
||||
final_response.metadata = {
|
||||
"intent_latency": str(round(intent_latency * 1000, 3)),
|
||||
"function_latency": str(round(function_latency * 1000, 3)),
|
||||
"hallucination": str(handler_map["Arch-Function"].hallucination),
|
||||
"tokens_uncertainty": json.dumps(
|
||||
handler_map["Arch-Function"].hallu_handler.token_probs_map
|
||||
),
|
||||
"prompt_prefilling": str(
|
||||
handler_map["Arch-Function"].prompt_prefilling
|
||||
"hallucination": str(
|
||||
handler_map["Arch-Function"].hallucination_state.hallucination
|
||||
),
|
||||
}
|
||||
|
||||
return function_calling_response
|
||||
except ValueError as e:
|
||||
res.statuscode = 503
|
||||
error_message = "Tool call extraction error"
|
||||
logger.error(f" {error_message}: {e}")
|
||||
return {"error": f"[Arch-Function] - {error_message} - {e}"}
|
||||
error_messages = f"[Arch-Function] - Error in tool call extraction: {e}"
|
||||
except StopIteration as e:
|
||||
res.statuscode = 500
|
||||
error_message = "Hallucination iterator error"
|
||||
logger.error(f" {error_message}: {e}")
|
||||
return {"error": f"[Arch-Function] - {error_message} - {e}"}
|
||||
error_messages = f"[Arch-Function] - Error in hallucination check: {e}"
|
||||
except Exception as e:
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Function] - {e}"}
|
||||
# [TODO] Review: define the behavior if `Arch-Intent` doesn't detect an intent
|
||||
error_messages = f"[Arch-Function] - Error in ChatCompletion: {e}"
|
||||
else:
|
||||
return {
|
||||
"result": "No intent matched",
|
||||
"intent_latency": round(intent_latency * 1000, 3),
|
||||
intent_response.metadata = {
|
||||
"intent_latency": str(round(intent_latency * 1000, 3)),
|
||||
}
|
||||
final_response = intent_response
|
||||
|
||||
except Exception as e:
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
# logger.error(f"Error in chat_completion from `Arch-Intent`: {e}")
|
||||
logger.error(f"Error in chat_completion /function_calling: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Intent] - {e}"}
|
||||
error_messages = f"[Arch-Intent] - Error in ChatCompletion: {e}"
|
||||
|
||||
if error_messages is not None:
|
||||
logger.error(error_messages)
|
||||
final_response = ChatCompletionResponse(metadata={"error": error_messages})
|
||||
|
||||
return final_response
|
||||
|
||||
|
||||
@app.post("/guardrails")
|
||||
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
|
||||
logger.info("[Endpoint: /guardrails] - Gateway")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
|
||||
|
||||
final_response: GuardResponse = None
|
||||
error_messages = None
|
||||
|
||||
try:
|
||||
guard_start_time = time.perf_counter()
|
||||
guard_result = handler_map["Arch-Guard"].predict(req)
|
||||
final_response = handler_map["Arch-Guard"].predict(req)
|
||||
guard_latency = time.perf_counter() - guard_start_time
|
||||
return {
|
||||
"response": guard_result,
|
||||
final_response.metadata = {
|
||||
"guard_latency": round(guard_latency * 1000, 3),
|
||||
}
|
||||
except Exception as e:
|
||||
# [TODO] Review: update how to collect debugging outputs
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Guard] - {e}"}
|
||||
error_messages = f"[Arch-Guard]: {e}"
|
||||
|
||||
if error_messages is not None:
|
||||
logger.error(error_messages)
|
||||
final_response = GuardResponse(metadata={"error": error_messages})
|
||||
|
||||
return final_response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue