fix model server stop process (#217)

* fix model server stop process

* replace

* replace

* add test

* add multiple pids test

* add check install for linux

* reformat
This commit is contained in:
CTran 2024-10-24 19:21:47 -07:00 committed by GitHub
parent ff6e9bd9bd
commit 25dddcbfd9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 141 additions and 48 deletions

View file

@ -15,11 +15,8 @@ logging.basicConfig(
log = logging.getLogger("model_server.cli")
log.setLevel(logging.INFO)
# Path to the file where the server process ID will be stored
PID_FILE = os.path.join(tempfile.gettempdir(), "model_server.pid")
def run_server():
def run_server(port=51000):
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
if len(sys.argv) > 1:
action = sys.argv[1]
@ -27,22 +24,18 @@ def run_server():
action = "start"
if action == "start":
start_server()
start_server(port)
elif action == "stop":
stop_server()
stop_server(port)
elif action == "restart":
restart_server()
restart_server(port)
else:
log.info(f"Unknown action: {action}")
sys.exit(1)
def start_server():
"""Start the Uvicorn server and save the process ID."""
if os.path.exists(PID_FILE):
log.info("Server is already running. Use 'model_server restart' to restart it.")
sys.exit(1)
def start_server(port=51000):
"""Start the Uvicorn server"""
log.info(
"Starting model server - loading some awesomeness, this may take some time :)"
)
@ -55,7 +48,7 @@ def start_server():
"--host",
"0.0.0.0",
"--port",
"51000",
f"{port}",
],
start_new_session=True,
bufsize=1,
@ -64,10 +57,7 @@ def start_server():
stderr=subprocess.PIPE, # Suppress standard error. There is a logger that model_server prints to
)
if wait_for_health_check("http://0.0.0.0:51000/healthz"):
# Write the process ID to the PID file
with open(PID_FILE, "w") as f:
f.write(str(process.pid))
if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"):
log.info(f"Model server started with PID {process.pid}")
else:
# Add model_server boot-up logs
@ -89,40 +79,88 @@ def wait_for_health_check(url, timeout=180):
return False
def stop_server():
def check_and_install_lsof():
"""Check if lsof is installed, and if not, install it using apt-get."""
try:
# Check if lsof is installed by running "lsof -v"
subprocess.run(
["lsof", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
print("lsof is already installed.")
except subprocess.CalledProcessError:
print("lsof not found, installing...")
try:
# Update package list and install lsof
subprocess.run(["sudo", "apt-get", "update"], check=True)
subprocess.run(["sudo", "apt-get", "install", "-y", "lsof"], check=True)
print("lsof installed successfully.")
except subprocess.CalledProcessError as install_error:
print(f"Failed to install lsof: {install_error}")
def kill_process(port=51000, wait=True, timeout=10):
"""Stop the running Uvicorn server."""
log.info("Stopping model server")
if not os.path.exists(PID_FILE):
log.info("Process id file not found, seems like model server was not running")
return
# Read the process ID from the PID file
with open(PID_FILE, "r") as f:
pid = int(f.read())
try:
# Get process by PID
process = psutil.Process(pid)
# Gracefully terminate the process
process.terminate() # Sends SIGTERM by default
process.wait(timeout=10) # Wait for up to 10 seconds for the process to exit
log.info(f"Model server with PID {pid} stopped.")
os.remove(PID_FILE)
except psutil.NoSuchProcess:
log.info(f"Model server with PID {pid} not found. Cleaning up PID file.")
os.remove(PID_FILE)
except psutil.TimeoutExpired:
log.info(
f"Model server with PID {pid} did not terminate in time. Forcing shutdown."
# Run the function to check and install lsof if necessary
# Step 1: Run lsof command to get the process using the port
lsof_command = f"lsof -n | grep {port} | grep -i LISTEN"
result = subprocess.run(
lsof_command, shell=True, capture_output=True, text=True
)
process.kill() # Forcefully kill the process
os.remove(PID_FILE)
if result.returncode != 0:
print(f"No process found listening on port {port}.")
return
# Step 2: Parse the process IDs from the output
process_ids = [line.split()[1] for line in result.stdout.splitlines()]
if not process_ids:
print(f"No process found listening on port {port}.")
return
# Step 3: Kill each process using its PID
for pid in process_ids:
print(f"Killing model server process with PID {pid}")
subprocess.run(f"kill {pid}", shell=True)
if wait:
# Step 4: Wait for the process to be killed by checking if it's still running
start_time = time.time()
while True:
check_process = subprocess.run(
f"ps -p {pid}", shell=True, capture_output=True, text=True
)
if check_process.returncode != 0:
print(f"Process {pid} has been killed.")
break
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
print(
f"Process {pid} did not terminate within {timeout} seconds."
)
print(f"Attempting to force kill process {pid}...")
subprocess.run(f"kill -9 {pid}", shell=True) # SIGKILL
break
print(
f"Waiting for process {pid} to be killed... ({elapsed_time:.2f} seconds)"
)
time.sleep(0.5)
except Exception as e:
print(f"Error occurred: {e}")
def restart_server():
def stop_server(port=51000, wait=True, timeout=10):
check_and_install_lsof()
kill_process(port, wait, timeout)
def restart_server(port=51000):
"""Restart the Uvicorn server."""
stop_server()
start_server()
stop_server(port)
start_server(port)