remove mode/hardware

This commit is contained in:
cotran 2024-10-16 13:33:47 -07:00
parent b1746b38b4
commit cf612409f0
5 changed files with 18 additions and 36 deletions

View file

@ -60,28 +60,25 @@ def get_zero_shot_model(
return zero_shot_model
def get_prompt_guard(model_name, hardware_config="cpu"):
def get_prompt_guard(model_name):
logger.info("Loading Guard Model...")
if hardware_config == "cpu":
if glb.DEVICE == "cpu":
from optimum.intel import OVModelForSequenceClassification
device = "cpu"
model_class = OVModelForSequenceClassification
elif hardware_config == "gpu":
elif glb.DEVICE == "gpu":
import torch
from transformers import AutoModelForSequenceClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
model_class = AutoModelForSequenceClassification
prompt_guard = {
"hardware_config": hardware_config,
"device": device,
"device": glb.DEVICE,
"model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True
model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True
),
}