mirror of
https://github.com/katanemo/plano.git
synced 2026-06-02 14:35:14 +02:00
remove mode/hardware
This commit is contained in:
parent
b1746b38b4
commit
cf612409f0
5 changed files with 18 additions and 36 deletions
|
|
@ -18,14 +18,16 @@ arch_function_generation_params = {
|
||||||
"stop_token_ids": [151645],
|
"stop_token_ids": [151645],
|
||||||
}
|
}
|
||||||
|
|
||||||
arch_guard_model_type = {"cpu": "katanemo/Arch-Guard-cpu", "gpu": "katanemo/Arch-Guard"}
|
arch_guard_model_type = {
|
||||||
|
"cpu": "katanemo/Arch-Guard-cpu",
|
||||||
|
"cuda": "katanemo/Arch-Guard",
|
||||||
|
"mps": "katanemo/Arch-Guard",
|
||||||
|
}
|
||||||
|
|
||||||
# Model definition
|
# Model definition
|
||||||
embedding_model = loader.get_embedding_model()
|
embedding_model = loader.get_embedding_model()
|
||||||
zero_shot_model = loader.get_zero_shot_model()
|
zero_shot_model = loader.get_zero_shot_model()
|
||||||
|
|
||||||
prompt_guard_dict = loader.get_prompt_guard(
|
prompt_guard_dict = loader.get_prompt_guard(arch_guard_model_type[glb.device])
|
||||||
arch_guard_model_type[glb.HARDWARE], glb.HARDWARE
|
|
||||||
)
|
|
||||||
|
|
||||||
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
|
arch_guard_handler = ArchGuardHanlder(model_dict=prompt_guard_dict)
|
||||||
|
|
|
||||||
|
|
@ -3,4 +3,3 @@ import app.commons.utilities as utils
|
||||||
|
|
||||||
DEVICE = utils.get_device()
|
DEVICE = utils.get_device()
|
||||||
MODE = utils.get_serving_mode()
|
MODE = utils.get_serving_mode()
|
||||||
HARDWARE = utils.get_hardware(MODE)
|
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,11 @@ def get_device():
|
||||||
available_device = {
|
available_device = {
|
||||||
"cpu": True,
|
"cpu": True,
|
||||||
"cuda": torch.cuda.is_available(),
|
"cuda": torch.cuda.is_available(),
|
||||||
"mps": torch.backends.mps.is_available()
|
"mps": (
|
||||||
|
torch.backends.mps.is_available()
|
||||||
if hasattr(torch.backends, "mps")
|
if hasattr(torch.backends, "mps")
|
||||||
else False,
|
else False
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
if available_device["cuda"]:
|
if available_device["cuda"]:
|
||||||
|
|
@ -37,24 +39,6 @@ def get_device():
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
def get_serving_mode():
|
|
||||||
mode = os.getenv("MODE", "cloud")
|
|
||||||
|
|
||||||
if mode not in ["cloud", "local-gpu", "local-cpu"]:
|
|
||||||
raise ValueError(f"Invalid serving mode: {mode}")
|
|
||||||
|
|
||||||
return mode
|
|
||||||
|
|
||||||
|
|
||||||
def get_hardware(mode):
|
|
||||||
if mode == "local-cpu":
|
|
||||||
hardware = "cpu"
|
|
||||||
else:
|
|
||||||
hardware = "gpu" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
return hardware
|
|
||||||
|
|
||||||
|
|
||||||
def get_client(endpoint):
|
def get_client(endpoint):
|
||||||
client = OpenAI(base_url=endpoint, api_key="EMPTY")
|
client = OpenAI(base_url=endpoint, api_key="EMPTY")
|
||||||
return client
|
return client
|
||||||
|
|
|
||||||
|
|
@ -60,28 +60,25 @@ def get_zero_shot_model(
|
||||||
return zero_shot_model
|
return zero_shot_model
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_guard(model_name, hardware_config="cpu"):
|
def get_prompt_guard(model_name):
|
||||||
logger.info("Loading Guard Model...")
|
logger.info("Loading Guard Model...")
|
||||||
|
|
||||||
if hardware_config == "cpu":
|
if glb.DEVICE == "cpu":
|
||||||
from optimum.intel import OVModelForSequenceClassification
|
from optimum.intel import OVModelForSequenceClassification
|
||||||
|
|
||||||
device = "cpu"
|
|
||||||
model_class = OVModelForSequenceClassification
|
model_class = OVModelForSequenceClassification
|
||||||
elif hardware_config == "gpu":
|
elif glb.DEVICE == "gpu":
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForSequenceClassification
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
model_class = AutoModelForSequenceClassification
|
model_class = AutoModelForSequenceClassification
|
||||||
|
|
||||||
prompt_guard = {
|
prompt_guard = {
|
||||||
"hardware_config": hardware_config,
|
"device": glb.DEVICE,
|
||||||
"device": device,
|
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||||
"model": model_class.from_pretrained(
|
"model": model_class.from_pretrained(
|
||||||
model_name, device_map=device, low_cpu_mem_usage=True
|
model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,11 @@ class ArchGuardHanlder:
|
||||||
|
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
|
|
||||||
def guard_predict(self, input_text):
|
def guard_predict(self, input_text, max_length=512):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
inputs = self.tokenizer(
|
inputs = self.tokenizer(
|
||||||
input_text, truncation=True, max_length=512, return_tensors="pt"
|
input_text, truncation=True, max_length=max_length, return_tensors="pt"
|
||||||
).to(self.device)
|
).to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue