Use intent model from archfc to pick prompt gateway (#328)

This commit is contained in:
Shuguang Chen 2024-12-20 13:25:01 -08:00 committed by GitHub
parent 67b8fd635e
commit ba7279becb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
151 changed files with 8642 additions and 10932 deletions

View file

@ -9,7 +9,7 @@
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"args": ["app.main:app","--reload", "--port", "51000"]
"args": ["src.main:app","--reload", "--port", "51000"]
}
]
}

View file

@ -15,7 +15,7 @@ WORKDIR /src
# specify list of models that will go into the image as a comma separated list
# following models have been tested to work with this image
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
ENV MODELS="katanemo/bge-large-en-v1.5-onnx"
ENV MODELS=""
COPY ./app ./app
COPY ./app/guard_model_config.yaml .
@ -28,4 +28,4 @@ COPY ./app/openai_params.yaml .
# RUN python install.py && \
# find /root/.cache/torch/sentence_transformers/ -name onnx -exec rm -rf {} +
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
CMD ["uvicorn", "src.app.main:app", "--host", "0.0.0.0", "--port", "80"]

View file

@ -45,7 +45,7 @@ RUN if command -v nvcc >/dev/null 2>&1; then \
COPY . /src
# Specify list of models that will go into the image as a comma separated list
ENV MODELS="katanemo/bge-large-en-v1.5-onnx"
ENV MODELS=""
ENV DEBIAN_FRONTEND=noninteractive
COPY /app /app

View file

@ -1,178 +0,0 @@
import importlib
import sys
import os
import time
import requests
import psutil
import tempfile
import subprocess
import logging
def get_version():
try:
version = importlib.metadata.version("archgw_modelserver")
return version
except importlib.metadata.PackageNotFoundError:
return "version not found"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
log = logging.getLogger("model_server.cli")
log.setLevel(logging.INFO)
log.info(f"model server version: {get_version()}")
def run_server(port=51000):
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
if len(sys.argv) > 1:
action = sys.argv[1]
else:
action = "start"
if action == "start":
start_server(port)
elif action == "stop":
stop_server(port)
elif action == "restart":
restart_server(port)
else:
log.info(f"Unknown action: {action}")
sys.exit(1)
def start_server(port=51000):
"""Start the Uvicorn server"""
log.info(
"starting model server - loading some awesomeness, this may take some time :)"
)
process = subprocess.Popen(
[
"python",
"-m",
"uvicorn",
"app.main:app",
"--host",
"0.0.0.0",
"--port",
f"{port}",
],
start_new_session=True,
bufsize=1,
universal_newlines=True,
stdout=subprocess.PIPE, # Suppress standard output. There is a logger that model_server prints to
stderr=subprocess.PIPE, # Suppress standard error. There is a logger that model_server prints to
)
if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"):
log.info(f"Model server started with PID {process.pid}")
else:
# Add model_server boot-up logs
log.info("model server - didn't start in time, shutting down")
process.terminate()
def wait_for_health_check(url, timeout=300):
"""Wait for the Uvicorn server to respond to health-check requests."""
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(url)
if response.status_code == 200:
return True
except requests.ConnectionError:
time.sleep(1)
print("Timed out waiting for model server to respond.")
return False
def check_and_install_lsof():
"""Check if lsof is installed, and if not, install it using apt-get."""
try:
# Check if lsof is installed by running "lsof -v"
subprocess.run(
["lsof", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
print("lsof is already installed.")
except subprocess.CalledProcessError:
print("lsof not found, installing...")
try:
# Update package list and install lsof
subprocess.run(["sudo", "apt-get", "update"], check=True)
subprocess.run(["sudo", "apt-get", "install", "-y", "lsof"], check=True)
print("lsof installed successfully.")
except subprocess.CalledProcessError as install_error:
print(f"Failed to install lsof: {install_error}")
def kill_process(port=51000, wait=True, timeout=10):
"""Stop the running Uvicorn server."""
log.info("Stopping model server")
try:
# Run the function to check and install lsof if necessary
# Step 1: Run lsof command to get the process using the port
lsof_command = f"lsof -n | grep {port} | grep -i LISTEN"
result = subprocess.run(
lsof_command, shell=True, capture_output=True, text=True
)
if result.returncode != 0:
print(f"No process found listening on port {port}.")
return
# Step 2: Parse the process IDs from the output
process_ids = [line.split()[1] for line in result.stdout.splitlines()]
if not process_ids:
print(f"No process found listening on port {port}.")
return
# Step 3: Kill each process using its PID
for pid in process_ids:
print(f"Killing model server process with PID {pid}")
subprocess.run(f"kill {pid}", shell=True)
if wait:
# Step 4: Wait for the process to be killed by checking if it's still running
start_time = time.time()
while True:
check_process = subprocess.run(
f"ps -p {pid}", shell=True, capture_output=True, text=True
)
if check_process.returncode != 0:
print(f"Process {pid} has been killed.")
break
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
print(
f"Process {pid} did not terminate within {timeout} seconds."
)
print(f"Attempting to force kill process {pid}...")
subprocess.run(f"kill -9 {pid}", shell=True) # SIGKILL
break
print(
f"Waiting for process {pid} to be killed... ({elapsed_time:.2f} seconds)"
)
time.sleep(0.5)
except Exception as e:
print(f"Error occurred: {e}")
def stop_server(port=51000, wait=True, timeout=10):
check_and_install_lsof()
kill_process(port, wait, timeout)
def restart_server(port=51000):
"""Restart the Uvicorn server."""
stop_server(port)
start_server(port)

View file

@ -1,38 +0,0 @@
import app.commons.globals as glb
import app.commons.utilities as utils
import app.loader as loader
from app.function_calling.model_handler import ArchFunctionHandler
from app.prompt_guard.model_handler import ArchGuardHanlder
logger = utils.get_model_server_logger()
arch_function_hanlder = ArchFunctionHandler()
PREFILL_LIST = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"]
PREFILL_ENABLED = True
TOOL_CALL_TOKEN = "<tool_call>"
arch_function_endpoint = "https://api.fc.archgw.com/v1"
arch_function_client = utils.get_client(arch_function_endpoint)
arch_function_generation_params = {
"temperature": 0.2,
"top_p": 1.0,
"top_k": 50,
"max_tokens": 512,
"stop_token_ids": [151645],
# "top_logprobs": 10,
}
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
# Model definition
embedding_model = loader.get_embedding_model()
zero_shot_model = loader.get_zero_shot_model()
prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE])
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
# Patterns for function name and parameter parsing

View file

@ -1,4 +0,0 @@
import app.commons.utilities as utils
DEVICE = utils.get_device()

View file

@ -1,83 +0,0 @@
import os
import yaml
import torch
import string
import logging
from openai import OpenAI
logger_instance = None
def get_device():
available_device = {
"cpu": True,
"cuda": torch.cuda.is_available(),
"mps": (
torch.backends.mps.is_available()
if hasattr(torch.backends, "mps")
else False
),
}
if available_device["cuda"]:
device = "cuda"
elif available_device["mps"]:
device = "mps"
else:
device = "cpu"
return device
def get_client(endpoint):
client = OpenAI(base_url=endpoint, api_key="EMPTY")
return client
def get_model_server_logger():
global logger_instance
if logger_instance is not None:
# If the logger is already initialized, return the existing instance
return logger_instance
# Define log file path outside current directory (e.g., ~/archgw_logs)
log_dir = os.path.expanduser("~/archgw_logs")
log_file = "modelserver.log"
log_file_path = os.path.join(log_dir, log_file)
# Ensure the log directory exists, create it if necessary, handle permissions errors
try:
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True) # Create directory if it doesn't exist
# Check if the script has write permission in the log directory
if not os.access(log_dir, os.W_OK):
raise PermissionError(f"No write permission for the directory: {log_dir}")
# Configure logging to file and console using basicConfig
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in file
],
)
except (PermissionError, OSError):
# Dont' fallback to console logging if there are issues writing to the log file
raise RuntimeError(f"No write permission for the directory: {log_dir}")
# Initialize the logger instance after configuring handlers
logger_instance = logging.getLogger("model_server_logger")
return logger_instance
def remove_punctuations(s):
s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation)))
return " ".join(s.split()).lower()
def get_label_map(labels):
return {remove_punctuations(label): label for label in labels}

View file

@ -1,137 +0,0 @@
import json
import random
from typing import Any, Dict, List
ARCH_FUNCTION_CALLING_TASK_PROMPT = """
You are a helpful assistant.
""".strip()
ARCH_FUNCTION_CALLING_TOOL_PROMPT = """
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
""".strip()
ARCH_FUNCTION_CALLING_FORMAT_PROMPT = """
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
""".strip()
class ArchFunctionHandler:
def __init__(self) -> None:
super().__init__()
def _format_system(self, tools: List[Dict[str, Any]]):
def convert_tools(tools):
return "\n".join([json.dumps(tool) for tool in tools])
tool_text = convert_tools(tools)
system_prompt = (
ARCH_FUNCTION_CALLING_TASK_PROMPT
+ "\n\n"
+ ARCH_FUNCTION_CALLING_TOOL_PROMPT.format(tool_text=tool_text)
+ "\n\n"
+ ARCH_FUNCTION_CALLING_FORMAT_PROMPT
)
return system_prompt
def _add_execution_results_prompting(
self,
messages: list[dict],
execution_results: list,
) -> dict:
content = []
for result in execution_results:
content.append(f"<tool_response>\n{json.dumps(result)}\n</tool_response>")
content = "\n".join(content)
messages.append({"role": "user", "content": content})
return messages
def extract_tool_calls(self, content: str):
tool_calls = []
flag = False
for line in content.split("\n"):
if "<tool_call>" == line:
flag = True
elif "</tool_call>" == line:
flag = False
else:
if flag:
try:
tool_content = json.loads(line)
except Exception:
fixed_content = self.fix_json_string(line)
try:
tool_content = json.loads(fixed_content)
except json.JSONDecodeError:
return content
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
"arguments": tool_content["arguments"],
},
}
)
flag = False
return tool_calls
def fix_json_string(self, json_str: str):
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
# Stack to keep track of brackets
stack = []
# Clean string to collect valid characters
fixed_str = ""
# Dictionary for matching brackets
matching_bracket = {")": "(", "}": "{", "]": "["}
# Dictionary for the opposite of matching_bracket
opening_bracket = {v: k for k, v in matching_bracket.items()}
for char in json_str:
if char in "{[(":
stack.append(char)
fixed_str += char
elif char in "}])":
if stack and stack[-1] == matching_bracket[char]:
stack.pop()
fixed_str += char
else:
# Ignore the unmatched closing brackets
continue
else:
fixed_str += char
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
while stack:
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str.replace("'", '"')

View file

@ -1,157 +0,0 @@
import json
import hashlib
import app.commons.constants as const
import random
from fastapi import Response
from pydantic import BaseModel
from app.commons.utilities import get_model_server_logger
from typing import Any, Dict, List, Optional
logger = get_model_server_logger()
class Message(BaseModel):
role: Optional[str] = ""
content: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
tool_call_id: Optional[str] = ""
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
class Choice(BaseModel):
message: Message
finish_reason: Optional[str] = "stop"
index: Optional[int] = 0
class ChatCompletionResponse(BaseModel):
choices: List[Choice]
model: Optional[str] = "Arch-Function"
created: Optional[str] = ""
id: Optional[str] = ""
object: Optional[str] = "chat_completion"
def process_messages(history: list[Message]):
updated_history = []
for hist in history:
if hist.tool_calls:
if len(hist.tool_calls) > 1:
error_msg = f"Only one tool call is supported, tools counts: {len(hist.tool_calls)}"
logger.error(error_msg)
raise ValueError(error_msg)
tool_call_str = json.dumps(hist.tool_calls[0]["function"])
updated_history.append(
{
"role": "assistant",
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
}
)
elif hist.role == "tool":
updated_history.append(
{
"role": "user",
"content": f"<tool_response>\n{hist.content}\n</tool_response>",
}
)
else:
updated_history.append({"role": hist.role, "content": hist.content})
return updated_history
async def chat_completion(req: ChatMessage, res: Response):
logger.info("starting request")
tools_encoded = const.arch_function_hanlder._format_system(req.tools)
messages = [{"role": "system", "content": tools_encoded}]
updated_history = process_messages(req.messages)
for message in updated_history:
messages.append({"role": message["role"], "content": message["content"]})
client_model_name = const.arch_function_client.models.list().data[0].id
logger.info(
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
)
# Retrieve the first token, handling the Stream object carefully
try:
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=const.PREFILL_ENABLED,
extra_body=const.arch_function_generation_params,
)
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
if const.PREFILL_ENABLED:
first_token_content = ""
for token in resp:
first_token_content = token.choices[
0
].delta.content.strip() # Clean up the content
if first_token_content: # Break if it's non-empty
break
# Check if the first token requires tool call handling
if first_token_content != const.TOOL_CALL_TOKEN:
# Engage pre-filling response if no tool call is indicated
resp.close()
logger.info("Tool call is not found! Engage pre filling")
prefill_content = random.choice(const.PREFILL_LIST)
messages.append({"role": "assistant", "content": prefill_content})
# Send a new completion request with the updated messages
# the model will continue the final message in the chat instead of starting a new one
# disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response.
extra_body = {
**const.arch_function_generation_params,
"continue_final_message": True,
"add_generation_prompt": False,
}
pre_fill_resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
extra_body=extra_body,
)
full_response = pre_fill_resp.choices[0].message.content
else:
# Initialize full response and iterate over tokens to gather the full response
full_response = first_token_content
for token in resp:
if hasattr(token.choices[0].delta, "content"):
full_response += token.choices[0].delta.content
else:
logger.info("Stream is disabled, not engaging pre-filling")
full_response = resp.choices[0].message.content
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
if tool_calls:
message = Message(content="", tool_calls=tool_calls)
else:
message = Message(content=full_response, tool_calls=[])
choice = Choice(message=message)
chat_completion_response = ChatCompletionResponse(
choices=[choice], model=client_model_name
)
logger.info(
f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
)
logger.info(
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
)
return chat_completion_response

View file

@ -1,84 +0,0 @@
import os
import app.commons.globals as glb
from transformers import AutoTokenizer, AutoModel, pipeline
from optimum.onnxruntime import (
ORTModelForFeatureExtraction,
ORTModelForSequenceClassification,
)
import app.commons.utilities as utils
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from optimum.intel import OVModelForSequenceClassification
logger = utils.get_model_server_logger()
def get_embedding_model(
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5"),
):
logger.info("Loading Embedding Model...")
if glb.DEVICE != "cuda":
model = ORTModelForFeatureExtraction.from_pretrained(
model_name, file_name="onnx/model.onnx"
)
else:
model = AutoModel.from_pretrained(model_name, device_map=glb.DEVICE)
embedding_model = {
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model,
}
return embedding_model
def get_zero_shot_model(
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
):
logger.info("Loading Zero-shot Model...")
if glb.DEVICE != "cuda":
model = ORTModelForSequenceClassification.from_pretrained(
model_name, file_name="onnx/model.onnx"
)
else:
model = model_name
zero_shot_model = {
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name),
"model": model,
}
zero_shot_model["pipeline"] = pipeline(
"zero-shot-classification",
model=zero_shot_model["model"],
tokenizer=zero_shot_model["tokenizer"],
device=glb.DEVICE,
)
return zero_shot_model
def get_prompt_guard(model_name):
logger.info("Loading Guard Model...")
if glb.DEVICE == "cpu":
model_class = OVModelForSequenceClassification
else:
model_class = AutoModelForSequenceClassification
prompt_guard = {
"device": glb.DEVICE,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True
),
}
return prompt_guard

View file

@ -1,261 +0,0 @@
import os
import time
import torch
import app.commons.utilities as utils
import app.commons.globals as glb
import app.prompt_guard.model_utils as guard_utils
from typing import List, Dict
from pydantic import BaseModel
from fastapi import FastAPI, Response, HTTPException, Request
from app.function_calling.model_utils import ChatMessage
from app.commons.constants import embedding_model, zero_shot_model, arch_guard_handler
from app.function_calling.model_utils import (
chat_completion as arch_function_chat_completion,
)
from unittest.mock import patch
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.resources import Resource
resource = Resource.create(
{
"service.name": "model-server",
}
)
# Initialize the tracer provider
trace.set_tracer_provider(TracerProvider(resource=resource))
tracer = trace.get_tracer(__name__)
logger = utils.get_model_server_logger()
logger.info(f"Ready to serve traffic. available device: {glb.DEVICE}")
app = FastAPI()
FastAPIInstrumentor().instrument_app(app)
# DEFAULT_OTLP_HOST = "http://localhost:4317"
DEFAULT_OTLP_HOST = "none"
# Configure the OTLP exporter (Jaeger, Zipkin, etc.)
otlp_exporter = OTLPSpanExporter(
endpoint=os.getenv("OTLP_HOST", DEFAULT_OTLP_HOST) # noqa: F821
)
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
class EmbeddingRequest(BaseModel):
input: str
model: str
class GuardRequest(BaseModel):
input: str
task: str
class ZeroShotRequest(BaseModel):
input: str
labels: List[str]
model: str
class HallucinationRequest(BaseModel):
prompt: str
parameters: Dict
model: str
@app.get("/healthz")
async def healthz():
return {"status": "ok"}
@app.get("/models")
async def models():
return {
"object": "list",
"data": [{"id": embedding_model["model_name"], "object": "model"}],
}
@app.post("/embeddings")
async def embedding(req: EmbeddingRequest, res: Response):
logger.info(f"Embedding req: {req}")
if req.model != embedding_model["model_name"]:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
start_time = time.perf_counter()
encoded_input = embedding_model["tokenizer"](
req.input, padding=True, truncation=True, return_tensors="pt"
).to(glb.DEVICE)
with torch.no_grad():
embeddings = embedding_model["model"](**encoded_input)
embeddings = embeddings[0][:, 0]
embeddings = (
torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().numpy()
)
logger.info(f"Embedding Call Complete Time: {time.perf_counter()-start_time}")
data = [
{"object": "embedding", "embedding": embedding, "index": index + 1}
for index, embedding in enumerate(embeddings.tolist())
]
usage = {
"prompt_tokens": 0,
"total_tokens": 0,
}
return {"data": data, "model": req.model, "object": "list", "usage": usage}
@app.post("/guard")
async def guard(req: GuardRequest, res: Response, max_num_words=300):
"""
Take input as text and return the prediction of toxic and jailbreak
"""
if req.task in ["both", "toxic", "jailbreak"]:
arch_guard_handler.task = req.task
else:
raise NotImplementedError(f"{req.task} is not supported!")
start_time = time.perf_counter()
if len(req.input.split()) < max_num_words:
guard_result = arch_guard_handler.guard_predict(req.input)
else:
# text is long, split into chunks
chunks = guard_utils.split_text_into_chunks(req.input)
guard_result = {
"jailbreak_prob": [],
"time": 0,
"jailbreak_verdict": False,
"toxic_sentence": [],
"jailbreak_sentence": [],
}
for chunk in chunks:
chunk_result = arch_guard_handler.guard_predict(chunk)
guard_result["time"] += chunk_result["time"]
if chunk_result[f"{arch_guard_handler.task}_verdict"]:
guard_result[f"{arch_guard_handler.task}_verdict"] = True
guard_result[f"{arch_guard_handler.task}_sentence"].append(
chunk_result[f"{arch_guard_handler.task}_sentence"]
)
guard_result[f"{arch_guard_handler.task}_prob"].append(
chunk_result[f"{arch_guard_handler.task}_prob"].item()
)
logger.info(f"Time taken for Guard: {time.perf_counter() - start_time}")
return guard_result
@app.post("/zeroshot")
async def zeroshot(req: ZeroShotRequest, res: Response):
logger.info(f"zero-shot request: {req}")
if req.model != zero_shot_model["model_name"]:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
classifier = zero_shot_model["pipeline"]
label_map = utils.get_label_map(req.labels)
start_time = time.perf_counter()
predictions = classifier(
req.input, candidate_labels=list(label_map.keys()), multi_label=True
)
logger.info(f"zero-shot taking {time.perf_counter() - start_time} seconds")
predicted_class = label_map[predictions["labels"][0]]
predicted_score = predictions["scores"][0]
scores = {
label_map[label]: score
for label, score in zip(predictions["labels"], predictions["scores"])
}
predicted_class = label_map[predictions["labels"][0]]
return {
"predicted_class": predicted_class,
"predicted_class_score": predicted_score,
"scores": scores,
"model": req.model,
}
@app.post("/hallucination")
@patch("app.loader.glb.DEVICE", "cpu") # Mock the device to 'cpu'
async def hallucination(req: HallucinationRequest, res: Response):
"""
Take input as text and return the prediction of hallucination for each parameter
"""
logger.info(f"hallucination request: {req}")
if req.model != zero_shot_model["model_name"]:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
start_time = time.perf_counter()
classifier = zero_shot_model["pipeline"]
if "messages" in req.parameters:
req.parameters.pop("messages")
if not req.parameters or len(req.parameters) == 0:
return {
"params_scores": {},
"model": req.model,
}
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
predictions = classifier(
req.prompt,
candidate_labels=list(candidate_labels.keys()),
hypothesis_template="{}",
multi_label=True,
)
params_scores = {
candidate_labels[label]: score
for label, score in zip(predictions["labels"], predictions["scores"])
}
logger.info(
f"hallucination time cost: {params_scores}, taking {time.perf_counter() - start_time} seconds"
)
return {
"params_scores": params_scores,
"model": req.model,
}
@app.post("/v1/chat/completions")
async def chat_completion(req: ChatMessage, res: Response, request: Request):
try:
result = await arch_function_chat_completion(req, res)
return result
except Exception as e:
logger.error(f"Error in chat_completion: {e}")
res.status_code = 500
return {"error": "Internal server error"}

View file

@ -1,42 +0,0 @@
import time
import torch
import app.prompt_guard.model_utils as model_utils
class ArchGuardHanlder:
def __init__(self, model_dict, threshold=0.5):
self.task = "jailbreak"
self.positive_class = 2
self.model = model_dict["model"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
self.threshold = threshold
def guard_predict(self, input_text, max_length=512):
start_time = time.perf_counter()
inputs = self.tokenizer(
input_text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device)
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = model_utils.softmax(logits)[self.positive_class]
if prob > self.threshold:
verdict = True
sentence = input_text
else:
verdict = False
sentence = None
result_dict = {
f"{self.task}_prob": prob.item(),
f"{self.task}_verdict": verdict,
f"{self.task}_sentence": sentence,
"time": time.perf_counter() - start_time,
}
return result_dict

View file

@ -1,19 +0,0 @@
import numpy as np
def split_text_into_chunks(text, max_words=300):
"""
Max number of tokens for tokenizer is 512
Split the text into chunks of 300 words (as approximation for tokens)
"""
words = text.split() # Split text into words
# Estimate token count based on word count (1 word ≈ 1 token)
chunk_size = max_words # Use the word count as an approximation for tokens
chunks = [
" ".join(words[i : i + chunk_size]) for i in range(0, len(words), chunk_size)
]
return chunks
def softmax(x):
return np.exp(x) / np.exp(x).sum(axis=0)

View file

@ -1,106 +0,0 @@
import pytest
import httpx
from fastapi.testclient import TestClient
from app.main import app # Assuming your FastAPI app is in main.py
from unittest.mock import patch
import app.commons.globals as glb
import logging
logger = logging.getLogger(__name__)
client = TestClient(app)
logger.info(f"Model will be loaded on device: {glb.DEVICE}")
# Unit tests for the health check endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_healthz():
response = client.get("/healthz")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
# Unit test for the models endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_models():
response = client.get("/models")
assert response.status_code == 200
assert response.json()["object"] == "list"
assert len(response.json()["data"]) > 0
# Unit test for embeddings endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_embedding():
request_data = {"input": "Test embedding", "model": "katanemo/bge-large-en-v1.5"}
response = client.post("/embeddings", json=request_data)
if request_data["model"] == "katanemo/bge-large-en-v1.5":
assert response.status_code == 200
assert response.json()["object"] == "list"
assert "data" in response.json()
else:
assert response.status_code == 400
# Unit test for the guard endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_guard():
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
response = client.post("/guard", json=request_data)
assert response.status_code == 200
assert "jailbreak_verdict" in response.json()
# Unit test for the zero-shot endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_zeroshot():
request_data = {
"input": "Test input",
"labels": ["label1", "label2"],
"model": "katanemo/bart-large-mnli",
}
response = client.post("/zeroshot", json=request_data)
if request_data["model"] == "katanemo/bart-large-mnli":
assert response.status_code == 200
assert "predicted_class" in response.json()
else:
assert response.status_code == 400
# Unit test for the hallucination endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_hallucination():
request_data = {
"prompt": "Test hallucination",
"parameters": {"param1": "value1"},
"model": "katanemo/bart-large-mnli",
}
response = client.post("/hallucination", json=request_data)
if request_data["model"] == "katanemo/bart-large-mnli":
assert response.status_code == 200
assert "params_scores" in response.json()
else:
assert response.status_code == 400
# Unit test for the chat completion endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
async def test_chat_completion():
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
request_data = {
"messages": [{"role": "user", "content": "Hello!"}],
"model": "Arch-Function-1.5B",
"tools": [], # Assuming tools is part of the req as per the function
"metadata": {"x-arch-state": "[]"}, # Assuming metadata is needed
}
response = await client.post("/v1/chat/completions", json=request_data)
assert response.status_code == 200
assert "choices" in response.json()

View file

@ -1,794 +0,0 @@
[{
"case": "tool_call_halluciation",
"tokens" : ["<tool_call>"],
"expect": 1,
"logprobs": [[-0.3333307206630707,
-1.5310522317886353,
-3.5098977088928223,
-3.9004578590393066,
-5.775152683258057,
-5.814209461212158,
-5.9574151039123535,
-6.0094895362854,
-6.0094895362854,
-6.673445224761963]]
},
{
"case" : "parameter_value_hallucination",
"expect" : 0,
"tokens" : ["<tool_call>",
"\n",
"{'",
"name",
"':",
" '",
"get",
"_current",
"_weather",
"',",
" '",
"arguments",
"':",
" {'",
"location",
"':",
" '",
"Sea",
",",
" Australia",
"',",
" '",
"unit",
"':",
" '",
"c",
"elsius",
"',",
" '",
"days",
"':",
" '",
"1",
"'}}\n",
"</tool_call>"],
"logprobs": [[-0.008103232830762863,
-5.085402488708496,
-6.777836799621582,
-7.558959007263184,
-9.850253105163574,
-10.266852378845215,
-10.540244102478027,
-10.722506523132324,
-10.800618171691895,
-10.917786598205566],
[0.0,
-23.25142478942871,
-25.139137268066406,
-26.2847843170166,
-28.992677688598633,
-29.070789337158203,
-29.55248260498047,
-29.91700553894043,
-30.20341682434082,
-30.307567596435547],
[0.0,
-21.66313934326172,
-23.06916046142578,
-23.32953453063965,
-25.65988540649414,
-25.985353469848633,
-26.519121170043945,
-27.07892417907715,
-27.977216720581055,
-28.458908081054688],
[0.0,
-28.094383239746094,
-28.56305694580078,
-29.109844207763672,
-29.44832992553711,
-31.79170036315918,
-32.0,
-32.05207443237305,
-32.31244659423828,
-32.364524841308594],
[0.0,
-30.489830017089844,
-31.140766143798828,
-31.81774139404297,
-34.525634765625,
-35.8275032043457,
-36.504478454589844,
-39.05614471435547,
-40.123680114746094,
-40.696502685546875],
[0.0,
-25.646865844726562,
-26.66232681274414,
-27.781936645507812,
-28.979660034179688,
-31.140764236450195,
-31.92188835144043,
-31.973962783813477,
-33.04149627685547,
-33.58828353881836],
[0.0,
-23.511798858642578,
-24.136695861816406,
-25.230268478393555,
-25.777053833007812,
-25.80309295654297,
-26.45402717590332,
-26.636289596557617,
-26.740440368652344,
-26.896663665771484],
[0.0,
-22.366153717041016,
-24.683483123779297,
-26.610252380371094,
-26.610252380371094,
-27.313264846801758,
-27.67778778076172,
-28.510986328125,
-28.615135192871094,
-29.13588523864746],
[0.0,
-22.52237319946289,
-24.292919158935547,
-24.344993591308594,
-24.39706802368164,
-24.73555564880371,
-29.943042755126953,
-29.969079971313477,
-30.021154403686523,
-30.0341739654541],
[0.0,
-30.17738151550293,
-30.411718368530273,
-30.88039207458496,
-30.984540939331055,
-31.270952224731445,
-31.895851135253906,
-32.46867370605469,
-32.624900817871094,
-33.484134674072266],
[0.0,
-28.146459579467773,
-29.396255493164062,
-30.099267959594727,
-31.127744674682617,
-31.179821014404297,
-32.807159423828125,
-33.7445068359375,
-33.770545959472656,
-34.069976806640625],
[0.0,
-26.323841094970703,
-26.558177947998047,
-30.515867233276367,
-30.932466506958008,
-31.37510108947754,
-31.531326293945312,
-31.70056915283203,
-32.065093994140625,
-32.364524841308594],
[0.0,
-26.922698974609375,
-30.28152847290039,
-31.505287170410156,
-33.30187225341797,
-33.73148727416992,
-34.27827453613281,
-34.33034896850586,
-34.460533142089844,
-34.720909118652344],
[0.0,
-21.532955169677734,
-26.94873809814453,
-29.109848022460938,
-30.80228042602539,
-31.55736541748047,
-33.484134674072266,
-34.681854248046875,
-35.384864807128906,
-35.853538513183594],
[0.0,
-19.502033233642578,
-20.46541976928711,
-24.110658645629883,
-24.501218795776367,
-25.256305694580078,
-25.82912826538086,
-25.881202697753906,
-26.063465118408203,
-26.063465118408203],
[0.0,
-24.37103271484375,
-25.256305694580078,
-25.933277130126953,
-26.714401245117188,
-28.2506103515625,
-31.010576248168945,
-32.07810974121094,
-34.62977981567383,
-35.241661071777344],
[-1.1920922133867862e-06,
-14.398697853088379,
-14.424736976623535,
-17.158666610717773,
-17.41904067993164,
-18.200162887573242,
-18.434499740600586,
-18.66883659362793,
-19.71033477783203,
-19.71033477783203],
[-0.0001445904199499637,
-8.98305892944336,
-11.35246467590332,
-13.1490478515625,
-13.669795989990234,
-14.073375701904297,
-14.516012191772461,
-14.555068969726562,
-15.622602462768555,
-15.635622024536133],
[-0.44747352600097656,
-1.0202960968017578,
-8.467000961303711,
-10.914518356323242,
-11.25300407409668,
-11.435266494750977,
-12.346576690673828,
-13.075624465942383,
-13.12769889831543,
-13.231849670410156],
[-3.123767137527466,
-1.1188862323760986,
-1.639634370803833,
-2.0562336444854736,
-2.8633930683135986,
-2.9675419330596924,
-3.4882919788360596,
-3.69659161567688,
-4.217339515686035,
-4.243376731872559],
[-7.199982064776123e-05,
-9.76410961151123,
-11.144091606140137,
-16.507802963256836,
-17.132701873779297,
-17.44515037536621,
-17.9138240814209,
-18.33042335510254,
-18.9162654876709,
-19.39795684814453],
[0.0,
-22.991050720214844,
-23.824249267578125,
-24.969894409179688,
-25.46460723876953,
-25.829130172729492,
-26.480066299438477,
-26.909683227539062,
-27.33930206298828,
-27.391376495361328],
[-0.21928852796554565,
-1.625309705734253,
-9.775025367736816,
-12.977627754211426,
-16.388530731201172,
-17.091541290283203,
-19.044347763061523,
-19.38283348083496,
-19.460947036743164,
-19.59113311767578],
[0.0,
-24.006507873535156,
-27.443450927734375,
-27.729862213134766,
-28.12042236328125,
-28.276647567749023,
-28.927583694458008,
-30.099267959594727,
-31.479251861572266,
-32.07810974121094],
[0.0,
-18.17412567138672,
-18.772987365722656,
-21.689178466796875,
-21.92351531982422,
-23.7200984954834,
-23.79821014404297,
-23.79821014404297,
-24.032546997070312,
-25.308382034301758],
[-0.12947827577590942,
-2.1083219051361084,
-12.419143676757812,
-15.23118782043457,
-15.595710754394531,
-15.830047607421875,
-17.001731872558594,
-17.60059356689453,
-18.121341705322266,
-18.251529693603516],
[0.0,
-19.449962615966797,
-24.371034622192383,
-24.917821884155273,
-25.529701232910156,
-25.85516929626465,
-26.037429809570312,
-26.115543365478516,
-26.623271942138672,
-26.649309158325195],
[-0.03332124650478363,
-3.4181859493255615,
-15.759925842285156,
-15.812002182006836,
-16.593124389648438,
-17.894996643066406,
-18.09027671813965,
-18.79328727722168,
-19.144792556762695,
-20.147233963012695],
[0.0,
-21.142393112182617,
-22.157852172851562,
-23.511798858642578,
-24.657445907592773,
-25.021968841552734,
-25.5427188873291,
-25.59479331970215,
-25.75101661682129,
-25.95931625366211],
[0.0,
-23.04312515258789,
-24.94385528564453,
-26.323841094970703,
-27.54759979248047,
-28.563060760498047,
-29.786819458007812,
-30.620018005371094,
-30.69812774658203,
-31.08869171142578],
[0.0,
-26.167617797851562,
-28.771360397338867,
-29.55248260498047,
-30.906429290771484,
-31.114728927612305,
-31.414159774780273,
-31.622459411621094,
-31.713590621948242,
-31.726608276367188],
[-0.05012698099017143,
-3.018392562866211,
-11.740934371948242,
-13.146955490112305,
-13.797887802124023,
-14.943536758422852,
-16.037107467651367,
-16.375595092773438,
-16.714080810546875,
-17.36501693725586],
[-0.9704352021217346,
-0.7360983490943909,
-2.1941938400268555,
-4.225115776062012,
-5.0062360763549805,
-5.2666120529174805,
-5.839434623718262,
-7.2714948654174805,
-8.33902645111084,
-8.495253562927246],
[-0.014467108063399792,
-4.258565902709961,
-8.789079666137695,
-10.429437637329102,
-10.793962478637695,
-11.835458755493164,
-11.939607620239258,
-13.31959342956543,
-13.866378784179688,
-15.038063049316406],
[0.0,
-20.08787727355957,
-21.350692749023438,
-21.415786743164062,
-21.50691795349121,
-21.50691795349121,
-22.7176570892334,
-24.13669776916504,
-24.188772201538086,
-24.34499740600586]]
},
{
"case": "fail_case",
"expect" : 0,
"tokens" : ["<tool_call>",
"\n",
"{'",
"name",
"':",
" '",
"get",
"_current",
"_weather",
"',",
" '",
"arguments",
"':",
" {'",
"location",
"':",
" '",
"Seattle",
",",
" WA",
"',",
" '",
"unit",
"':",
" '",
"c",
"elsius",
"',",
" '",
"days",
"':",
" '",
"7",
"'}}\n",
"</tool_call>"],
"logprobs":[[-0.00013815402053296566,
-9.113236427307129,
-10.571331977844238,
-14.099404335021973,
-14.28166675567627,
-15.583537101745605,
-15.81787395477295,
-16.143341064453125,
-16.143341064453125,
-16.260509490966797],
[0.0,
-26.896663665771484,
-27.32628059387207,
-27.41741180419922,
-32.07810974121094,
-32.07810974121094,
-32.28641128540039,
-32.29943084716797,
-32.44263458251953,
-32.520748138427734],
[0.0,
-22.444263458251953,
-24.527257919311523,
-27.15703773498535,
-28.016273498535156,
-28.2506103515625,
-28.693246841430664,
-29.070789337158203,
-29.565500259399414,
-29.812854766845703],
[0.0,
-27.860050201416016,
-28.641170501708984,
-29.448333740234375,
-30.932466506958008,
-31.63547706604004,
-32.33848571777344,
-32.85923767089844,
-33.17168426513672,
-33.45809555053711],
[0.0,
-31.81774139404297,
-31.895854949951172,
-32.05207824707031,
-35.43694305419922,
-36.3482551574707,
-38.61351013183594,
-39.26444625854492,
-40.61839294433594,
-41.71196365356445],
[0.0,
-27.33930206298828,
-27.834014892578125,
-28.849472045898438,
-30.567943572998047,
-32.98942565917969,
-33.067535400390625,
-33.067535400390625,
-35.67127990722656,
-35.69731903076172],
[0.0,
-25.33441925048828,
-26.063465118408203,
-26.219690322875977,
-26.2457275390625,
-26.53213882446289,
-27.365337371826172,
-28.354759216308594,
-28.667207717895508,
-28.74532127380371],
[0.0,
-24.423107147216797,
-24.579330444335938,
-26.81855010986328,
-28.12042236328125,
-28.32872200012207,
-28.61513328552246,
-29.16191864013672,
-29.187957763671875,
-29.240032196044922],
[0.0,
-22.027664184570312,
-23.850284576416016,
-23.980472564697266,
-24.292922973632812,
-24.787633895874023,
-29.279088973999023,
-29.55248260498047,
-29.903987884521484,
-30.190399169921875],
[0.0,
-31.609439849853516,
-31.817739486694336,
-32.54678726196289,
-32.676971435546875,
-32.781124114990234,
-32.98942565917969,
-33.106590270996094,
-33.57526397705078,
-34.369407653808594],
[0.0,
-29.34418296813965,
-29.63059425354004,
-30.021156311035156,
-30.984540939331055,
-33.21073913574219,
-34.30431365966797,
-34.56468963623047,
-34.70789337158203,
-34.79902648925781],
[0.0,
-25.438566207885742,
-25.69894027709961,
-30.190397262573242,
-30.802276611328125,
-31.58340072631836,
-31.609437942504883,
-31.64849281311035,
-31.973960876464844,
-32.29943084716797],
[0.0,
-27.157039642333984,
-32.104148864746094,
-32.33848571777344,
-34.04393768310547,
-34.12205505371094,
-34.40846252441406,
-34.42148208618164,
-34.772987365722656,
-34.87713623046875],
[0.0,
-24.813671112060547,
-26.974777221679688,
-31.010578155517578,
-31.08869171142578,
-32.1822624206543,
-35.33279037475586,
-35.489013671875,
-36.999183654785156,
-37.88446044921875],
[0.0,
-20.46541976928711,
-20.647682189941406,
-23.069164276123047,
-24.136699676513672,
-25.438570022583008,
-25.646869659423828,
-26.193655014038086,
-26.297805786132812,
-26.506103515625],
[0.0,
-27.18307113647461,
-28.30268096923828,
-28.56305694580078,
-29.526439666748047,
-32.416595458984375,
-35.202598571777344,
-36.426361083984375,
-39.31651306152344,
-39.38160705566406],
[0.0,
-18.7469482421875,
-20.100894927978516,
-21.402767181396484,
-21.428804397583008,
-22.20992660522461,
-22.34011459350586,
-22.730674743652344,
-23.069162368774414,
-23.980472564697266],
[-3.576278118089249e-07,
-15.2579345703125,
-16.481693267822266,
-17.991863250732422,
-19.215621948242188,
-20.25712013244629,
-21.350692749023438,
-22.314077377319336,
-22.496337890625,
-22.938974380493164],
[-0.08506780862808228,
-2.506549835205078,
-14.848289489746094,
-15.473188400268555,
-16.33242416381836,
-16.358461380004883,
-16.566761016845703,
-17.03543472290039,
-17.686370849609375,
-17.816556930541992],
[-0.0194891095161438,
-4.445854187011719,
-5.591499328613281,
-5.956024169921875,
-6.685070037841797,
-13.142353057861328,
-13.558952331542969,
-15.173273086547852,
-15.303461074829102,
-15.85024642944336],
[-0.0005990855861455202,
-7.4212646484375,
-15.675132751464844,
-15.72720718383789,
-16.76870346069336,
-16.76870346069336,
-17.706050872802734,
-18.669435501098633,
-19.398483276367188,
-19.658857345581055],
[0.0,
-24.110658645629883,
-25.829130172729492,
-26.011390686035156,
-26.011390686035156,
-26.532140731811523,
-26.58421516418457,
-27.651750564575195,
-27.75589942932129,
-28.055330276489258],
[-1.1408883333206177,
-0.38580334186553955,
-7.494022369384766,
-12.519245147705078,
-14.576202392578125,
-16.034297943115234,
-16.945608139038086,
-17.908992767333984,
-18.664077758789062,
-19.34105110168457],
[0.0,
-26.688365936279297,
-29.83889389038086,
-30.177383422851562,
-30.64605712890625,
-31.244916915893555,
-31.270954132080078,
-32.83319854736328,
-34.655818939208984,
-34.89015579223633],
[0.0,
-18.929210662841797,
-19.16354751586914,
-23.589908599853516,
-24.683481216430664,
-24.995929718017578,
-25.516677856445312,
-25.542715072631836,
-25.77705192565918,
-26.063465118408203],
[-0.2519786059856415,
-1.5017764568328857,
-12.437495231628418,
-15.457839012145996,
-15.744250297546387,
-16.837820053100586,
-17.41064453125,
-17.56686782836914,
-17.61894416809082,
-18.035541534423828],
[0.0,
-20.517494201660156,
-24.683483123779297,
-25.67290496826172,
-26.58421516418457,
-27.651750564575195,
-27.781936645507812,
-27.912124633789062,
-28.09438705444336,
-28.445892333984375],
[-3.40932747349143e-05,
-10.284820556640625,
-18.252273559570312,
-20.17904281616211,
-21.663175582885742,
-22.027700424194336,
-22.288074493408203,
-22.704673767089844,
-23.12127113342285,
-23.277496337890625],
[0.0,
-22.60049057006836,
-25.46460723876953,
-25.829130172729492,
-26.063467025756836,
-27.287227630615234,
-27.391376495361328,
-27.4694881439209,
-27.67778778076172,
-28.055330276489258],
[0.0,
-23.902362823486328,
-28.823436737060547,
-29.240036010742188,
-29.31814956665039,
-29.917007446289062,
-30.021160125732422,
-31.21887969970703,
-32.416603088378906,
-32.416603088378906],
[0.0,
-28.641170501708984,
-31.947925567626953,
-32.59886169433594,
-33.848655700683594,
-34.109031677246094,
-34.73393249511719,
-35.02033996582031,
-35.02033996582031,
-36.074859619140625],
[-0.013183215633034706,
-4.335395336151123,
-19.619365692138672,
-20.035964965820312,
-20.244266510009766,
-21.311800003051758,
-21.441987991333008,
-22.561595916748047,
-23.108383178710938,
-23.264606475830078],
[-8.344646857949556e-07,
-14.190400123596191,
-15.9088716506958,
-18.17412567138672,
-18.46053695678711,
-18.46053695678711,
-18.512611389160156,
-18.90317153930664,
-19.059398651123047,
-19.085433959960938],
[0.0,
-17.70545196533203,
-18.903175354003906,
-20.829944610595703,
-22.574451446533203,
-22.860862731933594,
-23.069162368774414,
-23.32953643798828,
-23.694061279296875,
-24.188772201538086],
[0.0,
-20.022781372070312,
-21.038240432739258,
-21.220502853393555,
-22.496337890625,
-22.769729614257812,
-23.589908599853516,
-23.65500259399414,
-23.94141387939453,
-24.266881942749023]]
}
]

View file

@ -1,55 +0,0 @@
import unittest
from unittest.mock import patch, MagicMock
import subprocess
import time
from app.cli import kill_process
class TestStopServer(unittest.TestCase):
@patch("subprocess.run")
def test_stop_server_no_process(self, mock_run):
# Mock subprocess.run to simulate no process listening on the port
mock_run.return_value.returncode = 1
with patch("builtins.print") as mock_print:
kill_process(port=51000)
mock_print.assert_called_with("No process found listening on port 51000.")
@patch("subprocess.run")
def test_stop_server_process_killed(self, mock_run):
# Simulate lsof returning a process id
mock_run.side_effect = [
MagicMock(returncode=0, stdout="uvicorn 1234 user LISTEN\n"),
MagicMock(returncode=0), # for killing the process
MagicMock(returncode=1), # for checking the process after it is killed
]
with patch("builtins.print") as mock_print:
kill_process(port=51000, wait=True, timeout=5)
mock_print.assert_any_call("Killing model server process with PID 1234")
mock_print.assert_any_call("Process 1234 has been killed.")
@patch("subprocess.run")
def test_stop_server_multiple_pids(self, mock_run):
# Simulate lsof returning multiple process ids (e.g., 1234 and 5678)
mock_run.side_effect = [
MagicMock(
returncode=0,
stdout="uvicorn 1234 user LISTEN\nuvicorn 5678 user LISTEN\n",
), # lsof output
MagicMock(returncode=0), # first kill command for PID 1234
MagicMock(returncode=1), # PID 1234 is successfully terminated
MagicMock(returncode=0), # second kill command for PID 5678
MagicMock(returncode=1), # PID 5678 is successfully terminated
]
with patch("builtins.print") as mock_print:
kill_process(port=51000, wait=True, timeout=5)
# Assert that the function tried to kill both PIDs
mock_print.assert_any_call("Killing model server process with PID 1234")
mock_print.assert_any_call("Process 1234 has been killed.")
mock_print.assert_any_call("Killing model server process with PID 5678")
mock_print.assert_any_call("Process 5678 has been killed.")
if __name__ == "__main__":
unittest.main()

View file

@ -1,90 +0,0 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import app.commons.constants as const
from fastapi import Response
from app.function_calling.model_utils import (
process_messages,
chat_completion,
Message,
ChatMessage,
Choice,
ChatCompletionResponse,
)
def sample_messages():
# Ensure fields are explicitly set with valid data or empty values
return [
Message(role="user", content="Hello!", tool_calls=[], tool_call_id=""),
Message(
role="assistant",
content="",
tool_calls=[{"function": {"name": "sample_tool"}}],
tool_call_id="sample_id",
),
Message(
role="tool", content="Response from tool", tool_calls=[], tool_call_id=""
),
]
def sample_request(sample_messages):
return ChatMessage(
messages=sample_messages,
tools=[{"name": "sample_tool", "description": "A sample tool"}],
)
@patch("app.commons.constants.arch_function_hanlder")
def test_process_messages(mock_hanlder):
messages = sample_messages()
processed = process_messages(messages)
assert len(processed) == 3
assert processed[0] == {"role": "user", "content": "Hello!"}
assert processed[1] == {
"role": "assistant",
"content": '<tool_call>\n{"name": "sample_tool"}\n</tool_call>',
}
assert processed[2] == {
"role": "user",
"content": "<tool_response>\nResponse from tool\n</tool_response>",
}
@patch("app.commons.constants.arch_function_client")
@patch("app.commons.constants.arch_function_hanlder")
@pytest.mark.asyncio
async def test_chat_completion(mock_hanlder, mock_client):
# Mock the model list return for client
mock_client.models.list.return_value = MagicMock(
data=[MagicMock(id="sample_model")]
)
request = sample_request(sample_messages())
# Simulate stream response as list of tokens
mock_response = AsyncMock()
mock_response.__aiter__.return_value = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
]
mock_client.chat.completions.create.return_value = mock_response
# Mock the tool formatter
mock_hanlder._format_system.return_value = "<formatted_tools>"
response = Response()
chat_response = await chat_completion(request, response)
assert isinstance(chat_response, ChatCompletionResponse)
assert chat_response.choices[0].message.content is not None
first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
assert first_call_args["stream"] == True
assert "model" in first_call_args
assert first_call_args["messages"][0]["content"] == "<formatted_tools>"
# Check that the arguments for the second call to 'create' include the pre-fill completion
second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
assert second_call_args["stream"] == False
assert "model" in second_call_args
assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST

View file

@ -1,148 +0,0 @@
import json
from app.function_calling.hallucination_handler import HallucinationStateHandler
import pytest
import os
# Get the directory of the current file
current_dir = os.path.dirname(__file__)
# Construct the full path to the JSON file
json_file_path = os.path.join(current_dir, "test_cases.json")
with open(json_file_path) as f:
test_cases = json.load(f)
get_weather_api = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get current weather at a location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "str",
"description": "The location to get the weather for",
"format": "City, State",
},
"unit": {
"type": "str",
"description": "The unit to return the weather in.",
"enum": ["celsius", "fahrenheit"],
"default": "celsius",
},
"days": {
"type": "str",
"description": "the number of days for the request.",
},
},
"required": ["location", "days"],
},
},
}
function_description = get_weather_api["function"]
if type(function_description) != list:
function_description = [get_weather_api["function"]]
@pytest.mark.parametrize("case", test_cases)
def test_hallucination(case):
state = HallucinationStateHandler(
response_iterator=None, function=function_description
)
for token, logprob in zip(case["tokens"], case["logprobs"]):
if token != "</tool_call>":
state.append_and_check_token_hallucination(token, logprob)
if state.hallucination:
break
assert state.hallucination == case["expect"]
@pytest.mark.parametrize("is_hallucinate_sample", [True, False])
def test_hallucination_prompt(is_hallucinate_sample):
TASK_PROMPT = """
You are a helpful assistant.
""".strip()
TOOL_PROMPT = """
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
""".strip()
FORMAT_PROMPT = """
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
""".strip()
def convert_tools(tools):
return "\n".join([json.dumps(tool) for tool in tools])
def format_prompt(tools):
tool_text = convert_tools(tools)
return (
TASK_PROMPT
+ "\n\n"
+ TOOL_PROMPT.format(tool_text=tool_text)
+ "\n\n"
+ FORMAT_PROMPT
+ "\n"
)
openai_format_tools = [get_weather_api]
system_prompt = format_prompt(openai_format_tools)
from openai import OpenAI
client = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
# List models API
model = client.models.list().data[0].id
assert model == "Arch-Function"
if not is_hallucinate_sample:
messages = [
{"role": "system", "content": system_prompt},
# {"role": "user", "content": "can you help me check weather?"},
{"role": "user", "content": "How is the weather in Seattle in 7 days?"},
# {"role": "assistant", "content": "Of course!"},
# {"role": "user", "content": "Seattle please"}
]
else:
messages = [
{"role": "system", "content": system_prompt},
# {"role": "user", "content": "can you help me check weather?"},
{"role": "user", "content": "How is the weather in Seattle in days?"},
# {"role": "assistant", "content": "Of course!"},
# {"role": "user", "content": "Seattle please"}
]
extra_body = {
"temperature": 0.6,
"top_p": 1.0,
"top_k": 50,
# "continue_final_message": True,
# "add_generation_prompt": False,
"logprobs": True,
"top_logprobs": 10,
}
resp = client.chat.completions.create(
model="Arch-Function", messages=messages, extra_body=extra_body, stream=True
)
hallu = HallucinationStateHandler(
response_iterator=resp, function=function_description
)
for token in hallu:
assert len(hallu.tokens) >= 0
assert hallu.hallucination == is_hallucinate_sample

View file

@ -1,102 +0,0 @@
import os
import pytest
from unittest.mock import patch, MagicMock
import app.commons.globals as glb
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
# Mock constants
glb.DEVICE = "cpu" # Adjust as needed for your test case
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
@pytest.fixture
def mock_env():
# Mock environment variables
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
# Test for get_embedding_model function
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
@patch("app.loader.AutoModel.from_pretrained")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
mock_automodel.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
embedding_model = get_embedding_model()
# Assertions
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
mock_tokenizer.assert_called_once_with(
"katanemo/bge-large-en-v1.5", trust_remote_code=True
)
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
)
else:
mock_automodel.assert_called_once_with(
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
)
# Test for get_zero_shot_model function
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
@patch("app.loader.pipeline")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
mock_pipeline.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
zero_shot_model = get_zero_shot_model()
# Assertions
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
)
else:
assert mock_pipeline.called_once()
# Test for get_prompt_guard function
@patch("app.loader.AutoTokenizer.from_pretrained")
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
# Mock model based on device
if glb.DEVICE == "cpu":
mock_ov_model.return_value = MagicMock()
else:
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
# Assertions
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
mock_tokenizer.assert_called_once_with(
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
)
if glb.DEVICE == "cpu":
mock_ov_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)
else:
mock_auto_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)

View file

@ -1,102 +0,0 @@
import os
import pytest
from unittest.mock import patch, MagicMock
import app.commons.globals as glb
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
# Mock constants
glb.DEVICE = "cuda" # Adjust as needed for your test case
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
@pytest.fixture
def mock_env():
# Mock environment variables
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
# Test for get_embedding_model function
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
@patch("app.loader.AutoModel.from_pretrained")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
mock_automodel.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
embedding_model = get_embedding_model()
# Assertions
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
mock_tokenizer.assert_called_once_with(
"katanemo/bge-large-en-v1.5", trust_remote_code=True
)
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
)
else:
mock_automodel.assert_called_once_with(
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
)
# Test for get_zero_shot_model function
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
@patch("app.loader.pipeline")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
mock_pipeline.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
zero_shot_model = get_zero_shot_model()
# Assertions
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
)
else:
assert mock_pipeline.called_once()
# Test for get_prompt_guard function
@patch("app.loader.AutoTokenizer.from_pretrained")
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
# Mock model based on device
if glb.DEVICE == "cpu":
mock_ov_model.return_value = MagicMock()
else:
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
# Assertions
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
mock_tokenizer.assert_called_once_with(
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
)
if glb.DEVICE == "cpu":
mock_ov_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)
else:
mock_auto_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)

View file

@ -1,102 +0,0 @@
import os
import pytest
from unittest.mock import patch, MagicMock
import app.commons.globals as glb
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
# Mock constants
glb.DEVICE = "mps" # Adjust as needed for your test case
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
@pytest.fixture
def mock_env():
# Mock environment variables
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
# Test for get_embedding_model function
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
@patch("app.loader.AutoModel.from_pretrained")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
mock_automodel.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
embedding_model = get_embedding_model()
# Assertions
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
mock_tokenizer.assert_called_once_with(
"katanemo/bge-large-en-v1.5", trust_remote_code=True
)
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
)
else:
mock_automodel.assert_called_once_with(
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
)
# Test for get_zero_shot_model function
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
@patch("app.loader.pipeline")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
mock_pipeline.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
zero_shot_model = get_zero_shot_model()
# Assertions
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
)
else:
assert mock_pipeline.called_once()
# Test for get_prompt_guard function
@patch("app.loader.AutoTokenizer.from_pretrained")
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
# Mock model based on device
if glb.DEVICE == "cpu":
mock_ov_model.return_value = MagicMock()
else:
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
# Assertions
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
mock_tokenizer.assert_called_once_with(
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
)
if glb.DEVICE == "cpu":
mock_ov_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)
else:
mock_auto_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)

View file

@ -1,66 +0,0 @@
from typing import List
import pytest
import json
from app.function_calling.model_utils import Message, process_messages
test_input_history = """
[
{
"role": "user",
"content": "how is the weather in chicago for next 5 days?"
},
{
"role": "assistant",
"model": "Arch-Function-1.5B",
"tool_calls": [
{
"id": "call_3394",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": { "city": "Chicago", "days": 5 }
}
}
]
},
{
"role": "tool",
"content": "--",
"tool_call_id": "call_3394"
},
{
"role": "assistant",
"content": "--",
"model": "gpt-3.5-turbo-0125"
},
{
"role": "user",
"content": "how is the weather in chicago for next 5 days?"
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_5306",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": { "city": "Chicago", "days": 5 }
}
}
]
}
]
"""
def test_update_fc_history():
history = json.loads(test_input_history)
message_history = []
for h in history:
message_history.append(Message(**h))
updated_history = process_messages(message_history)
assert len(updated_history) == 6
# ensure that tool role does not exist anymore
assert all([h["role"] != "tool" for h in updated_history])

2583
model_server/poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,34 +1,26 @@
[tool.poetry]
name = "archgw_modelserver"
version = "0.1.6"
version = "0.1.7"
description = "A model server for serving models"
authors = ["Katanemo Labs, Inc <archgw@katanemo.com>"]
authors = ["Katanemo Labs, Inc <info@katanemo.com>"]
license = "Apache 2.0"
readme = "README.md"
packages = [
{ include = "app" }, # Include the 'app' package
{ include = "app/function_calling" }, # Include the 'app' package
{ include = "src" }
]
include = ["app/*.yaml"]
[tool.poetry.dependencies]
python = ">=3.12"
python = "^3.12"
fastapi = "0.115.0"
sentence-transformers = "3.1.1"
torch = "2.4.1"
uvicorn = "0.31.0"
transformers = "*"
pyyaml = "6.0.2"
accelerate = "*"
psutil = "6.0.0"
optimum-intel = "*"
openvino = "2024.4.0"
pandas = "*"
dateparser = "*"
openai = "1.50.2"
tf-keras = "*"
onnx = "1.17.0"
onnxruntime = "1.19.2"
httpx = "0.27.2" # https://community.openai.com/t/typeerror-asyncclient-init-got-an-unexpected-keyword-argument-proxies/1040287
pytest-asyncio = "*"
pytest = "*"
@ -36,10 +28,20 @@ opentelemetry-api = "^1.28.0"
opentelemetry-sdk = "^1.28.0"
opentelemetry-exporter-otlp = "^1.28.0"
opentelemetry-instrumentation-fastapi = "^0.49b0"
overrides = "^7.7.0"
pytest-retry = "^1.6.3"
pytest-httpserver = "^1.1.0"
[tool.poetry.scripts]
archgw_modelserver = "app.cli:run_server"
archgw_modelserver = "src.cli:run_server"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
python_files = ["test*.py"]
addopts = ["-v", "-s"]
retries = 2
retry_delay = 0.5
cumulative_timing = false

214
model_server/src/cli.py Normal file
View file

@ -0,0 +1,214 @@
import importlib
import logging
from os import path
import os
from signal import SIGKILL
import sys
import subprocess
import argparse
import tempfile
import time
import requests
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
def get_version():
try:
version = importlib.metadata.version("archgw_modelserver")
return version
except importlib.metadata.PackageNotFoundError:
return "version not found"
def wait_for_health_check(url, timeout=300):
"""Wait for the Uvicorn server to respond to health-check requests."""
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(url)
if response.status_code == 200:
return True
except requests.ConnectionError:
time.sleep(1)
return False
def parse_args():
parser = argparse.ArgumentParser(description="Manage the Uvicorn server.")
parser.add_argument(
"action",
choices=["start", "stop", "restart"],
default="start",
nargs="?",
help="Action to perform on the server (default: start).",
)
parser.add_argument(
"--port",
type=int,
default=51000,
help="Port number for the server (default: 51000).",
)
parser.add_argument(
"--foreground",
default=False,
action="store_true",
help="Run the server in the foreground (default: False).",
)
return parser.parse_args()
def get_pid_file():
temp_dir = tempfile.gettempdir()
return path.join(temp_dir, "model_server.pid")
def stop_server():
"""Stop the Uvicorn server."""
pid_file = get_pid_file()
if os.path.exists(pid_file):
logger.info(f"PID file found, shutting down the server.")
# read pid from file
with open(pid_file, "r") as f:
pid = int(f.read())
logger.info(f"Killing model server {pid}")
try:
os.kill(pid, SIGKILL)
except ProcessLookupError:
logger.info(f"Process {pid} not found")
os.remove(pid_file)
else:
logger.info("No PID file found, server is not running.")
def restart_server(port=51000, foreground=False):
"""Restart the Uvicorn server."""
stop_server()
start_server(port, foreground)
def run_server():
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
args = parse_args()
action = args.action
if action == "start":
start_server(args.port, args.foreground)
elif action == "stop":
stop_server()
elif action == "restart":
restart_server(args.port, args.foreground)
else:
logger.info(f"Unknown action: {action}")
sys.exit(1)
def ensure_killed(process):
process.terminate()
# if the process is not terminated, kill it
now = time.time()
# wait for 5 seconds
while time.time() - now < 5:
if process.poll() is not None:
break
time.sleep(1)
if process.poll() is None:
logger.info("Killing model server")
process.kill()
def start_server(port=51000, foreground=False):
"""Start the Uvicorn server."""
logging.info("model server version: %s", get_version())
stop_server()
logger.info(
"starting model server, port: %s, foreground: %s. Please wait ...",
port,
foreground,
)
if foreground:
process = subprocess.Popen(
[
"python",
"-m",
"uvicorn",
"src.main:app",
"--host",
"0.0.0.0",
"--port",
str(port),
],
)
else:
process = subprocess.Popen(
[
"python",
"-m",
"uvicorn",
"src.main:app",
"--host",
"0.0.0.0",
"--port",
str(port),
],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
)
try:
if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"):
logger.info(
f"model server health check passed, port {port}, pid: {process.pid}"
)
else:
logger.error("health check failed, shutting it down.")
process.terminate()
except KeyboardInterrupt:
logger.info("model server stopped by user during initialization.")
ensure_killed(process)
# write process id to temp file in temp folder
pid_file = get_pid_file()
logger.info(f"writing pid {process.pid} to {pid_file}")
with open(pid_file, "w") as f:
f.write(str(process.pid))
if foreground:
try:
process.wait()
except KeyboardInterrupt:
logger.info("model server stopped by user.")
ensure_killed(process)
def main():
"""
Start, stop, or restart the Uvicorn server based on command-line arguments.
"""
args = parse_args()
if args.action == "start":
start_server(args.port, args.foreground)
elif args.action == "stop":
stop_server()
elif args.action == "restart":
restart_server(args.port)
else:
logger.error(f"Unknown action: {args.action}")
sys.exit(1)

View file

@ -0,0 +1,38 @@
import os
from openai import OpenAI
from src.commons.utils import get_model_server_logger
from src.core.guardrails import get_guardrail_handler
from src.core.function_calling import (
ArchIntentConfig,
ArchIntentHandler,
ArchFunctionConfig,
ArchFunctionHandler,
)
# Define logger
logger = get_model_server_logger()
# Define the client
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://api.fc.archgw.com/v1")
ARCH_API_KEY = "EMPTY"
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
# Define model names
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
logger.info("loading prompt guard model ...")
arch_guard_model = get_guardrail_handler()
# Define model handlers
handler_map = {
"Arch-Intent": ArchIntentHandler(
ARCH_CLIENT, ARCH_INTENT_MODEL_ALIAS, ArchIntentConfig
),
"Arch-Function": ArchFunctionHandler(
ARCH_CLIENT, ARCH_FUNCTION_MODEL_ALIAS, ArchFunctionConfig
),
"Arch-Guard": arch_guard_model,
}

View file

@ -0,0 +1,87 @@
import os
import sys
import time
import logging
import requests
import subprocess
import importlib
PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Default log directory and file
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, ".logs")
DEFAULT_LOG_FILE = "modelserver.log"
def get_model_server_logger(log_dir=None, log_file=None):
"""
Get or initialize the logger instance for the model server.
Parameters:
- log_dir (str): Custom directory to store the log file. Defaults to `./.logs`.
- log_file (str): Custom log file name. Defaults to `modelserver.log`.
Returns:
- logging.Logger: Configured logger instance.
"""
log_dir = log_dir or DEFAULT_LOG_DIR
log_file = log_file or DEFAULT_LOG_FILE
log_file_path = os.path.join(log_dir, log_file)
# Check if the logger is already configured
logger = logging.getLogger("model_server_logger")
if logger.hasHandlers():
# Return existing logger instance if already configured
return logger
# Ensure the log directory exists, create it if necessary
try:
# Create directory if it doesn't exist
os.makedirs(log_dir, exist_ok=True)
# Check for write permissions
if not os.access(log_dir, os.W_OK):
raise PermissionError(f"No write permission for the directory: {log_dir}")
except (PermissionError, OSError) as e:
raise RuntimeError(f"Failed to initialize logger: {e}")
# Configure logging to file
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
# logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in the file
logging.StreamHandler(), # Also log to console
],
)
return logger
logger = get_model_server_logger()
logging.info("initializing torch device ...")
import torch
def get_device():
available_device = {
"cpu": True,
"cuda": torch.cuda.is_available(),
"mps": (
torch.backends.mps.is_available()
if hasattr(torch.backends, "mps")
else False
),
}
if available_device["cuda"]:
device = "cuda"
elif available_device["mps"]:
device = "mps"
else:
device = "cpu"
return device

View file

@ -0,0 +1,644 @@
import json
import random
import builtins
import textwrap
from openai import OpenAI
from typing import Any, Dict, List
from overrides import override
from src.commons.utils import get_model_server_logger
from src.core.model_utils import (
Message,
ChatMessage,
Choice,
ChatCompletionResponse,
ArchBaseHandler,
)
from src.core.hallucination import HallucinationStateHandler
logger = get_model_server_logger()
class ArchIntentConfig:
TASK_PROMPT = textwrap.dedent(
"""
You are a helpful assistant.
"""
).strip()
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
"""
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
<tools>
{tool_text}
</tools>
"""
).strip()
FORMAT_PROMPT = textwrap.dedent(
"""
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
- First line must read 'Yes' or 'No'.
- If yes, a second line must include a comma-separated list of tool indexes.
"""
).strip()
EXTRA_INSTRUCTION = "Are there any tools can help?"
GENERATION_PARAMS = {
"temperature": 0.01,
"max_tokens": 1,
"stop_token_ids": [151645],
}
class ArchIntentHandler(ArchBaseHandler):
def __init__(self, client: OpenAI, model_name: str, config: ArchIntentConfig):
"""
Initializes the intent handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
config (ArchIntentConfig): The configuration for Arch-Intent.
"""
super().__init__(
client,
model_name,
config.TASK_PROMPT,
config.TOOL_PROMPT_TEMPLATE,
config.FORMAT_PROMPT,
config.GENERATION_PARAMS,
)
self.extra_instruction = config.EXTRA_INSTRUCTION
self.prompt_prefilling = False
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into a JSON-like format with indexed keys.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A string representation of converted tools.
"""
converted = [
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
]
return "\n".join(converted)
def detect_intent(self, content: str) -> bool:
"""
Detect if any intent match with prompts
Args:
content: str: Model response that contains intent detection results
Returns:
bool: A boolean value to indicate if any intent match with prompts or not
"""
if hasattr(content.choices[0].message, "content"):
return content.choices[0].message.content == "Yes"
else:
return False
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Generates a chat completion for a given request.
Args:
req (ChatMessage): A chat message request object.
Returns:
ChatCompletionResponse: The model's response to the chat request.
Note:
Currently only support vllm inference
"""
# In the case that no tools are available, simply return `No` to avoid making a call
if len(req.tools) == 0:
model_response = Message(content="No", tool_calls=[])
else:
messages = self._process_messages(
req.messages, req.tools, self.extra_instruction
)
model_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body=self.generation_params,
)
logger.info(
"arch_intent response: %s", json.dumps(model_response.model_dump())
)
model_response = Message(
content=model_response.choices[0].message.content, tool_calls=[]
)
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
)
return chat_completion_response
# =============================================================================================================
class ArchFunctionConfig:
TASK_PROMPT = textwrap.dedent(
"""
You are a helpful assistant.
"""
).strip()
TOOL_PROMPT_TEMPLATE = textwrap.dedent(
"""
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
"""
).strip()
FORMAT_PROMPT = textwrap.dedent(
"""
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
"""
).strip()
GENERATION_PARAMS = {
"temperature": 0.6,
"top_p": 1.0,
"top_k": 10,
"max_tokens": 512,
"stop_token_ids": [151645],
"logprobs": True,
"top_logprobs": 10,
}
PREFILL_CONFIG = {
"prefill_params": {
"continue_final_message": True,
"add_generation_prompt": False,
},
"prefill_prefix": [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
],
}
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class ArchFunctionHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
config: ArchFunctionConfig,
):
"""
Initializes the function handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
config (ArchFunctionConfig): The configuration for Arch-Function
"""
super().__init__(
client,
model_name,
config.TASK_PROMPT,
config.TOOL_PROMPT_TEMPLATE,
config.FORMAT_PROMPT,
config.GENERATION_PARAMS,
)
self.prefill_params = config.PREFILL_CONFIG["prefill_params"]
self.prefill_prefix = config.PREFILL_CONFIG["prefill_prefix"]
self.prompt_prefilling = False
# Predefine data types for verification. Only support Python for now.
# [TODO] Extend the list of support data types
self.support_data_types = {
type_name: getattr(builtins, type_name)
for type_name in config.SUPPORT_DATA_TYPES
}
@override
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into JSON format.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A string representation of converted tools.
"""
converted = [json.dumps(tool) for tool in tools]
return "\n".join(converted)
def _fix_json_string(self, json_str: str) -> str:
"""
Fixes malformed JSON strings by ensuring proper bracket matching.
Args:
json_str (str): A JSON string that might be malformed.
Returns:
str: A corrected JSON string.
"""
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
# Stack to keep track of brackets
stack = []
# Clean string to collect valid characters
fixed_str = ""
# Dictionary for matching brackets
matching_bracket = {")": "(", "}": "{", "]": "["}
# Dictionary for the opposite of matching_bracket
opening_bracket = {v: k for k, v in matching_bracket.items()}
for char in json_str:
if char in "{[(":
stack.append(char)
fixed_str += char
elif char in "}])":
if stack and stack[-1] == matching_bracket[char]:
stack.pop()
fixed_str += char
else:
# Ignore the unmatched closing brackets
continue
else:
fixed_str += char
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
while stack:
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str.replace("'", '"')
def _extract_tool_calls(self, content: str) -> Dict[str, any]:
"""
Extracts tool call information from a given string.
Args:
content (str): The content string containing potential tool call information.
Returns:
Dict: A dictionary of extraction, including:
- "result": A list of tool call dictionaries.
- "status": A boolean indicating if the extraction was valid.
- "message": An error message or exception if extraction failed.
"""
tool_calls, is_valid, error_message = [], True, ""
flag = False
for line in content.split("\n"):
if not is_valid:
break
if "<tool_call>" == line:
flag = True
elif "</tool_call>" == line:
flag = False
else:
if flag:
try:
tool_content = json.loads(line)
except Exception as e:
fixed_content = self._fix_json_string(line)
try:
tool_content = json.loads(fixed_content)
except Exception:
tool_calls, is_valid, error_message = [], False, e
break
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
"arguments": tool_content["arguments"],
},
}
)
flag = False
return {"result": tool_calls, "status": is_valid, "message": error_message}
def _correcting_type(self, value, target_type):
try:
if target_type == float and isinstance(value, int):
return float(value)
elif target_type == list and isinstance(value, str):
return ast.literal_eval(value)
elif target_type == str and not isinstance(value, str):
return str(value)
# Add more conversion rules as needed
except (ValueError, TypeError, json.JSONDecodeError):
pass
return value
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
) -> Dict[str, any]:
"""
Verifies the validity of extracted tool calls against the provided tools.
Args:
tools (List[Dict[str, Any]]): A list of available tools.
tool_calls (List[Dict[str, Any]]): A list of tool calls to verify.
Returns:
Dict: A dictionary of verification, including:
- "status": A boolean indicating if the tool calls are valid.
- "invalid_tool_call": A dictionary of the invalid tool call if any.
- "message": An error message.
"""
is_valid, invalid_tool_call, error_message = True, None, ""
functions = {}
for tool in tools:
if tool["type"] == "function":
functions[tool["function"]["name"]] = tool["function"]["parameters"]
for tool_call in tool_calls:
if not is_valid:
break
func_name = tool_call["function"]["name"]
func_args = tool_call["function"]["arguments"]
# Check whether the function is available or not
if func_name not in functions:
is_valid = False
invalid_tool_call = tool_call
error_message = f"{func_name} is not defined!"
break
else:
# Check if all the requried parameters can be found in the tool calls
for required_param in functions[func_name].get("required", []):
if required_param not in func_args:
is_valid = False
invalid_tool_call = tool_call
error_message = f"`{required_param}` is requiried by the function `{func_name}` but not found in the tool call!"
break
# Verify the data type of each parameter in the tool calls
for param_name in func_args:
if param_name not in functions[func_name]["properties"]:
is_valid = False
invalid_tool_call = tool_call
error_message = f"Parameter `{param_name}` is not defined in the function `{func_name}`."
break
else:
param_value = func_args[param_name]
data_type = functions[func_name]["properties"][param_name][
"type"
]
if data_type in self.support_data_types:
if not isinstance(
param_value,
self.support_data_types[data_type],
) and not isinstance(
self._correcting_type(
param_value, self.support_data_types[data_type]
),
self.support_data_types[data_type],
):
is_valid = False
invalid_tool_call = tool_call
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
break
return {
"status": is_valid,
"invalid_tool_call": invalid_tool_call,
"message": error_message,
}
def _add_prefill_message(self, messages: List[Dict[str, str]]):
"""
Update messages and generation params for prompt prefilling
Args:
messages (List[Dict[str, str]]): A list of messages.
Returns:
prefill_messages (List[Dict[str, str]]): A list of messages.
"""
return messages + [
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
]
def _engage_parameter_gathering(self, messages: List[Dict[str, str]]):
"""
Engage parameter gathering for tool calls
"""
# TODO: log enaging parameter gathering
prefill_response = self.client.chat.completions.create(
messages=self._add_prefill_message(messages),
model=self.model_name,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
self.prompt_prefilling = True
return prefill_response
def _check_length_and_pop_messages(self, messages, max_tokens=4096):
"""
Trims the `messages` list to ensure the total token count does not exceed `max_tokens`.
Args:
messages (list): List of message dictionaries.
max_tokens (int): Maximum allowed token count.
Returns:
list: Trimmed list of messages.
"""
def estimate_token_length(messages):
"""Estimate the total token length of the messages."""
total_tokens = 0
for message in messages:
# Approximate token length: assuming ~4 characters per token on average
total_tokens += len(message["content"]) // 4
return total_tokens
# Calculate initial token length
total_tokens = estimate_token_length(messages)
# Trim messages if token count exceeds the limit
while total_tokens > max_tokens and len(messages) >= 3:
# Find the first non-system message pair
for i in range(len(messages)):
if messages[i]["role"] != "system":
# Remove the 'user'/'assistant' pair
if i + 1 < len(messages) and messages[i + 1]["role"] in [
"user",
"assistant",
]:
del messages[i : i + 2]
else:
del messages[i]
break
# Recalculate token length
total_tokens = estimate_token_length(messages)
return messages
@override
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Generates a chat completion response for a given request.
Args:
req (ChatMessage): A chat message request object.
enable_prefilling (bool, optional): Whether to enable prefill responses. Defaults to True.
Returns:
ChatCompletionResponse: The model's response to the chat request.
Note:
Currently only support vllm inference
"""
logger.info(
f"model_server => arch_function: request body: {json.dumps(req.model_dump())}"
)
messages = self._process_messages(req.messages, req.tools)
messages = self._check_length_and_pop_messages(messages)
# always enable `stream=True` to collect model responses
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=True,
extra_body=self.generation_params,
)
# initialize the hallucination handler, which is an iterator
self.hallu_handler = HallucinationStateHandler(
response_iterator=response, function=req.tools
)
model_response, self.has_tool_call = "", None
self.hallucination = False
for _ in self.hallu_handler:
# check if the first token is <tool_call>
if len(self.hallu_handler.tokens) > 0 and self.has_tool_call is None:
if self.hallu_handler.tokens[0] == "<tool_call>":
self.has_tool_call = True
else:
self.has_tool_call = False
break
# if the model is hallucinating, start parameter gathering
if self.hallu_handler.hallucination is True:
self.hallucination = True
logger.info(
f"{self.hallu_handler.error_message} - start parameter gathering"
)
logger.info(
f"Hallucinated response : {''.join(self.hallu_handler.tokens)}"
)
# [TODO] - add break when hallucination is detected
break
if self.hallucination is True:
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
if self.has_tool_call and self.hallucination is False:
# [TODO] - Review: remove the following code
model_response = "".join(self.hallu_handler.tokens)
logger.info(f"Tool call found, no hallucination detected {model_response}!")
# start parameter gathering if the model is not generating tool calls
if self.has_tool_call is False:
# [TODO] - Review: remove the following code
logger.info("No tool call found, start parameter gathering")
prefill_response = self._engage_parameter_gathering(messages)
model_response = prefill_response.choices[0].message.content
# Extract tool calls from model response
extracted = self._extract_tool_calls(model_response)
if len(extracted["result"]) and extracted["status"]:
# [TODO] Review: define the behavior in the case that tool call extraction fails
# if not extracted["status"]:
verified = self._verify_tool_calls(
tools=req.tools, tool_calls=extracted["result"]
)
# [TODO] - Review: remvoe the following code
# print(f"[Verified] - {verified}")
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
if verified["status"]:
model_response = Message(content="", tool_calls=extracted["result"])
log_message = f"model_server <= arch_function: (tool_calls): {json.dumps([tool_call['function'] for tool_call in extracted['result']])}"
logger.info(log_message)
else:
raise ValueError(f"Invalid tool call: {verified['message']}")
else:
model_response = Message(content=model_response, tool_calls=[])
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
)
# [TODO] Review: define the protocol to collect debugging output
logger.info(
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.model_dump())}"
)
return chat_completion_response

View file

@ -0,0 +1,171 @@
import time
import torch
import numpy as np
import src.commons.utils as utils
from transformers import AutoTokenizer
from src.core.model_utils import GuardRequest, GuardResponse
# from optimum.intel import OVModelForSequenceClassification
from transformers import AutoModelForSequenceClassification
class ArchGuardHanlder:
def __init__(self, model_dict):
"""
Initializes the ArchGuardHanlder with the given model dictionary.
Args:
model_dict (dict): A dictionary containing the model, tokenizer, and device information.
"""
self.model = model_dict["model"]
self.model_name = model_dict["model_name"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
self.support_tasks = {"jailbreak": {"positive_class": 2, "threshold": 0.5}}
def _split_text_into_chunks(self, text, max_num_words=300):
"""
Splits the input text into chunks of up to `max_num_words` words.
Args:
text (str): The input text to be split.
max_num_words (int, optional): The maximum number of words in each chunk. Defaults to 300.
Returns:
List[str]: A list of text chunks.
"""
words = text.split()
chunks = [
" ".join(words[i : i + max_num_words])
for i in range(0, len(words), max_num_words)
]
return chunks
@staticmethod
def softmax(x):
"""
Computes the softmax of the input array.
Args:
x (np.ndarray): The input array.
Returns:
np.ndarray: The softmax of the input.
"""
return np.exp(x) / np.exp(x).sum(axis=0)
def _predict_text(self, task, text, max_length=512) -> GuardResponse:
"""
Predicts the result for the provided text for a specific task.
Args:
task (str): The task to perform (e.g., "jailbreak").
text (str): The input text to classify.
max_length (int, optional): The maximum length for tokenization. Defaults to 512.
Returns:
GuardResponse: A GuardResponse object containing the prediction.
"""
inputs = self.tokenizer(
text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device)
start_time = time.perf_counter()
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = ArchGuardHanlder.softmax(logits)[
self.support_tasks[task]["positive_class"]
]
latency = time.perf_counter() - start_time
if prob > self.support_tasks[task]["threshold"]:
verdict = True
sentence = text
else:
verdict = False
sentence = None
return GuardResponse(
prob=[prob.item()], verdict=verdict, sentence=[sentence], latency=latency
)
def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse:
"""
Makes a prediction based on the GuardRequest input.
Args:
req (GuardRequest): The GuardRequest object containing the input text and task.
max_num_words (int, optional): The maximum number of words in each chunk if splitting is needed. Defaults to 300.
Returns:
GuardResponse: A GuardResponse object containing the prediction.
Note:
currently only support jailbreak check
"""
if req.task not in self.support_tasks:
raise NotImplementedError(f"{req.task} is not supported!")
if len(req.input.split()) < max_num_words:
return self._predict_text(req.task, req.input)
else:
# split into chunks if text is long
text_chunks = self._split_text_into_chunks(req.input)
prob, verdict, sentence, latency = [], False, [], 0
for chunk in text_chunks:
chunk_result = self._predict_text(req.task, chunk)
if chunk_result.verdict:
prob.append(chunk_result.prob[0])
verdict = True
sentence.append(chunk_result.sentence[0])
latency += chunk_result.latency
return GuardResponse(
prob=prob, verdict=verdict, sentence=sentence, latency=latency
)
def get_guardrail_handler(device: str = None):
"""
Initializes and returns an instance of ArchGuardHanlder based on the specified device.
Args:
device (str, optional): The device to use for model inference (e.g., "cpu" or "cuda"). Defaults to None.
Returns:
ArchGuardHanlder: An instance of ArchGuardHanlder configured for the specified device.
"""
if device is None:
device = utils.get_device()
model_class, model_name = None, None
# if device == "cpu":
# model_class = OVModelForSequenceClassification
# model_name = "katanemo/Arch-Guard-cpu"
# else:
model_class = AutoModelForSequenceClassification
model_name = "katanemo/Arch-Guard"
guardrail_dict = {
"device": device,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
),
}
return ArchGuardHanlder(model_dict=guardrail_dict)

View file

@ -1,15 +1,22 @@
import json
import math
import torch
import random
from typing import Any, Dict, List, Tuple
import itertools
from typing import Dict, List, Tuple
from enum import Enum
import string
from src.commons.utils import get_model_server_logger
logger = get_model_server_logger()
# constants
FUNC_NAME_START_PATTERN = ('<tool_call>\n{"name":"', "<tool_call>\n{'name':'")
FUNC_NAME_END_TOKEN = ('",', "',")
TOOL_CALL_TOKEN = "<tool_call>"
END_TOOL_CALL_TOKEN = "</tool_call>"
FIRST_PARAM_NAME_START_PATTERN = ('"arguments":{"', "'arguments':{'")
PARAMETER_NAME_END_TOKENS = ('":', ':"', "':", ":'")
@ -17,6 +24,8 @@ PARAMETER_NAME_START_PATTERN = (',"', ",'")
PARAMETER_VALUE_START_PATTERN = ('":', "':")
PARAMETER_VALUE_END_TOKEN = ('",', "}}\n", "',")
BRACKETS = {"(": ")", "{": "}", "[": "]"}
# Thresholds
class MaskToken(Enum):
@ -28,10 +37,15 @@ class MaskToken(Enum):
HALLUCINATION_THRESHOLD_DICT = {
MaskToken.TOOL_CALL.value: {"entropy": 0.1, "varentropy": 0.5},
MaskToken.TOOL_CALL.value: {
"entropy": 0.35,
"varentropy": 1.7,
"probability": 0.8,
},
MaskToken.PARAMETER_VALUE.value: {
"entropy": 0.5,
"varentropy": 2.5,
"entropy": 0.28,
"varentropy": 1.2,
"probability": 0.8,
},
}
@ -48,10 +62,10 @@ def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
Returns:
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
"""
return entropy > thd["entropy"] or varentropy > thd["varentropy"]
return entropy > thd["entropy"] and varentropy > thd["varentropy"]
def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]:
"""
Calculate the entropy and variance of entropy (varentropy) from log probabilities.
@ -71,7 +85,26 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
dim=-1,
)
return entropy.item(), varentropy.item()
return entropy.item(), varentropy.item(), token_probs[0].item()
def is_parameter_required(
function_description: Dict,
parameter_name: str,
) -> bool:
"""
Check if a parameter in required list
Args:
function_description (dict): The API description in JSON format.
parameter_name (str): The name of the parameter to check.
Returns:
bool: True if the parameter has the specified property, False otherwise.
"""
required_parameters = function_description.get("required", {})
return parameter_name in required_parameters
def is_parameter_property(
@ -107,7 +140,6 @@ class HallucinationStateHandler:
hallucination (bool): Flag indicating if a hallucination is detected.
hallucination_message (str): Message describing the hallucination.
parameter_name (list): List of extracted parameter names.
function_description (dict): Description of functions and their parameters.
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
"""
@ -122,23 +154,34 @@ class HallucinationStateHandler:
self.parameter_name_done: bool = False
self.hallucination: bool = False
self.error_message: str = ""
self.error_type: str = ""
self.parameter_name: List[str] = []
self.token_probs_map: List[Tuple[str, float, float]] = []
self.response_iterator = response_iterator
self._process_function(function)
self.open_bracket = False
self.bracket = None
self.check_parameter_name = {}
self.HALLUCINATION_THRESHOLD_DICT = HALLUCINATION_THRESHOLD_DICT
def _process_function(self, function):
self.function = function
if self.function is None:
raise ValueError("API descriptions not set.")
parameter_names = {}
for func in self.function:
func_name = func["name"]
parameters = func["parameters"]["properties"]
parameter_names[func_name] = list(parameters.keys())
self.function_description = parameter_names
self.function_properties = {x["name"]: x["parameters"] for x in self.function}
self.function_properties = {
x["function"]["name"]: x["function"]["parameters"] for x in self.function
}
def _reset_parameters(self):
"""
Resets all parameters in the HallucinationStateHandler to their default values.
"""
self.state = None
self.parameter_name_done = False
self.hallucination = False
self.error_message = ""
self.open_bracket = False
self.bracket = None
self.check_parameter_name = {}
def append_and_check_token_hallucination(self, token, logprob):
"""
@ -175,9 +218,12 @@ class HallucinationStateHandler:
raise ValueError(
f"Error extracting logprobs from response: {e}"
)
self.append_and_check_token_hallucination(
token_content, logprobs
)
if token_content == END_TOOL_CALL_TOKEN:
self._reset_parameters()
else:
self.append_and_check_token_hallucination(
token_content, logprobs
)
return token_content
except StopIteration:
raise StopIteration
@ -199,7 +245,7 @@ class HallucinationStateHandler:
self.mask.append(MaskToken.FUNCTION_NAME)
else:
self.state = None
self._is_function_name_hallucinated()
self._get_function_name()
# Check if the token is a function name start token, change the state
if content.endswith(FUNC_NAME_START_PATTERN):
@ -217,11 +263,13 @@ class HallucinationStateHandler:
PARAMETER_NAME_END_TOKENS
):
self.state = None
self._is_parameter_name_hallucinated()
self.parameter_name_done = True
self._get_parameter_name()
# if the parameter name is done and the token is a parameter name start token, change the state
elif self.parameter_name_done and content.endswith(
PARAMETER_NAME_START_PATTERN
elif (
self.parameter_name_done
and self.open_bracket == False
and content.endswith(PARAMETER_NAME_START_PATTERN)
):
self.state = "parameter_name"
@ -235,24 +283,49 @@ class HallucinationStateHandler:
PARAMETER_VALUE_END_TOKEN
):
# checking if the token is a value token and is not empty
if self.tokens[-1].strip() not in ['"', ""]:
open_brackets = [
char for char in self.tokens[-1].strip() if char in BRACKETS
]
if open_brackets:
self.open_bracket = True
self.bracket = open_brackets[0]
if self.open_bracket and BRACKETS[self.bracket] in self.tokens[-1].strip():
self.open_bracket = False
self.bracket = None
if (
not all(
char in set(string.punctuation) for char in self.tokens[-1].strip()
)
and self.tokens[-1].strip() != ""
):
self.mask.append(MaskToken.PARAMETER_VALUE)
# checking if the parameter doesn't have default and the token is the first parameter value token
# checking if the parameter doesn't have enum and the token is the first parameter value token
if (
len(self.mask) > 1
and self.mask[-2] != MaskToken.PARAMETER_VALUE
and is_parameter_required(
self.function_properties[self.function_name],
self.parameter_name[-1],
)
and not is_parameter_property(
self.function_properties[self.function_name],
self.parameter_name[-1],
"default",
"enum",
)
):
self._check_logprob()
if self.parameter_name[-1] not in self.check_parameter_name:
self._check_logprob()
self.check_parameter_name[self.parameter_name[-1]] = True
else:
self.mask.append(MaskToken.NOT_USED)
# if the state is parameter value and the token is an end token, change the state
elif self.state == "parameter_value" and content.endswith(
PARAMETER_VALUE_END_TOKEN
elif (
self.state == "parameter_value"
and self.open_bracket == False
and content.endswith(PARAMETER_VALUE_END_TOKEN)
):
self.state = None
# if the parameter name is done and the token is a parameter value start token, change the state
@ -272,17 +345,16 @@ class HallucinationStateHandler:
Detects hallucinations based on entropy and variance of entropy.
"""
probs = self.logprobs[-1]
entropy, varentropy = calculate_entropy(probs)
self.token_probs_map.append((self.tokens[-1], entropy, varentropy))
entropy, varentropy, probability = calculate_uncertainty(probs)
self.token_probs_map.append((self.tokens[-1], entropy, varentropy, probability))
if check_threshold(
entropy, varentropy, HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value]
entropy,
varentropy,
self.HALLUCINATION_THRESHOLD_DICT[self.mask[-1].value],
):
self.hallucination = True
self.error_type = "Hallucination"
self.error_message = (
f"Hallucination: token '{self.tokens[-1]}' is uncertain."
)
self.error_message = f"Hallucination: token '{self.tokens[-1]}' is uncertain. {self.token_probs_map}"
def _count_consecutive_token(self, token=MaskToken.PARAMETER_VALUE) -> int:
"""
@ -300,25 +372,23 @@ class HallucinationStateHandler:
else 0
)
def _is_function_name_hallucinated(self):
def _get_parameter_name(self):
"""
Checks the extracted function name against the function descriptions.
Detects hallucinations if the function name is not found.
"""
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:])
if self.function_name not in self.function_description.keys():
self.error_type = "function_name"
self.error_message = f"Function name '{self.function_name}' not found in given function descriptions."
Get the parameter name from the tokens.
def _is_parameter_name_hallucinated(self):
"""
Checks the extracted parameter name against the function descriptions.
Detects hallucinations if the parameter name is not found.
Returns:
str: The extracted parameter name.
"""
p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name)
if parameter_name not in self.function_description[self.function_name]:
self.error_type = "parameter_name"
self.error_message = f"Parameter name '{parameter_name}' not found in given function descriptions."
def _get_function_name(self):
"""
Get the function name from the tokens.
Returns:
str: The extracted function name.
"""
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:])

View file

@ -0,0 +1,181 @@
import json
from openai import OpenAI
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from overrides import final
class Message(BaseModel):
role: Optional[str] = ""
content: Optional[str] = ""
tool_call_id: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
class Choice(BaseModel):
id: Optional[int] = 0
message: Message
finish_reason: Optional[str] = "stop"
class ChatCompletionResponse(BaseModel):
id: Optional[int] = 0
object: Optional[str] = "chat_completion"
created: Optional[str] = ""
choices: List[Choice]
model: str
metadata: Optional[Dict[str, str]] = {}
class GuardRequest(BaseModel):
input: str
task: str
class GuardResponse(BaseModel):
prob: List
verdict: bool
sentence: List
latency: float = 0
# ================================================================================================
class ArchBaseHandler:
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
):
"""
Initializes the base handler.
Args:
client (OpenAI): An OpenAI client instance.
model_name (str): Name of the model to use.
task_prompt (str): The main task prompt for the system.
tool_prompt (str): A prompt to describe tools.
format_prompt (str): A prompt specifying the desired output format.
generation_params (Dict): Generation parameters for the model.
"""
self.client = client
self.model_name = model_name
self.task_prompt = task_prompt
self.tool_prompt_template = tool_prompt_template
self.format_prompt = format_prompt
self.generation_params = generation_params
def _convert_tools(self, tools: List[Dict[str, Any]]) -> str:
"""
Converts a list of tools into the desired internal representation.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Raises:
NotImplementedError: Method should be overridden in subclasses.
"""
raise NotImplementedError()
@final
def _format_system_prompt(self, tools: List[Dict[str, Any]]) -> str:
"""
Formats the system prompt using provided tools.
Args:
tools (List[Dict[str, Any]]): A list of tools represented as dictionaries.
Returns:
str: A formatted system prompt.
"""
tool_text = self._convert_tools(tools)
system_prompt = (
self.task_prompt
+ "\n\n"
+ self.tool_prompt_template.format(tool_text=tool_text)
+ "\n\n"
+ self.format_prompt
)
return system_prompt
@final
def _process_messages(
self,
messages: List[Message],
tools: List[Dict[str, Any]] = None,
extra_instruction: str = None,
):
"""
Processes a list of messages and formats them appropriately.
Args:
messages (List[Message]): A list of message objects.
tools (List[Dict[str, Any]], optional): A list of tools to include in the system prompt.
extra_instruction (str, optional): Additional instructions to append to the last user message.
Returns:
List[Dict[str, Any]]: A list of processed message dictionaries.
"""
processed_messages = []
if tools:
processed_messages.append(
{"role": "system", "content": self._format_system_prompt(tools)}
)
for message in messages:
role, content, tool_calls = (
message.role,
message.content,
message.tool_calls,
)
if tool_calls:
# [TODO] Extend to support multiple function calls
role = "assistant"
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
elif message.role == "tool":
role = "user"
content = (
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
)
processed_messages.append({"role": role, "content": content})
assert processed_messages[-1]["role"] == "user"
if extra_instruction:
processed_messages[-1]["content"] += extra_instruction
return processed_messages
async def chat_completion(self, req: ChatMessage) -> ChatCompletionResponse:
"""
Abstract method for generating chat completions.
Args:
req (ChatMessage): A chat message request object.
Raises:
NotImplementedError: Method should be overridden in subclasses.
"""
raise NotImplementedError()

134
model_server/src/main.py Normal file
View file

@ -0,0 +1,134 @@
import json
import logging
import os
import time
from src.commons.globals import handler_map
from src.core.model_utils import ChatMessage, GuardRequest
from fastapi import FastAPI, Response
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.resources import Resource
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
resource = Resource.create(
{
"service.name": "model-server",
}
)
# Initialize the tracer provider
trace.set_tracer_provider(TracerProvider(resource=resource))
tracer = trace.get_tracer(__name__)
app = FastAPI()
FastAPIInstrumentor().instrument_app(app)
# DEFAULT_OTLP_HOST = "http://localhost:4317"
DEFAULT_OTLP_HOST = "none"
# Configure the OTLP exporter (Jaeger, Zipkin, etc.)
otlp_exporter = OTLPSpanExporter(
endpoint=os.getenv("OTLP_HOST", DEFAULT_OTLP_HOST) # noqa: F821
)
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
@app.get("/healthz")
async def healthz():
return {"status": "ok"}
@app.get("/models")
async def models():
return {
"object": "list",
"data": [{"id": model_name, "object": "model"} for model_name in handler_map],
}
@app.post("/function_calling")
async def function_calling(req: ChatMessage, res: Response):
try:
intent_start_time = time.perf_counter()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
intent_latency = time.perf_counter() - intent_start_time
if handler_map["Arch-Intent"].detect_intent(intent_response):
# [TODO] measure agreement between intent detection and function calling
try:
function_start_time = time.perf_counter()
function_calling_response = await handler_map[
"Arch-Function"
].chat_completion(req)
function_latency = time.perf_counter() - function_start_time
function_calling_response.metadata = {
"intent_latency": str(round(intent_latency * 1000, 3)),
"function_latency": str(round(function_latency * 1000, 3)),
"hallucination": str(handler_map["Arch-Function"].hallucination),
"tokens_uncertainty": json.dumps(
handler_map["Arch-Function"].hallu_handler.token_probs_map
),
"prompt_prefilling": str(
handler_map["Arch-Function"].prompt_prefilling
),
}
return function_calling_response
except ValueError as e:
res.statuscode = 503
error_message = "Tool call extraction error"
logger.error(f" {error_message}: {e}")
return {"error": f"[Arch-Function] - {error_message} - {e}"}
except StopIteration as e:
res.statuscode = 500
error_message = "Hallucination iterator error"
logger.error(f" {error_message}: {e}")
return {"error": f"[Arch-Function] - {error_message} - {e}"}
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
res.status_code = 500
return {"error": f"[Arch-Function] - {e}"}
# [TODO] Review: define the behavior if `Arch-Intent` doesn't detect an intent
else:
return {
"result": "No intent matched",
"intent_latency": round(intent_latency * 1000, 3),
}
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
# logger.error(f"Error in chat_completion from `Arch-Intent`: {e}")
logger.error(f"Error in chat_completion /function_calling: {e}")
res.status_code = 500
return {"error": f"[Arch-Intent] - {e}"}
@app.post("/guardrails")
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
try:
guard_start_time = time.perf_counter()
guard_result = handler_map["Arch-Guard"].predict(req)
guard_latency = time.perf_counter() - guard_start_time
return {
"response": guard_result,
"guard_latency": round(guard_latency * 1000, 3),
}
except Exception as e:
# [TODO] Review: update how to collect debugging outputs
res.status_code = 500
return {"error": f"[Arch-Guard] - {e}"}

View file

View file

@ -0,0 +1,173 @@
import os
from src.commons.globals import handler_map
from src.core.model_utils import ChatMessage, Message
import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, patch
from src.main import app
from src.commons.globals import handler_map
# define function
get_weather_api = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get current weather at a location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "str",
"description": "The location to get the weather for",
"format": "City, State",
},
"unit": {
"type": "str",
"description": "The unit to return the weather in.",
"enum": ["celsius", "fahrenheit"],
"default": "celsius",
},
"days": {
"type": "str",
"description": "the number of days for the request.",
},
},
"required": ["location", "days"],
},
},
}
# get_data class return request, intent, hallucination, parameter_gathering
def get_hallucination_data_complex():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle?")
message2 = Message(
role="assistant", content="Can you specify the unit you want the weather in?"
)
message3 = Message(role="user", content="In celcius please!")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1, message2, message3], tools=tools)
return req, True, True, True
def get_hallucination_data_easy():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
# model will hallucinate
return req, True, True, True
def get_hallucination_data_medium():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
# first token will not be tool call
return req, True, True, True
def get_complete_data_2():
# Create instances of the Message class
message1 = Message(
role="user",
content="what is the weather forecast for seattle in the next 10 days?",
)
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, True, False, False
def get_complete_data():
# Create instances of the Message class
message1 = Message(role="user", content="How is the weather in Seattle in 7 days?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, True, False, False
def get_irrelevant_data():
# Create instances of the Message class
message1 = Message(role="user", content="What is 1+1?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, False, False, False
def get_greeting_data():
# Create instances of the Message class
message1 = Message(role="user", content="Hello how are you?")
# Create a list of tools
tools = [get_weather_api]
# Create an instance of the ChatMessage class
req = ChatMessage(messages=[message1], tools=tools)
return req, False, False, False
@pytest.mark.asyncio
@pytest.mark.parametrize(
"get_data_func",
[
get_hallucination_data_complex,
get_hallucination_data_easy,
get_complete_data,
get_irrelevant_data,
get_complete_data_2,
],
)
async def test_function_calling(get_data_func):
req, intent, hallucination, parameter_gathering = get_data_func()
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
if intent:
function_calling_response = await handler_map["Arch-Function"].chat_completion(
req
)
assert handler_map["Arch-Function"].hallu_handler.hallucination == hallucination
response_txt = function_calling_response.choices[0].message.content
if parameter_gathering:
prefill_prefix = handler_map["Arch-Function"].prefill_prefix
assert any(
response_txt.startswith(prefix) for prefix in prefill_prefix
), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"

View file

@ -0,0 +1,69 @@
from unittest.mock import patch, MagicMock
from src.core.guardrails import get_guardrail_handler
# Mock constants
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
# [TODO] Review: check the following code to test under `cpu`, `cuda`, and `mps`
# Test for `get_guardrail_handler()` function on `cpu`
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cpu(mock_auto_model, mock_tokenizer):
device = "cpu"
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
mock_auto_model.assert_called_once_with(
guardrail.model_name,
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `cuda`
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_cuda(mock_auto_model, mock_tokenizer):
device = "cuda"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
mock_auto_model.assert_called_once_with(
guardrail.model_name,
device_map=device,
low_cpu_mem_usage=True,
)
# Test for `get_guardrail_handler()` function on `mps`
@patch("src.core.guardrails.AutoTokenizer.from_pretrained")
@patch("src.core.guardrails.AutoModelForSequenceClassification.from_pretrained")
def test_guardrail_handler_on_mps(mock_auto_model, mock_tokenizer):
device = "mps"
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
guardrail = get_guardrail_handler(device=device)
mock_tokenizer.assert_called_once_with(guardrail.model_name, trust_remote_code=True)
mock_auto_model.assert_called_once_with(
guardrail.model_name,
device_map=device,
low_cpu_mem_usage=True,
)

View file

@ -0,0 +1,50 @@
from src.commons.globals import handler_map
from src.core.function_calling import Message
test_input_history = [
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
{
"role": "assistant",
"model": "Arch-Function",
"tool_calls": [
{
"id": "call_3394",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": {"city": "Chicago", "days": 5},
},
}
],
},
{"role": "tool", "content": "--", "tool_call_id": "call_3394"},
{"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"},
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_5306",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": {"city": "Chicago", "days": 5},
},
}
],
},
{"role": "tool", "content": "--", "tool_call_id": "call_5306"},
]
def test_update_fc_history():
message_history = []
for h in test_input_history:
message_history.append(Message(**h))
updated_history = handler_map["Arch-Function"]._process_messages(message_history)
assert len(updated_history) == 7
# ensure that tool role does not exist anymore
assert all([h["role"] != "tool" for h in updated_history])

View file

@ -0,0 +1,53 @@
import pytest
import httpx
from fastapi.testclient import TestClient
from src.main import app
client = TestClient(app)
# [TODO] Review: check the following code. Seems something wrong with asyncio package❗
# Unit tests for the health check endpoint
@pytest.mark.asyncio
async def test_healthz():
response = client.get("/healthz")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
# [TODO] Review: check the following code. Seems something wrong with asyncio package❗
# Unit test for the models endpoint
@pytest.mark.asyncio
async def test_models():
response = client.get("/models")
assert response.status_code == 200
assert response.json()["object"] == "list"
assert len(response.json()["data"]) > 0
# [TODO] Review: check the following code. Seems something wrong with asyncio package❗
# Unit test for the guardrail endpoint
@pytest.mark.asyncio
async def test_guardrail_endpoint():
request_data = {"input": "Test for jailbreak and toxicity", "task": "jailbreak"}
response = client.post("/guardrails", json=request_data)
assert response.status_code == 200
assert "response" in response.json()
# [TODO] Review: check the following code. Seems something wrong with asyncio package❗
# Unit test for the function calling endpoint
@pytest.mark.asyncio
async def test_function_calling_endpoint():
async with httpx.AsyncClient(app=app, base_url="http://test") as client:
request_data = {
"messages": [{"role": "user", "content": "Hello!"}],
"model": "Arch-Function",
"tools": [],
"metadata": {"x-arch-state": "[]"},
}
response = await client.post("/function_calling", json=request_data)
assert response.status_code == 200
assert "result" in response.json()