Merge branch 'main' of https://github.com/katanemo/arch into cotran/hallu-fix

This commit is contained in:
cotran 2024-10-15 11:25:58 -07:00
commit b8c6bd73af
43 changed files with 865 additions and 644 deletions

View file

@ -1,107 +0,0 @@
import sys
import os
import time
import requests
import psutil
import tempfile
import subprocess
# Path to the file where the server process ID will be stored
PID_FILE = os.path.join(tempfile.gettempdir(), "model_server.pid")
def run_server():
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
if len(sys.argv) > 1:
action = sys.argv[1]
else:
action = "start"
if action == "start":
start_server()
elif action == "stop":
stop_server()
elif action == "restart":
restart_server()
else:
print(f"Unknown action: {action}")
sys.exit(1)
def start_server():
"""Start the Uvicorn server and save the process ID."""
if os.path.exists(PID_FILE):
print("Server is already running. Use 'model_server restart' to restart it.")
sys.exit(1)
print(
"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)"
)
process = subprocess.Popen(
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],
start_new_session=True,
stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to
stderr=subprocess.DEVNULL, # Suppress standard error. There is a logger that model_server prints to
)
if wait_for_health_check("http://0.0.0.0:51000/healthz"):
# Write the process ID to the PID file
with open(PID_FILE, "w") as f:
f.write(str(process.pid))
print(f"Archgw Model Server started with PID {process.pid}")
else:
# Add model_server boot-up logs
print("Archgw Model Server - Didn't Sart In Time. Shutting Down")
process.terminate()
def wait_for_health_check(url, timeout=180):
"""Wait for the Uvicorn server to respond to health-check requests."""
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(url)
if response.status_code == 200:
return True
except requests.ConnectionError:
time.sleep(1)
print("Timed out waiting for Archgw Model Server to respond.")
return False
def stop_server():
"""Stop the running Uvicorn server."""
if not os.path.exists(PID_FILE):
print("Status: Archgw Model Server not running")
return
# Read the process ID from the PID file
with open(PID_FILE, "r") as f:
pid = int(f.read())
try:
# Get process by PID
process = psutil.Process(pid)
# Gracefully terminate the process
process.terminate() # Sends SIGTERM by default
process.wait(timeout=10) # Wait for up to 10 seconds for the process to exit
print(f"Server with PID {pid} stopped.")
os.remove(PID_FILE)
except psutil.NoSuchProcess:
print(f"Process with PID {pid} not found. Cleaning up PID file.")
os.remove(PID_FILE)
except psutil.TimeoutExpired:
print(f"Process with PID {pid} did not terminate in time. Forcing shutdown.")
process.kill() # Forcefully kill the process
os.remove(PID_FILE)
def restart_server():
"""Restart the Uvicorn server."""
print("Check: Is Archgw Model Server running?")
stop_server()
start_server()

119
model_server/app/cli.py Normal file
View file

@ -0,0 +1,119 @@
import sys
import os
import time
import requests
import psutil
import tempfile
import subprocess
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
log = logging.getLogger("model_server.cli")
log.setLevel(logging.INFO)
# Path to the file where the server process ID will be stored
PID_FILE = os.path.join(tempfile.gettempdir(), "model_server.pid")
def run_server():
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
if len(sys.argv) > 1:
action = sys.argv[1]
else:
action = "start"
if action == "start":
start_server()
elif action == "stop":
stop_server()
elif action == "restart":
restart_server()
else:
log.info(f"Unknown action: {action}")
sys.exit(1)
def start_server():
"""Start the Uvicorn server and save the process ID."""
if os.path.exists(PID_FILE):
log.info("Server is already running. Use 'model_server restart' to restart it.")
sys.exit(1)
log.info(
"Starting model server - loading some awesomeness, this may take some time :)"
)
process = subprocess.Popen(
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],
start_new_session=True,
bufsize=1,
universal_newlines=True,
stdout=subprocess.PIPE, # Suppress standard output. There is a logger that model_server prints to
stderr=subprocess.PIPE, # Suppress standard error. There is a logger that model_server prints to
)
if wait_for_health_check("http://0.0.0.0:51000/healthz"):
# Write the process ID to the PID file
with open(PID_FILE, "w") as f:
f.write(str(process.pid))
log.info(f"Model server started with PID {process.pid}")
else:
# Add model_server boot-up logs
log.info("Model server - Didn't Sart In Time. Shutting Down")
process.terminate()
def wait_for_health_check(url, timeout=180):
"""Wait for the Uvicorn server to respond to health-check requests."""
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(url)
if response.status_code == 200:
return True
except requests.ConnectionError:
time.sleep(1)
print("Timed out waiting for model server to respond.")
return False
def stop_server():
"""Stop the running Uvicorn server."""
log.info("Stopping model server")
if not os.path.exists(PID_FILE):
log.info("Process id file not found, seems like model server was not running")
return
# Read the process ID from the PID file
with open(PID_FILE, "r") as f:
pid = int(f.read())
try:
# Get process by PID
process = psutil.Process(pid)
# Gracefully terminate the process
process.terminate() # Sends SIGTERM by default
process.wait(timeout=10) # Wait for up to 10 seconds for the process to exit
log.info(f"Model server with PID {pid} stopped.")
os.remove(PID_FILE)
except psutil.NoSuchProcess:
log.info(f"Model server with PID {pid} not found. Cleaning up PID file.")
os.remove(PID_FILE)
except psutil.TimeoutExpired:
log.info(
f"Model server with PID {pid} did not terminate in time. Forcing shutdown."
)
process.kill() # Forcefully kill the process
os.remove(PID_FILE)
def restart_server():
"""Restart the Uvicorn server."""
stop_server()
start_server()

View file

@ -5,6 +5,7 @@ import app.loader as loader
from app.function_calling.model_handler import ArchFunctionHandler
from app.prompt_guard.model_handler import ArchGuardHanlder
logger = utils.get_model_server_logger()
arch_function_hanlder = ArchFunctionHandler()
arch_function_endpoint = "https://api.fc.archgw.com/v1"
@ -19,7 +20,6 @@ arch_function_generation_params = {
arch_guard_model_type = {"cpu": "katanemo/Arch-Guard-cpu", "gpu": "katanemo/Arch-Guard"}
# Model definition
embedding_model = loader.get_embedding_model()
zero_shot_model = loader.get_zero_shot_model()

View file

@ -6,12 +6,15 @@ from optimum.onnxruntime import (
ORTModelForFeatureExtraction,
ORTModelForSequenceClassification,
)
import app.commons.utilities as utils
logger = utils.get_model_server_logger()
def get_embedding_model(
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5"),
):
print("Loading Embedding Model...")
logger.info("Loading Embedding Model...")
if glb.DEVICE != "cuda":
model = ORTModelForFeatureExtraction.from_pretrained(
@ -32,7 +35,7 @@ def get_embedding_model(
def get_zero_shot_model(
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
):
print("Loading Zero-shot Model...")
logger.info("Loading Zero-shot Model...")
if glb.DEVICE != "cuda":
model = ORTModelForSequenceClassification.from_pretrained(
@ -58,7 +61,7 @@ def get_zero_shot_model(
def get_prompt_guard(model_name, hardware_config="cpu"):
print("Loading Guard Model...")
logger.info("Loading Guard Model...")
if hardware_config == "cpu":
from optimum.intel import OVModelForSequenceClassification

View file

@ -4,7 +4,6 @@ import app.commons.utilities as utils
import app.commons.globals as glb
import app.prompt_guard.model_utils as guard_utils
from typing import List, Dict
from pydantic import BaseModel
from fastapi import FastAPI, Response, HTTPException
@ -17,8 +16,7 @@ from app.function_calling.model_utils import (
logger = utils.get_model_server_logger()
logger.info(f"Devices Avialble: {glb.DEVICE}")
logger.info(f"Ready to serve traffic. available device: {glb.DEVICE}")
app = FastAPI()