mirror of
https://github.com/katanemo/plano.git
synced 2026-04-29 19:06:34 +02:00
Improve cli (#179)
This commit is contained in:
parent
ceca0dba28
commit
7d5f760884
18 changed files with 611 additions and 445 deletions
|
|
@ -6,12 +6,15 @@ from optimum.onnxruntime import (
|
|||
ORTModelForFeatureExtraction,
|
||||
ORTModelForSequenceClassification,
|
||||
)
|
||||
import app.commons.utilities as utils
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5"),
|
||||
):
|
||||
print("Loading Embedding Model...")
|
||||
logger.info("Loading Embedding Model...")
|
||||
|
||||
if glb.DEVICE != "cuda":
|
||||
model = ORTModelForFeatureExtraction.from_pretrained(
|
||||
|
|
@ -32,7 +35,7 @@ def get_embedding_model(
|
|||
def get_zero_shot_model(
|
||||
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
|
||||
):
|
||||
print("Loading Zero-shot Model...")
|
||||
logger.info("Loading Zero-shot Model...")
|
||||
|
||||
if glb.DEVICE != "cuda":
|
||||
model = ORTModelForSequenceClassification.from_pretrained(
|
||||
|
|
@ -58,7 +61,7 @@ def get_zero_shot_model(
|
|||
|
||||
|
||||
def get_prompt_guard(model_name, hardware_config="cpu"):
|
||||
print("Loading Guard Model...")
|
||||
logger.info("Loading Guard Model...")
|
||||
|
||||
if hardware_config == "cpu":
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue