Use intent model from archfc to pick prompt gateway (#328)

This commit is contained in:
Shuguang Chen 2024-12-20 13:25:01 -08:00 committed by GitHub
parent 67b8fd635e
commit ba7279becb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
151 changed files with 8642 additions and 10932 deletions

View file

214
model_server/src/cli.py Normal file
View file

@ -0,0 +1,214 @@
import importlib
import logging
from os import path
import os
from signal import SIGKILL
import sys
import subprocess
import argparse
import tempfile
import time
import requests
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
def get_version():
try:
version = importlib.metadata.version("archgw_modelserver")
return version
except importlib.metadata.PackageNotFoundError:
return "version not found"
def wait_for_health_check(url, timeout=300):
"""Wait for the Uvicorn server to respond to health-check requests."""
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(url)
if response.status_code == 200:
return True
except requests.ConnectionError:
time.sleep(1)
return False
def parse_args():
parser = argparse.ArgumentParser(description="Manage the Uvicorn server.")
parser.add_argument(
"action",
choices=["start", "stop", "restart"],
default="start",
nargs="?",
help="Action to perform on the server (default: start).",
)
parser.add_argument(
"--port",
type=int,
default=51000,
help="Port number for the server (default: 51000).",
)
parser.add_argument(
"--foreground",
default=False,
action="store_true",
help="Run the server in the foreground (default: False).",
)
return parser.parse_args()
def get_pid_file():
temp_dir = tempfile.gettempdir()
return path.join(temp_dir, "model_server.pid")
def stop_server():
"""Stop the Uvicorn server."""
pid_file = get_pid_file()
if os.path.exists(pid_file):
logger.info(f"PID file found, shutting down the server.")
# read pid from file
with open(pid_file, "r") as f:
pid = int(f.read())
logger.info(f"Killing model server {pid}")
try:
os.kill(pid, SIGKILL)
except ProcessLookupError:
logger.info(f"Process {pid} not found")
os.remove(pid_file)
else:
logger.info("No PID file found, server is not running.")
def restart_server(port=51000, foreground=False):
"""Restart the Uvicorn server."""
stop_server()
start_server(port, foreground)
def run_server():
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
args = parse_args()
action = args.action
if action == "start":
start_server(args.port, args.foreground)
elif action == "stop":
stop_server()
elif action == "restart":
restart_server(args.port, args.foreground)
else:
logger.info(f"Unknown action: {action}")
sys.exit(1)
def ensure_killed(process):
process.terminate()
# if the process is not terminated, kill it
now = time.time()
# wait for 5 seconds
while time.time() - now < 5:
if process.poll() is not None:
break
time.sleep(1)
if process.poll() is None:
logger.info("Killing model server")
process.kill()
def start_server(port=51000, foreground=False):
"""Start the Uvicorn server."""
logging.info("model server version: %s", get_version())
stop_server()
logger.info(
"starting model server, port: %s, foreground: %s. Please wait ...",
port,
foreground,
)
if foreground:
process = subprocess.Popen(
[
"python",
"-m",
"uvicorn",
"src.main:app",
"--host",
"0.0.0.0",
"--port",
str(port),
],
)
else:
process = subprocess.Popen(
[
"python",
"-m",
"uvicorn",
"src.main:app",
"--host",
"0.0.0.0",
"--port",
str(port),
],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
)
try:
if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"):
logger.info(
f"model server health check passed, port {port}, pid: {process.pid}"
)
else:
logger.error("health check failed, shutting it down.")
process.terminate()
except KeyboardInterrupt:
logger.info("model server stopped by user during initialization.")
ensure_killed(process)
# write process id to temp file in temp folder
pid_file = get_pid_file()
logger.info(f"writing pid {process.pid} to {pid_file}")
with open(pid_file, "w") as f:
f.write(str(process.pid))
if foreground:
try:
process.wait()
except KeyboardInterrupt:
logger.info("model server stopped by user.")
ensure_killed(process)
def main():
"""
Start, stop, or restart the Uvicorn server based on command-line arguments.
"""
args = parse_args()
if args.action == "start":
start_server(args.port, args.foreground)
elif args.action == "stop":
stop_server()
elif args.action == "restart":
restart_server(args.port)
else:
logger.error(f"Unknown action: {args.action}")
sys.exit(1)

View file

View file

@ -0,0 +1,38 @@
import os
from openai import OpenAI
from src.commons.utils import get_model_server_logger
from src.core.guardrails import get_guardrail_handler
from src.core.function_calling import (
ArchIntentConfig,
ArchIntentHandler,
ArchFunctionConfig,
ArchFunctionHandler,
)
# Define logger
logger = get_model_server_logger()
# Define the client
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://api.fc.archgw.com/v1")
ARCH_API_KEY = "EMPTY"
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
# Define model names
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
logger.info("loading prompt guard model ...")
arch_guard_model = get_guardrail_handler()
# Define model handlers
handler_map = {
"Arch-Intent": ArchIntentHandler(
ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig
),
"Arch-Function": ArchFunctionHandler(
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
),
"Arch-Guard": arch_guard_model,
}

View file

@ -0,0 +1,87 @@
import os
import sys
import time
import logging
import requests
import subprocess
import importlib
PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Default log directory and file
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, ".logs")
DEFAULT_LOG_FILE = "modelserver.log"
def get_model_server_logger(log_dir=None, log_file=None):
"""
Get or initialize the logger instance for the model server.
Parameters:
- log_dir (str): Custom directory to store the log file. Defaults to `./.logs`.
- log_file (str): Custom log file name. Defaults to `modelserver.log`.
Returns:
- logging.Logger: Configured logger instance.
"""
log_dir = log_dir or DEFAULT_LOG_DIR
log_file = log_file or DEFAULT_LOG_FILE
log_file_path = os.path.join(log_dir, log_file)
# Check if the logger is already configured
logger = logging.getLogger("model_server_logger")
if logger.hasHandlers():
# Return existing logger instance if already configured
return logger
# Ensure the log directory exists, create it if necessary
try:
# Create directory if it doesn't exist
os.makedirs(log_dir, exist_ok=True)
# Check for write permissions
if not os.access(log_dir, os.W_OK):
raise PermissionError(f"No write permission for the directory: {log_dir}")
except (PermissionError, OSError) as e:
raise RuntimeError(f"Failed to initialize logger: {e}")
# Configure logging to file
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
# logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in the file
logging.StreamHandler(), # Also log to console
],
)
return logger
logger = get_model_server_logger()
logging.info("initializing torch device ...")
import torch
def get_device():
available_device = {
"cpu": True,
"cuda": torch.cuda.is_available(),
"mps": (
torch.backends.mps.is_available()
if hasattr(torch.backends, "mps")
else False
),
}
if available_device["cuda"]:
device = "cuda"
elif available_device["mps"]:
device = "mps"
else:
device = "cpu"
return device

