Init update on model_server

This commit is contained in:
Shuguang Chen 2024-12-04 16:41:30 -08:00
parent 1d9de28086
commit afe1410b37
25 changed files with 1758 additions and 1922 deletions

View file

@ -1,10 +1,7 @@
import importlib
import sys
import os
import time
import requests
import psutil
import tempfile
import subprocess
import logging

View file

@ -1,38 +1,83 @@
import app.commons.globals as glb
import app.commons.utilities as utils
import app.loader as loader
# ========================== Arch-Intent Default Params ==========================
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
ARCH_INTENT_INSTRUCTION = "Are there any tools can help?"
from app.function_calling.model_handler import ArchFunctionHandler
from app.prompt_guard.model_handler import ArchGuardHanlder
ARCH_INTENT_TASK_PROMPT = """
You are a helpful assistant.
""".strip()
logger = utils.get_model_server_logger()
arch_function_hanlder = ArchFunctionHandler()
PREFILL_LIST = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"]
PREFILL_ENABLED = True
TOOL_CALL_TOKEN = "<tool_call>"
arch_function_endpoint = "https://api.fc.archgw.com/v1"
arch_function_client = utils.get_client(arch_function_endpoint)
arch_function_generation_params = {
"temperature": 0.2,
"top_p": 1.0,
"top_k": 50,
"max_tokens": 512,
"stop_token_ids": [151645],
# "top_logprobs": 10,
ARCH_INTENT_TOOL_PROMPT = """
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
<tools>
{tool_text}
</tools>
""".strip()
ARCH_INTENT_FORMAT_PROMPT = """
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
- First line must read 'Yes' or 'No'.
- If yes, a second line must include a comma-separated list of tool indexes.
""".strip()
ARCH_INTENT_GENERATION_CONFIG = {
"generation_params": {
"stop_token_ids": [151645],
"max_tokens": 1,
"guided_choice": ["Yes", "No"],
}
}
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
# ========================== Arch-Function Default Params ==========================
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
ARCH_FUNCTION_TASK_PROMPT = """
You are a helpful assistant.
""".strip()
ARCH_FUNCTION_TOOL_PROMPT = """
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
""".strip()
ARCH_FUNCTION_FORMAT_PROMPT = """
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
""".strip()
ARCH_FUNCTION_GENERATION_CONFIG = {
"generation_params": {
"temperature": 0.2,
"top_p": 1.0,
"top_k": 50,
"max_tokens": 512,
"stop_token_ids": [151645],
},
"prefill_params": {
"continue_final_message": True,
"add_generation_prompt": False,
},
"prefill_prefix": [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
],
}
# Model definition
embedding_model = loader.get_embedding_model()
zero_shot_model = loader.get_zero_shot_model()
prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE])
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
# Patterns for function name and parameter parsing

View file

@ -1,4 +1,65 @@
import app.commons.utilities as utils
from app.commons.constants import *
from app.model_handler.function_calling import ArchIntentHandler, ArchFunctionHandler
from app.model_handler.guardrails import ArchGuardHanlder
DEVICE = utils.get_device()
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
from openai import OpenAI
logger = utils.get_model_server_logger()
def get_guardrail_handler():
device = utils.get_device()
model_class, model_name = None, None
if device == "cpu":
model_class = OVModelForSequenceClassification
model_name = "katanemo/Arch-Guard-cpu"
else:
model_class = AutoModelForSequenceClassification
if device == "cuda":
model_name = "katanemo/Arch-Guard"
else:
model_name = "katanemo/Arch-Guard"
guardrail_dict = {
"device": device,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
),
}
return ArchGuardHanlder(model_dict=guardrail_dict)
# Define the client
ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
# Define model handlers
handler_map = {
"Arch-Intent": ArchIntentHandler(
ARCH_CLIENT,
ARCH_INTENT_MODEL_ALIAS,
ARCH_INTENT_TASK_PROMPT,
ARCH_INTENT_TOOL_PROMPT,
ARCH_INTENT_FORMAT_PROMPT,
ARCH_INTENT_INSTRUCTION,
**ARCH_INTENT_GENERATION_CONFIG,
),
"Arch-Function": ArchFunctionHandler(
ARCH_CLIENT,
ARCH_FUNCTION_MODEL_ALIAS,
ARCH_FUNCTION_TASK_PROMPT,
ARCH_FUNCTION_TOOL_PROMPT,
ARCH_FUNCTION_FORMAT_PROMPT,
**ARCH_FUNCTION_GENERATION_CONFIG,
),
"Arch-Guard": get_guardrail_handler(),
}

View file

@ -1,11 +1,7 @@
import os
import yaml
import torch
import string
import logging
from openai import OpenAI
logger_instance = None
@ -31,11 +27,6 @@ def get_device():
return device
def get_client(endpoint):
client = OpenAI(base_url=endpoint, api_key="EMPTY")
return client
def get_model_server_logger():
global logger_instance
@ -72,12 +63,3 @@ def get_model_server_logger():
# Initialize the logger instance after configuring handlers
logger_instance = logging.getLogger("model_server_logger")
return logger_instance
def remove_punctuations(s):
s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation)))
return " ".join(s.split()).lower()
def get_label_map(labels):
return {remove_punctuations(label): label for label in labels}

View file

@ -1,137 +0,0 @@
import json
import random
from typing import Any, Dict, List
ARCH_FUNCTION_CALLING_TASK_PROMPT = """
You are a helpful assistant.
""".strip()
ARCH_FUNCTION_CALLING_TOOL_PROMPT = """
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
""".strip()
ARCH_FUNCTION_CALLING_FORMAT_PROMPT = """
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
""".strip()
class ArchFunctionHandler:
def __init__(self) -> None:
super().__init__()
def _format_system(self, tools: List[Dict[str, Any]]):
def convert_tools(tools):
return "\n".join([json.dumps(tool) for tool in tools])
tool_text = convert_tools(tools)
system_prompt = (
ARCH_FUNCTION_CALLING_TASK_PROMPT
+ "\n\n"
+ ARCH_FUNCTION_CALLING_TOOL_PROMPT.format(tool_text=tool_text)
+ "\n\n"
+ ARCH_FUNCTION_CALLING_FORMAT_PROMPT
)
return system_prompt
def _add_execution_results_prompting(
self,
messages: list[dict],
execution_results: list,
) -> dict:
content = []
for result in execution_results:
content.append(f"<tool_response>\n{json.dumps(result)}\n</tool_response>")
content = "\n".join(content)
messages.append({"role": "user", "content": content})
return messages
def extract_tool_calls(self, content: str):
tool_calls = []
flag = False
for line in content.split("\n"):
if "<tool_call>" == line:
flag = True
elif "</tool_call>" == line:
flag = False
else:
if flag:
try:
tool_content = json.loads(line)
except Exception:
fixed_content = self.fix_json_string(line)
try:
tool_content = json.loads(fixed_content)
except json.JSONDecodeError:
return content
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
"arguments": tool_content["arguments"],
},
}
)
flag = False
return tool_calls
def fix_json_string(self, json_str: str):
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
# Stack to keep track of brackets
stack = []
# Clean string to collect valid characters
fixed_str = ""
# Dictionary for matching brackets
matching_bracket = {")": "(", "}": "{", "]": "["}
# Dictionary for the opposite of matching_bracket
opening_bracket = {v: k for k, v in matching_bracket.items()}
for char in json_str:
if char in "{[(":
stack.append(char)
fixed_str += char
elif char in "}])":
if stack and stack[-1] == matching_bracket[char]:
stack.pop()
fixed_str += char
else:
# Ignore the unmatched closing brackets
continue
else:
fixed_str += char
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
while stack:
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str.replace("'", '"')

View file

@ -1,157 +0,0 @@
import json
import hashlib
import app.commons.constants as const
import random
from fastapi import Response
from pydantic import BaseModel
from app.commons.utilities import get_model_server_logger
from typing import Any, Dict, List, Optional
logger = get_model_server_logger()
class Message(BaseModel):
role: Optional[str] = ""
content: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
tool_call_id: Optional[str] = ""
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
class Choice(BaseModel):
message: Message
finish_reason: Optional[str] = "stop"
index: Optional[int] = 0
class ChatCompletionResponse(BaseModel):
choices: List[Choice]
model: Optional[str] = "Arch-Function"
created: Optional[str] = ""
id: Optional[str] = ""
object: Optional[str] = "chat_completion"
def process_messages(history: list[Message]):
updated_history = []
for hist in history:
if hist.tool_calls:
if len(hist.tool_calls) > 1:
error_msg = f"Only one tool call is supported, tools counts: {len(hist.tool_calls)}"
logger.error(error_msg)
raise ValueError(error_msg)
tool_call_str = json.dumps(hist.tool_calls[0]["function"])
updated_history.append(
{
"role": "assistant",
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
}
)
elif hist.role == "tool":
updated_history.append(
{
"role": "user",
"content": f"<tool_response>\n{hist.content}\n</tool_response>",
}
)
else:
updated_history.append({"role": hist.role, "content": hist.content})
return updated_history
async def chat_completion(req: ChatMessage, res: Response):
logger.info("starting request")
tools_encoded = const.arch_function_hanlder._format_system(req.tools)
messages = [{"role": "system", "content": tools_encoded}]
updated_history = process_messages(req.messages)
for message in updated_history:
messages.append({"role": message["role"], "content": message["content"]})
client_model_name = const.arch_function_client.models.list().data[0].id
logger.info(
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
)
# Retrieve the first token, handling the Stream object carefully
try:
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=const.PREFILL_ENABLED,
extra_body=const.arch_function_generation_params,
)
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
if const.PREFILL_ENABLED:
first_token_content = ""
for token in resp:
first_token_content = token.choices[
0
].delta.content.strip() # Clean up the content
if first_token_content: # Break if it's non-empty
break
# Check if the first token requires tool call handling
if first_token_content != const.TOOL_CALL_TOKEN:
# Engage pre-filling response if no tool call is indicated
resp.close()
logger.info("Tool call is not found! Engage pre filling")
prefill_content = random.choice(const.PREFILL_LIST)
messages.append({"role": "assistant", "content": prefill_content})
# Send a new completion request with the updated messages
# the model will continue the final message in the chat instead of starting a new one
# disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response.
extra_body = {
**const.arch_function_generation_params,
"continue_final_message": True,
"add_generation_prompt": False,
}
pre_fill_resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
extra_body=extra_body,
)
full_response = pre_fill_resp.choices[0].message.content
else:
# Initialize full response and iterate over tokens to gather the full response
full_response = first_token_content
for token in resp:
if hasattr(token.choices[0].delta, "content"):
full_response += token.choices[0].delta.content
else:
logger.info("Stream is disabled, not engaging pre-filling")
full_response = resp.choices[0].message.content
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
if tool_calls:
message = Message(content="", tool_calls=tool_calls)
else:
message = Message(content=full_response, tool_calls=[])
choice = Choice(message=message)
chat_completion_response = ChatCompletionResponse(
choices=[choice], model=client_model_name
)
logger.info(
f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
)
logger.info(
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
)
return chat_completion_response

