Update cli and related utils

This commit is contained in:
Shuguang Chen 2024-12-08 11:16:34 -08:00
parent b4f4695f16
commit 320f4612b8
6 changed files with 267 additions and 215 deletions

View file

@ -1,52 +1,22 @@
import importlib
import sys
import time
import requests
import subprocess
import logging
import argparse
def get_version():
try:
version = importlib.metadata.version("archgw_modelserver")
return version
except importlib.metadata.PackageNotFoundError:
return "version not found"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
from src.commons.globals import logger
from src.commons.utils import (
get_version,
wait_for_health_check,
check_lsof,
install_lsof,
find_process_by_port,
kill_process_by_port,
)
log = logging.getLogger("model_server.cli")
log.setLevel(logging.INFO)
log.info(f"model server version: {get_version()}")
def run_server(port=51000):
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
if len(sys.argv) > 1:
action = sys.argv[1]
else:
action = "start"
if action == "start":
start_server(port)
elif action == "stop":
stop_server(port)
elif action == "restart":
restart_server(port)
else:
log.info(f"Unknown action: {action}")
sys.exit(1)
def start_server(port=51000):
"""Start the Uvicorn server"""
log.info(
"starting model server - loading some awesomeness, this may take some time :)"
)
"""Start the Uvicorn server."""
logger.info("Starting model server - loading some awesomeness, please wait...")
process = subprocess.Popen(
[
@ -57,119 +27,82 @@ def start_server(port=51000):
"--host",
"0.0.0.0",
"--port",
f"{port}",
str(port),
],
start_new_session=True,
bufsize=1,
universal_newlines=True,
stdout=subprocess.PIPE, # Suppress standard output. There is a logger that model_server prints to
stderr=subprocess.PIPE, # Suppress standard error. There is a logger that model_server prints to
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"):
log.info(f"Model server started with PID {process.pid}")
logger.info(f"Model server started successfully with PID {process.pid}.")
else:
# Add model_server boot-up logs
log.info("model server - didn't start in time, shutting down")
logger.error("Model server failed to start in time, shutting it down.")
process.terminate()
def wait_for_health_check(url, timeout=300):
"""Wait for the Uvicorn server to respond to health-check requests."""
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(url)
if response.status_code == 200:
return True
except requests.ConnectionError:
time.sleep(1)
print("Timed out waiting for model server to respond.")
return False
def check_and_install_lsof():
"""Check if lsof is installed, and if not, install it using apt-get."""
try:
# Check if lsof is installed by running "lsof -v"
subprocess.run(
["lsof", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
print("lsof is already installed.")
except subprocess.CalledProcessError:
print("lsof not found, installing...")
try:
# Update package list and install lsof
subprocess.run(["sudo", "apt-get", "update"], check=True)
subprocess.run(["sudo", "apt-get", "install", "-y", "lsof"], check=True)
print("lsof installed successfully.")
except subprocess.CalledProcessError as install_error:
print(f"Failed to install lsof: {install_error}")
def kill_process(port=51000, wait=True, timeout=10):
"""Stop the running Uvicorn server."""
log.info("Stopping model server")
try:
# Run the function to check and install lsof if necessary
# Step 1: Run lsof command to get the process using the port
lsof_command = f"lsof -n | grep {port} | grep -i LISTEN"
result = subprocess.run(
lsof_command, shell=True, capture_output=True, text=True
)
if result.returncode != 0:
print(f"No process found listening on port {port}.")
return
# Step 2: Parse the process IDs from the output
process_ids = [line.split()[1] for line in result.stdout.splitlines()]
if not process_ids:
print(f"No process found listening on port {port}.")
return
# Step 3: Kill each process using its PID
for pid in process_ids:
print(f"Killing model server process with PID {pid}")
subprocess.run(f"kill {pid}", shell=True)
if wait:
# Step 4: Wait for the process to be killed by checking if it's still running
start_time = time.time()
while True:
check_process = subprocess.run(
f"ps -p {pid}", shell=True, capture_output=True, text=True
)
if check_process.returncode != 0:
print(f"Process {pid} has been killed.")
break
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
print(
f"Process {pid} did not terminate within {timeout} seconds."
)
print(f"Attempting to force kill process {pid}...")
subprocess.run(f"kill -9 {pid}", shell=True) # SIGKILL
break
print(
f"Waiting for process {pid} to be killed... ({elapsed_time:.2f} seconds)"
)
time.sleep(0.5)
except Exception as e:
print(f"Error occurred: {e}")
def stop_server(port=51000, wait=True, timeout=10):
check_and_install_lsof()
kill_process(port, wait, timeout)
"""Stop the Uvicorn server."""
if check_lsof():
logger.info("`lsof` is already installed.")
else:
logger.info("`lsof` not found, attempting to install...")
if install_lsof():
logger.info("`lsof` installed successfully.")
else:
logger.error("Failed to install `lsof`.")
sys.exit(1)
logger.info(f"Stopping processes on port {port}...")
port_processes = find_process_by_port(port)
if port_processes is None:
logger.info(f"No processes found listening on port {port}.")
else:
if len(port_processes):
process_killed = kill_process_by_port(port_processes, wait, timeout)
if not process_killed:
logger.error(f"Unable to kill all processes on {port}")
else:
logger.error(f"Unable to find processes on {port}")
def restart_server(port=51000):
"""Restart the Uvicorn server."""
stop_server(port)
start_server(port)
def main():
"""
Start, stop, or restart the Uvicorn server based on command-line arguments.
"""
parser = argparse.ArgumentParser(description="Manage the Uvicorn server.")
parser.add_argument(
"action",
choices=["start", "stop", "restart"],
default="start",
nargs="?",
help="Action to perform on the server (default: start).",
)
parser.add_argument(
"--port",
type=int,
default=51000,
help="Port number for the server (default: 51000).",
)
args = parser.parse_args()
logger.info(f"Model server version: {get_version()}")
if args.action == "start":
start_server(args.port)
elif args.action == "stop":
stop_server(args.port)
elif args.action == "restart":
restart_server(args.port)
else:
logger.error(f"Unknown action: {args.action}")
sys.exit(1)

View file

@ -1,13 +1,11 @@
import src.commons.utilities as utils
from openai import OpenAI
from src.commons.constants import *
from src.core.function_calling import ArchIntentHandler, ArchFunctionHandler
from src.core.guardrails import get_guardrail_handler
from src.commons.utils import get_model_server_logger
logger = utils.get_model_server_logger()
logger = get_model_server_logger()
# Define the client
ARCH_ENDPOINT = "https://api.fc.archgw.com/v1"

View file

@ -1,65 +0,0 @@
import os
import torch
import logging
logger_instance = None
def get_device():
available_device = {
"cpu": True,
"cuda": torch.cuda.is_available(),
"mps": (
torch.backends.mps.is_available()
if hasattr(torch.backends, "mps")
else False
),
}
if available_device["cuda"]:
device = "cuda"
elif available_device["mps"]:
device = "mps"
else:
device = "cpu"
return device
def get_model_server_logger():
global logger_instance
if logger_instance is not None:
# If the logger is already initialized, return the existing instance
return logger_instance
# Define log file path outside current directory (e.g., ~/archgw_logs)
log_dir = os.path.expanduser("~/archgw_logs")
log_file = "modelserver.log"
log_file_path = os.path.join(log_dir, log_file)
# Ensure the log directory exists, create it if necessary, handle permissions errors
try:
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True) # Create directory if it doesn't exist
# Check if the script has write permission in the log directory
if not os.access(log_dir, os.W_OK):
raise PermissionError(f"No write permission for the directory: {log_dir}")
# Configure logging to file and console using basicConfig
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in file
],
)
except (PermissionError, OSError):
# Dont' fallback to console logging if there are issues writing to the log file
raise RuntimeError(f"No write permission for the directory: {log_dir}")
# Initialize the logger instance after configuring handlers
logger_instance = logging.getLogger("model_server_logger")
return logger_instance

