Salmanap/fix network agent demo (#153)

* staging my changes to re-based from main

* adding debug statements to rust

* merged with main

* ready to push network agent

* removed the incomplete sql example

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-261.local>
This commit is contained in:
Salman Paracha 2024-10-08 22:19:20 -07:00 committed by GitHub
parent 6acfea7787
commit b63a01fe82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 252 additions and 1987 deletions

View file

@ -7,6 +7,7 @@ from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForSequenc
def get_device():
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
@ -14,10 +15,12 @@ def get_device():
else:
device = "cpu"
print(f"Devices Avialble: {device}")
return device
def load_transformers(model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")):
print("Loading Embedding Model")
transformers = {}
device = get_device()
transformers["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
@ -33,6 +36,7 @@ def load_guard_model(
model_name,
hardware_config="cpu",
):
print("Loading Guard Model")
guard_model = {}
guard_model["tokenizer"] = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True
@ -58,9 +62,7 @@ def load_guard_model(
return guard_model
def load_zero_shot_models(
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")
):
def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")):
zero_shot_model = {}
device = get_device()
zero_shot_model["model"] = ORTModelForSequenceClassification.from_pretrained(
@ -79,6 +81,5 @@ def load_zero_shot_models(
return zero_shot_model
if __name__ == "__main__":
print(get_device())