View file

@ -1,84 +0,0 @@
import os
import app.commons.globals as glb
from transformers import AutoTokenizer, AutoModel, pipeline
from optimum.onnxruntime import (
ORTModelForFeatureExtraction,
ORTModelForSequenceClassification,
)
import app.commons.utilities as utils
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from optimum.intel import OVModelForSequenceClassification
logger = utils.get_model_server_logger()
def get_embedding_model(
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5"),
):
logger.info("Loading Embedding Model...")
if glb.DEVICE != "cuda":
model = ORTModelForFeatureExtraction.from_pretrained(
model_name, file_name="onnx/model.onnx"
)
else:
model = AutoModel.from_pretrained(model_name, device_map=glb.DEVICE)
embedding_model = {
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model,
}
return embedding_model
def get_zero_shot_model(
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
):
logger.info("Loading Zero-shot Model...")
if glb.DEVICE != "cuda":
model = ORTModelForSequenceClassification.from_pretrained(
model_name, file_name="onnx/model.onnx"
)
else:
model = model_name
zero_shot_model = {
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name),
"model": model,
}
zero_shot_model["pipeline"] = pipeline(
"zero-shot-classification",
model=zero_shot_model["model"],
tokenizer=zero_shot_model["tokenizer"],
device=glb.DEVICE,
)
return zero_shot_model
def get_prompt_guard(model_name):
logger.info("Loading Guard Model...")
if glb.DEVICE == "cpu":
model_class = OVModelForSequenceClassification
else:
model_class = AutoModelForSequenceClassification
prompt_guard = {
"device": glb.DEVICE,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True
),
}
return prompt_guard

View file

@ -1,21 +1,10 @@
import os
import time
import torch
import app.commons.utilities as utils
import app.commons.globals as glb
import app.prompt_guard.model_utils as guard_utils
from typing import List, Dict
from pydantic import BaseModel
from fastapi import FastAPI, Response, HTTPException, Request
from app.function_calling.model_utils import ChatMessage
from app.commons.constants import embedding_model, zero_shot_model, arch_guard_handler
from app.function_calling.model_utils import (
chat_completion as arch_function_chat_completion,
)
from unittest.mock import patch
from app.commons.globals import handler_map
from app.model_handler.function_calling import ChatMessage
from app.model_handler.guardrails import GuardRequest
from fastapi import FastAPI, Response, Request
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
@ -23,6 +12,7 @@ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExport
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.resources import Resource
resource = Resource.create(
{
"service.name": "model-server",
@ -34,10 +24,6 @@ trace.set_tracer_provider(TracerProvider(resource=resource))
tracer = trace.get_tracer(__name__)
logger = utils.get_model_server_logger()
logger.info(f"Ready to serve traffic. available device: {glb.DEVICE}")
app = FastAPI()
FastAPIInstrumentor().instrument_app(app)
@ -53,28 +39,6 @@ otlp_exporter = OTLPSpanExporter(
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
class EmbeddingRequest(BaseModel):
input: str
model: str
class GuardRequest(BaseModel):
input: str
task: str
class ZeroShotRequest(BaseModel):
input: str
labels: List[str]
model: str
class HallucinationRequest(BaseModel):
prompt: str
parameters: Dict
model: str
@app.get("/healthz")
async def healthz():
return {"status": "ok"}
@ -84,172 +48,40 @@ async def healthz():
async def models():
return {
"object": "list",
"data": [{"id": embedding_model["model_name"], "object": "model"}],
"data": [{"id": model_name, "object": "model"} for model_name in handler_map],
}
@app.post("/embeddings")
async def embedding(req: EmbeddingRequest, res: Response):
logger.info(f"Embedding req: {req}")
if req.model != embedding_model["model_name"]:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
start_time = time.perf_counter()
encoded_input = embedding_model["tokenizer"](
req.input, padding=True, truncation=True, return_tensors="pt"
).to(glb.DEVICE)
with torch.no_grad():
embeddings = embedding_model["model"](**encoded_input)
embeddings = embeddings[0][:, 0]
embeddings = (
torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().numpy()
)
logger.info(f"Embedding Call Complete Time: {time.perf_counter()-start_time}")
data = [
{"object": "embedding", "embedding": embedding, "index": index + 1}
for index, embedding in enumerate(embeddings.tolist())
]
usage = {
"prompt_tokens": 0,
"total_tokens": 0,
}
return {"data": data, "model": req.model, "object": "list", "usage": usage}
@app.post("/guard")
async def guard(req: GuardRequest, res: Response, max_num_words=300):
"""
Take input as text and return the prediction of toxic and jailbreak
"""
if req.task in ["both", "toxic", "jailbreak"]:
arch_guard_handler.task = req.task
else:
raise NotImplementedError(f"{req.task} is not supported!")
start_time = time.perf_counter()
if len(req.input.split()) < max_num_words:
guard_result = arch_guard_handler.guard_predict(req.input)
else:
# text is long, split into chunks
chunks = guard_utils.split_text_into_chunks(req.input)
guard_result = {
"jailbreak_prob": [],
"time": 0,
"jailbreak_verdict": False,
"toxic_sentence": [],
"jailbreak_sentence": [],
}
for chunk in chunks:
chunk_result = arch_guard_handler.guard_predict(chunk)
guard_result["time"] += chunk_result["time"]
if chunk_result[f"{arch_guard_handler.task}_verdict"]:
guard_result[f"{arch_guard_handler.task}_verdict"] = True
guard_result[f"{arch_guard_handler.task}_sentence"].append(
chunk_result[f"{arch_guard_handler.task}_sentence"]
)
guard_result[f"{arch_guard_handler.task}_prob"].append(
chunk_result[f"{arch_guard_handler.task}_prob"].item()
)
logger.info(f"Time taken for Guard: {time.perf_counter() - start_time}")
return guard_result
@app.post("/zeroshot")
async def zeroshot(req: ZeroShotRequest, res: Response):
logger.info(f"zero-shot request: {req}")
if req.model != zero_shot_model["model_name"]:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
classifier = zero_shot_model["pipeline"]
label_map = utils.get_label_map(req.labels)
start_time = time.perf_counter()
predictions = classifier(
req.input, candidate_labels=list(label_map.keys()), multi_label=True
)
logger.info(f"zero-shot taking {time.perf_counter() - start_time} seconds")
predicted_class = label_map[predictions["labels"][0]]
predicted_score = predictions["scores"][0]
scores = {
label_map[label]: score
for label, score in zip(predictions["labels"], predictions["scores"])
}
predicted_class = label_map[predictions["labels"][0]]
return {
"predicted_class": predicted_class,
"predicted_class_score": predicted_score,
"scores": scores,
"model": req.model,
}
@app.post("/hallucination")
@patch("app.loader.glb.DEVICE", "cpu") # Mock the device to 'cpu'
async def hallucination(req: HallucinationRequest, res: Response):
"""
Take input as text and return the prediction of hallucination for each parameter
"""
logger.info(f"hallucination request: {req}")
if req.model != zero_shot_model["model_name"]:
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
start_time = time.perf_counter()
classifier = zero_shot_model["pipeline"]
if "messages" in req.parameters:
req.parameters.pop("messages")
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
predictions = classifier(
req.prompt,
candidate_labels=list(candidate_labels.keys()),
hypothesis_template="{}",
multi_label=True,
)
params_scores = {
candidate_labels[label]: score
for label, score in zip(predictions["labels"], predictions["scores"])
}
logger.info(
f"hallucination time cost: {params_scores}, taking {time.perf_counter() - start_time} seconds"
)
return {
"params_scores": params_scores,
"model": req.model,
}
@app.post("/v1/chat/completions")
async def chat_completion(req: ChatMessage, res: Response, request: Request):
@app.post("/function_calling")
async def function_calling(req: ChatMessage, res: Response, request: Request):
try:
result = await arch_function_chat_completion(req, res)
return result
intent_result = await handler_map["Arch-Intent"].chat_completion(req)
if intent_result.choices[0].message.content == "Yes":
try:
function_result = await handler_map["Arch-Function"].chat_completion(
req
)
return function_result
except Exception as e:
# [TODO]
# logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
res.status_code = 500
return {"error": f"[Arch-Function] - {e}"}
except Exception as e:
logger.error(f"Error in chat_completion: {e}")
# [TODO]
# logger.error(f"Error in chat_completion from `Arch-Intent`: {e}")
res.status_code = 500
return {"error": "Internal server error"}
return {"error": f"[Arch-Intent] - {e}"}
@app.post("/guardrails")
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
try:
guard_result = handler_map["Arch-Guard"].predict(req)
return guard_result
except Exception as e:
# [TODO]
res.status_code = 500
return {"error": f"[Arch-Guard] - {e}"}

View file

@ -0,0 +1,415 @@
import json
import random
import builtins
from openai import OpenAI
from pydantic import BaseModel
from typing import Any, Dict, List, Optional
from overrides import override, final
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
class Message(BaseModel):
role: Optional[str] = ""
content: Optional[str] = ""
tool_call_id: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
class Choice(BaseModel):
id: Optional[int] = 0
message: Message
finish_reason: Optional[str] = "stop"
class ChatCompletionResponse(BaseModel):
id: Optional[int] = 0
object: Optional[str] = "chat_completion"
created: Optional[str] = ""
model: str
choices: List[Choice]
class ArchBaseHandler:
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
format_prompt: str,
generation_params: Dict,
):
self.client = client
self.model_name = model_name
self.task_prompt = task_prompt
self.tool_prompt = tool_prompt
self.format_prompt = format_prompt
self.generation_params = generation_params
def _convert_tools(self, tools: List[Dict[str, Any]]):
raise NotImplementedError()
@final
def _format_system(self, tools: List[Dict[str, Any]]):
tool_text = self._convert_tools(tools)
system_prompt = (
self.task_prompt
+ "\n\n"
+ self.tool_prompt.format(tool_text=tool_text)
+ "\n\n"
+ self.format_prompt
)
return system_prompt
@final
def _process_messages(
self,
messages: List[Message],
tools: List[Dict[str, Any]] = None,
extra_instructions: str = None,
):
processed_messages = []
if tools:
processed_messages.append(
{"role": "system", "content": self._format_system(tools)}
)
for message in messages:
role, content, tool_calls = (
message.role,
message.content,
message.tool_calls,
)
if tool_calls:
# [TODO] Extend to support multiple function calls
role = "assistant"
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
elif message.role == "tool":
role = "user"
content = (
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
)
processed_messages.append({"role": role, "content": content})
assert processed_messages[-1]["role"] == "user"
if extra_instructions:
processed_messages[-1]["content"] += extra_instructions
return processed_messages
async def chat_completion(self, req: ChatMessage):
raise NotImplementedError()
class ArchIntentHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
format_prompt: str,
intent_instruction: str,
generation_params: Dict,
):
super().__init__(
client,
model_name,
task_prompt,
tool_prompt,
format_prompt,
generation_params,
)
self.intent_instruction = intent_instruction
@override
def _convert_tools(self, tools: List[Dict[str, Any]]):
converted = [
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
]
return "\n".join(converted)
@override
async def chat_completion(self, req: ChatMessage):
"""
Note: Currently only support vllm inference
"""
messages = self._process_messages(
req.messages, req.tools, self.intent_instruction
)
model_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body=self.generation_params,
)
model_response = Message(content=model_response, tool_calls=[])
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
)
return chat_completion_response
class ArchFunctionHandler(ArchBaseHandler):
def __init__(
self,
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
format_prompt: str,
generation_params: Dict,
prefill_params: Dict,
prefill_prefix: List,
):
super().__init__(
client,
model_name,
task_prompt,
tool_prompt,
format_prompt,
generation_params,
)
self.prefill_params = prefill_params
self.prefill_prefix = prefill_prefix
# Predefine data types for verification. Only support Python for now.
# [TODO] Extend the list of support data types
self.support_data_types = {
type_name: getattr(builtins, type_name) for type_name in SUPPORT_DATA_TYPES
}
@override
def _convert_tools(self, tools: List[Dict[str, Any]]):
converted = [json.dumps(tool) for tool in tools]
return "\n".join(converted)
def _fix_json_string(self, json_str: str):
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
# Stack to keep track of brackets
stack = []
# Clean string to collect valid characters
fixed_str = ""
# Dictionary for matching brackets
matching_bracket = {")": "(", "}": "{", "]": "["}
# Dictionary for the opposite of matching_bracket
opening_bracket = {v: k for k, v in matching_bracket.items()}
for char in json_str:
if char in "{[(":
stack.append(char)
fixed_str += char
elif char in "}])":
if stack and stack[-1] == matching_bracket[char]:
stack.pop()
fixed_str += char
else:
# Ignore the unmatched closing brackets
continue
else:
fixed_str += char
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
while stack:
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str.replace("'", '"')
def _extract_tool_calls(self, content: str):
tool_calls, is_valid, error_message = [], True, ""
flag = False
for line in content.split("\n"):
if "<tool_call>" == line:
flag = True
elif "</tool_call>" == line:
flag = False
else:
if flag:
try:
tool_content = json.loads(line)
except Exception as e:
fixed_content = self._fix_json_string(line)
try:
tool_content = json.loads(fixed_content)
except Exception:
tool_calls, is_valid, error_message = [], False, e
return tool_calls, is_valid, error_message
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_content["name"],
"arguments": tool_content["arguments"],
},
}
)
flag = False
return tool_calls, is_valid, error_message
def _verify_tool_calls(
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
):
is_valid, error_tool_call, error_message = True, None, ""
functions = {}
for tool in tools:
if tool["type"] == "function":
functions[tool["function"]["name"]] = tool["function"]["parameters"]
for tool_call in tool_calls:
func_name, func_args = (
tool_call["function"]["name"],
tool_call["function"]["arguments"],
)
# Check whether the function is available or not
if func_name not in functions:
is_valid = False
error_message = f"{func_name} is not defined!"
return is_valid, error_message
else:
# Check if all the requried parameters can be found in the tool calls
for required_param in functions[func_name].get("required", []):
if required_param not in func_args:
is_valid = False
error_tool_call = tool_call
error_message = f"`{required_param}` is requried by the function `{func_name}` but not found in the tool call!"
return is_valid, error_tool_call, error_message
# Verify the data type of each parameter in the tool calls
for param_name, param_value in func_args:
data_type = functions[func_name]["properties"][param_name]["type"]
if data_type in self.support_data_types:
if not isinstance(
param_value, self.support_data_types[data_type]
):
is_valid = False
error_tool_call = tool_call
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
return is_valid, error_tool_call, error_message
return is_valid, error_tool_call, error_message
@override
async def chat_completion(self, req: ChatMessage, enable_prefilling=True):
"""
Note: Currently only support vllm inference
"""
messages = self._process_messages(req.messages, req.tools)
# Retrieve the first token, handling the Stream object carefully
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=enable_prefilling,
extra_body=self.generation_params,
)
model_response = ""
if enable_prefilling:
has_tool_call = None
model_response = ""
for token in response:
token_content = token.choices[0].delta.content.strip()
if has_tool_call is None and token_content != "<tool_call>":
has_tool_call = False
response.close()
break
else:
has_tool_call = True
if has_tool_call is True:
model_response += token_content
# start parameter gathering if the model is not generating a tool call
if has_tool_call is False:
messages.append(
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
)
prefill_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
model_response = prefill_response.choices[0].message.content
else:
model_response = response.choices[0].message.content
tool_calls, is_valid, error_message = self._extract_tool_calls(model_response)
if tool_calls:
is_valid, error_tool_call, error_message = self._verify_tool_calls(
tools=req.tools, tool_calls=tool_calls
)
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
if is_valid:
model_response = Message(content="", tool_calls=tool_calls)
# else:
else:
model_response = Message(content=model_response, tool_calls=[])
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name
)
# [TODO] Review: define the protocol to collect debugging output
# logger.info(
# f"model_server <= arch_function: (tool_calls): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
# )
# logger.info(
# f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
# )
return chat_completion_response

View file

@ -0,0 +1,95 @@
import time
import torch
import numpy as np
from pydantic import BaseModel
class GuardRequest(BaseModel):
input: str
task: str
class ArchGuardHanlder:
def __init__(self, model_dict):
self.model = model_dict["model"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
self.support_tasks = {"jailbreak": {"positive_class": 2, "threshold": 0.5}}
def _split_text_into_chunks(self, text, max_num_words=300):
"""
Split the text into chunks of `max_num_words` words
"""
words = text.split() # Split text into words
chunks = [
" ".join(words[i : i + max_num_words])
for i in range(0, len(words), max_num_words)
]
return chunks
@staticmethod
def softmax(x):
return np.exp(x) / np.exp(x).sum(axis=0)
def _predict_text(self, task, text, max_length=512):
inputs = self.tokenizer(
text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device)
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = ArchGuardHanlder.softmax(logits)[
self.support_tasks[task]["positive_class"]
]
if prob > self.support_tasks[task]["threshold"]:
verdict = True
sentence = text
else:
verdict = False
sentence = None
result_dict = {
"prob": prob.item(),
"verdict": verdict,
"sentence": sentence,
}
return result_dict
def predict(self, req: GuardRequest, max_num_words=300):
"""
Note: currently only support jailbreak check
"""
if req.task not in self.support_tasks:
raise NotImplementedError(f"{req.task} is not supported!")
guard_result = {
"prob": [],
"verdict": False,
"sentence": [],
}
start_time = time.perf_counter()
if len(req.input.split()) < max_num_words:
guard_result = self._predict_text(req.task, req.input)
else:
# split into chunks if text is long
text_chunks = self._split_text_into_chunks(req.input)
for chunk in text_chunks:
chunk_result = self._predict_text(req.task, chunk)
if chunk_result["verdict"]:
guard_result["verdict"] = True
guard_result["sentence"].append(chunk_result["sentence"])
guard_result["prob"].append(chunk_result["prob"].item())
guard_result["latency"] = time.perf_counter() - start_time
return guard_result

View file

@ -1,8 +1,6 @@
import json
import math
import torch
import random
from typing import Any, Dict, List, Tuple
from typing import Dict, List, Tuple
import itertools
from enum import Enum
@ -74,26 +72,6 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
return entropy.item(), varentropy.item()
def is_parameter_property(
function_description: Dict, parameter_name: str, property_name: str
) -> bool:
"""
Check if a parameter in an API description has a specific property.
Args:
function_description (dict): The API description in JSON format.
parameter_name (str): The name of the parameter to check.
property_name (str): The property to look for (e.g., 'format', 'default').
Returns:
bool: True if the parameter has the specified property, False otherwise.
"""
parameters = function_description.get("properties", {})
parameter_info = parameters.get(parameter_name, {})
return property_name in parameter_info
class HallucinationStateHandler:
"""
A class to handle the state of hallucination detection in token processing.
@ -111,7 +89,7 @@ class HallucinationStateHandler:
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
"""
def __init__(self, response_iterator=None, function=None):
def __init__(self, response_iterator=None):
"""
Initializes the HallucinationStateHandler with default values.
"""
@ -126,19 +104,6 @@ class HallucinationStateHandler:
self.parameter_name: List[str] = []
self.token_probs_map: List[Tuple[str, float, float]] = []
self.response_iterator = response_iterator
self._process_function(function)
def _process_function(self, function):
self.function = function
if self.function is None:
raise ValueError("API descriptions not set.")
parameter_names = {}
for func in self.function:
func_name = func["name"]
parameters = func["parameters"]["properties"]
parameter_names[func_name] = list(parameters.keys())
self.function_description = parameter_names
self.function_properties = {x["name"]: x["parameters"] for x in self.function}
def append_and_check_token_hallucination(self, token, logprob):
"""
@ -237,6 +202,8 @@ class HallucinationStateHandler:
# checking if the token is a value token and is not empty
if self.tokens[-1].strip() not in ['"', ""]:
self.mask.append(MaskToken.PARAMETER_VALUE)
# [TODO] Review: update the following code: `is_parameter_property` should not be here
# checking if the parameter doesn't have default and the token is the first parameter value token
if (
len(self.mask) > 1
@ -299,26 +266,3 @@ class HallucinationStateHandler:
if self.mask and self.mask[-1] == token
else 0
)
def _is_function_name_hallucinated(self):
"""
Checks the extracted function name against the function descriptions.
Detects hallucinations if the function name is not found.
"""
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
self.function_name = "".join(self.tokens[:-1][-f_len:])
if self.function_name not in self.function_description.keys():
self.error_type = "function_name"
self.error_message = f"Function name '{self.function_name}' not found in given function descriptions."
def _is_parameter_name_hallucinated(self):
"""
Checks the extracted parameter name against the function descriptions.
Detects hallucinations if the parameter name is not found.
"""
p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
parameter_name = "".join(self.tokens[:-1][-p_len:])
self.parameter_name.append(parameter_name)
if parameter_name not in self.function_description[self.function_name]:
self.error_type = "parameter_name"
self.error_message = f"Parameter name '{parameter_name}' not found in given function descriptions."

View file

@ -1,42 +0,0 @@
import time
import torch
import app.prompt_guard.model_utils as model_utils
class ArchGuardHanlder:
def __init__(self, model_dict, threshold=0.5):
self.task = "jailbreak"
self.positive_class = 2
self.model = model_dict["model"]
self.tokenizer = model_dict["tokenizer"]
self.device = model_dict["device"]
self.threshold = threshold
def guard_predict(self, input_text, max_length=512):
start_time = time.perf_counter()
inputs = self.tokenizer(
input_text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device)
with torch.no_grad():
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
prob = model_utils.softmax(logits)[self.positive_class]
if prob > self.threshold:
verdict = True
sentence = input_text
else:
verdict = False
sentence = None
result_dict = {
f"{self.task}_prob": prob.item(),
f"{self.task}_verdict": verdict,
f"{self.task}_sentence": sentence,
"time": time.perf_counter() - start_time,
}
return result_dict

View file

@ -1,19 +0,0 @@
import numpy as np
def split_text_into_chunks(text, max_words=300):
"""
Max number of tokens for tokenizer is 512
Split the text into chunks of 300 words (as approximation for tokens)
"""
words = text.split() # Split text into words
# Estimate token count based on word count (1 word ≈ 1 token)
chunk_size = max_words # Use the word count as an approximation for tokens
chunks = [
" ".join(words[i : i + chunk_size]) for i in range(0, len(words), chunk_size)
]
return chunks
def softmax(x):
return np.exp(x) / np.exp(x).sum(axis=0)

View file

@ -13,6 +13,7 @@ client = TestClient(app)
logger.info(f"Model will be loaded on device: {glb.DEVICE}")
# [TODO] Review: update the following code
# Unit tests for the health check endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
@ -22,6 +23,7 @@ async def test_healthz():
assert response.json() == {"status": "ok"}
# [TODO] Review: update the following code
# Unit test for the models endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
@ -32,6 +34,7 @@ async def test_models():
assert len(response.json()["data"]) > 0
# [TODO] Review: update the following code
# Unit test for embeddings endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
@ -46,6 +49,7 @@ async def test_embedding():
assert response.status_code == 400
# [TODO] Review: update the following code
# Unit test for the guard endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
@ -56,6 +60,7 @@ async def test_guard():
assert "jailbreak_verdict" in response.json()
# [TODO] Review: update the following code
# Unit test for the zero-shot endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
@ -73,6 +78,7 @@ async def test_zeroshot():
assert response.status_code == 400
# [TODO] Review: update the following code
# Unit test for the hallucination endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
@ -90,6 +96,7 @@ async def test_hallucination():
assert response.status_code == 400
# [TODO] Review: update the following code
# Unit test for the chat completion endpoint
@pytest.mark.asyncio
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,6 @@
import unittest
from unittest.mock import patch, MagicMock
import subprocess
import time
from app.cli import kill_process

View file

@ -1,13 +1,12 @@
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import app.commons.constants as const
from fastapi import Response
from app.function_calling.model_utils import (
process_messages,
chat_completion,
from unittest.mock import AsyncMock, MagicMock, patch
from app.commons.globals import handler_map
from app.model_handler.function_calling import (
Message,
ChatMessage,
Choice,
ChatCompletionResponse,
)
@ -31,14 +30,27 @@ def sample_messages():
def sample_request(sample_messages):
return ChatMessage(
messages=sample_messages,
tools=[{"name": "sample_tool", "description": "A sample tool"}],
tools=[
{
"type": "function",
"function": {
"name": "sample_tool",
"description": "A sample tool",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
],
)
@patch("app.commons.constants.arch_function_hanlder")
@patch("app.commons.globals.handler_map")
def test_process_messages(mock_hanlder):
messages = sample_messages()
processed = process_messages(messages)
processed = handler_map["Arch-Function"]._process_messages(messages)
assert len(processed) == 3
assert processed[0] == {"role": "user", "content": "Hello!"}
@ -48,10 +60,11 @@ def test_process_messages(mock_hanlder):
}
assert processed[2] == {
"role": "user",
"content": "<tool_response>\nResponse from tool\n</tool_response>",
"content": f"<tool_response>\n{json.dumps('Response from tool')}\n</tool_response>",
}
# [TODO] Review: Update the following test
@patch("app.commons.constants.arch_function_client")
@patch("app.commons.constants.arch_function_hanlder")
@pytest.mark.asyncio

View file

@ -0,0 +1,47 @@
import os
import pytest
from unittest.mock import patch, MagicMock
import app.commons.globals as glb
# Mock constants
glb.DEVICE = "cpu" # Adjust as needed for your test case
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
# [TODO] Review: update the following code to test under `cpu`, `cuda`, and `mps`
# Test for get_prompt_guard function
@patch("app.loader.AutoTokenizer.from_pretrained")
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
# Mock model based on device
if glb.DEVICE == "cpu":
mock_ov_model.return_value = MagicMock()
else:
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
# Assertions
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
mock_tokenizer.assert_called_once_with(
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
)
if glb.DEVICE == "cpu":
mock_ov_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)
else:
mock_auto_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)

View file

@ -1,8 +1,11 @@
import json
from app.function_calling.hallucination_handler import HallucinationStateHandler
import pytest
import os
from app.model_handler.hallucination_handler import HallucinationStateHandler
# Get the directory of the current file
current_dir = os.path.dirname(__file__)
@ -45,6 +48,7 @@ if type(function_description) != list:
function_description = [get_weather_api["function"]]
# [TODO] Review: update the following code
@pytest.mark.parametrize("case", test_cases)
def test_hallucination(case):
state = HallucinationStateHandler(
@ -58,6 +62,7 @@ def test_hallucination(case):
assert state.hallucination == case["expect"]
# [TODO] Review: update the following code
@pytest.mark.parametrize("is_hallucinate_sample", [True, False])
def test_hallucination_prompt(is_hallucinate_sample):
TASK_PROMPT = """

View file

@ -1,102 +0,0 @@
import os
import pytest
from unittest.mock import patch, MagicMock
import app.commons.globals as glb
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
# Mock constants
glb.DEVICE = "cpu" # Adjust as needed for your test case
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
@pytest.fixture
def mock_env():
# Mock environment variables
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
# Test for get_embedding_model function
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
@patch("app.loader.AutoModel.from_pretrained")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
mock_automodel.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
embedding_model = get_embedding_model()
# Assertions
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
mock_tokenizer.assert_called_once_with(
"katanemo/bge-large-en-v1.5", trust_remote_code=True
)
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
)
else:
mock_automodel.assert_called_once_with(
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
)
# Test for get_zero_shot_model function
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
@patch("app.loader.pipeline")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
mock_pipeline.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
zero_shot_model = get_zero_shot_model()
# Assertions
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
)
else:
assert mock_pipeline.called_once()
# Test for get_prompt_guard function
@patch("app.loader.AutoTokenizer.from_pretrained")
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
# Mock model based on device
if glb.DEVICE == "cpu":
mock_ov_model.return_value = MagicMock()
else:
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
# Assertions
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
mock_tokenizer.assert_called_once_with(
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
)
if glb.DEVICE == "cpu":
mock_ov_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)
else:
mock_auto_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)

