mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
88 lines
2.4 KiB
Python
88 lines
2.4 KiB
Python
import os
|
|
import app.commons.globals as glb
|
|
|
|
from transformers import AutoTokenizer, AutoModel, pipeline
|
|
from optimum.onnxruntime import (
|
|
ORTModelForFeatureExtraction,
|
|
ORTModelForSequenceClassification,
|
|
)
|
|
import app.commons.utilities as utils
|
|
|
|
logger = utils.get_model_server_logger()
|
|
|
|
|
|
def get_embedding_model(
|
|
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5"),
|
|
):
|
|
logger.info("Loading Embedding Model...")
|
|
|
|
if glb.DEVICE != "cuda":
|
|
model = ORTModelForFeatureExtraction.from_pretrained(
|
|
model_name, file_name="onnx/model.onnx"
|
|
)
|
|
else:
|
|
model = AutoModel.from_pretrained(model_name, device_map=glb.DEVICE)
|
|
|
|
embedding_model = {
|
|
"model_name": model_name,
|
|
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
|
"model": model,
|
|
}
|
|
|
|
return embedding_model
|
|
|
|
|
|
def get_zero_shot_model(
|
|
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
|
|
):
|
|
logger.info("Loading Zero-shot Model...")
|
|
|
|
if glb.DEVICE != "cuda":
|
|
model = ORTModelForSequenceClassification.from_pretrained(
|
|
model_name, file_name="onnx/model.onnx"
|
|
)
|
|
else:
|
|
model = model_name
|
|
|
|
zero_shot_model = {
|
|
"model_name": model_name,
|
|
"tokenizer": AutoTokenizer.from_pretrained(model_name),
|
|
"model": model,
|
|
}
|
|
|
|
zero_shot_model["pipeline"] = pipeline(
|
|
"zero-shot-classification",
|
|
model=zero_shot_model["model"],
|
|
tokenizer=zero_shot_model["tokenizer"],
|
|
device=glb.DEVICE,
|
|
)
|
|
|
|
return zero_shot_model
|
|
|
|
|
|
def get_prompt_guard(model_name, hardware_config="cpu"):
|
|
logger.info("Loading Guard Model...")
|
|
|
|
if hardware_config == "cpu":
|
|
from optimum.intel import OVModelForSequenceClassification
|
|
|
|
device = "cpu"
|
|
model_class = OVModelForSequenceClassification
|
|
elif hardware_config == "gpu":
|
|
import torch
|
|
from transformers import AutoModelForSequenceClassification
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model_class = AutoModelForSequenceClassification
|
|
|
|
prompt_guard = {
|
|
"hardware_config": hardware_config,
|
|
"device": device,
|
|
"model_name": model_name,
|
|
"tokenizer": AutoTokenizer.from_pretrained(model_name, trust_remote_code=True),
|
|
"model": model_class.from_pretrained(
|
|
model_name, device_map=device, low_cpu_mem_usage=True
|
|
),
|
|
}
|
|
|
|
return prompt_guard
|