mirror of
https://github.com/katanemo/plano.git
synced 2026-06-20 15:28:07 +02:00
Update guardrail_handler and its associated tests
This commit is contained in:
parent
b686cf8b87
commit
09f7e1e604
7 changed files with 115 additions and 1091 deletions
|
|
@ -1,43 +1,14 @@
|
|||
import app.commons.utilities as utils
|
||||
|
||||
from openai import OpenAI
|
||||
from app.commons.constants import *
|
||||
from app.model_handler.function_calling import ArchIntentHandler, ArchFunctionHandler
|
||||
from app.model_handler.guardrails import ArchGuardHanlder
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
from openai import OpenAI
|
||||
from app.model_handler.guardrails import get_guardrail_handler
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
def get_guardrail_handler():
|
||||
device = utils.get_device()
|
||||
|
||||
model_class, model_name = None, None
|
||||
if device == "cpu":
|
||||
model_class = OVModelForSequenceClassification
|
||||
model_name = "katanemo/Arch-Guard-cpu"
|
||||
else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
if device == "cuda":
|
||||
model_name = "katanemo/Arch-Guard"
|
||||
else:
|
||||
model_name = "katanemo/Arch-Guard"
|
||||
|
||||
guardrail_dict = {
|
||||
"device": device,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=device, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
return ArchGuardHanlder(model_dict=guardrail_dict)
|
||||
|
||||
|
||||
# Define the client
|
||||
ARCH_CLIENT = OpenAI(base_url="https://api.fc.archgw.com/v1", api_key="EMPTY")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue