Improve cli (#179)

This commit is contained in:
Adil Hafeez 2024-10-10 17:44:41 -07:00 committed by GitHub
parent ceca0dba28
commit 7d5f760884
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 611 additions and 445 deletions

View file

@ -10,8 +10,17 @@ from cli.core import (
stop_arch_modelserver,
start_arch,
stop_arch,
stream_gateway_logs,
)
from cli.utils import get_llm_provider_access_keys, load_env_file_to_dict
from cli.consts import KATANEMO_DOCKERHUB_REPO
from cli.utils import getLogger
import multiprocessing
from huggingface_hub import snapshot_download
import joblib
log = getLogger(__name__)
logo = r"""
_ _
@ -39,17 +48,17 @@ MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml"
@click.command()
@click.option(
"--services",
"--service",
default="all",
help="Services to build. Options are all, model_server, archgw. Default is all",
help="Optioanl parameter to specify which service to build. Options are model_server, archgw",
)
def build(services):
def build(service):
"""Build Arch from source. Must be in root of cloned repo."""
if services not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {services}. Exiting")
if service not in ["model_server", "archgw", "all"]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
# Check if /arch/Dockerfile exists
if services == "archgw" or services == "all":
if service == "archgw" or service == "all":
if os.path.exists(ARCHGW_DOCKERFILE):
click.echo("Building archgw image...")
try:
@ -60,7 +69,7 @@ def build(services):
"-f",
ARCHGW_DOCKERFILE,
"-t",
"archgw:latest",
f"{KATANEMO_DOCKERHUB_REPO}:latest",
".",
"--add-host=host.docker.internal:host-gateway",
],
@ -77,7 +86,7 @@ def build(services):
click.echo("archgw image built successfully.")
"""Install the model server dependencies using Poetry."""
if services == "model_server" or services == "all":
if service == "model_server" or service == "all":
# Check if pyproject.toml exists
if os.path.exists(MODEL_SERVER_BUILD_FILE):
click.echo("Installing model server dependencies with Poetry...")
@ -102,17 +111,17 @@ def build(services):
"--path", default=".", help="Path to the directory containing arch_config.yaml"
)
@click.option(
"--services",
"--service",
default="all",
help="Services to start. Options are all, model_server, archgw. Default is all",
help="Service to start. Options are model_server, archgw.",
)
def up(file, path, services):
def up(file, path, service):
"""Starts Arch."""
if services not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {services}. Exiting")
if service not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
if services == "model_server":
if service == "model_server":
start_arch_modelserver()
return
@ -142,7 +151,7 @@ def up(file, path, services):
print(f"Exiting archgw up: {e}")
sys.exit(1)
print("Starting Arch gateway and Arch model server services via docker ")
log.info("Starging arch model server and arch gateway")
# Set the ARCH_CONFIG_FILE environment variable
env_stage = {}
@ -184,7 +193,7 @@ def up(file, path, services):
env.update(env_stage)
env["ARCH_CONFIG_FILE"] = arch_config_file
if services == "archgw":
if service == "archgw":
start_arch(arch_config_file, env)
else:
start_arch_modelserver()
@ -193,19 +202,19 @@ def up(file, path, services):
@click.command()
@click.option(
"--services",
"--service",
default="all",
help="Services to down. Options are all, model_server, archgw. Default is all",
help="Service to down. Options are all, model_server, archgw. Default is all",
)
def down(services):
def down(service):
"""Stops Arch."""
if services not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {services}. Exiting")
if service not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
if services == "model_server":
if service == "model_server":
stop_arch_modelserver()
elif services == "archgw":
elif service == "archgw":
stop_arch()
else:
stop_arch_modelserver()
@ -234,9 +243,74 @@ def generate_prompt_targets(file):
targets.generate_prompt_targets(file)
def stream_model_server_logs(follow):
log_file = "~/archgw_logs/modelserver.log"
log_file_expanded = os.path.expanduser(log_file)
stream_command = ["tail"]
if follow:
stream_command.append("-f")
stream_command.append(log_file_expanded)
subprocess.run(
stream_command,
check=True,
stdout=sys.stdout,
stderr=sys.stderr,
)
@click.command()
@click.option(
"--service",
default="all",
help="Service to monitor. By default it will monitor both gateway and model_serve",
)
@click.option("--follow", help="Follow the logs", is_flag=True)
def logs(service, follow):
"""Stream logs from arch services."""
if service not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
archgw_process = None
if service == "archgw" or service == "all":
archgw_process = multiprocessing.Process(
target=stream_gateway_logs, args=(follow,)
)
archgw_process.start()
model_server_process = None
if service == "model_server" or service == "all":
model_server_process = multiprocessing.Process(
target=stream_model_server_logs, args=(follow,)
)
model_server_process.start()
if archgw_process:
archgw_process.join()
if model_server_process:
model_server_process.join()
model_list = [
"katanemo/Arch-Guard-cpu",
"katanemo/Arch-Guard",
"katanemo/bge-large-en-v1.5",
]
@click.command()
def download_models():
"""Download required models from Hugging Face Hub in the cache directory"""
for model in model_list:
log.info(f"Downloading model: {model}")
snapshot_download(repo_id=model)
main.add_command(up)
main.add_command(down)
main.add_command(build)
main.add_command(logs)
main.add_command(download_models)
main.add_command(generate_prompt_targets)
if __name__ == "__main__":