mirror of
https://github.com/katanemo/plano.git
synced 2026-05-06 14:22:51 +02:00
Merge branch 'main' of https://github.com/katanemo/arch into cotran/hallu-fix
This commit is contained in:
commit
b8c6bd73af
43 changed files with 865 additions and 644 deletions
|
|
@ -1,107 +0,0 @@
|
|||
import sys
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import psutil
|
||||
import tempfile
|
||||
import subprocess
|
||||
|
||||
|
||||
# Path to the file where the server process ID will be stored
|
||||
PID_FILE = os.path.join(tempfile.gettempdir(), "model_server.pid")
|
||||
|
||||
|
||||
def run_server():
|
||||
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
|
||||
if len(sys.argv) > 1:
|
||||
action = sys.argv[1]
|
||||
else:
|
||||
action = "start"
|
||||
|
||||
if action == "start":
|
||||
start_server()
|
||||
elif action == "stop":
|
||||
stop_server()
|
||||
elif action == "restart":
|
||||
restart_server()
|
||||
else:
|
||||
print(f"Unknown action: {action}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def start_server():
|
||||
"""Start the Uvicorn server and save the process ID."""
|
||||
if os.path.exists(PID_FILE):
|
||||
print("Server is already running. Use 'model_server restart' to restart it.")
|
||||
sys.exit(1)
|
||||
|
||||
print(
|
||||
"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)"
|
||||
)
|
||||
process = subprocess.Popen(
|
||||
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],
|
||||
start_new_session=True,
|
||||
stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to
|
||||
stderr=subprocess.DEVNULL, # Suppress standard error. There is a logger that model_server prints to
|
||||
)
|
||||
|
||||
if wait_for_health_check("http://0.0.0.0:51000/healthz"):
|
||||
# Write the process ID to the PID file
|
||||
with open(PID_FILE, "w") as f:
|
||||
f.write(str(process.pid))
|
||||
print(f"Archgw Model Server started with PID {process.pid}")
|
||||
else:
|
||||
# Add model_server boot-up logs
|
||||
print("Archgw Model Server - Didn't Sart In Time. Shutting Down")
|
||||
process.terminate()
|
||||
|
||||
|
||||
def wait_for_health_check(url, timeout=180):
|
||||
"""Wait for the Uvicorn server to respond to health-check requests."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.ConnectionError:
|
||||
time.sleep(1)
|
||||
print("Timed out waiting for Archgw Model Server to respond.")
|
||||
return False
|
||||
|
||||
|
||||
def stop_server():
|
||||
"""Stop the running Uvicorn server."""
|
||||
if not os.path.exists(PID_FILE):
|
||||
print("Status: Archgw Model Server not running")
|
||||
return
|
||||
|
||||
# Read the process ID from the PID file
|
||||
with open(PID_FILE, "r") as f:
|
||||
pid = int(f.read())
|
||||
|
||||
try:
|
||||
# Get process by PID
|
||||
process = psutil.Process(pid)
|
||||
|
||||
# Gracefully terminate the process
|
||||
process.terminate() # Sends SIGTERM by default
|
||||
process.wait(timeout=10) # Wait for up to 10 seconds for the process to exit
|
||||
|
||||
print(f"Server with PID {pid} stopped.")
|
||||
os.remove(PID_FILE)
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
print(f"Process with PID {pid} not found. Cleaning up PID file.")
|
||||
os.remove(PID_FILE)
|
||||
except psutil.TimeoutExpired:
|
||||
print(f"Process with PID {pid} did not terminate in time. Forcing shutdown.")
|
||||
process.kill() # Forcefully kill the process
|
||||
os.remove(PID_FILE)
|
||||
|
||||
|
||||
def restart_server():
|
||||
"""Restart the Uvicorn server."""
|
||||
print("Check: Is Archgw Model Server running?")
|
||||
stop_server()
|
||||
start_server()
|
||||
119
model_server/app/cli.py
Normal file
119
model_server/app/cli.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import sys
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import psutil
|
||||
import tempfile
|
||||
import subprocess
|
||||
import logging
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
log = logging.getLogger("model_server.cli")
|
||||
log.setLevel(logging.INFO)
|
||||
|
||||
# Path to the file where the server process ID will be stored
|
||||
PID_FILE = os.path.join(tempfile.gettempdir(), "model_server.pid")
|
||||
|
||||
|
||||
def run_server():
|
||||
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
|
||||
if len(sys.argv) > 1:
|
||||
action = sys.argv[1]
|
||||
else:
|
||||
action = "start"
|
||||
|
||||
if action == "start":
|
||||
start_server()
|
||||
elif action == "stop":
|
||||
stop_server()
|
||||
elif action == "restart":
|
||||
restart_server()
|
||||
else:
|
||||
log.info(f"Unknown action: {action}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def start_server():
|
||||
"""Start the Uvicorn server and save the process ID."""
|
||||
if os.path.exists(PID_FILE):
|
||||
log.info("Server is already running. Use 'model_server restart' to restart it.")
|
||||
sys.exit(1)
|
||||
|
||||
log.info(
|
||||
"Starting model server - loading some awesomeness, this may take some time :)"
|
||||
)
|
||||
process = subprocess.Popen(
|
||||
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],
|
||||
start_new_session=True,
|
||||
bufsize=1,
|
||||
universal_newlines=True,
|
||||
stdout=subprocess.PIPE, # Suppress standard output. There is a logger that model_server prints to
|
||||
stderr=subprocess.PIPE, # Suppress standard error. There is a logger that model_server prints to
|
||||
)
|
||||
|
||||
if wait_for_health_check("http://0.0.0.0:51000/healthz"):
|
||||
# Write the process ID to the PID file
|
||||
with open(PID_FILE, "w") as f:
|
||||
f.write(str(process.pid))
|
||||
log.info(f"Model server started with PID {process.pid}")
|
||||
else:
|
||||
# Add model_server boot-up logs
|
||||
log.info("Model server - Didn't Sart In Time. Shutting Down")
|
||||
process.terminate()
|
||||
|
||||
|
||||
def wait_for_health_check(url, timeout=180):
|
||||
"""Wait for the Uvicorn server to respond to health-check requests."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.ConnectionError:
|
||||
time.sleep(1)
|
||||
print("Timed out waiting for model server to respond.")
|
||||
return False
|
||||
|
||||
|
||||
def stop_server():
|
||||
"""Stop the running Uvicorn server."""
|
||||
log.info("Stopping model server")
|
||||
if not os.path.exists(PID_FILE):
|
||||
log.info("Process id file not found, seems like model server was not running")
|
||||
return
|
||||
|
||||
# Read the process ID from the PID file
|
||||
with open(PID_FILE, "r") as f:
|
||||
pid = int(f.read())
|
||||
|
||||
try:
|
||||
# Get process by PID
|
||||
process = psutil.Process(pid)
|
||||
|
||||
# Gracefully terminate the process
|
||||
process.terminate() # Sends SIGTERM by default
|
||||
process.wait(timeout=10) # Wait for up to 10 seconds for the process to exit
|
||||
|
||||
log.info(f"Model server with PID {pid} stopped.")
|
||||
os.remove(PID_FILE)
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
log.info(f"Model server with PID {pid} not found. Cleaning up PID file.")
|
||||
os.remove(PID_FILE)
|
||||
except psutil.TimeoutExpired:
|
||||
log.info(
|
||||
f"Model server with PID {pid} did not terminate in time. Forcing shutdown."
|
||||
)
|
||||
process.kill() # Forcefully kill the process
|
||||
os.remove(PID_FILE)
|
||||
|
||||
|
||||
def restart_server():
|
||||
"""Restart the Uvicorn server."""
|
||||
stop_server()
|
||||
start_server()
|
||||
|
|
@ -5,6 +5,7 @@ import app.loader as loader
|
|||
from app.function_calling.model_handler import ArchFunctionHandler
|
||||
from app.prompt_guard.model_handler import ArchGuardHanlder
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
arch_function_hanlder = ArchFunctionHandler()
|
||||
arch_function_endpoint = "https://api.fc.archgw.com/v1"
|
||||
|
|
@ -19,7 +20,6 @@ arch_function_generation_params = {
|
|||
|
||||
arch_guard_model_type = {"cpu": "katanemo/Arch-Guard-cpu", "gpu": "katanemo/Arch-Guard"}
|
||||
|
||||
|
||||
# Model definition
|
||||
embedding_model = loader.get_embedding_model()
|
||||
zero_shot_model = loader.get_zero_shot_model()
|
||||
|
|
|
|||
|
|
@ -6,12 +6,15 @@ from optimum.onnxruntime import (
|
|||
ORTModelForFeatureExtraction,
|
||||
ORTModelForSequenceClassification,
|
||||
)
|
||||
import app.commons.utilities as utils
|
||||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5"),
|
||||
):
|
||||
print("Loading Embedding Model...")
|
||||
logger.info("Loading Embedding Model...")
|
||||
|
||||
if glb.DEVICE != "cuda":
|
||||
model = ORTModelForFeatureExtraction.from_pretrained(
|
||||
|
|
@ -32,7 +35,7 @@ def get_embedding_model(
|
|||
def get_zero_shot_model(
|
||||
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/bart-large-mnli"),
|
||||
):
|
||||
print("Loading Zero-shot Model...")
|
||||
logger.info("Loading Zero-shot Model...")
|
||||
|
||||
if glb.DEVICE != "cuda":
|
||||
model = ORTModelForSequenceClassification.from_pretrained(
|
||||
|
|
@ -58,7 +61,7 @@ def get_zero_shot_model(
|
|||
|
||||
|
||||
def get_prompt_guard(model_name, hardware_config="cpu"):
|
||||
print("Loading Guard Model...")
|
||||
logger.info("Loading Guard Model...")
|
||||
|
||||
if hardware_config == "cpu":
|
||||
from optimum.intel import OVModelForSequenceClassification
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import app.commons.utilities as utils
|
|||
import app.commons.globals as glb
|
||||
import app.prompt_guard.model_utils as guard_utils
|
||||
|
||||
|
||||
from typing import List, Dict
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, Response, HTTPException
|
||||
|
|
@ -17,8 +16,7 @@ from app.function_calling.model_utils import (
|
|||
|
||||
logger = utils.get_model_server_logger()
|
||||
|
||||
logger.info(f"Devices Avialble: {glb.DEVICE}")
|
||||
|
||||
logger.info(f"Ready to serve traffic. available device: {glb.DEVICE}")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "archgw_modelserver"
|
||||
version = "0.0.2"
|
||||
version = "0.0.3"
|
||||
description = "A model server for serving models"
|
||||
authors = ["Katanemo Labs, Inc <archgw@katanemo.com>"]
|
||||
license = "Apache 2.0"
|
||||
|
|
@ -31,7 +31,7 @@ onnx = "1.17.0"
|
|||
onnxruntime = "1.19.2"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
archgw_modelserver = "app:run_server"
|
||||
archgw_modelserver = "app.cli:run_server"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue