formating and mointoring change (#136)

This commit is contained in:
Co Tran 2024-10-07 15:21:05 -07:00 committed by GitHub
parent 976b2eaae0
commit 93abe553e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 83 additions and 47 deletions

View file

@ -4,6 +4,7 @@ from transformers import AutoTokenizer, pipeline
import sqlite3
import torch
def get_device():
if torch.cuda.is_available():
device = "cuda"
@ -14,14 +15,18 @@ def get_device():
return device
def load_transformers(models=os.getenv("MODELS", "BAAI/bge-large-en-v1.5")):
transformers = {}
device = get_device()
for model in models.split(","):
transformers[model] = sentence_transformers.SentenceTransformer(model, device=device)
transformers[model] = sentence_transformers.SentenceTransformer(
model, device=device
)
return transformers
def load_guard_model(
model_name,
hardware_config="cpu",
@ -57,9 +62,12 @@ def load_zero_shot_models(
zero_shot_models = {}
device = get_device()
for model in models.split(","):
zero_shot_models[model] = pipeline("zero-shot-classification", model=model, device=device)
zero_shot_models[model] = pipeline(
"zero-shot-classification", model=model, device=device
)
return zero_shot_models
if __name__ =="__main__":
if __name__ == "__main__":
print(get_device())