View file

@ -1,102 +0,0 @@
import os
import pytest
from unittest.mock import patch, MagicMock
import app.commons.globals as glb
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
# Mock constants
glb.DEVICE = "cuda" # Adjust as needed for your test case
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
@pytest.fixture
def mock_env():
# Mock environment variables
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
# Test for get_embedding_model function
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
@patch("app.loader.AutoModel.from_pretrained")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
mock_automodel.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
embedding_model = get_embedding_model()
# Assertions
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
mock_tokenizer.assert_called_once_with(
"katanemo/bge-large-en-v1.5", trust_remote_code=True
)
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
)
else:
mock_automodel.assert_called_once_with(
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
)
# Test for get_zero_shot_model function
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
@patch("app.loader.pipeline")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
mock_pipeline.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
zero_shot_model = get_zero_shot_model()
# Assertions
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
)
else:
assert mock_pipeline.called_once()
# Test for get_prompt_guard function
@patch("app.loader.AutoTokenizer.from_pretrained")
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
# Mock model based on device
if glb.DEVICE == "cpu":
mock_ov_model.return_value = MagicMock()
else:
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
# Assertions
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
mock_tokenizer.assert_called_once_with(
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
)
if glb.DEVICE == "cpu":
mock_ov_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)
else:
mock_auto_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)

