mirror of
https://github.com/katanemo/plano.git
synced 2026-04-29 19:06:34 +02:00
Refactor model server hardware config + add unit tests to load/request to the server (#189)
* remove mode/hardware * add test and pre commit hook * add pytest dependieces * fix format * fix lint * fix precommit * fix pre commit * fix pre commit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit * fix precommit
This commit is contained in:
parent
3bd2ffe9fb
commit
8e54ac20d8
13 changed files with 480 additions and 43 deletions
|
|
@ -7,6 +7,10 @@ from optimum.onnxruntime import (
|
|||
ORTModelForSequenceClassification,
|
||||
)
|
||||
import app.commons.utilities as utils
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
|
@ -60,28 +64,20 @@ def get_zero_shot_model(
|
|||
return zero_shot_model
|
||||
|
||||
|
||||
def get_prompt_guard(model_name, hardware_config="cpu"):
|
||||
def get_prompt_guard(model_name):
|
||||
logger.info("Loading Guard Model...")
|
||||
|
||||
if hardware_config == "cpu":
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
||||
device = "cpu"
|
||||
if glb.DEVICE == "cpu":
|
||||
model_class = OVModelForSequenceClassification
|
||||
elif hardware_config == "gpu":
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
model_class = AutoModelForSequenceClassification
|
||||
|
||||
prompt_guard = {
|
||||
"hardware_config": hardware_config,
|
||||
"device": device,
|
||||
"device": glb.DEVICE,
|
||||
"model_name": model_name,
|
||||
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
||||
"model": model_class.from_pretrained(
|
||||
model_name, device_map=device, low_cpu_mem_usage=True
|
||||
model_name, device_map=glb.DEVICE, low_cpu_mem_usage=True
|
||||
),
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue