Improve cli (#179)

This commit is contained in:
Adil Hafeez 2024-10-10 17:44:41 -07:00 committed by GitHub
parent ceca0dba28
commit 7d5f760884
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 611 additions and 445 deletions

View file

@ -6,12 +6,15 @@ from optimum.onnxruntime import (
ORTModelForFeatureExtraction,
ORTModelForSequenceClassification,
)
import app.commons.utilities as utils
logger = utils.get_model_server_logger()
def get_embedding_model(
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5"),
):
print("Loading Embedding Model...")
logger.info("Loading Embedding Model...")
if glb.DEVICE != "cuda":
model = ORTModelForFeatureExtraction.from_pretrained(
@ -32,7 +35,7 @@ def get_embedding_model(
def get_zero_shot_model(
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
):
print("Loading Zero-shot Model...")
logger.info("Loading Zero-shot Model...")
if glb.DEVICE != "cuda":
model = ORTModelForSequenceClassification.from_pretrained(
@ -58,7 +61,7 @@ def get_zero_shot_model(
def get_prompt_guard(model_name, hardware_config="cpu"):
print("Loading Guard Model...")
logger.info("Loading Guard Model...")
if hardware_config == "cpu":
from optimum.intel import OVModelForSequenceClassification