Improve cli (#179)

This commit is contained in:
Adil Hafeez 2024-10-10 17:44:41 -07:00 committed by GitHub
parent ceca0dba28
commit 7d5f760884
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 611 additions and 445 deletions

1
arch/tools/cli/consts.py Normal file
View file

@ -0,0 +1 @@
KATANEMO_DOCKERHUB_REPO = "katanemo/archgw"

View file

@ -4,6 +4,37 @@ import time
import pkg_resources
import select
from cli.utils import run_docker_compose_ps, print_service_status, check_services_state
from cli.utils import getLogger
import sys
log = getLogger(__name__)
def stream_gateway_logs(follow):
"""
Stream logs from the arch gateway service.
"""
compose_file = pkg_resources.resource_filename(
__name__, "../config/docker-compose.yaml"
)
log.info("Logs from arch gateway service.")
options = ["docker", "compose", "-p", "arch", "logs"]
if follow:
options.append("-f")
try:
# Run `docker-compose logs` to stream logs from the gateway service
subprocess.run(
options,
cwd=os.path.dirname(compose_file),
check=True,
stdout=sys.stdout,
stderr=sys.stderr,
)
except subprocess.CalledProcessError as e:
log.info(f"Failed to stream logs: {str(e)}")
def start_arch(arch_config_file, env, log_timeout=120):
@ -14,7 +45,7 @@ def start_arch(arch_config_file, env, log_timeout=120):
path (str): The path where the prompt_confi.yml file is located.
log_timeout (int): Time in seconds to show logs before checking for healthy state.
"""
log.info("Starting arch gateway")
compose_file = pkg_resources.resource_filename(
__name__, "../config/docker-compose.yaml"
)
@ -35,9 +66,10 @@ def start_arch(arch_config_file, env, log_timeout=120):
), # Ensure the Docker command runs in the correct path
env=env, # Pass the modified environment
check=True, # Raise an exception if the command fails
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
)
print(f"Arch docker-compose started in detached.")
print("Monitoring `docker-compose ps` logs...")
log.info(f"Arch docker-compose started in detached.")
start_time = time.time()
services_status = {}
@ -51,14 +83,14 @@ def start_arch(arch_config_file, env, log_timeout=120):
# Check if timeout is reached
if elapsed_time > log_timeout:
print(f"Stopping log monitoring after {log_timeout} seconds.")
log.info(f"Stopping log monitoring after {log_timeout} seconds.")
break
current_services_status = run_docker_compose_ps(
compose_file=compose_file, env=env
)
if not current_services_status:
print(
log.info(
"Status for the services could not be detected. Something went wrong. Please run docker logs"
)
break
@ -74,11 +106,11 @@ def start_arch(arch_config_file, env, log_timeout=120):
running_states = ["running", "up"]
if check_services_state(current_services_status, running_states):
print("Arch is up and running!")
log.info("Arch gateway is up and running!")
break
if check_services_state(current_services_status, unhealthy_states):
print(
log.info(
"One or more Arch services are unhealthy. Please run `docker logs` for more information"
)
print_service_status(
@ -92,7 +124,7 @@ def start_arch(arch_config_file, env, log_timeout=120):
services_status[service_name]["State"]
!= current_services_status[service_name]["State"]
):
print(
log.info(
"One or more Arch services have changed state. Printing current state"
)
print_service_status(current_services_status)
@ -101,7 +133,7 @@ def start_arch(arch_config_file, env, log_timeout=120):
services_status = current_services_status
except subprocess.CalledProcessError as e:
print(f"Failed to start Arch: {str(e)}")
log.info(f"Failed to start Arch: {str(e)}")
def stop_arch():
@ -115,17 +147,21 @@ def stop_arch():
__name__, "../config/docker-compose.yaml"
)
log.info("Shutting down arch gateway service.")
try:
# Run `docker-compose down` to shut down all services
subprocess.run(
["docker", "compose", "-p", "arch", "down"],
cwd=os.path.dirname(compose_file),
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
print("Successfully shut down all services.")
log.info("Successfully shut down arch gateway service.")
except subprocess.CalledProcessError as e:
print(f"Failed to shut down services: {str(e)}")
log.info(f"Failed to shut down services: {str(e)}")
def start_arch_modelserver():
@ -134,12 +170,13 @@ def start_arch_modelserver():
"""
try:
log.info("archgw_modelserver restart")
subprocess.run(
["archgw_modelserver", "restart"], check=True, start_new_session=True
)
print("Successfull run the archgw model_server")
log.info("Successfull ran model_server")
except subprocess.CalledProcessError as e:
print(f"Failed to start model_server. Please check archgw_modelserver logs")
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1)
@ -153,7 +190,7 @@ def stop_arch_modelserver():
["archgw_modelserver", "stop"],
check=True,
)
print("Successfull stopped the archgw model_server")
log.info("Successfull stopped the archgw model_server")
except subprocess.CalledProcessError as e:
print(f"Failed to start model_server. Please check archgw_modelserver logs")
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1)

