mirror of
https://github.com/katanemo/plano.git
synced 2026-06-23 15:38:07 +02:00
Update guardrail_handler and its associated tests
This commit is contained in:
parent
b686cf8b87
commit
09f7e1e604
7 changed files with 115 additions and 1091 deletions
|
|
@ -1,8 +1,11 @@
|
|||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import app.commons.utilities as utils
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
|
|
@ -93,3 +96,27 @@ class ArchGuardHanlder:
|
|||
guard_result["latency"] = time.perf_counter() - start_time
|
||||
|
||||
return guard_result
|
||||
|
||||
|
||||
def get_guardrail_handler(device: str = None):
|
||||
if device is None:
|
||||
device = utils.get_device()
|
||||
|
||||
model_class, model_name = None, None
|
||||
if device == "cpu":
|
||||
model_class = OVModelForSequenceClassification
|
||||
model_name = "katanemo/Arch-Guard-cpu"
|
||||
else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
model_name = "katanemo/Arch-Guard"
|
||||
|
||||
guardrail_dict = {
|
||||
"device": device,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=device, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
return ArchGuardHanlder(model_dict=guardrail_dict)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue