Use intent model from archfc to pick prompt gateway (#328)

This commit is contained in:
Shuguang Chen 2024-12-20 13:25:01 -08:00 committed by GitHub
parent 67b8fd635e
commit ba7279becb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
151 changed files with 8642 additions and 10932 deletions

View file

@ -1,3 +1,4 @@
import json
import os
from jinja2 import Environment, FileSystemLoader
import yaml
@ -47,32 +48,27 @@ def validate_and_render_schema():
config_schema_yaml = yaml.safe_load(arch_config_schema)
inferred_clusters = {}
endpoints = config_yaml.get("endpoints", {})
# override the inferred clusters with the ones defined in the config
for name, endpoint_details in endpoints.items():
inferred_clusters[name] = endpoint_details
endpoint = inferred_clusters[name]["endpoint"]
if len(endpoint.split(":")) > 1:
inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))
if "prompt_targets" in config_yaml:
for prompt_target in config_yaml["prompt_targets"]:
name = prompt_target.get("endpoint", {}).get("name", None)
if not name:
continue
if name not in inferred_clusters:
inferred_clusters[name] = {
"name": name,
"port": 80, # default port
}
endpoints = config_yaml.get("endpoints", {})
# override the inferred clusters with the ones defined in the config
for name, endpoint_details in endpoints.items():
if name in inferred_clusters:
print("updating cluster", endpoint_details)
inferred_clusters[name].update(endpoint_details)
endpoint = inferred_clusters[name]["endpoint"]
if len(endpoint.split(":")) > 1:
inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
else:
inferred_clusters[name] = endpoint_details
print("updated clusters", inferred_clusters)
raise Exception(
f"Unknown endpoint {name}, please add it in endpoints section in your arch_config.yaml file"
)
arch_llm_providers = config_yaml["llm_providers"]
arch_tracing = config_yaml.get("tracing", {})
@ -90,6 +86,7 @@ def validate_and_render_schema():
rendered = template.render(data)
print(ENVOY_CONFIG_FILE_RENDERED)
print(rendered)
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
file.write(rendered)

View file

@ -1,8 +1,6 @@
KATANEMO_DOCKERHUB_REPO = "katanemo/archgw"
KATANEMO_LOCAL_MODEL_LIST = [
"katanemo/Arch-Guard-cpu",
"katanemo/Arch-Guard",
"katanemo/bge-large-en-v1.5",
]
SERVICE_NAME_ARCHGW = "archgw"
SERVICE_NAME_MODEL_SERVER = "model_server"

View file

@ -44,6 +44,7 @@ def start_archgw_docker(client, arch_config_file, env):
},
environment={
"OTEL_TRACING_HTTP_ENDPOINT": "http://host.docker.internal:4318/v1/traces",
"MODEL_SERVER_PORT": os.getenv("MODEL_SERVER_PORT", "51000"),
**env,
},
extra_hosts={"host.docker.internal": "host-gateway"},
@ -78,25 +79,6 @@ def stream_gateway_logs(follow):
log.info(f"Failed to stream logs: {str(e)}")
def stream_model_server_logs(follow):
"""
Get the model server logs, check if the user wants to follow/tail them.
"""
log_file_expanded = os.path.expanduser(MODEL_SERVER_LOG_FILE)
stream_command = ["tail"]
if follow:
stream_command.append("-f")
stream_command.append(log_file_expanded)
subprocess.run(
stream_command,
check=True,
stdout=sys.stdout,
stderr=sys.stderr,
)
def stream_access_logs(follow):
"""
Get the archgw access logs
@ -117,7 +99,7 @@ def stream_access_logs(follow):
)
def start_arch(arch_config_file, env, log_timeout=120):
def start_arch(arch_config_file, env, log_timeout=120, foreground=False):
"""
Start Docker Compose in detached mode and stream logs until services are healthy.
@ -130,6 +112,16 @@ def start_arch(arch_config_file, env, log_timeout=120):
try:
client = docker.from_env()
try:
container = client.containers.get("archgw")
log.info("archgw container found in docker, stopping and removing it")
# ensure that previous docker container is stopped and removed
container.stop()
container.remove()
log.info("Stopped and removed archgw container")
except docker.errors.NotFound as e:
pass
container = start_archgw_docker(client, arch_config_file, env)
start_time = time.time()
@ -153,6 +145,13 @@ def start_arch(arch_config_file, env, log_timeout=120):
log.info(f"Container health status: {container_status}")
time.sleep(1)
if foreground:
for line in container.logs(stream=True):
print(line.decode("utf-8").strip("\n"))
except KeyboardInterrupt:
log.info("Keyboard interrupt received, stopping arch gateway service.")
stop_arch()
except docker.errors.APIError as e:
log.info(f"Failed to start Arch: {str(e)}")
@ -186,17 +185,23 @@ def download_models_from_hf():
snapshot_download(repo_id=model)
def start_arch_modelserver():
def start_arch_modelserver(foreground):
"""
Start the model server. This assumes that the archgw_modelserver package is installed locally
"""
try:
log.info("archgw_modelserver restart")
subprocess.run(
["archgw_modelserver", "restart"], check=True, start_new_session=True
)
log.info("Successfully ran model_server")
if foreground:
subprocess.run(
["archgw_modelserver", "start", "--foreground"],
check=True,
)
else:
subprocess.run(
["archgw_modelserver", "start"],
check=True,
)
except subprocess.CalledProcessError as e:
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1)
@ -212,7 +217,6 @@ def stop_arch_modelserver():
["archgw_modelserver", "stop"],
check=True,
)
log.info("Successfully stopped the archgw model_server")
except subprocess.CalledProcessError as e:
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1)

