From 701187474fc69fc75b48110e2825ceb533a74d65 Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Fri, 4 Oct 2024 13:09:35 -0700 Subject: [PATCH] load_models checks for device before getting the BGE or NLI model loaded in memory. Was defaulting to CPU. And removed gunk for load_sql (#119) Co-authored-by: Salman Paracha --- model_server/app/load_models.py | 45 +++++++++++++-------------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/model_server/app/load_models.py b/model_server/app/load_models.py index 673614a7..9fbafc9e 100644 --- a/model_server/app/load_models.py +++ b/model_server/app/load_models.py @@ -2,23 +2,28 @@ import os import sentence_transformers from transformers import AutoTokenizer, pipeline import sqlite3 -from app.employee_data_generator import generate_employee_data -from app.network_data_generator import ( - generate_device_data, - generate_interface_stats_data, - generate_flow_data, -) +import torch +def get_device(): + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + return device def load_transformers(models=os.getenv("MODELS", "BAAI/bge-large-en-v1.5")): transformers = {} + device = get_device() + print(f"Using device: {device}") for model in models.split(","): - transformers[model] = sentence_transformers.SentenceTransformer(model) + transformers[model] = sentence_transformers.SentenceTransformer(model, device=device) return transformers - def load_guard_model( model_name, hardware_config="cpu", @@ -52,27 +57,11 @@ def load_zero_shot_models( models=os.getenv("ZERO_SHOT_MODELS", "tasksource/deberta-base-long-nli") ): zero_shot_models = {} - + device = get_device() for model in models.split(","): - zero_shot_models[model] = pipeline("zero-shot-classification", model=model) + zero_shot_models[model] = pipeline("zero-shot-classification", model=model, device=device) return zero_shot_models - -def load_sql(): - # Example Usage - conn = sqlite3.connect(":memory:") - - # create and load the employees table - generate_employee_data(conn) - - # create and load the devices table - device_data = generate_device_data(conn) - - # create and load the interface_stats table - generate_interface_stats_data(conn, device_data) - - # create and load the flow table - generate_flow_data(conn, device_data) - - return conn +if __name__ =="__main__": + print(get_device())