View file

@ -0,0 +1,186 @@
import os
import sys
import time
import torch
import logging
import requests
import subprocess
import importlib
PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Default log directory and file
DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, "logs")
DEFAULT_LOG_FILE = "modelserver.log"
def get_version():
try:
version = importlib.metadata.version("archgw_modelserver")
return version
except importlib.metadata.PackageNotFoundError:
return "version not found"
def get_device():
available_device = {
"cpu": True,
"cuda": torch.cuda.is_available(),
"mps": (
torch.backends.mps.is_available()
if hasattr(torch.backends, "mps")
else False
),
}
if available_device["cuda"]:
device = "cuda"
elif available_device["mps"]:
device = "mps"
else:
device = "cpu"
return device
def get_model_server_logger(log_dir=None, log_file=None):
"""
Get or initialize the logger instance for the model server.
Parameters:
- log_dir (str): Custom directory to store the log file. Defaults to `~/archgw_logs`.
- log_file (str): Custom log file name. Defaults to `modelserver.log`.
Returns:
- logging.Logger: Configured logger instance.
"""
log_dir = log_dir or DEFAULT_LOG_DIR
log_file = log_file or DEFAULT_LOG_FILE
log_file_path = os.path.join(log_dir, log_file)
# Check if the logger is already configured
logger = logging.getLogger("model_server_logger")
if logger.hasHandlers():
# Return existing logger instance if already configured
return logger
# Ensure the log directory exists, create it if necessary
try:
# Create directory if it doesn't exist
os.makedirs(log_dir, exist_ok=True)
# Check for write permissions
if not os.access(log_dir, os.W_OK):
raise PermissionError(f"No write permission for the directory: {log_dir}")
except (PermissionError, OSError) as e:
raise RuntimeError(f"Failed to initialize logger: {e}")
# Configure logging to file
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in the file
logging.StreamHandler(), # Also log to console
],
)
return logger
def wait_for_health_check(url, timeout=300):
"""Wait for the Uvicorn server to respond to health-check requests."""
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(url)
if response.status_code == 200:
return True
except requests.ConnectionError:
time.sleep(1)
return False
def check_lsof():
"""Check if lsof is installed or not"""
try:
subprocess.run(
["lsof", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
return True
except subprocess.CalledProcessError:
return False
def install_lsof():
"""Install lsof using apt-get."""
try:
subprocess.run(["sudo", "apt-get", "update"], check=True)
subprocess.run(["sudo", "apt-get", "install", "-y", "lsof"], check=True)
return True
except subprocess.CalledProcessError:
return False
sys.exit(1)
def terminate_process_by_pid(pid, timeout):
"""Terminate a process to terminate."""
start_time = time.time()
while time.time() - start_time < timeout:
result = subprocess.run(
["ps", "-p", str(pid)], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
if result.returncode != 0:
print.info(f"Process {pid} terminated successfully.")
return
time.sleep(0.5)
print.warning(
f"Process {pid} did not terminate within {timeout} seconds. Force killing..."
)
subprocess.run(["kill", "-9", str(pid)], check=False)
def find_process_by_port(port=51000):
"""Find processes listening on a specific port."""
port_processes = []
try:
lsof_command = f"lsof -n -i:{port} | grep LISTEN"
result = subprocess.run(
lsof_command, shell=True, capture_output=True, text=True
)
if result.returncode != 0 or not result.stdout.strip():
return None
else:
port_processes = result.stdout.splitlines()
return port_processes
except Exception:
return []
def kill_process_by_port(port_processes=51000, wait=True, timeout=10):
"""Kill processes on a specific port."""
try:
# Extract process IDs from lsof output
process_ids = [line.split()[1] for line in port_processes]
for pid in process_ids:
print(f"Killing process with PID {pid}...")
subprocess.run(["kill", pid], check=False)
if wait:
terminate_process_by_pid(pid, timeout)
return True
except Exception:
return False

View file

@ -1,12 +1,12 @@
import time
import torch
import numpy as np
import src.commons.utilities as utils
import src.commons.utils as utils
from typing import List
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from optimum.intel import OVModelForSequenceClassification
from typing import List
class GuardRequest(BaseModel):