Reorganize model_server

This commit is contained in:
Shuguang Chen 2024-12-08 09:21:53 -08:00
parent a40cdc7b75
commit b4f4695f16
20 changed files with 20 additions and 20 deletions

View file

175
model_server/src/cli.py Normal file
View file

@ -0,0 +1,175 @@
import importlib
import sys
import time
import requests
import subprocess
import logging
def get_version():
try:
version = importlib.metadata.version("archgw_modelserver")
return version
except importlib.metadata.PackageNotFoundError:
return "version not found"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
log = logging.getLogger("model_server.cli")
log.setLevel(logging.INFO)
log.info(f"model server version: {get_version()}")
def run_server(port=51000):
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
if len(sys.argv) > 1:
action = sys.argv[1]
else:
action = "start"
if action == "start":
start_server(port)
elif action == "stop":
stop_server(port)
elif action == "restart":
restart_server(port)
else:
log.info(f"Unknown action: {action}")
sys.exit(1)
def start_server(port=51000):
"""Start the Uvicorn server"""
log.info(
"starting model server - loading some awesomeness, this may take some time :)"
)
process = subprocess.Popen(
[
"python",
"-m",
"uvicorn",
"app.main:app",
"--host",
"0.0.0.0",
"--port",
f"{port}",
],
start_new_session=True,
bufsize=1,
universal_newlines=True,
stdout=subprocess.PIPE, # Suppress standard output. There is a logger that model_server prints to
stderr=subprocess.PIPE, # Suppress standard error. There is a logger that model_server prints to
)
if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"):
log.info(f"Model server started with PID {process.pid}")
else:
# Add model_server boot-up logs
log.info("model server - didn't start in time, shutting down")
process.terminate()
def wait_for_health_check(url, timeout=300):
"""Wait for the Uvicorn server to respond to health-check requests."""
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(url)
if response.status_code == 200:
return True
except requests.ConnectionError:
time.sleep(1)
print("Timed out waiting for model server to respond.")
return False
def check_and_install_lsof():
"""Check if lsof is installed, and if not, install it using apt-get."""
try:
# Check if lsof is installed by running "lsof -v"
subprocess.run(
["lsof", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
print("lsof is already installed.")
except subprocess.CalledProcessError:
print("lsof not found, installing...")
try:
# Update package list and install lsof
subprocess.run(["sudo", "apt-get", "update"], check=True)
subprocess.run(["sudo", "apt-get", "install", "-y", "lsof"], check=True)
print("lsof installed successfully.")
except subprocess.CalledProcessError as install_error:
print(f"Failed to install lsof: {install_error}")
def kill_process(port=51000, wait=True, timeout=10):
"""Stop the running Uvicorn server."""
log.info("Stopping model server")
try:
# Run the function to check and install lsof if necessary
# Step 1: Run lsof command to get the process using the port
lsof_command = f"lsof -n | grep {port} | grep -i LISTEN"
result = subprocess.run(
lsof_command, shell=True, capture_output=True, text=True
)
if result.returncode != 0:
print(f"No process found listening on port {port}.")
return
# Step 2: Parse the process IDs from the output
process_ids = [line.split()[1] for line in result.stdout.splitlines()]
if not process_ids:
print(f"No process found listening on port {port}.")
return
# Step 3: Kill each process using its PID
for pid in process_ids:
print(f"Killing model server process with PID {pid}")
subprocess.run(f"kill {pid}", shell=True)
if wait:
# Step 4: Wait for the process to be killed by checking if it's still running
start_time = time.time()
while True:
check_process = subprocess.run(
f"ps -p {pid}", shell=True, capture_output=True, text=True
)
if check_process.returncode != 0:
print(f"Process {pid} has been killed.")
break
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
print(
f"Process {pid} did not terminate within {timeout} seconds."
)
print(f"Attempting to force kill process {pid}...")
subprocess.run(f"kill -9 {pid}", shell=True) # SIGKILL
break
print(
f"Waiting for process {pid} to be killed... ({elapsed_time:.2f} seconds)"
)
time.sleep(0.5)
except Exception as e:
print(f"Error occurred: {e}")
def stop_server(port=51000, wait=True, timeout=10):
check_and_install_lsof()
kill_process(port, wait, timeout)
def restart_server(port=51000):
"""Restart the Uvicorn server."""
stop_server(port)
start_server(port)

View file

View file

@ -0,0 +1,79 @@
# ========================== Arch-Intent Default Params ==========================
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
ARCH_INTENT_INSTRUCTION = "Are there any tools can help?"
ARCH_INTENT_TASK_PROMPT = """
You are a helpful assistant.
"""
ARCH_INTENT_TOOL_PROMPT_TEMPLATE = """
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
<tools>
{tool_text}
</tools>
"""
ARCH_INTENT_FORMAT_PROMPT = """
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
- First line must read 'Yes' or 'No'.
- If yes, a second line must include a comma-separated list of tool indexes.
"""
ARCH_INTENT_GENERATION_CONFIG = {
"generation_params": {"max_tokens": 1, "stop_token_ids": [151645]}
}
# ========================== Arch-Function Default Params ==========================
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
ARCH_FUNCTION_TASK_PROMPT = """
You are a helpful assistant.
"""
ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE = """
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
"""
ARCH_FUNCTION_FORMAT_PROMPT = """
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
"""
ARCH_FUNCTION_GENERATION_CONFIG = {
"generation_params": {
"temperature": 0.2,
"top_p": 1.0,
"top_k": 50,
"max_tokens": 512,
"stop_token_ids": [151645],
},
"prefill_params": {
"continue_final_message": True,
"add_generation_prompt": False,
},
"prefill_prefix": [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
],
}

View file

@ -0,0 +1,38 @@
import src.commons.utilities as utils
from openai import OpenAI
from src.commons.constants import *
from src.core.function_calling import ArchIntentHandler, ArchFunctionHandler
from src.core.guardrails import get_guardrail_handler
logger = utils.get_model_server_logger()
# Define the client
ARCH_ENDPOINT = "https://api.fc.archgw.com/v1"
ARCH_API_KEY = "EMPTY"
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
# Define model handlers
handler_map = {
"Arch-Intent": ArchIntentHandler(
ARCH_CLIENT,
ARCH_INTENT_MODEL_ALIAS,
ARCH_INTENT_TASK_PROMPT,
ARCH_INTENT_TOOL_PROMPT_TEMPLATE,
ARCH_INTENT_FORMAT_PROMPT,
ARCH_INTENT_INSTRUCTION,
**ARCH_INTENT_GENERATION_CONFIG,
),
"Arch-Function": ArchFunctionHandler(
ARCH_CLIENT,
ARCH_FUNCTION_MODEL_ALIAS,
ARCH_FUNCTION_TASK_PROMPT,
ARCH_FUNCTION_TOOL_PROMPT_TEMPLATE,
ARCH_FUNCTION_FORMAT_PROMPT,
**ARCH_FUNCTION_GENERATION_CONFIG,
),
"Arch-Guard": get_guardrail_handler(),
}

View file

@ -0,0 +1,65 @@
import os
import torch
import logging
logger_instance = None
def get_device():
available_device = {
"cpu": True,
"cuda": torch.cuda.is_available(),
"mps": (
torch.backends.mps.is_available()
if hasattr(torch.backends, "mps")
else False
),
}
if available_device["cuda"]:
device = "cuda"
elif available_device["mps"]:
device = "mps"
else:
device = "cpu"
return device
def get_model_server_logger():
global logger_instance
if logger_instance is not None:
# If the logger is already initialized, return the existing instance
return logger_instance
# Define log file path outside current directory (e.g., ~/archgw_logs)
log_dir = os.path.expanduser("~/archgw_logs")
log_file = "modelserver.log"
log_file_path = os.path.join(log_dir, log_file)
# Ensure the log directory exists, create it if necessary, handle permissions errors
try:
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True) # Create directory if it doesn't exist
# Check if the script has write permission in the log directory
if not os.access(log_dir, os.W_OK):
raise PermissionError(f"No write permission for the directory: {log_dir}")
# Configure logging to file and console using basicConfig
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in file
],
)
except (PermissionError, OSError):
# Dont' fallback to console logging if there are issues writing to the log file
raise RuntimeError(f"No write permission for the directory: {log_dir}")
# Initialize the logger instance after configuring handlers
logger_instance = logging.getLogger("model_server_logger")
return logger_instance