View file

View file

@ -0,0 +1,644 @@
import json
import random
import builtins
import textwrap
from openai import OpenAI
from typing import Any, Dict, List
from overrides import override
from src.commons.utils import get_model_server_logger
from src.core.model_utils import (
Message,
ChatMessage,
Choice,
ChatCompletionResponse,
ArchBaseHandler,
)
from src.core.hallucination import HallucinationStateHandler
logger = get_model_server_logger()
class ArchIntentConfig:
TASK_PROMPT = textwrap.dedent(
"""
You are a helpful assistant.
"""
).strip()
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
"""
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
<tools>
{tool_text}
</tools>
"""
).strip()
FORMAT_PROMPT = textwrap.dedent(
"""
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
- First line must read 'Yes' or 'No'.
- If yes, a second line must include a comma-separated list of tool indexes.
"""
).strip()
EXTRA_INSTRUCTION = "Are there any tools can help?"
GENERATION_PARAMS = {
"temperature": 0.01,
"max_tokens": 1,
"stop_token_ids": [151645],
}
class ArchIntentHandler(ArchBaseHandler):
def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig):
"""
Initializes the intent handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
config (ArchIntentConfig): The configuration for Arch-Intent.
"""
super().__init__(
client,
model_name,
config.TASK_PROMPT,
config.TOOL_PROMPT_TEMPLATE,
config.FORMAT_PROMPT,
config.GENERATION_PARAMS,
)
self.extra_instruction = config.EXTRA_INSTRUCTION
self.prompt_prefilling = False
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into a JSON-like format with indexed keys.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A string representation of converted tools.
"""
converted = [
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
]
return "\n".join(converted)
def detect_intent(self, content: str) -> bool:
"""
Detect if any intent match with prompts
Args:
content: str: Model response that contains intent detection results
Returns:
bool: A boolean value to indicate if any intent match with prompts or not
"""
if hasattr(content.choices[0].message, "content"):
return content.choices[0].message.content == "Yes"
else:
return False
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Generates a chat completion for a given request.
Args:
req (ChatMessage): A chat message request object.
Returns:
ChatCompletionResponse: The model's response to the chat request.
Note:
Currently only support vllm inference
"""
# In the case that no tools are available, simply return `No` to avoid making a call
if len(req.tools) == 0:
model_response = Message(content="No", tool_calls=[])
else:
messages = self._process_messages(
req.messages, req.tools, self.extra_instruction
)
model_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body=self.generation_params,
)
logger.info(
"arch_intent response: %s", json.dumps(model_response.model_dump())
)
model_response = Message(
content=model_response.choices[0].message.content, tool_calls=[]
)
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
)
return chat_completion_response
# =============================================================================================================
class ArchFunctionConfig:
TASK_PROMPT = textwrap.dedent(
"""
You are a helpful assistant.
"""
).strip()
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
"""
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
"""
).strip()
FORMAT_PROMPT = textwrap.dedent(
"""
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
"""
).strip()
GENERATION_PARAMS = {
"temperature": 0.6,
"top_p": 1.0,
"top_k": 10,
"max_tokens": 512,
"stop_token_ids": [151645],
"logprobs": True,
"top_logprobs": 10,
}
PREFILL_CONFIG = {
"prefill_params": {
"continue_final_message": True,
"add_generation_prompt": False,
},
"prefill_prefix": [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
],
}
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class ArchFunctionHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
config: ArchFunctionConfig,
):
"""
Initializes the function handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
config (ArchFunctionConfig): The configuration for Arch-Function
"""
super().__init__(
client,
model_name,
config.TASK_PROMPT,
config.TOOL_PROMPT_TEMPLATE,
config.FORMAT_PROMPT,
config.GENERATION_PARAMS,
)
self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
self.prompt_prefilling = False
# Predefine data types for verification. Only support Python for now.
# [TODO] Extend the list of support data types
self.support_data_types = {
type_name: getattr(builtins, type_name)
for type_name in config.SUPPORT_DATA_TYPES
}
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into JSON format.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A string representation of converted tools.
"""
converted = [json.dumps(tool) for tool in tools]
return "\n".join(converted)
def _fix_json_string(self, json_str: str) -> str:
"""
Fixes malformed JSON strings by ensuring proper bracket matching.
Args:
json_str (str): A JSON string that might be malformed.
Returns:
str: A corrected JSON string.
"""
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
# Stack to keep track of brackets
stack = []
# Clean string to collect valid characters
fixed_str = ""
# Dictionary for matching brackets
matching_bracket = {")": "(", "}": "{", "]": "["}
# Dictionary for the opposite of matching_bracket
opening_bracket = {v: k for k, v in matching_bracket.items()}
for char in json_str:
if char in "{[(":
stack.append(char)
fixed_str += char
elif char in "}])":
if stack and stack[-1] == matching_bracket[char]:
stack.pop()
fixed_str += char
else:
# Ignore the unmatched closing brackets
continue
else:
fixed_str += char
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
while stack:
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str.replace("'", '"')
def _extract_tool_calls(self, content: str) -> Dict[str, any]:
"""
Extracts tool call information from a given string.
Args:
content (str): The content string containing potential tool call information.
Returns:
Dict: A dictionary of extraction, including:
- "result": A list of tool call dictionaries.
- "status": A boolean indicating if the extraction was valid.
- "message": An error message or exception if extraction failed.
"""
tool_calls, is_valid, error_message = [], True, ""
flag = False
for line in content.split("\n"):
if not is_valid:
break
if "<tool_call>" == line:
flag = True
elif "</tool_call>" == line:
flag = False
else:
if flag:
try:
tool_content = json.loads(line)
except Exception as e:
fixed_content = self._fix_json_string(line)
try:
tool_content = json.loads(fixed_content)
except Exception:
tool_calls, is_valid, error_message = [], False, e
break
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
"arguments": tool_content["arguments"],
},
}
)
flag = False
return {"result": tool_calls, "status": is_valid, "message": error_message}
def _correcting_type(self, value, target_type):
try:
if target_type == float and isinstance(value, int):
return float(value)
elif target_type == list and isinstance(value, str):
return ast.literal_eval(value)
elif target_type == str and not isinstance(value, str):
return str(value)
# Add more conversion rules as needed
except (ValueError, TypeError, json.JSONDecodeError):
pass
return value
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
) -> Dict[str, any]:
"""
Verifies the validity of extracted tool calls against the provided tools.
Args:
tools (List[Dict[str, Any]]): A list of available tools.
tool_calls (List[Dict[str, Any]]): A list of tool calls to verify.
Returns:
Dict: A dictionary of verification, including:
- "status": A boolean indicating if the tool calls are valid.
- "invalid_tool_call": A dictionary of the invalid tool call if any.
- "message": An error message.
"""
is_valid, invalid_tool_call, error_message = True, None, ""
functions = {}
for tool in tools:
if tool["type"] == "function":
functions[tool["function"]["name"]] = tool["function"]["parameters"]
for tool_call in tool_calls:
if not is_valid:
break
func_name = tool_call["function"]["name"]
func_args = tool_call["function"]["arguments"]
# Check whether the function is available or not
if func_name not in functions:
is_valid = False
invalid_tool_call = tool_call
error_message = f"{func_name} is not defined!"
break
else:
# Check if all the requried parameters can be found in the tool calls
for required_param in functions[func_name].get("required", []):
if required_param not in func_args:
is_valid = False
invalid_tool_call = tool_call
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
break
# Verify the data type of each parameter in the tool calls
for param_name in func_args:
if param_name not in functions[func_name]["properties"]:
is_valid = False
invalid_tool_call = tool_call
error_message = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
break
else:
param_value = func_args[param_name]
data_type = functions[func_name]["properties"][param_name][
"type"
]
if data_type in self.support_data_types:
if not isinstance(
param_value,
self.support_data_types[data_type],
) and not isinstance(
self._correcting_type(
param_value, self.support_data_types[data_type]
),
self.support_data_types[data_type],
):
is_valid = False
invalid_tool_call = tool_call
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
break
return {
"status": is_valid,
"invalid_tool_call": invalid_tool_call,
"message": error_message,
}
def _add_prefill_message(self, messages: List[Dict[str, str]]):
"""
Update messages and generation params for prompt prefilling
Args:
messages (List[Dict[str, str]]): A list of messages.
Returns:
prefill_messages (List[Dict[str, str]]): A list of messages.
"""
return messages + [
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
]
def _engage_parameter_gathering(self, messages: List[Dict[str, str]]):
"""
Engage parameter gathering for tool calls
"""
# TODO: log enaging parameter gathering
prefill_response = self.client.chat.completions.create(
messages=self._add_prefill_message(messages),
model=self.model_name,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
self.prompt_prefilling = True
return prefill_response
def _check_length_and_pop_messages(self, messages, max_tokens=4096):
"""
Trims the `messages` list to ensure the total token count does not exceed `max_tokens`.
Args:
messages (list): List of message dictionaries.
max_tokens (int): Maximum allowed token count.
Returns:
list: Trimmed list of messages.
"""
def estimate_token_length(messages):
"""Estimate the total token length of the messages."""
total_tokens = 0
for message in messages:
# Approximate token length: assuming ~4 characters per token on average
total_tokens += len(message["content"]) // 4
return total_tokens
# Calculate initial token length
total_tokens = estimate_token_length(messages)
# Trim messages if token count exceeds the limit
while total_tokens > max_tokens and len(messages) >= 3:
# Find the first non-system message pair
for i in range(len(messages)):
if messages[i]["role"] != "system":
# Remove the 'user'/'assistant' pair
if i + 1 < len(messages) and messages[i + 1]["role"] in [
"user",
"assistant",
]:
del messages[i : i + 2]
else:
del messages[i]
break
# Recalculate token length
total_tokens = estimate_token_length(messages)
return messages
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Generates a chat completion response for a given request.
Args:
req (ChatMessage): A chat message request object.
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
Returns:
ChatCompletionResponse: The model's response to the chat request.
Note:
Currently only support vllm inference
"""
logger.info(
f"model_server => arch_function: request body: {json.dumps(req.model_dump())}"
)
messages = self._process_messages(req.messages, req.tools)
messages = self._check_length_and_pop_messages(messages)
# always enable `stream=True` to collect model responses
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=True,
extra_body=self.generation_params,
)
# initialize the hallucination handler, which is an iterator
self.hallu_handler = HallucinationStateHandler(
response_iterator=response, function=req.tools
)
model_response, self.has_tool_call = "", None
self.hallucination = False
for _ in self.hallu_handler:
# check if the first token is <tool_call>
if len(self.hallu_handler.tokens) > 0 and self.has_tool_call is None:
if self.hallu_handler.tokens[0] == "<tool_call>":
self.has_tool_call = True
else:
self.has_tool_call = False
break
# if the model is hallucinating, start parameter gathering
if self.hallu_handler.hallucination is True:
self.hallucination = True
logger.info(
f"{self.hallu_handler.error_message} - start parameter gathering"
)
logger.info(
f"Hallucinated response : {''.join(self.hallu_handler.tokens)}"
)
# [TODO] - add break when hallucination is detected
break
if self.hallucination is True:
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
if self.has_tool_call and self.hallucination is False:
# [TODO] - Review: remove the following code
model_response = "".join(self.hallu_handler.tokens)
logger.info(f"Tool call found, no hallucination detected {model_response}!")
# start parameter gathering if the model is not generating tool calls
if self.has_tool_call is False:
# [TODO] - Review: remove the following code
logger.info("No tool call found, start parameter gathering")
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
# Extract tool calls from model response
extracted = self._extract_tool_calls(model_response)
if len(extracted["result"]) and extracted["status"]:
# [TODO] Review: define the behavior in the case that tool call extraction fails
# if not extracted["status"]:
verified = self._verify_tool_calls(
tools=req.tools, tool_calls=extracted["result"]
)
# [TODO] - Review: remvoe the following code
# print(f"[Verified] - {verified}")
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
if verified["status"]:
model_response = Message(content="", tool_calls=extracted["result"])
log_message = f"model_server <= arch_function: (tool_calls): {json.dumps([tool_call['function'] for tool_call in extracted['result']])}"
logger.info(log_message)
else:
raise ValueError(f"Invalid tool call: {verified['message']}")
else:
model_response = Message(content=model_response, tool_calls=[])
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
)
# [TODO] Review: define the protocol to collect debugging output
logger.info(
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.model_dump())}"
)
return chat_completion_response

View file

@ -0,0 +1,171 @@
import time
import torch
import numpy as np
import src.commons.utils as utils
from transformers import AutoTokenizer
from src.core.model_utils import GuardRequest, GuardResponse
# from optimum.intel import OVModelForSequenceClassification
from transformers import AutoModelForSequenceClassification
class ArchGuardHanlder:
def __init__(self, model_dict):
"""
Initializes the ArchGuardHanlder with the given model dictionary.
Args:
model_dict (dict): A dictionary containing the model, tokenizer, and device information.
"""
self.model = model_dict["model"]
self.model_name = model_dict["model_name"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
self.support_tasks = {"jailbreak": {"positive_class": 2, "threshold": 0.5}}
def _split_text_into_chunks(self, text, max_num_words=300):
"""
Splits the input text into chunks of up to `max_num_words` words.
Args:
text (str): The input text to be split.
max_num_words (int, optional): The maximum number of words in each chunk. Defaults to 300.
Returns:
List[str]: A list of text chunks.
"""
words = text.split()
chunks = [
" ".join(words[i : i + max_num_words])
for i in range(0, len(words), max_num_words)
]
return chunks
@staticmethod
def softmax(x):
"""
Computes the softmax of the input array.
Args:
x (np.ndarray): The input array.
Returns:
np.ndarray: The softmax of the input.
"""
return np.exp(x) / np.exp(x).sum(axis=0)
def _predict_text(self, task, text, max_length=512) -> GuardResponse:
"""
Predicts the result for the provided text for a specific task.
Args:
task (str): The task to perform (e.g., "jailbreak").
text (str): The input text to classify.
max_length (int, optional): The maximum length for tokenization. Defaults to 512.
Returns:
GuardResponse: A GuardResponse object containing the prediction.
"""
inputs = self.tokenizer(
text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device)
start_time = time.perf_counter()
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = ArchGuardHanlder.softmax(logits)[
self.support_tasks[task]["positive_class"]
]
latency = time.perf_counter() - start_time
if prob > self.support_tasks[task]["threshold"]:
verdict = True
sentence = text
else:
verdict = False
sentence = None
return GuardResponse(
prob=[prob.item()], verdict=verdict, sentence=[sentence], latency=latency
)
def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse:
"""
Makes a prediction based on the GuardRequest input.
Args:
req (GuardRequest): The GuardRequest object containing the input text and task.
max_num_words (int, optional): The maximum number of words in each chunk if splitting is needed. Defaults to 300.
Returns:
GuardResponse: A GuardResponse object containing the prediction.
Note:
currently only support jailbreak check
"""
if req.task not in self.support_tasks:
raise NotImplementedError(f"{req.task} is not supported!")
if len(req.input.split()) < max_num_words:
return self._predict_text(req.task, req.input)
else:
# split into chunks if text is long
text_chunks = self._split_text_into_chunks(req.input)
prob, verdict, sentence, latency = [], False, [], 0
for chunk in text_chunks:
chunk_result = self._predict_text(req.task, chunk)
if chunk_result.verdict:
prob.append(chunk_result.prob[0])
verdict = True
sentence.append(chunk_result.sentence[0])
latency += chunk_result.latency
return GuardResponse(
prob=prob, verdict=verdict, sentence=sentence, latency=latency
)
def get_guardrail_handler(device: str = None):
"""
Initializes and returns an instance of ArchGuardHanlder based on the specified device.
Args:
device (str, optional): The device to use for model inference (e.g., "cpu" or "cuda"). Defaults to None.
Returns:
ArchGuardHanlder: An instance of ArchGuardHanlder configured for the specified device.
"""
if device is None:
device = utils.get_device()
model_class, model_name = None, None
# if device == "cpu":
# model_class = OVModelForSequenceClassification
# model_name = "katanemo/Arch-Guard-cpu"
# else:
model_class = AutoModelForSequenceClassification
model_name = "katanemo/Arch-Guard"
guardrail_dict = {
"device": device,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
),
}
return ArchGuardHanlder(model_dict=guardrail_dict)

