mirror of
https://github.com/katanemo/plano.git
synced 2026-06-02 14:35:14 +02:00
remove mode/hardware
This commit is contained in:
parent
b1746b38b4
commit
cf612409f0
5 changed files with 18 additions and 36 deletions
|
|
@ -60,28 +60,25 @@ def get_zero_shot_model(
|
|||
return zero_shot_model
|
||||
|
||||
|
||||
def get_prompt_guard(model_name, hardware_config="cpu"):
|
||||
def get_prompt_guard(model_name):
|
||||
logger.info("Loading Guard Model...")
|
||||
|
||||
if hardware_config == "cpu":
|
||||
if glb.DEVICE == "cpu":
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
||||
device = "cpu"
|
||||
model_class = OVModelForSequenceClassification
|
||||
elif hardware_config == "gpu":
|
||||
elif glb.DEVICE == "gpu":
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model_class = AutoModelForSequenceClassification
|
||||
|
||||
prompt_guard = {
|
||||
"hardware_config": hardware_config,
|
||||
"device": device,
|
||||
"device": glb.DEVICE,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=device, low_cpu_mem_usage=True
|
||||
model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue