lint + formating with black (#158)

* lint + formating with black

* add black as pre commit
This commit is contained in:
Co Tran 2024-10-09 11:25:07 -07:00 committed by GitHub
parent 498e7f9724
commit 5c4a6bc8ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 581 additions and 295 deletions

View file

@ -7,7 +7,6 @@ from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForSequenc
def get_device():
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
@ -19,13 +18,15 @@ def get_device():
return device
def load_transformers(model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")):
def load_transformers(
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")
):
print("Loading Embedding Model")
transformers = {}
device = get_device()
transformers["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
transformers["model"] = ORTModelForFeatureExtraction.from_pretrained(
model_name, device_map = device
model_name, device_map=device
)
transformers["model_name"] = model_name
@ -62,7 +63,9 @@ def load_guard_model(
return guard_model
def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")):
def load_zero_shot_models(
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")
):
zero_shot_model = {}
device = get_device()
zero_shot_model["model"] = ORTModelForSequenceClassification.from_pretrained(
@ -81,5 +84,6 @@ def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deb
return zero_shot_model
if __name__ == "__main__":
print(get_device())