View file

View file

@ -0,0 +1,167 @@
import json
from openai import OpenAI
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from overrides import final
class Message(BaseModel):
role: Optional[str] = ""
content: Optional[str] = ""
tool_call_id: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
class Choice(BaseModel):
id: Optional[int] = 0
message: Message
finish_reason: Optional[str] = "stop"
class ChatCompletionResponse(BaseModel):
id: Optional[int] = 0
object: Optional[str] = "chat_completion"
created: Optional[str] = ""
choices: List[Choice]
model: str
class ArchBaseHandler:
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
):
"""
Initializes the base handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
"""
self.client = client
self.model_name = model_name
self.task_prompt = task_prompt
self.tool_prompt_template = tool_prompt_template
self.format_prompt = format_prompt
self.generation_params = generation_params
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into the desired internal representation.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Raises:
NotImplementedError: Method should be overridden in subclasses.
"""
raise NotImplementedError()
@final
def _format_system_prompt(self, tools: List[Dict[str, Any]]) -> str:
"""
Formats the system prompt using provided tools.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A formatted system prompt.
"""
tool_text = self._convert_tools(tools)
system_prompt = (
self.task_prompt
+ "\n\n"
+ self.tool_prompt_template.format(tool_text=tool_text)
+ "\n\n"
+ self.format_prompt
)
return system_prompt
@final
def _process_messages(
self,
messages: List[Message],
tools: List[Dict[str, Any]] = None,
extra_instruction: str = None,
):
"""
Processes a list of messages and formats them appropriately.
Args:
messages (List[Message]): A list of message objects.
tools (List[Dict[str, Any]], optional): A list of tools to include in the system prompt.
extra_instruction (str, optional): Additional instructions to append to the last user message.
Returns:
List[Dict[str, Any]]: A list of processed message dictionaries.
"""
processed_messages = []
if tools:
processed_messages.append(
{"role": "system", "content": self._format_system_prompt(tools)}
)
for message in messages:
role, content, tool_calls = (
message.role,
message.content,
message.tool_calls,
)
if tool_calls:
# [TODO] Extend to support multiple function calls
role = "assistant"
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
elif message.role == "tool":
role = "user"
content = (
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
)
processed_messages.append({"role": role, "content": content})
assert processed_messages[-1]["role"] == "user"
if extra_instruction:
processed_messages[-1]["content"] += extra_instruction
return processed_messages
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Abstract method for generating chat completions.
Args:
req (ChatMessage): A chat message request object.
Raises:
NotImplementedError: Method should be overridden in subclasses.
"""
raise NotImplementedError()