View file

@ -1,102 +0,0 @@
import os
import pytest
from unittest.mock import patch, MagicMock
import app.commons.globals as glb
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
# Mock constants
glb.DEVICE = "mps" # Adjust as needed for your test case
arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
@pytest.fixture
def mock_env():
# Mock environment variables
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
# Test for get_embedding_model function
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
@patch("app.loader.AutoModel.from_pretrained")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
mock_automodel.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
embedding_model = get_embedding_model()
# Assertions
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
mock_tokenizer.assert_called_once_with(
"katanemo/bge-large-en-v1.5", trust_remote_code=True
)
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
)
else:
mock_automodel.assert_called_once_with(
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
)
# Test for get_zero_shot_model function
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
@patch("app.loader.pipeline")
@patch("app.loader.AutoTokenizer.from_pretrained")
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
mock_pipeline.return_value = MagicMock()
mock_ort_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
zero_shot_model = get_zero_shot_model()
# Assertions
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
if glb.DEVICE != "cuda":
mock_ort_model.assert_called_once_with(
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
)
else:
assert mock_pipeline.called_once()
# Test for get_prompt_guard function
@patch("app.loader.AutoTokenizer.from_pretrained")
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
# Mock model based on device
if glb.DEVICE == "cpu":
mock_ov_model.return_value = MagicMock()
else:
mock_auto_model.return_value = MagicMock()
mock_tokenizer.return_value = MagicMock()
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
# Assertions
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
mock_tokenizer.assert_called_once_with(
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
)
if glb.DEVICE == "cpu":
mock_ov_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)
else:
mock_auto_model.assert_called_once_with(
arch_guard_model_type[glb.DEVICE],
device_map=glb.DEVICE,
low_cpu_mem_usage=True,
)

