diff --git a/model_server/src/cli.py b/model_server/src/cli.py index 42b0c341..d863d028 100644 --- a/model_server/src/cli.py +++ b/model_server/src/cli.py @@ -1,52 +1,22 @@ -import importlib import sys -import time -import requests import subprocess -import logging +import argparse - -def get_version(): - try: - version = importlib.metadata.version("archgw_modelserver") - return version - except importlib.metadata.PackageNotFoundError: - return "version not found" - - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +from src.commons.globals import logger +from src.commons.utils import ( + get_version, + wait_for_health_check, + check_lsof, + install_lsof, + find_process_by_port, + kill_process_by_port, ) -log = logging.getLogger("model_server.cli") -log.setLevel(logging.INFO) -log.info(f"model server version: {get_version()}") - - -def run_server(port=51000): - """Start, stop, or restart the Uvicorn server based on command-line arguments.""" - if len(sys.argv) > 1: - action = sys.argv[1] - else: - action = "start" - - if action == "start": - start_server(port) - elif action == "stop": - stop_server(port) - elif action == "restart": - restart_server(port) - else: - log.info(f"Unknown action: {action}") - sys.exit(1) - def start_server(port=51000): - """Start the Uvicorn server""" - log.info( - "starting model server - loading some awesomeness, this may take some time :)" - ) + """Start the Uvicorn server.""" + + logger.info("Starting model server - loading some awesomeness, please wait...") process = subprocess.Popen( [ @@ -57,119 +27,82 @@ def start_server(port=51000): "--host", "0.0.0.0", "--port", - f"{port}", + str(port), ], start_new_session=True, bufsize=1, universal_newlines=True, - stdout=subprocess.PIPE, # Suppress standard output. There is a logger that model_server prints to - stderr=subprocess.PIPE, # Suppress standard error. There is a logger that model_server prints to + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, ) if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"): - log.info(f"Model server started with PID {process.pid}") + logger.info(f"Model server started successfully with PID {process.pid}.") else: - # Add model_server boot-up logs - log.info("model server - didn't start in time, shutting down") + logger.error("Model server failed to start in time, shutting it down.") process.terminate() -def wait_for_health_check(url, timeout=300): - """Wait for the Uvicorn server to respond to health-check requests.""" - start_time = time.time() - while time.time() - start_time < timeout: - try: - response = requests.get(url) - if response.status_code == 200: - return True - except requests.ConnectionError: - time.sleep(1) - print("Timed out waiting for model server to respond.") - return False - - -def check_and_install_lsof(): - """Check if lsof is installed, and if not, install it using apt-get.""" - try: - # Check if lsof is installed by running "lsof -v" - subprocess.run( - ["lsof", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - print("lsof is already installed.") - except subprocess.CalledProcessError: - print("lsof not found, installing...") - try: - # Update package list and install lsof - subprocess.run(["sudo", "apt-get", "update"], check=True) - subprocess.run(["sudo", "apt-get", "install", "-y", "lsof"], check=True) - print("lsof installed successfully.") - except subprocess.CalledProcessError as install_error: - print(f"Failed to install lsof: {install_error}") - - -def kill_process(port=51000, wait=True, timeout=10): - """Stop the running Uvicorn server.""" - log.info("Stopping model server") - try: - # Run the function to check and install lsof if necessary - # Step 1: Run lsof command to get the process using the port - lsof_command = f"lsof -n | grep {port} | grep -i LISTEN" - result = subprocess.run( - lsof_command, shell=True, capture_output=True, text=True - ) - - if result.returncode != 0: - print(f"No process found listening on port {port}.") - return - - # Step 2: Parse the process IDs from the output - process_ids = [line.split()[1] for line in result.stdout.splitlines()] - - if not process_ids: - print(f"No process found listening on port {port}.") - return - - # Step 3: Kill each process using its PID - for pid in process_ids: - print(f"Killing model server process with PID {pid}") - subprocess.run(f"kill {pid}", shell=True) - - if wait: - # Step 4: Wait for the process to be killed by checking if it's still running - start_time = time.time() - - while True: - check_process = subprocess.run( - f"ps -p {pid}", shell=True, capture_output=True, text=True - ) - if check_process.returncode != 0: - print(f"Process {pid} has been killed.") - break - - elapsed_time = time.time() - start_time - if elapsed_time > timeout: - print( - f"Process {pid} did not terminate within {timeout} seconds." - ) - print(f"Attempting to force kill process {pid}...") - subprocess.run(f"kill -9 {pid}", shell=True) # SIGKILL - break - - print( - f"Waiting for process {pid} to be killed... ({elapsed_time:.2f} seconds)" - ) - time.sleep(0.5) - - except Exception as e: - print(f"Error occurred: {e}") - - def stop_server(port=51000, wait=True, timeout=10): - check_and_install_lsof() - kill_process(port, wait, timeout) + """Stop the Uvicorn server.""" + if check_lsof(): + logger.info("`lsof` is already installed.") + else: + logger.info("`lsof` not found, attempting to install...") + if install_lsof(): + logger.info("`lsof` installed successfully.") + else: + logger.error("Failed to install `lsof`.") + sys.exit(1) + + logger.info(f"Stopping processes on port {port}...") + port_processes = find_process_by_port(port) + if port_processes is None: + logger.info(f"No processes found listening on port {port}.") + else: + if len(port_processes): + process_killed = kill_process_by_port(port_processes, wait, timeout) + if not process_killed: + logger.error(f"Unable to kill all processes on {port}") + else: + logger.error(f"Unable to find processes on {port}") def restart_server(port=51000): """Restart the Uvicorn server.""" stop_server(port) start_server(port) + + +def main(): + """ + Start, stop, or restart the Uvicorn server based on command-line arguments. + """ + parser = argparse.ArgumentParser(description="Manage the Uvicorn server.") + parser.add_argument( + "action", + choices=["start", "stop", "restart"], + default="start", + nargs="?", + help="Action to perform on the server (default: start).", + ) + parser.add_argument( + "--port", + type=int, + default=51000, + help="Port number for the server (default: 51000).", + ) + + args = parser.parse_args() + + logger.info(f"Model server version: {get_version()}") + + if args.action == "start": + start_server(args.port) + elif args.action == "stop": + stop_server(args.port) + elif args.action == "restart": + restart_server(args.port) + else: + logger.error(f"Unknown action: {args.action}") + sys.exit(1) diff --git a/model_server/src/commons/globals.py b/model_server/src/commons/globals.py index c29bc35b..0dadb2a3 100644 --- a/model_server/src/commons/globals.py +++ b/model_server/src/commons/globals.py @@ -1,13 +1,11 @@ -import src.commons.utilities as utils - from openai import OpenAI from src.commons.constants import * from src.core.function_calling import ArchIntentHandler, ArchFunctionHandler from src.core.guardrails import get_guardrail_handler +from src.commons.utils import get_model_server_logger -logger = utils.get_model_server_logger() - +logger = get_model_server_logger() # Define the client ARCH_ENDPOINT = "https://api.fc.archgw.com/v1" diff --git a/model_server/src/commons/utilities.py b/model_server/src/commons/utilities.py deleted file mode 100644 index 0ef1a18f..00000000 --- a/model_server/src/commons/utilities.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import torch -import logging - - -logger_instance = None - - -def get_device(): - available_device = { - "cpu": True, - "cuda": torch.cuda.is_available(), - "mps": ( - torch.backends.mps.is_available() - if hasattr(torch.backends, "mps") - else False - ), - } - - if available_device["cuda"]: - device = "cuda" - elif available_device["mps"]: - device = "mps" - else: - device = "cpu" - - return device - - -def get_model_server_logger(): - global logger_instance - - if logger_instance is not None: - # If the logger is already initialized, return the existing instance - return logger_instance - - # Define log file path outside current directory (e.g., ~/archgw_logs) - log_dir = os.path.expanduser("~/archgw_logs") - log_file = "modelserver.log" - log_file_path = os.path.join(log_dir, log_file) - - # Ensure the log directory exists, create it if necessary, handle permissions errors - try: - if not os.path.exists(log_dir): - os.makedirs(log_dir, exist_ok=True) # Create directory if it doesn't exist - - # Check if the script has write permission in the log directory - if not os.access(log_dir, os.W_OK): - raise PermissionError(f"No write permission for the directory: {log_dir}") - # Configure logging to file and console using basicConfig - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", - handlers=[ - logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in file - ], - ) - except (PermissionError, OSError): - # Dont' fallback to console logging if there are issues writing to the log file - raise RuntimeError(f"No write permission for the directory: {log_dir}") - - # Initialize the logger instance after configuring handlers - logger_instance = logging.getLogger("model_server_logger") - return logger_instance diff --git a/model_server/src/commons/utils.py b/model_server/src/commons/utils.py new file mode 100644 index 00000000..5616b369 --- /dev/null +++ b/model_server/src/commons/utils.py @@ -0,0 +1,186 @@ +import os +import sys +import time +import torch +import logging +import requests +import subprocess +import importlib + + +PROJ_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +# Default log directory and file +DEFAULT_LOG_DIR = os.path.join(PROJ_DIR, "logs") +DEFAULT_LOG_FILE = "modelserver.log" + + +def get_version(): + try: + version = importlib.metadata.version("archgw_modelserver") + return version + except importlib.metadata.PackageNotFoundError: + return "version not found" + + +def get_device(): + available_device = { + "cpu": True, + "cuda": torch.cuda.is_available(), + "mps": ( + torch.backends.mps.is_available() + if hasattr(torch.backends, "mps") + else False + ), + } + + if available_device["cuda"]: + device = "cuda" + elif available_device["mps"]: + device = "mps" + else: + device = "cpu" + + return device + + +def get_model_server_logger(log_dir=None, log_file=None): + """ + Get or initialize the logger instance for the model server. + + Parameters: + - log_dir (str): Custom directory to store the log file. Defaults to `~/archgw_logs`. + - log_file (str): Custom log file name. Defaults to `modelserver.log`. + + Returns: + - logging.Logger: Configured logger instance. + """ + log_dir = log_dir or DEFAULT_LOG_DIR + log_file = log_file or DEFAULT_LOG_FILE + log_file_path = os.path.join(log_dir, log_file) + + # Check if the logger is already configured + logger = logging.getLogger("model_server_logger") + if logger.hasHandlers(): + # Return existing logger instance if already configured + return logger + + # Ensure the log directory exists, create it if necessary + try: + # Create directory if it doesn't exist + os.makedirs(log_dir, exist_ok=True) + + # Check for write permissions + if not os.access(log_dir, os.W_OK): + raise PermissionError(f"No write permission for the directory: {log_dir}") + except (PermissionError, OSError) as e: + raise RuntimeError(f"Failed to initialize logger: {e}") + + # Configure logging to file + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in the file + logging.StreamHandler(), # Also log to console + ], + ) + + return logger + + +def wait_for_health_check(url, timeout=300): + """Wait for the Uvicorn server to respond to health-check requests.""" + + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(url) + if response.status_code == 200: + return True + except requests.ConnectionError: + time.sleep(1) + + return False + + +def check_lsof(): + """Check if lsof is installed or not""" + try: + subprocess.run( + ["lsof", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + return True + except subprocess.CalledProcessError: + return False + + +def install_lsof(): + """Install lsof using apt-get.""" + try: + subprocess.run(["sudo", "apt-get", "update"], check=True) + subprocess.run(["sudo", "apt-get", "install", "-y", "lsof"], check=True) + return True + except subprocess.CalledProcessError: + return False + sys.exit(1) + + +def terminate_process_by_pid(pid, timeout): + """Terminate a process to terminate.""" + + start_time = time.time() + while time.time() - start_time < timeout: + result = subprocess.run( + ["ps", "-p", str(pid)], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + if result.returncode != 0: + print.info(f"Process {pid} terminated successfully.") + return + time.sleep(0.5) + + print.warning( + f"Process {pid} did not terminate within {timeout} seconds. Force killing..." + ) + subprocess.run(["kill", "-9", str(pid)], check=False) + + +def find_process_by_port(port=51000): + """Find processes listening on a specific port.""" + + port_processes = [] + + try: + lsof_command = f"lsof -n -i:{port} | grep LISTEN" + result = subprocess.run( + lsof_command, shell=True, capture_output=True, text=True + ) + + if result.returncode != 0 or not result.stdout.strip(): + return None + else: + port_processes = result.stdout.splitlines() + return port_processes + + except Exception: + return [] + + +def kill_process_by_port(port_processes=51000, wait=True, timeout=10): + """Kill processes on a specific port.""" + + try: + # Extract process IDs from lsof output + process_ids = [line.split()[1] for line in port_processes] + for pid in process_ids: + print(f"Killing process with PID {pid}...") + subprocess.run(["kill", pid], check=False) + + if wait: + terminate_process_by_pid(pid, timeout) + + return True + + except Exception: + return False diff --git a/model_server/src/core/guardrails.py b/model_server/src/core/guardrails.py index 0d02d0e5..64e283ae 100644 --- a/model_server/src/core/guardrails.py +++ b/model_server/src/core/guardrails.py @@ -1,12 +1,12 @@ import time import torch import numpy as np -import src.commons.utilities as utils +import src.commons.utils as utils +from typing import List from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification from optimum.intel import OVModelForSequenceClassification -from typing import List class GuardRequest(BaseModel): diff --git a/model_server/tests/test_app.py b/model_server/tests/test_app.py index 12a77772..c9f09770 100644 --- a/model_server/tests/test_app.py +++ b/model_server/tests/test_app.py @@ -8,7 +8,7 @@ from src.main import app client = TestClient(app) -# [TODO] Review: check the following code +# [TODO] Review: check the following code. Seems something wrong with asyncio package❗ # Unit tests for the health check endpoint @pytest.mark.asyncio async def test_healthz(): @@ -17,7 +17,7 @@ async def test_healthz(): assert response.json() == {"status": "ok"} -# [TODO] Review: check the following code +# [TODO] Review: check the following code. Seems something wrong with asyncio package❗ # Unit test for the models endpoint @pytest.mark.asyncio async def test_models(): @@ -27,7 +27,7 @@ async def test_models(): assert len(response.json()["data"]) > 0 -# [TODO] Review: check the following code +# [TODO] Review: check the following code. Seems something wrong with asyncio package❗ # Unit test for the guardrail endpoint @pytest.mark.asyncio async def test_guardrail_endpoint(): @@ -37,7 +37,7 @@ async def test_guardrail_endpoint(): assert "jailbreak_verdict" in response.json() -# [TODO] Review: check the following code +# [TODO] Review: check the following code. Seems something wrong with asyncio package❗ # Unit test for the function calling endpoint @pytest.mark.asyncio async def test_function_calling_endpoint():