remove mode/hardware

This commit is contained in:
cotran 2024-10-16 13:33:47 -07:00
parent b1746b38b4
commit cf612409f0
5 changed files with 18 additions and 36 deletions

View file

@ -18,14 +18,16 @@ arch_function_generation_params = {
"stop_token_ids": [151645], "stop_token_ids": [151645],
} }
arch_guard_model_type = {"cpu": "katanemo/Arch-Guard-cpu", "gpu": "katanemo/Arch-Guard"} arch_guard_model_type = {
"cpu": "katanemo/Arch-Guard-cpu",
"cuda": "katanemo/Arch-Guard",
"mps": "katanemo/Arch-Guard",
}
# Model definition # Model definition
embedding_model = loader.get_embedding_model() embedding_model = loader.get_embedding_model()
zero_shot_model = loader.get_zero_shot_model() zero_shot_model = loader.get_zero_shot_model()
prompt_guard_dict = loader.get_prompt_guard( prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.device])
arch_guard_model_type[glb.HARDWARE], glb.HARDWARE
)
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict) arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)

View file

@ -3,4 +3,3 @@ import app.commons.utilities as utils
DEVICE = utils.get_device() DEVICE = utils.get_device()
MODE = utils.get_serving_mode() MODE = utils.get_serving_mode()
HARDWARE = utils.get_hardware(MODE)

View file

@ -22,9 +22,11 @@ def get_device():
available_device = { available_device = {
"cpu": True, "cpu": True,
"cuda": torch.cuda.is_available(), "cuda": torch.cuda.is_available(),
"mps": torch.backends.mps.is_available() "mps": (
torch.backends.mps.is_available()
if hasattr(torch.backends, "mps") if hasattr(torch.backends, "mps")
else False, else False
),
} }
if available_device["cuda"]: if available_device["cuda"]:
@ -37,24 +39,6 @@ def get_device():
return device return device
def get_serving_mode():
mode = os.getenv("MODE", "cloud")
if mode not in ["cloud", "local-gpu", "local-cpu"]:
raise ValueError(f"Invalid serving mode: {mode}")
return mode
def get_hardware(mode):
if mode == "local-cpu":
hardware = "cpu"
else:
hardware = "gpu" if torch.cuda.is_available() else "cpu"
return hardware
def get_client(endpoint): def get_client(endpoint):
client = OpenAI(base_url=endpoint, api_key="EMPTY") client = OpenAI(base_url=endpoint, api_key="EMPTY")
return client return client

View file

@ -60,28 +60,25 @@ def get_zero_shot_model(
return zero_shot_model return zero_shot_model
def get_prompt_guard(model_name, hardware_config="cpu"): def get_prompt_guard(model_name):
logger.info("Loading Guard Model...") logger.info("Loading Guard Model...")
if hardware_config == "cpu": if glb.DEVICE == "cpu":
from optimum.intel import OVModelForSequenceClassification from optimum.intel import OVModelForSequenceClassification
device = "cpu"
model_class = OVModelForSequenceClassification model_class = OVModelForSequenceClassification
elif hardware_config == "gpu": elif glb.DEVICE == "gpu":
import torch import torch
from transformers import AutoModelForSequenceClassification from transformers import AutoModelForSequenceClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
model_class = AutoModelForSequenceClassification model_class = AutoModelForSequenceClassification
prompt_guard = { prompt_guard = {
"hardware_config": hardware_config, "device": glb.DEVICE,
"device": device,
"model_name": model_name, "model_name": model_name,
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True), "tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
"model": model_class.from_pretrained( "model": model_class.from_pretrained(
model_name, device_map=device, low_cpu_mem_usage=True model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True
), ),
} }

View file

@ -15,11 +15,11 @@ class ArchGuardHanlder:
self.threshold = threshold self.threshold = threshold
def guard_predict(self, input_text): def guard_predict(self, input_text, max_length=512):
start_time = time.perf_counter() start_time = time.perf_counter()
inputs = self.tokenizer( inputs = self.tokenizer(
input_text, truncation=True, max_length=512, return_tensors="pt" input_text, truncation=True, max_length=max_length, return_tensors="pt"
).to(self.device) ).to(self.device)
with torch.no_grad(): with torch.no_grad():