View file

@ -1,66 +1,50 @@
from typing import List
import pytest
import json
from app.function_calling.model_utils import Message, process_messages
from app.commons.globals import handler_map
from app.model_handler.function_calling import Message
test_input_history = """
[
test_input_history = [
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
{
"role": "user",
"content": "how is the weather in chicago for next 5 days?"
"role": "assistant",
"model": "Arch-Function",
"tool_calls": [
{
"id": "call_3394",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": {"city": "Chicago", "days": 5},
},
}
],
},
{"role": "tool", "content": "--", "tool_call_id": "call_3394"},
{"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"},
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
{
"role": "assistant",
"model": "Arch-Function-1.5B",
"tool_calls": [
{
"id": "call_3394",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": { "city": "Chicago", "days": 5 }
}
}
]
"role": "assistant",
"tool_calls": [
{
"id": "call_5306",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": {"city": "Chicago", "days": 5},
},
}
],
},
{
"role": "tool",
"content": "--",
"tool_call_id": "call_3394"
},
{
"role": "assistant",
"content": "--",
"model": "gpt-3.5-turbo-0125"
},
{
"role": "user",
"content": "how is the weather in chicago for next 5 days?"
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_5306",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": { "city": "Chicago", "days": 5 }
}
}
]
}
]
"""
{"role": "tool", "content": "--", "tool_call_id": "call_5306"},
]
def test_update_fc_history():
history = json.loads(test_input_history)
message_history = []
for h in history:
for h in test_input_history:
message_history.append(Message(**h))
updated_history = process_messages(message_history)
assert len(updated_history) == 6
updated_history = handler_map["Arch-Function"]._process_messages(message_history)
assert len(updated_history) == 7
# ensure that tool role does not exist anymore
assert all([h["role"] != "tool" for h in updated_history])