fix cli models and logs (#196)

* removing unnecessar setup.py files

* updated the cli for debug and access logs

* ran the pre-commit locally to fix pull request

* fixed bug where if archgw_process is None we didn't handle it gracefully

* Apply suggestions from code review

Co-authored-by: Adil Hafeez <adil@katanemo.com>

* fixed changes based on PR

* fixed version not found message

* fixed message based on PR feedback

* adding poetry lock

* fixed pre-commit

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-261.local>
Co-authored-by: Adil Hafeez <adil@katanemo.com>
This commit is contained in:
Salman Paracha 2024-10-18 12:09:45 -07:00 committed by GitHub
parent 6cd05572c4
commit 6fb63510b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 362 additions and 274 deletions

View file

@ -3,22 +3,28 @@ import os
import pkg_resources
import sys
import subprocess
import multiprocessing
import importlib.metadata
from cli import targets
from cli import config_generator
from cli.utils import getLogger, get_llm_provider_access_keys, load_env_file_to_dict
from cli.core import (
start_arch_modelserver,
stop_arch_modelserver,
start_arch,
stop_arch,
stream_gateway_logs,
stream_model_server_logs,
stream_access_logs,
download_models_from_hf,
)
from cli.consts import (
KATANEMO_DOCKERHUB_REPO,
KATANEMO_LOCAL_MODEL_LIST,
SERVICE_NAME_ARCHGW,
SERVICE_NAME_MODEL_SERVER,
SERVICE_ALL,
)
from cli.utils import get_llm_provider_access_keys, load_env_file_to_dict
from cli.consts import KATANEMO_DOCKERHUB_REPO
from cli.utils import getLogger
import multiprocessing
from huggingface_hub import snapshot_download
import joblib
log = getLogger(__name__)
@ -31,34 +37,46 @@ logo = r"""
"""
# Command to build archgw and model_server Docker images
ARCHGW_DOCKERFILE = "./arch/Dockerfile"
MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml"
def get_version():
try:
version = importlib.metadata.version("archgw")
return version
except importlib.metadata.PackageNotFoundError:
return "version not found"
@click.group(invoke_without_command=True)
@click.option("--version", is_flag=True, help="Show the archgw cli version and exit.")
@click.pass_context
def main(ctx):
def main(ctx, version):
if version:
click.echo(f"archgw cli version: {get_version()}")
ctx.exit()
if ctx.invoked_subcommand is None:
click.echo("""Arch (The Intelligent Prompt Gateway) CLI""")
click.echo(logo)
click.echo(ctx.get_help())
# Command to build archgw and model_server Docker images
ARCHGW_DOCKERFILE = "./arch/Dockerfile"
MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml"
@click.command()
@click.option(
"--service",
default="all",
default=SERVICE_ALL,
help="Optioanl parameter to specify which service to build. Options are model_server, archgw",
)
def build(service):
"""Build Arch from source. Must be in root of cloned repo."""
if service not in ["model_server", "archgw", "all"]:
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
# Check if /arch/Dockerfile exists
if service == "archgw" or service == "all":
if service == SERVICE_NAME_ARCHGW or service == SERVICE_ALL:
if os.path.exists(ARCHGW_DOCKERFILE):
click.echo("Building archgw image...")
try:
@ -86,7 +104,7 @@ def build(service):
click.echo("archgw image built successfully.")
"""Install the model server dependencies using Poetry."""
if service == "model_server" or service == "all":
if service == SERVICE_NAME_MODEL_SERVER or service == SERVICE_ALL:
# Check if pyproject.toml exists
if os.path.exists(MODEL_SERVER_BUILD_FILE):
click.echo("Installing model server dependencies with Poetry...")
@ -112,16 +130,18 @@ def build(service):
)
@click.option(
"--service",
default="all",
default=SERVICE_ALL,
help="Service to start. Options are model_server, archgw.",
)
def up(file, path, service):
"""Starts Arch."""
if service not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {service}. Exiting")
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
log.info(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
if service == "model_server":
if service == SERVICE_NAME_MODEL_SERVER:
log.info("Download archgw models from HuggingFace...")
download_models_from_hf()
start_arch_modelserver()
return
@ -134,10 +154,10 @@ def up(file, path, service):
# Check if the file exists
if not os.path.exists(arch_config_file):
print(f"Error: {arch_config_file} does not exist.")
log.info(f"Error: {arch_config_file} does not exist.")
return
print(f"Validating {arch_config_file}")
log.info(f"Validating {arch_config_file}")
arch_schema_config = pkg_resources.resource_filename(
__name__, "../config/arch_config_schema.yaml"
)
@ -148,7 +168,7 @@ def up(file, path, service):
arch_config_schema_file=arch_schema_config,
)
except Exception as e:
print(f"Exiting archgw up: {e}")
log.info(f"Exiting archgw up: {e}")
sys.exit(1)
log.info("Starging arch model server and arch gateway")
@ -171,7 +191,7 @@ def up(file, path, service):
): # check to see if the environment variables in the current environment or not
for access_key in access_keys:
if env.get(access_key) is None:
print(f"Access Key: {access_key} not found. Exiting Start")
log.info(f"Access Key: {access_key} not found. Exiting Start")
sys.exit(1)
else:
env_stage[access_key] = env.get(access_key)
@ -179,7 +199,7 @@ def up(file, path, service):
env_file_dict = load_env_file_to_dict(app_env_file)
for access_key in access_keys:
if env_file_dict.get(access_key) is None:
print(f"Access Key: {access_key} not found. Exiting Start")
log.info(f"Access Key: {access_key} not found. Exiting Start")
sys.exit(1)
else:
env_stage[access_key] = env_file_dict[access_key]
@ -193,9 +213,11 @@ def up(file, path, service):
env.update(env_stage)
env["ARCH_CONFIG_FILE"] = arch_config_file
if service == "archgw":
if service == SERVICE_NAME_ARCHGW:
start_arch(arch_config_file, env)
else:
# this will used the cached versions of the models, so its safe to use everytime.
download_models_from_hf()
start_arch_modelserver()
start_arch(arch_config_file, env)
@ -203,18 +225,19 @@ def up(file, path, service):
@click.command()
@click.option(
"--service",
default="all",
default=SERVICE_ALL,
help="Service to down. Options are all, model_server, archgw. Default is all",
)
def down(service):
"""Stops Arch."""
if service not in ["all", "model_server", "archgw"]:
print(f"Error: Invalid service {service}. Exiting")
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
log.info(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
if service == "model_server":
if service == SERVICE_NAME_MODEL_SERVER:
stop_arch_modelserver()
elif service == "archgw":
elif service == SERVICE_NAME_ARCHGW:
stop_arch()
else:
stop_arch_modelserver()
@ -243,74 +266,72 @@ def generate_prompt_targets(file):
targets.generate_prompt_targets(file)
def stream_model_server_logs(follow):
log_file = "~/archgw_logs/modelserver.log"
log_file_expanded = os.path.expanduser(log_file)
stream_command = ["tail"]
if follow:
stream_command.append("-f")
stream_command.append(log_file_expanded)
subprocess.run(
stream_command,
check=True,
stdout=sys.stdout,
stderr=sys.stderr,
)
@click.command()
@click.option(
"--service",
default="all",
help="Service to monitor. By default it will monitor both gateway and model_serve",
default=SERVICE_ALL,
help="Service to monitor. By default it will monitor both core gateway and model_server logs.",
)
@click.option(
"--debug",
help="For detailed debug logs to trace calls from archgw <> model_server <> api_server, etc",
is_flag=True,
)
@click.option("--follow", help="Follow the logs", is_flag=True)
def logs(service, follow):
"""Stream logs from arch services."""
def logs(service, debug, follow):
"""Stream logs from access logs services."""
if service not in ["all", "model_server", "archgw"]:
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
archgw_process = None
if service == "archgw" or service == "all":
archgw_process = multiprocessing.Process(
target=stream_gateway_logs, args=(follow,)
)
archgw_process.start()
model_server_process = None
if service == "model_server" or service == "all":
model_server_process = multiprocessing.Process(
target=stream_model_server_logs, args=(follow,)
)
model_server_process.start()
if debug:
try:
archgw_process = None
if service == SERVICE_NAME_ARCHGW or service == SERVICE_ALL:
archgw_process = multiprocessing.Process(
target=stream_gateway_logs, args=(follow,)
)
archgw_process.start()
if archgw_process:
archgw_process.join()
if model_server_process:
model_server_process.join()
model_server_process = None
if service == SERVICE_NAME_MODEL_SERVER or service == SERVICE_ALL:
model_server_process = multiprocessing.Process(
target=stream_model_server_logs, args=(follow,)
)
model_server_process.start()
if archgw_process:
archgw_process.join()
if model_server_process:
model_server_process.join()
except KeyboardInterrupt:
log.info("KeyboardInterrupt detected. Exiting.")
if archgw_process and archgw_process.is_alive():
archgw_process.terminate()
model_list = [
"katanemo/Arch-Guard-cpu",
"katanemo/Arch-Guard",
"katanemo/bge-large-en-v1.5",
]
if model_server_process and model_server_process.is_alive():
model_server_process.terminate()
else:
try:
archgw_access_logs_process = None
archgw_access_logs_process = multiprocessing.Process(
target=stream_access_logs, args=(follow,)
)
archgw_access_logs_process.start()
@click.command()
def download_models():
"""Download required models from Hugging Face Hub in the cache directory"""
for model in model_list:
log.info(f"Downloading model: {model}")
snapshot_download(repo_id=model)
if archgw_access_logs_process:
archgw_access_logs_process.join()
except KeyboardInterrupt:
log.info("KeyboardInterrupt detected. Exiting.")
if archgw_access_logs_process.is_alive():
archgw_access_logs_process.terminate()
main.add_command(up)
main.add_command(down)
main.add_command(build)
main.add_command(logs)
main.add_command(download_models)
main.add_command(generate_prompt_targets)
if __name__ == "__main__":