View file

@ -16,10 +16,9 @@ from cli.core import (
stop_arch_modelserver,
start_arch,
stop_arch,
stream_gateway_logs,
stream_model_server_logs,
stream_access_logs,
download_models_from_hf,
stream_access_logs,
stream_gateway_logs,
)
from cli.consts import (
KATANEMO_DOCKERHUB_REPO,
@ -138,16 +137,27 @@ def build(service):
default=SERVICE_ALL,
help="Service to start. Options are model_server, archgw.",
)
def up(file, path, service):
@click.option(
"--foreground",
default=False,
help="Run Arch in the foreground. Default is False",
is_flag=True,
)
def up(file, path, service, foreground):
"""Starts Arch."""
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
log.info(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
if service == SERVICE_ALL and foreground:
# foreground can only be specified when starting individual services
log.info("foreground flag is only supported for individual services. Exiting.")
sys.exit(1)
if service == SERVICE_NAME_MODEL_SERVER:
log.info("Download archgw models from HuggingFace...")
download_models_from_hf()
start_arch_modelserver()
start_arch_modelserver(foreground)
return
if file:
@ -214,12 +224,11 @@ def up(file, path, service):
env.update(env_stage)
if service == SERVICE_NAME_ARCHGW:
start_arch(arch_config_file, env)
start_arch(arch_config_file, env, foreground=foreground)
else:
# this will used the cached versions of the models, so its safe to use everytime.
download_models_from_hf()
start_arch_modelserver()
start_arch(arch_config_file, env)
start_arch_modelserver(foreground)
start_arch(arch_config_file, env, foreground=foreground)
@click.command()
@ -267,65 +276,37 @@ def generate_prompt_targets(file):
@click.command()
@click.option(
"--service",
default=SERVICE_ALL,
help="Service to monitor. By default it will monitor both core gateway and model_server logs.",
)
@click.option(
"--debug",
help="For detailed debug logs to trace calls from archgw <> model_server <> api_server, etc",
is_flag=True,
)
@click.option("--follow", help="Follow the logs", is_flag=True)
def logs(service, debug, follow):
def logs(debug, follow):
"""Stream logs from access logs services."""
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
print(f"Error: Invalid service {service}. Exiting")
sys.exit(1)
if debug:
try:
archgw_process = None
if service == SERVICE_NAME_ARCHGW or service == SERVICE_ALL:
archgw_process = multiprocessing.Process(
target=stream_gateway_logs, args=(follow,)
)
archgw_process.start()
model_server_process = None
if service == SERVICE_NAME_MODEL_SERVER or service == SERVICE_ALL:
model_server_process = multiprocessing.Process(
target=stream_model_server_logs, args=(follow,)
)
model_server_process.start()
if archgw_process:
archgw_process.join()
if model_server_process:
model_server_process.join()
except KeyboardInterrupt:
log.info("KeyboardInterrupt detected. Exiting.")
if archgw_process and archgw_process.is_alive():
archgw_process.terminate()
if model_server_process and model_server_process.is_alive():
model_server_process.terminate()
else:
try:
archgw_access_logs_process = None
archgw_access_logs_process = multiprocessing.Process(
target=stream_access_logs, args=(follow,)
archgw_process = None
try:
if debug:
archgw_process = multiprocessing.Process(
target=stream_gateway_logs, args=(follow,)
)
archgw_access_logs_process.start()
archgw_process.start()
if archgw_access_logs_process:
archgw_access_logs_process.join()
except KeyboardInterrupt:
log.info("KeyboardInterrupt detected. Exiting.")
if archgw_access_logs_process.is_alive():
archgw_access_logs_process.terminate()
archgw_access_logs_process = multiprocessing.Process(
target=stream_access_logs, args=(follow,)
)
archgw_access_logs_process.start()
archgw_access_logs_process.join()
if archgw_process:
archgw_process.join()
except KeyboardInterrupt:
log.info("KeyboardInterrupt detected. Exiting.")
if archgw_access_logs_process.is_alive():
archgw_access_logs_process.terminate()
if archgw_process and archgw_process.is_alive():
archgw_process.terminate()
main.add_command(up)