View file

@ -0,0 +1,394 @@
import json
import math
import torch
import itertools
from typing import Dict, List, Tuple
from enum import Enum
import string
from src.commons.utils import get_model_server_logger
logger = get_model_server_logger()
# constants
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
FUNC_NAME_END_TOKEN = ('",', "',")
TOOL_CALL_TOKEN = "<tool_call>"
END_TOOL_CALL_TOKEN = "</tool_call>"
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
PARAMETER_NAME_START_PATTERN = (',"', ",'")
PARAMETER_VALUE_START_PATTERN = ('":', "':")
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
BRACKETS = {"(": ")", "{": "}", "[": "]"}
# Thresholds
class MaskToken(Enum):
FUNCTION_NAME = "f"
PARAMETER_VALUE = "v"
PARAMETER_NAME = "p"
NOT_USED = "e"
TOOL_CALL = "t"
HALLUCINATION_THRESHOLD_DICT = {
MaskToken.TOOL_CALL.value: {
"entropy": 0.35,
"varentropy": 1.7,
"probability": 0.8,
},
MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.28,
"varentropy": 1.2,
"probability": 0.8,
},
}
def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
"""
Check if the given entropy or variance of entropy exceeds the specified thresholds.
Args:
entropy (float): The entropy value to check.
varentropy (float): The variance of entropy value to check.
thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'.
Returns:
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
"""
return entropy > thd["entropy"] and varentropy > thd["varentropy"]
def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]:
"""
Calculate the entropy and variance of entropy (varentropy) from log probabilities.
Args:
log_probs (list of float): A list of log probabilities.
Returns:
tuple: A tuple containing:
- log_probs (list of float): The input log probabilities as a list.
- entropy (float): The calculated entropy.
- varentropy (float): The calculated variance of entropy.
"""
log_probs = torch.tensor(log_probs)
token_probs = torch.exp(log_probs)
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
varentropy = torch.sum(
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
dim=-1,
)
return entropy.item(), varentropy.item(), token_probs[0].item()
def is_parameter_required(
function_description: Dict,
parameter_name: str,
) -> bool:
"""
Check if a parameter in required list
Args:
function_description (dict): The API description in JSON format.
parameter_name (str): The name of the parameter to check.
Returns:
bool: True if the parameter has the specified property, False otherwise.
"""
required_parameters = function_description.get("required", {})
return parameter_name in required_parameters
def is_parameter_property(
function_description: Dict, parameter_name: str, property_name: str
) -> bool:
"""
Check if a parameter in an API description has a specific property.
Args:
function_description (dict): The API description in JSON format.
parameter_name (str): The name of the parameter to check.
property_name (str): The property to look for (e.g., 'format', 'default').
Returns:
bool: True if the parameter has the specified property, False otherwise.
"""
parameters = function_description.get("properties", {})
parameter_info = parameters.get(parameter_name, {})
return property_name in parameter_info
class HallucinationStateHandler:
"""
A class to handle the state of hallucination detection in token processing.
Attributes:
tokens (list): List of tokens processed.
logprobs (list): List of log probabilities for each token.
state (str): Current state of the handler.
mask (list): List of masks indicating the type of each token.
parameter_name_done (bool): Flag indicating if parameter name extraction is done.
hallucination (bool): Flag indicating if a hallucination is detected.
hallucination_message (str): Message describing the hallucination.
parameter_name (list): List of extracted parameter names.
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
"""
def __init__(self, response_iterator=None, function=None):
"""
Initializes the HallucinationStateHandler with default values.
"""
self.tokens: List[str] = []
self.logprobs: List[float] = []
self.state: str = None
self.mask: List[str] = []
self.parameter_name_done: bool = False
self.hallucination: bool = False
self.error_message: str = ""
self.parameter_name: List[str] = []
self.token_probs_map: List[Tuple[str, float, float]] = []
self.response_iterator = response_iterator
self._process_function(function)
self.open_bracket = False
self.bracket = None
self.check_parameter_name = {}
self.HALLUCINATION_THRESHOLD_DICT = HALLUCINATION_THRESHOLD_DICT
def _process_function(self, function):
self.function = function
if self.function is None:
raise ValueError("API descriptions not set.")
self.function_properties = {
x["function"]["name"]: x["function"]["parameters"] for x in self.function
}
def _reset_parameters(self):
"""
Resets all parameters in the HallucinationStateHandler to their default values.
"""
self.state = None
self.parameter_name_done = False
self.hallucination = False
self.error_message = ""
self.open_bracket = False
self.bracket = None
self.check_parameter_name = {}
def append_and_check_token_hallucination(self, token, logprob):
"""
Check if the given token is hallucinated based on the log probability.
Args:
token (str): The token to check.
logprob (float): The log probability of the token.
Returns:
bool: True if the token is hallucinated, False otherwise.
"""
self.tokens.append(token)
self.logprobs.append(logprob)
self._process_token()
return self.hallucination
def __iter__(self):
return self
def __next__(self):
if self.response_iterator is not None:
try:
r = next(self.response_iterator)
if hasattr(r.choices[0].delta, "content"):
token_content = r.choices[0].delta.content
if token_content:
try:
logprobs = [
p.logprob
for p in r.choices[0].logprobs.content[0].top_logprobs
]
except Exception as e:
raise ValueError(
f"Error extracting logprobs from response: {e}"
)
if token_content == END_TOOL_CALL_TOKEN:
self._reset_parameters()
else:
self.append_and_check_token_hallucination(
token_content, logprobs
)
return token_content
except StopIteration:
raise StopIteration
def _process_token(self):
"""
Processes the current token and updates the state and mask accordingly.
Detects hallucinations based on the token type and log probabilities.
"""
content = "".join(self.tokens).replace(" ", "")
if self.tokens[-1] == TOOL_CALL_TOKEN:
self.mask.append(MaskToken.TOOL_CALL)
self._check_logprob()
# Function name extraction logic
# If the state is function name and the token is not an end token, add to the mask
if self.state == "function_name":
if self.tokens[-1] not in FUNC_NAME_END_TOKEN:
self.mask.append(MaskToken.FUNCTION_NAME)
else:
self.state = None
self._get_function_name()
# Check if the token is a function name start token, change the state
if content.endswith(FUNC_NAME_START_PATTERN):
self.state = "function_name"
# Parameter name extraction logic
# if the state is parameter name and the token is not an end token, add to the mask
if self.state == "parameter_name" and not content.endswith(
PARAMETER_NAME_END_TOKENS
):
self.mask.append(MaskToken.PARAMETER_NAME)
# if the state is parameter name and the token is an end token, change the state, check hallucination and set the flag parameter name done
# The need for parameter name done is to allow the check of parameter value pattern
elif self.state == "parameter_name" and content.endswith(
PARAMETER_NAME_END_TOKENS
):
self.state = None
self.parameter_name_done = True
self._get_parameter_name()
# if the parameter name is done and the token is a parameter name start token, change the state
elif (
self.parameter_name_done
and self.open_bracket == False
and content.endswith(PARAMETER_NAME_START_PATTERN)
):
self.state = "parameter_name"
# if token is a first parameter value start token, change the state
if content.endswith(FIRST_PARAM_NAME_START_PATTERN):
self.state = "parameter_name"
# Parameter value extraction logic
# if the state is parameter value and the token is not an end token, add to the mask
if self.state == "parameter_value" and not content.endswith(
PARAMETER_VALUE_END_TOKEN
):
# checking if the token is a value token and is not empty
open_brackets = [
char for char in self.tokens[-1].strip() if char in BRACKETS
]
if open_brackets:
self.open_bracket = True
self.bracket = open_brackets[0]
if self.open_bracket and BRACKETS[self.bracket] in self.tokens[-1].strip():
self.open_bracket = False
self.bracket = None
if (
not all(
char in set(string.punctuation) for char in self.tokens[-1].strip()
)
and self.tokens[-1].strip() != ""
):
self.mask.append(MaskToken.PARAMETER_VALUE)
# checking if the parameter doesn't have enum and the token is the first parameter value token
if (
len(self.mask) > 1
and self.mask[-2] != MaskToken.PARAMETER_VALUE
and is_parameter_required(
self.function_properties[self.function_name],
self.parameter_name[-1],
)
and not is_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
"enum",
)
):
if self.parameter_name[-1] not in self.check_parameter_name:
self._check_logprob()
self.check_parameter_name[self.parameter_name[-1]] = True
else:
self.mask.append(MaskToken.NOT_USED)
# if the state is parameter value and the token is an end token, change the state
elif (
self.state == "parameter_value"
and self.open_bracket == False
and content.endswith(PARAMETER_VALUE_END_TOKEN)
):
self.state = None
# if the parameter name is done and the token is a parameter value start token, change the state
elif self.parameter_name_done and content.endswith(
PARAMETER_VALUE_START_PATTERN
):
self.state = "parameter_value"
# Maintain consistency between stack and mask
# If the mask length is less than tokens, add an not used (e) token to the mask
if len(self.mask) != len(self.tokens):
self.mask.append(MaskToken.NOT_USED)
def _check_logprob(self):
"""
Checks the log probability of the current token and updates the token probability map.
Detects hallucinations based on entropy and variance of entropy.
"""
probs = self.logprobs[-1]
entropy, varentropy, probability = calculate_uncertainty(probs)
self.token_probs_map.append((self.tokens[-1], entropy, varentropy, probability))
if check_threshold(
entropy,
varentropy,
self.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value],
):
self.hallucination = True
self.error_message = f"Hallucination: token '{self.tokens[-1]}' is uncertain. {self.token_probs_map}"
def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int:
"""
Counts the number of consecutive occurrences of a given token in the mask.
Args:
token (str): The token to count in the mask.
Returns:
int: The number of consecutive occurrences of the token.
"""
return (
len(list(itertools.takewhile(lambda x: x == token, reversed(self.mask))))
if self.mask and self.mask[-1] == token
else 0
)
def _get_parameter_name(self):
"""
Get the parameter name from the tokens.
Returns:
str: The extracted parameter name.
"""
p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name)
def _get_function_name(self):
"""
Get the function name from the tokens.
Returns:
str: The extracted function name.
"""
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:])

