lint + formating with black (#158)

* lint + formating with black

* add black as pre commit
This commit is contained in:
Co Tran 2024-10-09 11:25:07 -07:00 committed by GitHub
parent 498e7f9724
commit 5c4a6bc8ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 581 additions and 295 deletions

View file

@ -35,11 +35,13 @@ def start_server():
print("Server is already running. Use 'model_server restart' to restart it.")
sys.exit(1)
print(f"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)")
print(
f"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)"
)
process = subprocess.Popen(
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],
start_new_session=True,
stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to
stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to
stderr=subprocess.DEVNULL, # Suppress standard error. There is a logger that model_server prints to
)

View file

@ -7,7 +7,6 @@ from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForSequenc
def get_device():
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
@ -19,13 +18,15 @@ def get_device():
return device
def load_transformers(model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")):
def load_transformers(
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")
):
print("Loading Embedding Model")
transformers = {}
device = get_device()
transformers["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
transformers["model"] = ORTModelForFeatureExtraction.from_pretrained(
model_name, device_map = device
model_name, device_map=device
)
transformers["model_name"] = model_name
@ -62,7 +63,9 @@ def load_guard_model(
return guard_model
def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")):
def load_zero_shot_models(
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")
):
zero_shot_model = {}
device = get_device()
zero_shot_model["model"] = ORTModelForSequenceClassification.from_pretrained(
@ -81,5 +84,6 @@ def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deb
return zero_shot_model
if __name__ == "__main__":
print(get_device())

View file

@ -7,7 +7,12 @@ from app.load_models import (
get_device,
)
import os
from app.utils import GuardHandler, split_text_into_chunks, load_yaml_config, get_model_server_logger
from app.utils import (
GuardHandler,
split_text_into_chunks,
load_yaml_config,
get_model_server_logger,
)
import torch
import yaml
import string
@ -39,6 +44,7 @@ guard_handler = GuardHandler(toxic_model=None, jailbreak_model=jailbreak_model)
app = FastAPI()
class EmbeddingRequest(BaseModel):
input: str
model: str
@ -84,6 +90,7 @@ async def embedding(req: EmbeddingRequest, res: Response):
}
return {"data": data, "model": req.model, "object": "list", "usage": usage}
class GuardRequest(BaseModel):
input: str
task: str

View file

@ -9,6 +9,7 @@ import logging
logger_instance = None
def load_yaml_config(file_name):
# Load the YAML file from the package
yaml_path = pkg_resources.resource_filename("app", file_name)
@ -138,6 +139,7 @@ class GuardHandler:
}
return result_dict
def get_model_server_logger():
global logger_instance
@ -164,8 +166,8 @@ def get_model_server_logger():
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(log_file_path, mode='w'), # Overwrite logs in file
]
logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in file
],
)
except (PermissionError, OSError) as e:
# Dont' fallback to console logging if there are issues writing to the log file