View file

@ -10,8 +10,17 @@ from cli.core import (
stop_arch_modelserver,
start_arch,
stop_arch,
stream_gateway_logs,
)
from cli.utils import get_llm_provider_access_keys, load_env_file_to_dict
from cli.consts import KATANEMO_DOCKERHUB_REPO
from cli.utils import getLogger
import multiprocessing
from huggingface_hub import snapshot_download
import joblib
log = getLogger(__name__)
logo = r"""
_ _
@ -39,17 +48,17 @@ MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml"
@click.command()
@click.option(
"--services",
"--service",
default="all",
help="Services to build. Options are all, model_server, archgw. Default is all",
help="Optioanl parameter to specify which service to build. Options are model_server, archgw",
)
def build(services):
def build(service):
"""Build Arch from source. Must be in root of cloned repo."""
if services not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {services}. Exiting")
if service not in ["model_server", "archgw", "all"]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
# Check if /arch/Dockerfile exists
if services == "archgw" or services == "all":
if service == "archgw" or service == "all":
if os.path.exists(ARCHGW_DOCKERFILE):
click.echo("Building archgw image...")
try:
@ -60,7 +69,7 @@ def build(services):
"-f",
ARCHGW_DOCKERFILE,
"-t",
"archgw:latest",
f"{KATANEMO_DOCKERHUB_REPO}:latest",
".",
"--add-host=host.docker.internal:host-gateway",
],
@ -77,7 +86,7 @@ def build(services):
click.echo("archgw image built successfully.")
"""Install the model server dependencies using Poetry."""
if services == "model_server" or services == "all":
if service == "model_server" or service == "all":
# Check if pyproject.toml exists
if os.path.exists(MODEL_SERVER_BUILD_FILE):
click.echo("Installing model server dependencies with Poetry...")
@ -102,17 +111,17 @@ def build(services):
"--path", default=".", help="Path to the directory containing arch_config.yaml"
)
@click.option(
"--services",
"--service",
default="all",
help="Services to start. Options are all, model_server, archgw. Default is all",
help="Service to start. Options are model_server, archgw.",
)
def up(file, path, services):
def up(file, path, service):
"""Starts Arch."""
if services not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {services}. Exiting")
if service not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
if services == "model_server":
if service == "model_server":
start_arch_modelserver()
return
@ -142,7 +151,7 @@ def up(file, path, services):
print(f"Exiting archgw up: {e}")
sys.exit(1)
print("Starting Arch gateway and Arch model server services via docker ")
log.info("Starging arch model server and arch gateway")
# Set the ARCH_CONFIG_FILE environment variable
env_stage = {}
@ -184,7 +193,7 @@ def up(file, path, services):
env.update(env_stage)
env["ARCH_CONFIG_FILE"] = arch_config_file
if services == "archgw":
if service == "archgw":
start_arch(arch_config_file, env)
else:
start_arch_modelserver()
@ -193,19 +202,19 @@ def up(file, path, services):
@click.command()
@click.option(
"--services",
"--service",
default="all",
help="Services to down. Options are all, model_server, archgw. Default is all",
help="Service to down. Options are all, model_server, archgw. Default is all",
)
def down(services):
def down(service):
"""Stops Arch."""
if services not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {services}. Exiting")
if service not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
if services == "model_server":
if service == "model_server":
stop_arch_modelserver()
elif services == "archgw":
elif service == "archgw":
stop_arch()
else:
stop_arch_modelserver()
@ -234,9 +243,74 @@ def generate_prompt_targets(file):
targets.generate_prompt_targets(file)
def stream_model_server_logs(follow):
log_file = "~/archgw_logs/modelserver.log"
log_file_expanded = os.path.expanduser(log_file)
stream_command = ["tail"]
if follow:
stream_command.append("-f")
stream_command.append(log_file_expanded)
subprocess.run(
stream_command,
check=True,
stdout=sys.stdout,
stderr=sys.stderr,
)
@click.command()
@click.option(
"--service",
default="all",
help="Service to monitor. By default it will monitor both gateway and model_serve",
)
@click.option("--follow", help="Follow the logs", is_flag=True)
def logs(service, follow):
"""Stream logs from arch services."""
if service not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
archgw_process = None
if service == "archgw" or service == "all":
archgw_process = multiprocessing.Process(
target=stream_gateway_logs, args=(follow,)
)
archgw_process.start()
model_server_process = None
if service == "model_server" or service == "all":
model_server_process = multiprocessing.Process(
target=stream_model_server_logs, args=(follow,)
)
model_server_process.start()
if archgw_process:
archgw_process.join()
if model_server_process:
model_server_process.join()
model_list = [
"katanemo/Arch-Guard-cpu",
"katanemo/Arch-Guard",
"katanemo/bge-large-en-v1.5",
]
@click.command()
def download_models():
"""Download required models from Hugging Face Hub in the cache directory"""
for model in model_list:
log.info(f"Downloading model: {model}")
snapshot_download(repo_id=model)
main.add_command(up)
main.add_command(down)
main.add_command(build)
main.add_command(logs)
main.add_command(download_models)
main.add_command(generate_prompt_targets)
if __name__ == "__main__":

View file

@ -5,6 +5,23 @@ import select
import shlex
import yaml
import json
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
def getLogger(name="cli"):
import logging
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
return logger
log = getLogger(__name__)
def run_docker_compose_ps(compose_file, env):
@ -15,7 +32,7 @@ def run_docker_compose_ps(compose_file, env):
path (str): The path where the docker-compose.yml file is located.
"""
try:
# Run `docker-compose ps` to get the health status of each service
# Run `docker compose ps` to get the health status of each service
ps_process = subprocess.Popen(
[
"docker",
@ -38,7 +55,7 @@ def run_docker_compose_ps(compose_file, env):
# Check if there is any error output
if error_output:
print(
log.info(
f"Error while checking service status:\n{error_output}",
file=os.sys.stderr,
)
@ -48,18 +65,18 @@ def run_docker_compose_ps(compose_file, env):
return services
except subprocess.CalledProcessError as e:
print(f"Failed to check service status. Error:\n{e.stderr}")
log.info(f"Failed to check service status. Error:\n{e.stderr}")
return e
# Helper method to print service status
def print_service_status(services):
print(f"{'Service Name':<25} {'State':<20} {'Ports'}")
print("=" * 72)
log.info(f"{'Service Name':<25} {'State':<20} {'Ports'}")
log.info("=" * 72)
for service_name, info in services.items():
status = info["STATE"]
ports = info["PORTS"]
print(f"{service_name:<25} {status:<20} {ports}")
log.info(f"{service_name:<25} {status:<20} {ports}")
# check for states based on the states passed in