View file

@ -0,0 +1,181 @@
import json
from openai import OpenAI
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from overrides import final
class Message(BaseModel):
role: Optional[str] = ""
content: Optional[str] = ""
tool_call_id: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
class Choice(BaseModel):
id: Optional[int] = 0
message: Message
finish_reason: Optional[str] = "stop"
class ChatCompletionResponse(BaseModel):
id: Optional[int] = 0
object: Optional[str] = "chat_completion"
created: Optional[str] = ""
choices: List[Choice]
model: str
metadata: Optional[Dict[str, str]] = {}
class GuardRequest(BaseModel):
input: str
task: str
class GuardResponse(BaseModel):
prob: List
verdict: bool
sentence: List
latency: float = 0
# ================================================================================================
class ArchBaseHandler:
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
):
"""
Initializes the base handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
"""
self.client = client
self.model_name = model_name
self.task_prompt = task_prompt
self.tool_prompt_template = tool_prompt_template
self.format_prompt = format_prompt
self.generation_params = generation_params
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into the desired internal representation.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Raises:
NotImplementedError: Method should be overridden in subclasses.
"""
raise NotImplementedError()
@final
def _format_system_prompt(self, tools: List[Dict[str, Any]]) -> str:
"""
Formats the system prompt using provided tools.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A formatted system prompt.
"""
tool_text = self._convert_tools(tools)
system_prompt = (
self.task_prompt
+ "\n\n"
+ self.tool_prompt_template.format(tool_text=tool_text)
+ "\n\n"
+ self.format_prompt
)
return system_prompt
@final
def _process_messages(
self,
messages: List[Message],
tools: List[Dict[str, Any]] = None,
extra_instruction: str = None,
):
"""
Processes a list of messages and formats them appropriately.
Args:
messages (List[Message]): A list of message objects.
tools (List[Dict[str, Any]], optional): A list of tools to include in the system prompt.
extra_instruction (str, optional): Additional instructions to append to the last user message.
Returns:
List[Dict[str, Any]]: A list of processed message dictionaries.
"""
processed_messages = []
if tools:
processed_messages.append(
{"role": "system", "content": self._format_system_prompt(tools)}
)
for message in messages:
role, content, tool_calls = (
message.role,
message.content,
message.tool_calls,
)
if tool_calls:
# [TODO] Extend to support multiple function calls
role = "assistant"
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
elif message.role == "tool":
role = "user"
content = (
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
)
processed_messages.append({"role": role, "content": content})
assert processed_messages[-1]["role"] == "user"
if extra_instruction:
processed_messages[-1]["content"] += extra_instruction
return processed_messages
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Abstract method for generating chat completions.
Args:
req (ChatMessage): A chat message request object.
Raises:
NotImplementedError: Method should be overridden in subclasses.
"""
raise NotImplementedError()

