Update guardrail_handler and its associated tests

This commit is contained in:
Shuguang Chen 2024-12-05 11:30:58 -08:00
parent b686cf8b87
commit 09f7e1e604
7 changed files with 115 additions and 1091 deletions

View file

@ -1,43 +1,14 @@
import app.commons.utilities as utils
from openai import OpenAI
from app.commons.constants import *
from app.model_handler.function_calling import ArchIntentHandler, ArchFunctionHandler
from app.model_handler.guardrails import ArchGuardHanlder
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
from openai import OpenAI
from app.model_handler.guardrails import get_guardrail_handler
logger = utils.get_model_server_logger()
def get_guardrail_handler():
device = utils.get_device()
model_class, model_name = None, None
if device == "cpu":
model_class = OVModelForSequenceClassification
model_name = "katanemo/Arch-Guard-cpu"
else:
model_class = AutoModelForSequenceClassification
if device == "cuda":
model_name = "katanemo/Arch-Guard"
else:
model_name = "katanemo/Arch-Guard"
guardrail_dict = {
"device": device,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
),
}
return ArchGuardHanlder(model_dict=guardrail_dict)
# Define the client
ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")