View file

@ -0,0 +1,466 @@
import json
import random
import builtins
from openai import OpenAI
from typing import Any, Dict, List, Tuple, Union
from overrides import override
from src.core.base_handler import (
Message,
ChatMessage,
Choice,
ChatCompletionResponse,
ArchBaseHandler,
)
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class ArchIntentHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt_template: str,
format_prompt: str,
extra_instruction: str,
generation_params: Dict,
):
"""
Initializes the intent handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt_template (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
extra_instruction (str): Instructions specific to intent handling.
generation_params (Dict): Generation parameters for the model.
"""
super().__init__(
client,
model_name,
task_prompt,
tool_prompt_template,
format_prompt,
generation_params,
)
self.extra_instruction = extra_instruction
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into a JSON-like format with indexed keys.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A string representation of converted tools.
"""
converted = [
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
]
return "\n".join(converted)
def detect_intent(self, content: str) -> bool:
"""
Detect if any intent match with prompts
Args:
content: str: Model response that contains intent detection results
Returns:
bool: A boolean value to indicate if any intent match with prompts or not
"""
if hasattr(content.choices[0].message, "content"):
return content.choices[0].message.content == "Yes"
else:
return False
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Generates a chat completion for a given request.
Args:
req (ChatMessage): A chat message request object.
Returns:
ChatCompletionResponse: The model's response to the chat request.
Note:
Currently only support vllm inference
"""
# In the case that no tools are available, simply return `No` to avoid making a call
if len(req.tools) == 0:
model_response = Message(content="No", tool_calls=[])
else:
messages = self._process_messages(
req.messages, req.tools, self.extra_instruction
)
model_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body=self.generation_params,
)
model_response = Message(
content=model_response.choices[0].message.content, tool_calls=[]
)
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
)
return chat_completion_response
class ArchFunctionHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
prefill_params: Dict,
prefill_prefix: List,
):
"""
Initializes the function handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt_template (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
prefill_params (Dict): Additional parameters for prefilling responses.
prefill_prefix (List[str]): List of prefixes for prefill responses.
"""
super().__init__(
client,
model_name,
task_prompt,
tool_prompt_template,
format_prompt,
generation_params,
)
self.prefill_params = prefill_params
self.prefill_prefix = prefill_prefix
# Predefine data types for verification. Only support Python for now.
# [TODO] Extend the list of support data types
self.support_data_types = {
type_name: getattr(builtins, type_name) for type_name in SUPPORT_DATA_TYPES
}
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into JSON format.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A string representation of converted tools.
"""
converted = [json.dumps(tool) for tool in tools]
return "\n".join(converted)
def _fix_json_string(self, json_str: str) -> str:
"""
Fixes malformed JSON strings by ensuring proper bracket matching.
Args:
json_str (str): A JSON string that might be malformed.
Returns:
str: A corrected JSON string.
"""
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
# Stack to keep track of brackets
stack = []
# Clean string to collect valid characters
fixed_str = ""
# Dictionary for matching brackets
matching_bracket = {")": "(", "}": "{", "]": "["}
# Dictionary for the opposite of matching_bracket
opening_bracket = {v: k for k, v in matching_bracket.items()}
for char in json_str:
if char in "{[(":
stack.append(char)
fixed_str += char
elif char in "}])":
if stack and stack[-1] == matching_bracket[char]:
stack.pop()
fixed_str += char
else:
# Ignore the unmatched closing brackets
continue
else:
fixed_str += char
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
while stack:
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str.replace("'", '"')
def _extract_tool_calls(
self, content: str
) -> Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
"""
Extracts tool call information from a given string.
Args:
content (str): The content string containing potential tool call information.
Returns:
Tuple[List[Dict[str, Any]], bool, Union[str, Exception]]:
- A list of tool call dictionaries.
- A boolean indicating if the extraction was valid.
- An error message or exception if extraction failed.
"""
tool_calls, is_valid, error_message = [], True, ""
flag = False
for line in content.split("\n"):
if "<tool_call>" == line:
flag = True
elif "</tool_call>" == line:
flag = False
else:
if flag:
try:
tool_content = json.loads(line)
except Exception as e:
fixed_content = self._fix_json_string(line)
try:
tool_content = json.loads(fixed_content)
except Exception:
tool_calls, is_valid, error_message = [], False, e
return tool_calls, is_valid, error_message
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
"arguments": tool_content["arguments"],
},
}
)
flag = False
return tool_calls, is_valid, error_message
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
) -> Tuple[bool, Union[Dict[str, Any], None], str]:
"""
Verifies the validity of extracted tool calls against the provided tools.
Args:
tools (List[Dict[str, Any]]): A list of available tools.
tool_calls (List[Dict[str, Any]]): A list of tool calls to verify.
Returns:
Tuple[bool, Union[Dict[str, Any], None], str]:
- A boolean indicating if the tool calls are valid.
- The invalid tool call dictionary if any.
- An error message.
"""
is_valid, error_tool_call, error_message = True, None, ""
functions = {}
for tool in tools:
if tool["type"] == "function":
functions[tool["function"]["name"]] = tool["function"]["parameters"]
for tool_call in tool_calls:
func_name, func_args = (
tool_call["function"]["name"],
tool_call["function"]["arguments"],
)
# Check whether the function is available or not
if func_name not in functions:
is_valid = False
error_message = f"{func_name} is not defined!"
return is_valid, error_message
else:
# Check if all the requried parameters can be found in the tool calls
for required_param in functions[func_name].get("required", []):
if required_param not in func_args:
is_valid = False
error_tool_call = tool_call
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
return is_valid, error_tool_call, error_message
# Verify the data type of each parameter in the tool calls
for param_name, param_value in func_args:
data_type = functions[func_name]["properties"][param_name]["type"]
if data_type in self.support_data_types:
if not isinstance(
param_value, self.support_data_types[data_type]
):
is_valid = False
error_tool_call = tool_call
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
return is_valid, error_tool_call, error_message
return is_valid, error_tool_call, error_message
def _prefill_response(self, messages: List[Dict[str, str]]):
"""
Prefills the response with the tool call prefix.
Args:
messages (List[Dict[str, str]]): A list of messages.
tools (List[Dict[str, Any]]): A list of tools.
Returns:
List[Dict[str, str]]: A list of messages with the prefill prefix.
"""
messages.append(
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
)
prefill_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
return prefill_response
@override
async def chat_completion(
self, req: ChatMessage, enable_prefilling=True
) -> ChatCompletionResponse:
"""
Generates a chat completion response for a given request.
Args:
req (ChatMessage): A chat message request object.
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
Returns:
ChatCompletionResponse: The model's response to the chat request.
Note:
Currently only support vllm inference
"""
messages = self._process_messages(req.messages, req.tools)
# Retrieve the first token, handling the Stream object carefully
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=enable_prefilling,
extra_body=self.generation_params,
)
model_response = ""
if enable_prefilling:
has_tool_call = None
model_response = ""
for token in response:
token_content = token.choices[0].delta.content.strip()
if token_content:
if has_tool_call is None and token_content != "<tool_call>":
has_tool_call = False
response.close()
break
else:
has_tool_call = True
if has_tool_call is True:
model_response += token_content
# start parameter gathering if the model is not generating a tool call
if has_tool_call is False:
prefill_response = self._prefill_response(messages)
model_response = prefill_response.choices[0].message.content
else:
model_response = response.choices[0].message.content
(
tool_calls,
extraction_status,
extraction_error_message,
) = self._extract_tool_calls(model_response)
if tool_calls:
# [TODO] Review: define the behavior in the case that tool call extraction fails
# if not extraction_status:
(
verification_status,
invalid_tool_call,
verification_error_message,
) = self._verify_tool_calls(tools=req.tools, tool_calls=tool_calls)
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
if verification_status:
model_response = Message(content="", tool_calls=tool_calls)
# else:
else:
model_response = Message(content=model_response, tool_calls=[])
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
)
# [TODO] Review: define the protocol to collect debugging output
# logger.info(
# f"model_server <= arch_function: (tool_calls): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
# )
# logger.info(
# f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
# )
return chat_completion_response

View file

@ -0,0 +1,182 @@
import time
import torch
import numpy as np
import src.commons.utilities as utils
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
from typing import List
class GuardRequest(BaseModel):
input: str
task: str
class GuardResponse(BaseModel):
prob: List
verdict: bool
sentence: List
latency: float = 0
class ArchGuardHanlder:
def __init__(self, model_dict):
"""
Initializes the ArchGuardHanlder with the given model dictionary.
Args:
model_dict (dict): A dictionary containing the model, tokenizer, and device information.
"""
self.model = model_dict["model"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
self.support_tasks = {"jailbreak": {"positive_class": 2, "threshold": 0.5}}
def _split_text_into_chunks(self, text, max_num_words=300):
"""
Splits the input text into chunks of up to `max_num_words` words.
Args:
text (str): The input text to be split.
max_num_words (int, optional): The maximum number of words in each chunk. Defaults to 300.
Returns:
List[str]: A list of text chunks.
"""
words = text.split()
chunks = [
" ".join(words[i : i + max_num_words])
for i in range(0, len(words), max_num_words)
]
return chunks
@staticmethod
def softmax(x):
"""
Computes the softmax of the input array.
Args:
x (np.ndarray): The input array.
Returns:
np.ndarray: The softmax of the input.
"""
return np.exp(x) / np.exp(x).sum(axis=0)
def _predict_text(self, task, text, max_length=512) -> GuardResponse:
"""
Predicts the result for the provided text for a specific task.
Args:
task (str): The task to perform (e.g., "jailbreak").
text (str): The input text to classify.
max_length (int, optional): The maximum length for tokenization. Defaults to 512.
Returns:
GuardResponse: A GuardResponse object containing the prediction.
"""
inputs = self.tokenizer(
text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device)
start_time = time.perf_counter()
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = ArchGuardHanlder.softmax(logits)[
self.support_tasks[task]["positive_class"]
]
latency = time.perf_counter() - start_time
if prob > self.support_tasks[task]["threshold"]:
verdict = True
sentence = text
else:
verdict = False
sentence = None
return GuardResponse(
prob=[prob.item()], verdict=verdict, sentence=[sentence], latency=latency
)
def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse:
"""
Makes a prediction based on the GuardRequest input.
Args:
req (GuardRequest): The GuardRequest object containing the input text and task.
max_num_words (int, optional): The maximum number of words in each chunk if splitting is needed. Defaults to 300.
Returns:
GuardResponse: A GuardResponse object containing the prediction.
Note:
currently only support jailbreak check
"""
if req.task not in self.support_tasks:
raise NotImplementedError(f"{req.task} is not supported!")
if len(req.input.split()) < max_num_words:
return self._predict_text(req.task, req.input)
else:
# split into chunks if text is long
text_chunks = self._split_text_into_chunks(req.input)
prob, verdict, sentence, latency = [], False, [], 0
for chunk in text_chunks:
chunk_result = self._predict_text(req.task, chunk)
if chunk_result.verdict:
prob.append(chunk_result.prob[0])
verdict = True
sentence.append(chunk_result.sentence[0])
latency += chunk_result.latency
return GuardResponse(
prob=prob, verdict=verdict, sentence=sentence, latency=latency
)
def get_guardrail_handler(device: str = None):
"""
Initializes and returns an instance of ArchGuardHanlder based on the specified device.
Args:
device (str, optional): The device to use for model inference (e.g., "cpu" or "cuda"). Defaults to None.
Returns:
ArchGuardHanlder: An instance of ArchGuardHanlder configured for the specified device.
"""
if device is None:
device = utils.get_device()
model_class, model_name = None, None
if device == "cpu":
model_class = OVModelForSequenceClassification
model_name = "katanemo/Arch-Guard-cpu"
else:
model_class = AutoModelForSequenceClassification
model_name = "katanemo/Arch-Guard"
guardrail_dict = {
"device": device,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
),
}
return ArchGuardHanlder(model_dict=guardrail_dict)

View file

@ -0,0 +1,310 @@
import math
import torch
import itertools
from typing import Dict, List, Tuple
from enum import Enum
# constants
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
FUNC_NAME_END_TOKEN = ('",', "',")
TOOL_CALL_TOKEN = "<tool_call>"
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
PARAMETER_NAME_START_PATTERN = (',"', ",'")
PARAMETER_VALUE_START_PATTERN = ('":', "':")
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
# Thresholds
class MaskToken(Enum):
FUNCTION_NAME = "f"
PARAMETER_VALUE = "v"
PARAMETER_NAME = "p"
NOT_USED = "e"
TOOL_CALL = "t"
HALLUCINATION_THRESHOLD_DICT = {
MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.5,
"varentropy": 2.5,
},
}
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
"""
Check if the given entropy or variance of entropy exceeds the specified thresholds.
Args:
entropy (float): The entropy value to check.
varentropy (float): The variance of entropy value to check.
thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'.
Returns:
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
"""
return entropy > thd["entropy"] or varentropy > thd["varentropy"]
def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
"""
Calculate the entropy and variance of entropy (varentropy) from log probabilities.
Args:
log_probs (list of float): A list of log probabilities.
Returns:
tuple: A tuple containing:
- log_probs (list of float): The input log probabilities as a list.
- entropy (float): The calculated entropy.
- varentropy (float): The calculated variance of entropy.
"""
log_probs = torch.tensor(log_probs)
token_probs = torch.exp(log_probs)
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
varentropy = torch.sum(
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
dim=-1,
)
return entropy.item(), varentropy.item()
def is_parameter_required(
function_description: Dict,
parameter_name: str,
) -> bool:
"""
Check if a parameter in required list
Args:
function_description (dict): The API description in JSON format.
parameter_name (str): The name of the parameter to check.
Returns:
bool: True if the parameter has the specified property, False otherwise.
"""
required_parameters = function_description.get("required", {})
return parameter_name in required_parameters
class HallucinationStateHandler:
"""
A class to handle the state of hallucination detection in token processing.
Attributes:
tokens (list): List of tokens processed.
logprobs (list): List of log probabilities for each token.
state (str): Current state of the handler.
mask (list): List of masks indicating the type of each token.
parameter_name_done (bool): Flag indicating if parameter name extraction is done.
hallucination (bool): Flag indicating if a hallucination is detected.
hallucination_message (str): Message describing the hallucination.
parameter_name (list): List of extracted parameter names.
function_description (dict): Description of functions and their parameters.
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
"""
def __init__(self, response_iterator=None):
"""
Initializes the HallucinationStateHandler with default values.
"""
self.tokens: List[str] = []
self.logprobs: List[float] = []
self.state: str = None
self.mask: List[str] = []
self.parameter_name_done: bool = False
self.hallucination: bool = False
self.error_message: str = ""
self.error_type: str = ""
self.parameter_name: List[str] = []
self.token_probs_map: List[Tuple[str, float, float]] = []
self.response_iterator = response_iterator
self.has_tool_call = False
def append_and_check_token_hallucination(self, token, logprob):
"""
Check if the given token is hallucinated based on the log probability.
Args:
token (str): The token to check.
logprob (float): The log probability of the token.
Returns:
bool: True if the token is hallucinated, False otherwise.
"""
self.tokens.append(token)
self.logprobs.append(logprob)
if self.has_tool_call:
self._process_token()
return self.hallucination
def __iter__(self):
return self
def __next__(self):
if self.response_iterator is not None:
try:
r = next(self.response_iterator)
if hasattr(r.choices[0].delta, "content"):
token_content = r.choices[0].delta.content
if token_content:
try:
logprobs = [
p.logprob
for p in r.choices[0].logprobs.content[0].top_logprobs
]
except Exception as e:
raise ValueError(
f"Error extracting logprobs from response: {e}"
)
self.append_and_check_token_hallucination(
token_content, logprobs
)
return token_content
except StopIteration:
raise StopIteration
def _process_token(self):
"""
Processes the current token and updates the state and mask accordingly.
Detects hallucinations based on the token type and log probabilities.
"""
content = "".join(self.tokens).replace(" ", "")
if self.tokens[-1] == TOOL_CALL_TOKEN:
self.mask.append(MaskToken.TOOL_CALL)
self._check_logprob()
# Function name extraction logic
# If the state is function name and the token is not an end token, add to the mask
if self.state == "function_name":
if self.tokens[-1] not in FUNC_NAME_END_TOKEN:
self.mask.append(MaskToken.FUNCTION_NAME)
else:
self.state = None
self._get_function_name()
# Check if the token is a function name start token, change the state
if content.endswith(FUNC_NAME_START_PATTERN):
self.state = "function_name"
# Parameter name extraction logic
# if the state is parameter name and the token is not an end token, add to the mask
if self.state == "parameter_name" and not content.endswith(
PARAMETER_NAME_END_TOKENS
):
self.mask.append(MaskToken.PARAMETER_NAME)
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
# The need for parameter name done is to allow the check of parameter value pattern
elif self.state == "parameter_name" and content.endswith(
PARAMETER_NAME_END_TOKENS
):
self.state = None
self.parameter_name_done = True
self._get_parameter_name()
# if the parameter name is done and the token is a parameter name start token, change the state
elif self.parameter_name_done and content.endswith(
PARAMETER_NAME_START_PATTERN
):
self.state = "parameter_name"
# if token is a first parameter value start token, change the state
if content.endswith(FIRST_PARAM_NAME_START_PATTERN):
self.state = "parameter_name"
# Parameter value extraction logic
# if the state is parameter value and the token is not an end token, add to the mask
if self.state == "parameter_value" and not content.endswith(
PARAMETER_VALUE_END_TOKEN
):
# checking if the token is a value token and is not empty
if self.tokens[-1].strip() not in ['"', ""]:
self.mask.append(MaskToken.PARAMETER_VALUE)
# [TODO] Review: update the following code: `is_parameter_property` should not be here, move to `ArchFunctionHandler`
# checking if the parameter doesn't have default and the token is the first parameter value token
if (
len(self.mask) > 1
and self.mask[-2] != MaskToken.PARAMETER_VALUE
and is_parameter_required(
self.function_properties[self.function_name],
self.parameter_name[-1],
)
):
self._check_logprob()
else:
self.mask.append(MaskToken.NOT_USED)
# if the state is parameter value and the token is an end token, change the state
elif self.state == "parameter_value" and content.endswith(
PARAMETER_VALUE_END_TOKEN
):
self.state = None
# if the parameter name is done and the token is a parameter value start token, change the state
elif self.parameter_name_done and content.endswith(
PARAMETER_VALUE_START_PATTERN
):
self.state = "parameter_value"
# Maintain consistency between stack and mask
# If the mask length is less than tokens, add an not used (e) token to the mask
if len(self.mask) != len(self.tokens):
self.mask.append(MaskToken.NOT_USED)
def _check_logprob(self):
"""
Checks the log probability of the current token and updates the token probability map.
Detects hallucinations based on entropy and variance of entropy.
"""
probs = self.logprobs[-1]
entropy, varentropy = calculate_entropy(probs)
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
if check_threshold(
entropy, varentropy, HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
):
self.hallucination = True
self.error_type = "Hallucination"
self.error_message = (
f"Hallucination: token '{self.tokens[-1]}' is uncertain."
)
def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int:
"""
Counts the number of consecutive occurrences of a given token in the mask.
Args:
token (str): The token to count in the mask.
Returns:
int: The number of consecutive occurrences of the token.
"""
return (
len(list(itertools.takewhile(lambda x: x == token, reversed(self.mask))))
if self.mask and self.mask[-1] == token
else 0
)
def _get_parameter_name(self):
"""
Get the parameter name from the tokens.
Returns:
str: The extracted parameter name.
"""
p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name)
def _get_function_name(self):
"""
Get the function name from the tokens.
Returns:
str: The extracted function name.
"""
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:])

90
model_server/src/main.py Normal file
View file

@ -0,0 +1,90 @@
import os
from src.commons.globals import handler_map
from src.core.base_handler import ChatMessage
from src.core.guardrails import GuardRequest
from fastapi import FastAPI, Response
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.resources import Resource
resource = Resource.create(
{
"service.name": "model-server",
}
)
# Initialize the tracer provider
trace.set_tracer_provider(TracerProvider(resource=resource))
tracer = trace.get_tracer(__name__)
app = FastAPI()
FastAPIInstrumentor().instrument_app(app)
# DEFAULT_OTLP_HOST = "http://localhost:4317"
DEFAULT_OTLP_HOST = "none"
# Configure the OTLP exporter (Jaeger, Zipkin, etc.)
otlp_exporter = OTLPSpanExporter(
endpoint=os.getenv("OTLP_HOST", DEFAULT_OTLP_HOST) # noqa: F821
)
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
@app.get("/healthz")
async def healthz():
return {"status": "ok"}
@app.get("/models")
async def models():
return {
"object": "list",
"data": [{"id": model_name, "object": "model"} for model_name in handler_map],
}
@app.post("/function_calling")
async def function_calling(req: ChatMessage, res: Response):
try:
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
if handler_map["Arch-Intent"].detect_intent(intent_response):
# [TODO] measure agreement between intent detection and function calling
try:
function_calling_response = await handler_map[
"Arch-Function"
].chat_completion(req)
return function_calling_response
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
# logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
res.status_code = 500
return {"error": f"[Arch-Function] - {e}"}
# [TODO] Review: define the behavior if `Arch-Intent` doesn't detect an intent
# else:
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
# logger.error(f"Error in chat_completion from `Arch-Intent`: {e}")
res.status_code = 500
return {"error": f"[Arch-Intent] - {e}"}
@app.post("/guardrails")
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
try:
guard_result = handler_map["Arch-Guard"].predict(req)
return guard_result
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
res.status_code = 500
return {"error": f"[Arch-Guard] - {e}"}