134
model_server/src/main.py Normal file
View file

@ -0,0 +1,134 @@
import json
import logging
import os
import time
from src.commons.globals import handler_map
from src.core.model_utils import ChatMessage, GuardRequest
from fastapi import FastAPI, Response
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.resources import Resource
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
resource = Resource.create(
{
"service.name": "model-server",
}
)
# Initialize the tracer provider
trace.set_tracer_provider(TracerProvider(resource=resource))
tracer = trace.get_tracer(__name__)
app = FastAPI()
FastAPIInstrumentor().instrument_app(app)
# DEFAULT_OTLP_HOST = "http://localhost:4317"
DEFAULT_OTLP_HOST = "none"
# Configure the OTLP exporter (Jaeger, Zipkin, etc.)
otlp_exporter = OTLPSpanExporter(
endpoint=os.getenv("OTLP_HOST", DEFAULT_OTLP_HOST) # noqa: F821
)
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
@app.get("/healthz")
async def healthz():
return {"status": "ok"}
@app.get("/models")
async def models():
return {
"object": "list",
"data": [{"id": model_name, "object": "model"} for model_name in handler_map],
}
@app.post("/function_calling")
async def function_calling(req: ChatMessage, res: Response):
try:
intent_start_time = time.perf_counter()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
intent_latency = time.perf_counter() - intent_start_time
if handler_map["Arch-Intent"].detect_intent(intent_response):
# [TODO] measure agreement between intent detection and function calling
try:
function_start_time = time.perf_counter()
function_calling_response = await handler_map[
"Arch-Function"
].chat_completion(req)
function_latency = time.perf_counter() - function_start_time
function_calling_response.metadata = {
"intent_latency": str(round(intent_latency * 1000, 3)),
"function_latency": str(round(function_latency * 1000, 3)),
"hallucination": str(handler_map["Arch-Function"].hallucination),
"tokens_uncertainty": json.dumps(
handler_map["Arch-Function"].hallu_handler.token_probs_map
),
"prompt_prefilling": str(
handler_map["Arch-Function"].prompt_prefilling
),
}
return function_calling_response
except ValueError as e:
res.statuscode = 503
error_message = "Tool call extraction error"
logger.error(f" {error_message}: {e}")
return {"error": f"[Arch-Function] - {error_message} - {e}"}
except StopIteration as e:
res.statuscode = 500
error_message = "Hallucination iterator error"
logger.error(f" {error_message}: {e}")
return {"error": f"[Arch-Function] - {error_message} - {e}"}
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
res.status_code = 500
return {"error": f"[Arch-Function] - {e}"}
# [TODO] Review: define the behavior if `Arch-Intent` doesn't detect an intent
else:
return {
"result": "No intent matched",
"intent_latency": round(intent_latency * 1000, 3),
}
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
# logger.error(f"Error in chat_completion from `Arch-Intent`: {e}")
logger.error(f"Error in chat_completion /function_calling: {e}")
res.status_code = 500
return {"error": f"[Arch-Intent] - {e}"}
@app.post("/guardrails")
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
try:
guard_start_time = time.perf_counter()
guard_result = handler_map["Arch-Guard"].predict(req)
guard_latency = time.perf_counter() - guard_start_time
return {
"response": guard_result,
"guard_latency": round(guard_latency * 1000, 3),
}
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
res.status_code = 500
return {"error": f"[Arch-Guard] - {e}"}