mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Init update on model_server
This commit is contained in:
parent
1d9de28086
commit
afe1410b37
25 changed files with 1758 additions and 1922 deletions
|
|
@ -1,10 +1,7 @@
|
|||
import importlib
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import psutil
|
||||
import tempfile
|
||||
import subprocess
|
||||
import logging
|
||||
|
||||
|
|
|
|||
|
|
@ -1,38 +1,83 @@
|
|||
import app.commons.globals as glb
|
||||
import app.commons.utilities as utils
|
||||
import app.loader as loader
|
||||
# ========================== Arch-Intent Default Params ==========================
|
||||
ARCH_INTENT_MODEL_ALIAS = "Arch-Intent"
|
||||
ARCH_INTENT_INSTRUCTION = "Are there any tools can help?"
|
||||
|
||||
from app.function_calling.model_handler import ArchFunctionHandler
|
||||
from app.prompt_guard.model_handler import ArchGuardHanlder
|
||||
ARCH_INTENT_TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
arch_function_hanlder = ArchFunctionHandler()
|
||||
PREFILL_LIST = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"]
|
||||
PREFILL_ENABLED = True
|
||||
TOOL_CALL_TOKEN = "<tool_call>"
|
||||
arch_function_endpoint = "https://api.fc.archgw.com/v1"
|
||||
arch_function_client = utils.get_client(arch_function_endpoint)
|
||||
arch_function_generation_params = {
|
||||
"temperature": 0.2,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
"max_tokens": 512,
|
||||
"stop_token_ids": [151645],
|
||||
# "top_logprobs": 10,
|
||||
ARCH_INTENT_TOOL_PROMPT = """
|
||||
You task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.
|
||||
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_INTENT_FORMAT_PROMPT = """
|
||||
Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:
|
||||
- First line must read 'Yes' or 'No'.
|
||||
- If yes, a second line must include a comma-separated list of tool indexes.
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_INTENT_GENERATION_CONFIG = {
|
||||
"generation_params": {
|
||||
"stop_token_ids": [151645],
|
||||
"max_tokens": 1,
|
||||
"guided_choice": ["Yes", "No"],
|
||||
}
|
||||
}
|
||||
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
|
||||
# ========================== Arch-Function Default Params ==========================
|
||||
ARCH_FUNCTION_MODEL_ALIAS = "Arch-Function"
|
||||
|
||||
ARCH_FUNCTION_TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_FUNCTION_TOOL_PROMPT = """
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_FUNCTION_FORMAT_PROMPT = """
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
""".strip()
|
||||
|
||||
ARCH_FUNCTION_GENERATION_CONFIG = {
|
||||
"generation_params": {
|
||||
"temperature": 0.2,
|
||||
"top_p": 1.0,
|
||||
"top_k": 50,
|
||||
"max_tokens": 512,
|
||||
"stop_token_ids": [151645],
|
||||
},
|
||||
"prefill_params": {
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
},
|
||||
"prefill_prefix": [
|
||||
"May",
|
||||
"Could",
|
||||
"Sure",
|
||||
"Definitely",
|
||||
"Certainly",
|
||||
"Of course",
|
||||
"Can",
|
||||
],
|
||||
}
|
||||
|
||||
# Model definition
|
||||
embedding_model = loader.get_embedding_model()
|
||||
zero_shot_model = loader.get_zero_shot_model()
|
||||
|
||||
prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
|
||||
# Patterns for function name and parameter parsing
|
||||
|
|
|
|||
|
|
@ -1,4 +1,65 @@
|
|||
import app.commons.utilities as utils
|
||||
|
||||
from app.commons.constants import *
|
||||
from app.model_handler.function_calling import ArchIntentHandler, ArchFunctionHandler
|
||||
from app.model_handler.guardrails import ArchGuardHanlder
|
||||
|
||||
DEVICE = utils.get_device()
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
def get_guardrail_handler():
|
||||
device = utils.get_device()
|
||||
|
||||
model_class, model_name = None, None
|
||||
if device == "cpu":
|
||||
model_class = OVModelForSequenceClassification
|
||||
model_name = "katanemo/Arch-Guard-cpu"
|
||||
else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
if device == "cuda":
|
||||
model_name = "katanemo/Arch-Guard"
|
||||
else:
|
||||
model_name = "katanemo/Arch-Guard"
|
||||
|
||||
guardrail_dict = {
|
||||
"device": device,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=device, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
return ArchGuardHanlder(model_dict=guardrail_dict)
|
||||
|
||||
|
||||
# Define the client
|
||||
ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
|
||||
|
||||
|
||||
# Define model handlers
|
||||
handler_map = {
|
||||
"Arch-Intent": ArchIntentHandler(
|
||||
ARCH_CLIENT,
|
||||
ARCH_INTENT_MODEL_ALIAS,
|
||||
ARCH_INTENT_TASK_PROMPT,
|
||||
ARCH_INTENT_TOOL_PROMPT,
|
||||
ARCH_INTENT_FORMAT_PROMPT,
|
||||
ARCH_INTENT_INSTRUCTION,
|
||||
**ARCH_INTENT_GENERATION_CONFIG,
|
||||
),
|
||||
"Arch-Function": ArchFunctionHandler(
|
||||
ARCH_CLIENT,
|
||||
ARCH_FUNCTION_MODEL_ALIAS,
|
||||
ARCH_FUNCTION_TASK_PROMPT,
|
||||
ARCH_FUNCTION_TOOL_PROMPT,
|
||||
ARCH_FUNCTION_FORMAT_PROMPT,
|
||||
**ARCH_FUNCTION_GENERATION_CONFIG,
|
||||
),
|
||||
"Arch-Guard": get_guardrail_handler(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,7 @@
|
|||
import os
|
||||
import yaml
|
||||
import torch
|
||||
import string
|
||||
import logging
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
logger_instance = None
|
||||
|
||||
|
|
@ -31,11 +27,6 @@ def get_device():
|
|||
return device
|
||||
|
||||
|
||||
def get_client(endpoint):
|
||||
client = OpenAI(base_url=endpoint, api_key="EMPTY")
|
||||
return client
|
||||
|
||||
|
||||
def get_model_server_logger():
|
||||
global logger_instance
|
||||
|
||||
|
|
@ -72,12 +63,3 @@ def get_model_server_logger():
|
|||
# Initialize the logger instance after configuring handlers
|
||||
logger_instance = logging.getLogger("model_server_logger")
|
||||
return logger_instance
|
||||
|
||||
|
||||
def remove_punctuations(s):
|
||||
s = s.translate(str.maketrans(string.punctuation, " " * len(string.punctuation)))
|
||||
return " ".join(s.split()).lower()
|
||||
|
||||
|
||||
def get_label_map(labels):
|
||||
return {remove_punctuations(label): label for label in labels}
|
||||
|
|
|
|||
|
|
@ -1,137 +0,0 @@
|
|||
import json
|
||||
import random
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_TASK_PROMPT = """
|
||||
You are a helpful assistant.
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_TOOL_PROMPT = """
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tool_text}
|
||||
</tools>
|
||||
""".strip()
|
||||
|
||||
|
||||
ARCH_FUNCTION_CALLING_FORMAT_PROMPT = """
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
""".strip()
|
||||
|
||||
|
||||
class ArchFunctionHandler:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _format_system(self, tools: List[Dict[str, Any]]):
|
||||
def convert_tools(tools):
|
||||
return "\n".join([json.dumps(tool) for tool in tools])
|
||||
|
||||
tool_text = convert_tools(tools)
|
||||
|
||||
system_prompt = (
|
||||
ARCH_FUNCTION_CALLING_TASK_PROMPT
|
||||
+ "\n\n"
|
||||
+ ARCH_FUNCTION_CALLING_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ ARCH_FUNCTION_CALLING_FORMAT_PROMPT
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _add_execution_results_prompting(
|
||||
self,
|
||||
messages: list[dict],
|
||||
execution_results: list,
|
||||
) -> dict:
|
||||
content = []
|
||||
for result in execution_results:
|
||||
content.append(f"<tool_response>\n{json.dumps(result)}\n</tool_response>")
|
||||
|
||||
content = "\n".join(content)
|
||||
messages.append({"role": "user", "content": content})
|
||||
|
||||
return messages
|
||||
|
||||
def extract_tool_calls(self, content: str):
|
||||
tool_calls = []
|
||||
|
||||
flag = False
|
||||
for line in content.split("\n"):
|
||||
if "<tool_call>" == line:
|
||||
flag = True
|
||||
elif "</tool_call>" == line:
|
||||
flag = False
|
||||
else:
|
||||
if flag:
|
||||
try:
|
||||
tool_content = json.loads(line)
|
||||
except Exception:
|
||||
fixed_content = self.fix_json_string(line)
|
||||
try:
|
||||
tool_content = json.loads(fixed_content)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_content["name"],
|
||||
"arguments": tool_content["arguments"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
flag = False
|
||||
|
||||
return tool_calls
|
||||
|
||||
def fix_json_string(self, json_str: str):
|
||||
# Remove any leading or trailing whitespace or newline characters
|
||||
json_str = json_str.strip()
|
||||
|
||||
# Stack to keep track of brackets
|
||||
stack = []
|
||||
|
||||
# Clean string to collect valid characters
|
||||
fixed_str = ""
|
||||
|
||||
# Dictionary for matching brackets
|
||||
matching_bracket = {")": "(", "}": "{", "]": "["}
|
||||
|
||||
# Dictionary for the opposite of matching_bracket
|
||||
opening_bracket = {v: k for k, v in matching_bracket.items()}
|
||||
|
||||
for char in json_str:
|
||||
if char in "{[(":
|
||||
stack.append(char)
|
||||
fixed_str += char
|
||||
elif char in "}])":
|
||||
if stack and stack[-1] == matching_bracket[char]:
|
||||
stack.pop()
|
||||
fixed_str += char
|
||||
else:
|
||||
# Ignore the unmatched closing brackets
|
||||
continue
|
||||
else:
|
||||
fixed_str += char
|
||||
|
||||
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
|
||||
while stack:
|
||||
unmatched_opening = stack.pop()
|
||||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str.replace("'", '"')
|
||||
|
|
@ -1,157 +0,0 @@
|
|||
import json
|
||||
import hashlib
|
||||
import app.commons.constants as const
|
||||
import random
|
||||
from fastapi import Response
|
||||
from pydantic import BaseModel
|
||||
from app.commons.utilities import get_model_server_logger
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
logger = get_model_server_logger()
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Optional[str] = ""
|
||||
content: Optional[str] = ""
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = []
|
||||
tool_call_id: Optional[str] = ""
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
message: Message
|
||||
finish_reason: Optional[str] = "stop"
|
||||
index: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
choices: List[Choice]
|
||||
model: Optional[str] = "Arch-Function"
|
||||
created: Optional[str] = ""
|
||||
id: Optional[str] = ""
|
||||
object: Optional[str] = "chat_completion"
|
||||
|
||||
|
||||
def process_messages(history: list[Message]):
|
||||
updated_history = []
|
||||
for hist in history:
|
||||
if hist.tool_calls:
|
||||
if len(hist.tool_calls) > 1:
|
||||
error_msg = f"Only one tool call is supported, tools counts: {len(hist.tool_calls)}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
tool_call_str = json.dumps(hist.tool_calls[0]["function"])
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
|
||||
}
|
||||
)
|
||||
elif hist.role == "tool":
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"<tool_response>\n{hist.content}\n</tool_response>",
|
||||
}
|
||||
)
|
||||
else:
|
||||
updated_history.append({"role": hist.role, "content": hist.content})
|
||||
return updated_history
|
||||
|
||||
|
||||
async def chat_completion(req: ChatMessage, res: Response):
|
||||
logger.info("starting request")
|
||||
|
||||
tools_encoded = const.arch_function_hanlder._format_system(req.tools)
|
||||
|
||||
messages = [{"role": "system", "content": tools_encoded}]
|
||||
|
||||
updated_history = process_messages(req.messages)
|
||||
for message in updated_history:
|
||||
messages.append({"role": message["role"], "content": message["content"]})
|
||||
|
||||
client_model_name = const.arch_function_client.models.list().data[0].id
|
||||
|
||||
logger.info(
|
||||
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
|
||||
)
|
||||
|
||||
# Retrieve the first token, handling the Stream object carefully
|
||||
|
||||
try:
|
||||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=const.PREFILL_ENABLED,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"model_server <= arch_function: error: {e}")
|
||||
raise
|
||||
|
||||
if const.PREFILL_ENABLED:
|
||||
first_token_content = ""
|
||||
for token in resp:
|
||||
first_token_content = token.choices[
|
||||
0
|
||||
].delta.content.strip() # Clean up the content
|
||||
if first_token_content: # Break if it's non-empty
|
||||
break
|
||||
|
||||
# Check if the first token requires tool call handling
|
||||
if first_token_content != const.TOOL_CALL_TOKEN:
|
||||
# Engage pre-filling response if no tool call is indicated
|
||||
resp.close()
|
||||
logger.info("Tool call is not found! Engage pre filling")
|
||||
prefill_content = random.choice(const.PREFILL_LIST)
|
||||
messages.append({"role": "assistant", "content": prefill_content})
|
||||
|
||||
# Send a new completion request with the updated messages
|
||||
# the model will continue the final message in the chat instead of starting a new one
|
||||
# disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response.
|
||||
extra_body = {
|
||||
**const.arch_function_generation_params,
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": False,
|
||||
}
|
||||
pre_fill_resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=False,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
full_response = pre_fill_resp.choices[0].message.content
|
||||
else:
|
||||
# Initialize full response and iterate over tokens to gather the full response
|
||||
full_response = first_token_content
|
||||
for token in resp:
|
||||
if hasattr(token.choices[0].delta, "content"):
|
||||
full_response += token.choices[0].delta.content
|
||||
else:
|
||||
logger.info("Stream is disabled, not engaging pre-filling")
|
||||
full_response = resp.choices[0].message.content
|
||||
|
||||
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
|
||||
|
||||
if tool_calls:
|
||||
message = Message(content="", tool_calls=tool_calls)
|
||||
else:
|
||||
message = Message(content=full_response, tool_calls=[])
|
||||
choice = Choice(message=message)
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[choice], model=client_model_name
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
|
||||
)
|
||||
logger.info(
|
||||
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
import os
|
||||
import app.commons.globals as glb
|
||||
|
||||
from transformers import AutoTokenizer, AutoModel, pipeline
|
||||
from optimum.onnxruntime import (
|
||||
ORTModelForFeatureExtraction,
|
||||
ORTModelForSequenceClassification,
|
||||
)
|
||||
import app.commons.utilities as utils
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5"),
|
||||
):
|
||||
logger.info("Loading Embedding Model...")
|
||||
|
||||
if glb.DEVICE != "cuda":
|
||||
model = ORTModelForFeatureExtraction.from_pretrained(
|
||||
model_name, file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_name, device_map=glb.DEVICE)
|
||||
|
||||
embedding_model = {
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model,
|
||||
}
|
||||
|
||||
return embedding_model
|
||||
|
||||
|
||||
def get_zero_shot_model(
|
||||
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
|
||||
):
|
||||
logger.info("Loading Zero-shot Model...")
|
||||
|
||||
if glb.DEVICE != "cuda":
|
||||
model = ORTModelForSequenceClassification.from_pretrained(
|
||||
model_name, file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
model = model_name
|
||||
|
||||
zero_shot_model = {
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name),
|
||||
"model": model,
|
||||
}
|
||||
|
||||
zero_shot_model["pipeline"] = pipeline(
|
||||
"zero-shot-classification",
|
||||
model=zero_shot_model["model"],
|
||||
tokenizer=zero_shot_model["tokenizer"],
|
||||
device=glb.DEVICE,
|
||||
)
|
||||
|
||||
return zero_shot_model
|
||||
|
||||
|
||||
def get_prompt_guard(model_name):
|
||||
logger.info("Loading Guard Model...")
|
||||
|
||||
if glb.DEVICE == "cpu":
|
||||
model_class = OVModelForSequenceClassification
|
||||
else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
|
||||
prompt_guard = {
|
||||
"device": glb.DEVICE,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
return prompt_guard
|
||||
|
|
@ -1,21 +1,10 @@
|
|||
import os
|
||||
import time
|
||||
import torch
|
||||
import app.commons.utilities as utils
|
||||
import app.commons.globals as glb
|
||||
import app.prompt_guard.model_utils as guard_utils
|
||||
|
||||
from typing import List, Dict
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, Response, HTTPException, Request
|
||||
from app.function_calling.model_utils import ChatMessage
|
||||
|
||||
from app.commons.constants import embedding_model, zero_shot_model, arch_guard_handler
|
||||
from app.function_calling.model_utils import (
|
||||
chat_completion as arch_function_chat_completion,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
from app.commons.globals import handler_map
|
||||
from app.model_handler.function_calling import ChatMessage
|
||||
from app.model_handler.guardrails import GuardRequest
|
||||
|
||||
from fastapi import FastAPI, Response, Request
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
|
||||
|
|
@ -23,6 +12,7 @@ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExport
|
|||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
"service.name": "model-server",
|
||||
|
|
@ -34,10 +24,6 @@ trace.set_tracer_provider(TracerProvider(resource=resource))
|
|||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
logger.info(f"Ready to serve traffic. available device: {glb.DEVICE}")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
FastAPIInstrumentor().instrument_app(app)
|
||||
|
|
@ -53,28 +39,6 @@ otlp_exporter = OTLPSpanExporter(
|
|||
trace.get_tracer_provider().add_span_processor(BatchSpanProcessor(otlp_exporter))
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: str
|
||||
model: str
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
input: str
|
||||
task: str
|
||||
|
||||
|
||||
class ZeroShotRequest(BaseModel):
|
||||
input: str
|
||||
labels: List[str]
|
||||
model: str
|
||||
|
||||
|
||||
class HallucinationRequest(BaseModel):
|
||||
prompt: str
|
||||
parameters: Dict
|
||||
model: str
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
return {"status": "ok"}
|
||||
|
|
@ -84,172 +48,40 @@ async def healthz():
|
|||
async def models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [{"id": embedding_model["model_name"], "object": "model"}],
|
||||
"data": [{"id": model_name, "object": "model"} for model_name in handler_map],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/embeddings")
|
||||
async def embedding(req: EmbeddingRequest, res: Response):
|
||||
logger.info(f"Embedding req: {req}")
|
||||
|
||||
if req.model != embedding_model["model_name"]:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
encoded_input = embedding_model["tokenizer"](
|
||||
req.input, padding=True, truncation=True, return_tensors="pt"
|
||||
).to(glb.DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
embeddings = embedding_model["model"](**encoded_input)
|
||||
embeddings = embeddings[0][:, 0]
|
||||
embeddings = (
|
||||
torch.nn.functional.normalize(embeddings, p=2, dim=1).detach().cpu().numpy()
|
||||
)
|
||||
|
||||
logger.info(f"Embedding Call Complete Time: {time.perf_counter()-start_time}")
|
||||
|
||||
data = [
|
||||
{"object": "embedding", "embedding": embedding, "index": index + 1}
|
||||
for index, embedding in enumerate(embeddings.tolist())
|
||||
]
|
||||
|
||||
usage = {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
|
||||
return {"data": data, "model": req.model, "object": "list", "usage": usage}
|
||||
|
||||
|
||||
@app.post("/guard")
|
||||
async def guard(req: GuardRequest, res: Response, max_num_words=300):
|
||||
"""
|
||||
Take input as text and return the prediction of toxic and jailbreak
|
||||
"""
|
||||
|
||||
if req.task in ["both", "toxic", "jailbreak"]:
|
||||
arch_guard_handler.task = req.task
|
||||
else:
|
||||
raise NotImplementedError(f"{req.task} is not supported!")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if len(req.input.split()) < max_num_words:
|
||||
guard_result = arch_guard_handler.guard_predict(req.input)
|
||||
else:
|
||||
# text is long, split into chunks
|
||||
chunks = guard_utils.split_text_into_chunks(req.input)
|
||||
|
||||
guard_result = {
|
||||
"jailbreak_prob": [],
|
||||
"time": 0,
|
||||
"jailbreak_verdict": False,
|
||||
"toxic_sentence": [],
|
||||
"jailbreak_sentence": [],
|
||||
}
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_result = arch_guard_handler.guard_predict(chunk)
|
||||
guard_result["time"] += chunk_result["time"]
|
||||
if chunk_result[f"{arch_guard_handler.task}_verdict"]:
|
||||
guard_result[f"{arch_guard_handler.task}_verdict"] = True
|
||||
guard_result[f"{arch_guard_handler.task}_sentence"].append(
|
||||
chunk_result[f"{arch_guard_handler.task}_sentence"]
|
||||
)
|
||||
guard_result[f"{arch_guard_handler.task}_prob"].append(
|
||||
chunk_result[f"{arch_guard_handler.task}_prob"].item()
|
||||
)
|
||||
|
||||
logger.info(f"Time taken for Guard: {time.perf_counter() - start_time}")
|
||||
|
||||
return guard_result
|
||||
|
||||
|
||||
@app.post("/zeroshot")
|
||||
async def zeroshot(req: ZeroShotRequest, res: Response):
|
||||
logger.info(f"zero-shot request: {req}")
|
||||
|
||||
if req.model != zero_shot_model["model_name"]:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
classifier = zero_shot_model["pipeline"]
|
||||
|
||||
label_map = utils.get_label_map(req.labels)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
predictions = classifier(
|
||||
req.input, candidate_labels=list(label_map.keys()), multi_label=True
|
||||
)
|
||||
|
||||
logger.info(f"zero-shot taking {time.perf_counter() - start_time} seconds")
|
||||
|
||||
predicted_class = label_map[predictions["labels"][0]]
|
||||
predicted_score = predictions["scores"][0]
|
||||
|
||||
scores = {
|
||||
label_map[label]: score
|
||||
for label, score in zip(predictions["labels"], predictions["scores"])
|
||||
}
|
||||
|
||||
predicted_class = label_map[predictions["labels"][0]]
|
||||
|
||||
return {
|
||||
"predicted_class": predicted_class,
|
||||
"predicted_class_score": predicted_score,
|
||||
"scores": scores,
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/hallucination")
|
||||
@patch("app.loader.glb.DEVICE", "cpu") # Mock the device to 'cpu'
|
||||
async def hallucination(req: HallucinationRequest, res: Response):
|
||||
"""
|
||||
Take input as text and return the prediction of hallucination for each parameter
|
||||
"""
|
||||
logger.info(f"hallucination request: {req}")
|
||||
if req.model != zero_shot_model["model_name"]:
|
||||
raise HTTPException(status_code=400, detail="unknown model: " + req.model)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
classifier = zero_shot_model["pipeline"]
|
||||
|
||||
if "messages" in req.parameters:
|
||||
req.parameters.pop("messages")
|
||||
|
||||
candidate_labels = {f"{k} is {v}": k for k, v in req.parameters.items()}
|
||||
|
||||
predictions = classifier(
|
||||
req.prompt,
|
||||
candidate_labels=list(candidate_labels.keys()),
|
||||
hypothesis_template="{}",
|
||||
multi_label=True,
|
||||
)
|
||||
|
||||
params_scores = {
|
||||
candidate_labels[label]: score
|
||||
for label, score in zip(predictions["labels"], predictions["scores"])
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"hallucination time cost: {params_scores}, taking {time.perf_counter() - start_time} seconds"
|
||||
)
|
||||
|
||||
return {
|
||||
"params_scores": params_scores,
|
||||
"model": req.model,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(req: ChatMessage, res: Response, request: Request):
|
||||
@app.post("/function_calling")
|
||||
async def function_calling(req: ChatMessage, res: Response, request: Request):
|
||||
try:
|
||||
result = await arch_function_chat_completion(req, res)
|
||||
return result
|
||||
intent_result = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
|
||||
if intent_result.choices[0].message.content == "Yes":
|
||||
try:
|
||||
function_result = await handler_map["Arch-Function"].chat_completion(
|
||||
req
|
||||
)
|
||||
return function_result
|
||||
except Exception as e:
|
||||
# [TODO]
|
||||
# logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Function] - {e}"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_completion: {e}")
|
||||
# [TODO]
|
||||
# logger.error(f"Error in chat_completion from `Arch-Intent`: {e}")
|
||||
res.status_code = 500
|
||||
return {"error": "Internal server error"}
|
||||
return {"error": f"[Arch-Intent] - {e}"}
|
||||
|
||||
|
||||
@app.post("/guardrails")
|
||||
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
|
||||
try:
|
||||
guard_result = handler_map["Arch-Guard"].predict(req)
|
||||
return guard_result
|
||||
except Exception as e:
|
||||
# [TODO]
|
||||
res.status_code = 500
|
||||
return {"error": f"[Arch-Guard] - {e}"}
|
||||
|
|
|
|||
415
model_server/app/model_handler/function_calling.py
Normal file
415
model_server/app/model_handler/function_calling.py
Normal file
|
|
@ -0,0 +1,415 @@
|
|||
import json
|
||||
import random
|
||||
import builtins
|
||||
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Dict, List, Optional
|
||||
from overrides import override, final
|
||||
|
||||
|
||||
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Optional[str] = ""
|
||||
content: Optional[str] = ""
|
||||
tool_call_id: Optional[str] = ""
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = []
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
id: Optional[int] = 0
|
||||
message: Message
|
||||
finish_reason: Optional[str] = "stop"
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: Optional[int] = 0
|
||||
object: Optional[str] = "chat_completion"
|
||||
created: Optional[str] = ""
|
||||
model: str
|
||||
choices: List[Choice]
|
||||
|
||||
|
||||
class ArchBaseHandler:
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt: str,
|
||||
format_prompt: str,
|
||||
generation_params: Dict,
|
||||
):
|
||||
self.client = client
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
self.task_prompt = task_prompt
|
||||
self.tool_prompt = tool_prompt
|
||||
self.format_prompt = format_prompt
|
||||
|
||||
self.generation_params = generation_params
|
||||
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]):
|
||||
raise NotImplementedError()
|
||||
|
||||
@final
|
||||
def _format_system(self, tools: List[Dict[str, Any]]):
|
||||
tool_text = self._convert_tools(tools)
|
||||
|
||||
system_prompt = (
|
||||
self.task_prompt
|
||||
+ "\n\n"
|
||||
+ self.tool_prompt.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ self.format_prompt
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
@final
|
||||
def _process_messages(
|
||||
self,
|
||||
messages: List[Message],
|
||||
tools: List[Dict[str, Any]] = None,
|
||||
extra_instructions: str = None,
|
||||
):
|
||||
processed_messages = []
|
||||
|
||||
if tools:
|
||||
processed_messages.append(
|
||||
{"role": "system", "content": self._format_system(tools)}
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
role, content, tool_calls = (
|
||||
message.role,
|
||||
message.content,
|
||||
message.tool_calls,
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
# [TODO] Extend to support multiple function calls
|
||||
role = "assistant"
|
||||
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
|
||||
elif message.role == "tool":
|
||||
role = "user"
|
||||
content = (
|
||||
f"<tool_response>\n{json.dumps(message.content)}\n</tool_response>"
|
||||
)
|
||||
|
||||
processed_messages.append({"role": role, "content": content})
|
||||
|
||||
assert processed_messages[-1]["role"] == "user"
|
||||
|
||||
if extra_instructions:
|
||||
processed_messages[-1]["content"] += extra_instructions
|
||||
|
||||
return processed_messages
|
||||
|
||||
async def chat_completion(self, req: ChatMessage):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ArchIntentHandler(ArchBaseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt: str,
|
||||
format_prompt: str,
|
||||
intent_instruction: str,
|
||||
generation_params: Dict,
|
||||
):
|
||||
super().__init__(
|
||||
client,
|
||||
model_name,
|
||||
task_prompt,
|
||||
tool_prompt,
|
||||
format_prompt,
|
||||
generation_params,
|
||||
)
|
||||
|
||||
self.intent_instruction = intent_instruction
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]):
|
||||
converted = [
|
||||
json.dumps({"index": f"T{idx}"} | tool) for idx, tool in enumerate(tools)
|
||||
]
|
||||
return "\n".join(converted)
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage):
|
||||
"""
|
||||
Note: Currently only support vllm inference
|
||||
"""
|
||||
|
||||
messages = self._process_messages(
|
||||
req.messages, req.tools, self.intent_instruction
|
||||
)
|
||||
|
||||
model_response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=False,
|
||||
extra_body=self.generation_params,
|
||||
)
|
||||
|
||||
model_response = Message(content=model_response, tool_calls=[])
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_response)], model=self.model_name
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
|
||||
class ArchFunctionHandler(ArchBaseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
client: OpenAI,
|
||||
model_name: str,
|
||||
task_prompt: str,
|
||||
tool_prompt: str,
|
||||
format_prompt: str,
|
||||
generation_params: Dict,
|
||||
prefill_params: Dict,
|
||||
prefill_prefix: List,
|
||||
):
|
||||
super().__init__(
|
||||
client,
|
||||
model_name,
|
||||
task_prompt,
|
||||
tool_prompt,
|
||||
format_prompt,
|
||||
generation_params,
|
||||
)
|
||||
|
||||
self.prefill_params = prefill_params
|
||||
self.prefill_prefix = prefill_prefix
|
||||
|
||||
# Predefine data types for verification. Only support Python for now.
|
||||
# [TODO] Extend the list of support data types
|
||||
self.support_data_types = {
|
||||
type_name: getattr(builtins, type_name) for type_name in SUPPORT_DATA_TYPES
|
||||
}
|
||||
|
||||
@override
|
||||
def _convert_tools(self, tools: List[Dict[str, Any]]):
|
||||
converted = [json.dumps(tool) for tool in tools]
|
||||
return "\n".join(converted)
|
||||
|
||||
def _fix_json_string(self, json_str: str):
|
||||
# Remove any leading or trailing whitespace or newline characters
|
||||
json_str = json_str.strip()
|
||||
|
||||
# Stack to keep track of brackets
|
||||
stack = []
|
||||
|
||||
# Clean string to collect valid characters
|
||||
fixed_str = ""
|
||||
|
||||
# Dictionary for matching brackets
|
||||
matching_bracket = {")": "(", "}": "{", "]": "["}
|
||||
|
||||
# Dictionary for the opposite of matching_bracket
|
||||
opening_bracket = {v: k for k, v in matching_bracket.items()}
|
||||
|
||||
for char in json_str:
|
||||
if char in "{[(":
|
||||
stack.append(char)
|
||||
fixed_str += char
|
||||
elif char in "}])":
|
||||
if stack and stack[-1] == matching_bracket[char]:
|
||||
stack.pop()
|
||||
fixed_str += char
|
||||
else:
|
||||
# Ignore the unmatched closing brackets
|
||||
continue
|
||||
else:
|
||||
fixed_str += char
|
||||
|
||||
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
|
||||
while stack:
|
||||
unmatched_opening = stack.pop()
|
||||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str.replace("'", '"')
|
||||
|
||||
def _extract_tool_calls(self, content: str):
|
||||
tool_calls, is_valid, error_message = [], True, ""
|
||||
|
||||
flag = False
|
||||
for line in content.split("\n"):
|
||||
if "<tool_call>" == line:
|
||||
flag = True
|
||||
elif "</tool_call>" == line:
|
||||
flag = False
|
||||
else:
|
||||
if flag:
|
||||
try:
|
||||
tool_content = json.loads(line)
|
||||
except Exception as e:
|
||||
fixed_content = self._fix_json_string(line)
|
||||
try:
|
||||
tool_content = json.loads(fixed_content)
|
||||
except Exception:
|
||||
tool_calls, is_valid, error_message = [], False, e
|
||||
return tool_calls, is_valid, error_message
|
||||
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_content["name"],
|
||||
"arguments": tool_content["arguments"],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
flag = False
|
||||
|
||||
return tool_calls, is_valid, error_message
|
||||
|
||||
def _verify_tool_calls(
|
||||
self, tools: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]]
|
||||
):
|
||||
is_valid, error_tool_call, error_message = True, None, ""
|
||||
|
||||
functions = {}
|
||||
for tool in tools:
|
||||
if tool["type"] == "function":
|
||||
functions[tool["function"]["name"]] = tool["function"]["parameters"]
|
||||
|
||||
for tool_call in tool_calls:
|
||||
func_name, func_args = (
|
||||
tool_call["function"]["name"],
|
||||
tool_call["function"]["arguments"],
|
||||
)
|
||||
|
||||
# Check whether the function is available or not
|
||||
if func_name not in functions:
|
||||
is_valid = False
|
||||
error_message = f"{func_name} is not defined!"
|
||||
return is_valid, error_message
|
||||
else:
|
||||
# Check if all the requried parameters can be found in the tool calls
|
||||
for required_param in functions[func_name].get("required", []):
|
||||
if required_param not in func_args:
|
||||
is_valid = False
|
||||
error_tool_call = tool_call
|
||||
error_message = f"`{required_param}` is requried by the function `{func_name}` but not found in the tool call!"
|
||||
return is_valid, error_tool_call, error_message
|
||||
|
||||
# Verify the data type of each parameter in the tool calls
|
||||
for param_name, param_value in func_args:
|
||||
data_type = functions[func_name]["properties"][param_name]["type"]
|
||||
|
||||
if data_type in self.support_data_types:
|
||||
if not isinstance(
|
||||
param_value, self.support_data_types[data_type]
|
||||
):
|
||||
is_valid = False
|
||||
error_tool_call = tool_call
|
||||
error_message = f"Parameter `{param_name}` is expected to have the data type `{self.support_data_types[data_type]}`, but got `{type(param_value)}`."
|
||||
return is_valid, error_tool_call, error_message
|
||||
|
||||
return is_valid, error_tool_call, error_message
|
||||
|
||||
@override
|
||||
async def chat_completion(self, req: ChatMessage, enable_prefilling=True):
|
||||
"""
|
||||
Note: Currently only support vllm inference
|
||||
"""
|
||||
|
||||
messages = self._process_messages(req.messages, req.tools)
|
||||
|
||||
# Retrieve the first token, handling the Stream object carefully
|
||||
response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=enable_prefilling,
|
||||
extra_body=self.generation_params,
|
||||
)
|
||||
|
||||
model_response = ""
|
||||
|
||||
if enable_prefilling:
|
||||
has_tool_call = None
|
||||
|
||||
model_response = ""
|
||||
for token in response:
|
||||
token_content = token.choices[0].delta.content.strip()
|
||||
|
||||
if has_tool_call is None and token_content != "<tool_call>":
|
||||
has_tool_call = False
|
||||
response.close()
|
||||
break
|
||||
else:
|
||||
has_tool_call = True
|
||||
|
||||
if has_tool_call is True:
|
||||
model_response += token_content
|
||||
|
||||
# start parameter gathering if the model is not generating a tool call
|
||||
if has_tool_call is False:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": random.choice(self.prefill_prefix),
|
||||
}
|
||||
)
|
||||
|
||||
prefill_response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=self.model_name,
|
||||
stream=False,
|
||||
extra_body={
|
||||
**self.generation_params,
|
||||
**self.prefill_params,
|
||||
},
|
||||
)
|
||||
|
||||
model_response = prefill_response.choices[0].message.content
|
||||
else:
|
||||
model_response = response.choices[0].message.content
|
||||
|
||||
tool_calls, is_valid, error_message = self._extract_tool_calls(model_response)
|
||||
|
||||
if tool_calls:
|
||||
is_valid, error_tool_call, error_message = self._verify_tool_calls(
|
||||
tools=req.tools, tool_calls=tool_calls
|
||||
)
|
||||
|
||||
# [TODO] Review: In the case that tool calls are invalid, define the protocol to collect debugging output and the behavior to handle it appropriately
|
||||
if is_valid:
|
||||
model_response = Message(content="", tool_calls=tool_calls)
|
||||
# else:
|
||||
|
||||
else:
|
||||
model_response = Message(content=model_response, tool_calls=[])
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_response)], model=self.model_name
|
||||
)
|
||||
|
||||
# [TODO] Review: define the protocol to collect debugging output
|
||||
# logger.info(
|
||||
# f"model_server <= arch_function: (tool_calls): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
|
||||
# )
|
||||
# logger.info(
|
||||
# f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
|
||||
# )
|
||||
|
||||
return chat_completion_response
|
||||
95
model_server/app/model_handler/guardrails.py
Normal file
95
model_server/app/model_handler/guardrails.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
input: str
|
||||
task: str
|
||||
|
||||
|
||||
class ArchGuardHanlder:
|
||||
def __init__(self, model_dict):
|
||||
self.model = model_dict["model"]
|
||||
self.tokenizer = model_dict["tokenizer"]
|
||||
self.device = model_dict["device"]
|
||||
|
||||
self.support_tasks = {"jailbreak": {"positive_class": 2, "threshold": 0.5}}
|
||||
|
||||
def _split_text_into_chunks(self, text, max_num_words=300):
|
||||
"""
|
||||
Split the text into chunks of `max_num_words` words
|
||||
"""
|
||||
words = text.split() # Split text into words
|
||||
|
||||
chunks = [
|
||||
" ".join(words[i : i + max_num_words])
|
||||
for i in range(0, len(words), max_num_words)
|
||||
]
|
||||
|
||||
return chunks
|
||||
|
||||
@staticmethod
|
||||
def softmax(x):
|
||||
return np.exp(x) / np.exp(x).sum(axis=0)
|
||||
|
||||
def _predict_text(self, task, text, max_length=512):
|
||||
inputs = self.tokenizer(
|
||||
text, truncation=True, max_length=max_length, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
|
||||
prob = ArchGuardHanlder.softmax(logits)[
|
||||
self.support_tasks[task]["positive_class"]
|
||||
]
|
||||
|
||||
if prob > self.support_tasks[task]["threshold"]:
|
||||
verdict = True
|
||||
sentence = text
|
||||
else:
|
||||
verdict = False
|
||||
sentence = None
|
||||
|
||||
result_dict = {
|
||||
"prob": prob.item(),
|
||||
"verdict": verdict,
|
||||
"sentence": sentence,
|
||||
}
|
||||
|
||||
return result_dict
|
||||
|
||||
def predict(self, req: GuardRequest, max_num_words=300):
|
||||
"""
|
||||
Note: currently only support jailbreak check
|
||||
"""
|
||||
|
||||
if req.task not in self.support_tasks:
|
||||
raise NotImplementedError(f"{req.task} is not supported!")
|
||||
|
||||
guard_result = {
|
||||
"prob": [],
|
||||
"verdict": False,
|
||||
"sentence": [],
|
||||
}
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if len(req.input.split()) < max_num_words:
|
||||
guard_result = self._predict_text(req.task, req.input)
|
||||
else:
|
||||
# split into chunks if text is long
|
||||
text_chunks = self._split_text_into_chunks(req.input)
|
||||
|
||||
for chunk in text_chunks:
|
||||
chunk_result = self._predict_text(req.task, chunk)
|
||||
if chunk_result["verdict"]:
|
||||
guard_result["verdict"] = True
|
||||
guard_result["sentence"].append(chunk_result["sentence"])
|
||||
guard_result["prob"].append(chunk_result["prob"].item())
|
||||
|
||||
guard_result["latency"] = time.perf_counter() - start_time
|
||||
|
||||
return guard_result
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
import json
|
||||
import math
|
||||
import torch
|
||||
import random
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
import itertools
|
||||
from enum import Enum
|
||||
|
||||
|
|
@ -74,26 +72,6 @@ def calculate_entropy(log_probs: List[float]) -> Tuple[float, float]:
|
|||
return entropy.item(), varentropy.item()
|
||||
|
||||
|
||||
def is_parameter_property(
|
||||
function_description: Dict, parameter_name: str, property_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a parameter in an API description has a specific property.
|
||||
|
||||
Args:
|
||||
function_description (dict): The API description in JSON format.
|
||||
parameter_name (str): The name of the parameter to check.
|
||||
property_name (str): The property to look for (e.g., 'format', 'default').
|
||||
|
||||
Returns:
|
||||
bool: True if the parameter has the specified property, False otherwise.
|
||||
"""
|
||||
parameters = function_description.get("properties", {})
|
||||
parameter_info = parameters.get(parameter_name, {})
|
||||
|
||||
return property_name in parameter_info
|
||||
|
||||
|
||||
class HallucinationStateHandler:
|
||||
"""
|
||||
A class to handle the state of hallucination detection in token processing.
|
||||
|
|
@ -111,7 +89,7 @@ class HallucinationStateHandler:
|
|||
token_probs_map (list): List mapping tokens to their entropy and variance of entropy.
|
||||
"""
|
||||
|
||||
def __init__(self, response_iterator=None, function=None):
|
||||
def __init__(self, response_iterator=None):
|
||||
"""
|
||||
Initializes the HallucinationStateHandler with default values.
|
||||
"""
|
||||
|
|
@ -126,19 +104,6 @@ class HallucinationStateHandler:
|
|||
self.parameter_name: List[str] = []
|
||||
self.token_probs_map: List[Tuple[str, float, float]] = []
|
||||
self.response_iterator = response_iterator
|
||||
self._process_function(function)
|
||||
|
||||
def _process_function(self, function):
|
||||
self.function = function
|
||||
if self.function is None:
|
||||
raise ValueError("API descriptions not set.")
|
||||
parameter_names = {}
|
||||
for func in self.function:
|
||||
func_name = func["name"]
|
||||
parameters = func["parameters"]["properties"]
|
||||
parameter_names[func_name] = list(parameters.keys())
|
||||
self.function_description = parameter_names
|
||||
self.function_properties = {x["name"]: x["parameters"] for x in self.function}
|
||||
|
||||
def append_and_check_token_hallucination(self, token, logprob):
|
||||
"""
|
||||
|
|
@ -237,6 +202,8 @@ class HallucinationStateHandler:
|
|||
# checking if the token is a value token and is not empty
|
||||
if self.tokens[-1].strip() not in ['"', ""]:
|
||||
self.mask.append(MaskToken.PARAMETER_VALUE)
|
||||
|
||||
# [TODO] Review: update the following code: `is_parameter_property` should not be here
|
||||
# checking if the parameter doesn't have default and the token is the first parameter value token
|
||||
if (
|
||||
len(self.mask) > 1
|
||||
|
|
@ -299,26 +266,3 @@ class HallucinationStateHandler:
|
|||
if self.mask and self.mask[-1] == token
|
||||
else 0
|
||||
)
|
||||
|
||||
def _is_function_name_hallucinated(self):
|
||||
"""
|
||||
Checks the extracted function name against the function descriptions.
|
||||
Detects hallucinations if the function name is not found.
|
||||
"""
|
||||
f_len = self._count_consecutive_token(MaskToken.FUNCTION_NAME)
|
||||
self.function_name = "".join(self.tokens[:-1][-f_len:])
|
||||
if self.function_name not in self.function_description.keys():
|
||||
self.error_type = "function_name"
|
||||
self.error_message = f"Function name '{self.function_name}' not found in given function descriptions."
|
||||
|
||||
def _is_parameter_name_hallucinated(self):
|
||||
"""
|
||||
Checks the extracted parameter name against the function descriptions.
|
||||
Detects hallucinations if the parameter name is not found.
|
||||
"""
|
||||
p_len = self._count_consecutive_token(MaskToken.PARAMETER_NAME)
|
||||
parameter_name = "".join(self.tokens[:-1][-p_len:])
|
||||
self.parameter_name.append(parameter_name)
|
||||
if parameter_name not in self.function_description[self.function_name]:
|
||||
self.error_type = "parameter_name"
|
||||
self.error_message = f"Parameter name '{parameter_name}' not found in given function descriptions."
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
import time
|
||||
import torch
|
||||
import app.prompt_guard.model_utils as model_utils
|
||||
|
||||
|
||||
class ArchGuardHanlder:
|
||||
def __init__(self, model_dict, threshold=0.5):
|
||||
self.task = "jailbreak"
|
||||
self.positive_class = 2
|
||||
|
||||
self.model = model_dict["model"]
|
||||
self.tokenizer = model_dict["tokenizer"]
|
||||
self.device = model_dict["device"]
|
||||
|
||||
self.threshold = threshold
|
||||
|
||||
def guard_predict(self, input_text, max_length=512):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
inputs = self.tokenizer(
|
||||
input_text, truncation=True, max_length=max_length, return_tensors="pt"
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = self.model(**inputs).logits.cpu().detach().numpy()[0]
|
||||
prob = model_utils.softmax(logits)[self.positive_class]
|
||||
|
||||
if prob > self.threshold:
|
||||
verdict = True
|
||||
sentence = input_text
|
||||
else:
|
||||
verdict = False
|
||||
sentence = None
|
||||
|
||||
result_dict = {
|
||||
f"{self.task}_prob": prob.item(),
|
||||
f"{self.task}_verdict": verdict,
|
||||
f"{self.task}_sentence": sentence,
|
||||
"time": time.perf_counter() - start_time,
|
||||
}
|
||||
|
||||
return result_dict
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def split_text_into_chunks(text, max_words=300):
|
||||
"""
|
||||
Max number of tokens for tokenizer is 512
|
||||
Split the text into chunks of 300 words (as approximation for tokens)
|
||||
"""
|
||||
words = text.split() # Split text into words
|
||||
# Estimate token count based on word count (1 word ≈ 1 token)
|
||||
chunk_size = max_words # Use the word count as an approximation for tokens
|
||||
chunks = [
|
||||
" ".join(words[i : i + chunk_size]) for i in range(0, len(words), chunk_size)
|
||||
]
|
||||
return chunks
|
||||
|
||||
|
||||
def softmax(x):
|
||||
return np.exp(x) / np.exp(x).sum(axis=0)
|
||||
|
|
@ -13,6 +13,7 @@ client = TestClient(app)
|
|||
logger.info(f"Model will be loaded on device: {glb.DEVICE}")
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit tests for the health check endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
|
|
@ -22,6 +23,7 @@ async def test_healthz():
|
|||
assert response.json() == {"status": "ok"}
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit test for the models endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
|
|
@ -32,6 +34,7 @@ async def test_models():
|
|||
assert len(response.json()["data"]) > 0
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit test for embeddings endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
|
|
@ -46,6 +49,7 @@ async def test_embedding():
|
|||
assert response.status_code == 400
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit test for the guard endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
|
|
@ -56,6 +60,7 @@ async def test_guard():
|
|||
assert "jailbreak_verdict" in response.json()
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit test for the zero-shot endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
|
|
@ -73,6 +78,7 @@ async def test_zeroshot():
|
|||
assert response.status_code == 400
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit test for the hallucination endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
|
|
@ -90,6 +96,7 @@ async def test_hallucination():
|
|||
assert response.status_code == 400
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
# Unit test for the chat completion endpoint
|
||||
@pytest.mark.asyncio
|
||||
@patch("app.loader.glb.DEVICE", glb.DEVICE) # Mock the device to 'cpu'
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,7 +1,6 @@
|
|||
import unittest
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
import subprocess
|
||||
import time
|
||||
from app.cli import kill_process
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import app.commons.constants as const
|
||||
|
||||
from fastapi import Response
|
||||
from app.function_calling.model_utils import (
|
||||
process_messages,
|
||||
chat_completion,
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from app.commons.globals import handler_map
|
||||
from app.model_handler.function_calling import (
|
||||
Message,
|
||||
ChatMessage,
|
||||
Choice,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
|
||||
|
|
@ -31,14 +30,27 @@ def sample_messages():
|
|||
def sample_request(sample_messages):
|
||||
return ChatMessage(
|
||||
messages=sample_messages,
|
||||
tools=[{"name": "sample_tool", "description": "A sample tool"}],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "sample_tool",
|
||||
"description": "A sample tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@patch("app.commons.constants.arch_function_hanlder")
|
||||
@patch("app.commons.globals.handler_map")
|
||||
def test_process_messages(mock_hanlder):
|
||||
messages = sample_messages()
|
||||
processed = process_messages(messages)
|
||||
processed = handler_map["Arch-Function"]._process_messages(messages)
|
||||
|
||||
assert len(processed) == 3
|
||||
assert processed[0] == {"role": "user", "content": "Hello!"}
|
||||
|
|
@ -48,10 +60,11 @@ def test_process_messages(mock_hanlder):
|
|||
}
|
||||
assert processed[2] == {
|
||||
"role": "user",
|
||||
"content": "<tool_response>\nResponse from tool\n</tool_response>",
|
||||
"content": f"<tool_response>\n{json.dumps('Response from tool')}\n</tool_response>",
|
||||
}
|
||||
|
||||
|
||||
# [TODO] Review: Update the following test
|
||||
@patch("app.commons.constants.arch_function_client")
|
||||
@patch("app.commons.constants.arch_function_hanlder")
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
47
model_server/app/tests/test_guardrails.py
Normal file
47
model_server/app/tests/test_guardrails.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "cpu" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
# [TODO] Review: update the following code to test under `cpu`, `cuda`, and `mps`
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
mock_auto_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
@ -1,8 +1,11 @@
|
|||
import json
|
||||
from app.function_calling.hallucination_handler import HallucinationStateHandler
|
||||
import pytest
|
||||
import os
|
||||
|
||||
|
||||
from app.model_handler.hallucination_handler import HallucinationStateHandler
|
||||
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(__file__)
|
||||
|
||||
|
|
@ -45,6 +48,7 @@ if type(function_description) != list:
|
|||
function_description = [get_weather_api["function"]]
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
@pytest.mark.parametrize("case", test_cases)
|
||||
def test_hallucination(case):
|
||||
state = HallucinationStateHandler(
|
||||
|
|
@ -58,6 +62,7 @@ def test_hallucination(case):
|
|||
assert state.hallucination == case["expect"]
|
||||
|
||||
|
||||
# [TODO] Review: update the following code
|
||||
@pytest.mark.parametrize("is_hallucinate_sample", [True, False])
|
||||
def test_hallucination_prompt(is_hallucinate_sample):
|
||||
TASK_PROMPT = """
|
||||
|
|
|
|||
|
|
@ -1,102 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "cpu" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
|
||||
mock_automodel.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
mock_automodel.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.pipeline")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
|
||||
mock_pipeline.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
zero_shot_model = get_zero_shot_model()
|
||||
|
||||
# Assertions
|
||||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
mock_auto_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "cuda" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
|
||||
mock_automodel.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
mock_automodel.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.pipeline")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
|
||||
mock_pipeline.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
zero_shot_model = get_zero_shot_model()
|
||||
|
||||
# Assertions
|
||||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
mock_auto_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import app.commons.globals as glb
|
||||
from app.loader import get_embedding_model, get_zero_shot_model, get_prompt_guard
|
||||
|
||||
# Mock constants
|
||||
glb.DEVICE = "mps" # Adjust as needed for your test case
|
||||
arch_guard_model_type = {
|
||||
"cpu": "katanemo/Arch-Guard-cpu",
|
||||
"cuda": "katanemo/Arch-Guard",
|
||||
"mps": "katanemo/Arch-Guard",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
# Mock environment variables
|
||||
os.environ["MODELS"] = "katanemo/bge-large-en-v1.5"
|
||||
os.environ["ZERO_SHOT_MODELS"] = "katanemo/bart-large-mnli"
|
||||
|
||||
|
||||
# Test for get_embedding_model function
|
||||
@patch("app.loader.ORTModelForFeatureExtraction.from_pretrained")
|
||||
@patch("app.loader.AutoModel.from_pretrained")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_embedding_model(mock_tokenizer, mock_automodel, mock_ort_model, mock_env):
|
||||
mock_automodel.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
|
||||
# Assertions
|
||||
assert embedding_model["model_name"] == "katanemo/bge-large-en-v1.5"
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
mock_automodel.assert_called_once_with(
|
||||
"katanemo/bge-large-en-v1.5", device_map=glb.DEVICE
|
||||
)
|
||||
|
||||
|
||||
# Test for get_zero_shot_model function
|
||||
@patch("app.loader.ORTModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.pipeline")
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
def test_get_zero_shot_model(mock_tokenizer, mock_pipeline, mock_ort_model, mock_env):
|
||||
mock_pipeline.return_value = MagicMock()
|
||||
mock_ort_model.return_value = MagicMock()
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
zero_shot_model = get_zero_shot_model()
|
||||
|
||||
# Assertions
|
||||
assert zero_shot_model["model_name"] == "katanemo/bart-large-mnli"
|
||||
mock_tokenizer.assert_called_once_with("katanemo/bart-large-mnli")
|
||||
if glb.DEVICE != "cuda":
|
||||
mock_ort_model.assert_called_once_with(
|
||||
"katanemo/bart-large-mnli", file_name="onnx/model.onnx"
|
||||
)
|
||||
else:
|
||||
assert mock_pipeline.called_once()
|
||||
|
||||
|
||||
# Test for get_prompt_guard function
|
||||
@patch("app.loader.AutoTokenizer.from_pretrained")
|
||||
@patch("app.loader.OVModelForSequenceClassification.from_pretrained")
|
||||
@patch("app.loader.AutoModelForSequenceClassification.from_pretrained")
|
||||
def test_get_prompt_guard(mock_auto_model, mock_ov_model, mock_tokenizer):
|
||||
# Mock model based on device
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.return_value = MagicMock()
|
||||
else:
|
||||
mock_auto_model.return_value = MagicMock()
|
||||
|
||||
mock_tokenizer.return_value = MagicMock()
|
||||
|
||||
prompt_guard = get_prompt_guard(arch_guard_model_type[glb.DEVICE])
|
||||
|
||||
# Assertions
|
||||
assert prompt_guard["model_name"] == arch_guard_model_type[glb.DEVICE]
|
||||
mock_tokenizer.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE], trust_remote_code=True
|
||||
)
|
||||
if glb.DEVICE == "cpu":
|
||||
mock_ov_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
else:
|
||||
mock_auto_model.assert_called_once_with(
|
||||
arch_guard_model_type[glb.DEVICE],
|
||||
device_map=glb.DEVICE,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
|
@ -1,66 +1,50 @@
|
|||
from typing import List
|
||||
import pytest
|
||||
import json
|
||||
from app.function_calling.model_utils import Message, process_messages
|
||||
from app.commons.globals import handler_map
|
||||
from app.model_handler.function_calling import Message
|
||||
|
||||
test_input_history = """
|
||||
[
|
||||
|
||||
test_input_history = [
|
||||
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather in chicago for next 5 days?"
|
||||
"role": "assistant",
|
||||
"model": "Arch-Function",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3394",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_3394"},
|
||||
{"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"},
|
||||
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"model": "Arch-Function-1.5B",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3394",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": { "city": "Chicago", "days": 5 }
|
||||
}
|
||||
}
|
||||
]
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_5306",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "--",
|
||||
"tool_call_id": "call_3394"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "--",
|
||||
"model": "gpt-3.5-turbo-0125"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather in chicago for next 5 days?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_5306",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": { "city": "Chicago", "days": 5 }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
"""
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_5306"},
|
||||
]
|
||||
|
||||
|
||||
def test_update_fc_history():
|
||||
history = json.loads(test_input_history)
|
||||
message_history = []
|
||||
for h in history:
|
||||
|
||||
for h in test_input_history:
|
||||
message_history.append(Message(**h))
|
||||
|
||||
updated_history = process_messages(message_history)
|
||||
assert len(updated_history) == 6
|
||||
updated_history = handler_map["Arch-Function"]._process_messages(message_history)
|
||||
assert len(updated_history) == 7
|
||||
# ensure that tool role does not exist anymore
|
||||
assert all([h["role"] != "tool" for h in updated_history])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue