mirror of
https://github.com/katanemo/plano.git
synced 2026-04-30 19:36:34 +02:00
Use intent model from archfc to pick prompt gateway (#328)
This commit is contained in:
parent
67b8fd635e
commit
ba7279becb
151 changed files with 8642 additions and 10932 deletions
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
import yaml
|
||||
|
|
@ -47,32 +48,27 @@ def validate_and_render_schema():
|
|||
config_schema_yaml = yaml.safe_load(arch_config_schema)
|
||||
inferred_clusters = {}
|
||||
|
||||
endpoints = config_yaml.get("endpoints", {})
|
||||
|
||||
# override the inferred clusters with the ones defined in the config
|
||||
for name, endpoint_details in endpoints.items():
|
||||
inferred_clusters[name] = endpoint_details
|
||||
endpoint = inferred_clusters[name]["endpoint"]
|
||||
if len(endpoint.split(":")) > 1:
|
||||
inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
|
||||
inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
|
||||
|
||||
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))
|
||||
|
||||
if "prompt_targets" in config_yaml:
|
||||
for prompt_target in config_yaml["prompt_targets"]:
|
||||
name = prompt_target.get("endpoint", {}).get("name", None)
|
||||
if not name:
|
||||
continue
|
||||
if name not in inferred_clusters:
|
||||
inferred_clusters[name] = {
|
||||
"name": name,
|
||||
"port": 80, # default port
|
||||
}
|
||||
|
||||
endpoints = config_yaml.get("endpoints", {})
|
||||
|
||||
# override the inferred clusters with the ones defined in the config
|
||||
for name, endpoint_details in endpoints.items():
|
||||
if name in inferred_clusters:
|
||||
print("updating cluster", endpoint_details)
|
||||
inferred_clusters[name].update(endpoint_details)
|
||||
endpoint = inferred_clusters[name]["endpoint"]
|
||||
if len(endpoint.split(":")) > 1:
|
||||
inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
|
||||
inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
|
||||
else:
|
||||
inferred_clusters[name] = endpoint_details
|
||||
|
||||
print("updated clusters", inferred_clusters)
|
||||
raise Exception(
|
||||
f"Unknown endpoint {name}, please add it in endpoints section in your arch_config.yaml file"
|
||||
)
|
||||
|
||||
arch_llm_providers = config_yaml["llm_providers"]
|
||||
arch_tracing = config_yaml.get("tracing", {})
|
||||
|
|
@ -90,6 +86,7 @@ def validate_and_render_schema():
|
|||
|
||||
rendered = template.render(data)
|
||||
print(ENVOY_CONFIG_FILE_RENDERED)
|
||||
print(rendered)
|
||||
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
|
||||
file.write(rendered)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
KATANEMO_DOCKERHUB_REPO = "katanemo/archgw"
|
||||
KATANEMO_LOCAL_MODEL_LIST = [
|
||||
"katanemo/Arch-Guard-cpu",
|
||||
"katanemo/Arch-Guard",
|
||||
"katanemo/bge-large-en-v1.5",
|
||||
]
|
||||
SERVICE_NAME_ARCHGW = "archgw"
|
||||
SERVICE_NAME_MODEL_SERVER = "model_server"
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ def start_archgw_docker(client, arch_config_file, env):
|
|||
},
|
||||
environment={
|
||||
"OTEL_TRACING_HTTP_ENDPOINT": "http://host.docker.internal:4318/v1/traces",
|
||||
"MODEL_SERVER_PORT": os.getenv("MODEL_SERVER_PORT", "51000"),
|
||||
**env,
|
||||
},
|
||||
extra_hosts={"host.docker.internal": "host-gateway"},
|
||||
|
|
@ -78,25 +79,6 @@ def stream_gateway_logs(follow):
|
|||
log.info(f"Failed to stream logs: {str(e)}")
|
||||
|
||||
|
||||
def stream_model_server_logs(follow):
|
||||
"""
|
||||
Get the model server logs, check if the user wants to follow/tail them.
|
||||
"""
|
||||
log_file_expanded = os.path.expanduser(MODEL_SERVER_LOG_FILE)
|
||||
|
||||
stream_command = ["tail"]
|
||||
if follow:
|
||||
stream_command.append("-f")
|
||||
|
||||
stream_command.append(log_file_expanded)
|
||||
subprocess.run(
|
||||
stream_command,
|
||||
check=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
def stream_access_logs(follow):
|
||||
"""
|
||||
Get the archgw access logs
|
||||
|
|
@ -117,7 +99,7 @@ def stream_access_logs(follow):
|
|||
)
|
||||
|
||||
|
||||
def start_arch(arch_config_file, env, log_timeout=120):
|
||||
def start_arch(arch_config_file, env, log_timeout=120, foreground=False):
|
||||
"""
|
||||
Start Docker Compose in detached mode and stream logs until services are healthy.
|
||||
|
||||
|
|
@ -130,6 +112,16 @@ def start_arch(arch_config_file, env, log_timeout=120):
|
|||
try:
|
||||
client = docker.from_env()
|
||||
|
||||
try:
|
||||
container = client.containers.get("archgw")
|
||||
log.info("archgw container found in docker, stopping and removing it")
|
||||
# ensure that previous docker container is stopped and removed
|
||||
container.stop()
|
||||
container.remove()
|
||||
log.info("Stopped and removed archgw container")
|
||||
except docker.errors.NotFound as e:
|
||||
pass
|
||||
|
||||
container = start_archgw_docker(client, arch_config_file, env)
|
||||
|
||||
start_time = time.time()
|
||||
|
|
@ -153,6 +145,13 @@ def start_arch(arch_config_file, env, log_timeout=120):
|
|||
log.info(f"Container health status: {container_status}")
|
||||
time.sleep(1)
|
||||
|
||||
if foreground:
|
||||
for line in container.logs(stream=True):
|
||||
print(line.decode("utf-8").strip("\n"))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
log.info("Keyboard interrupt received, stopping arch gateway service.")
|
||||
stop_arch()
|
||||
except docker.errors.APIError as e:
|
||||
log.info(f"Failed to start Arch: {str(e)}")
|
||||
|
||||
|
|
@ -186,17 +185,23 @@ def download_models_from_hf():
|
|||
snapshot_download(repo_id=model)
|
||||
|
||||
|
||||
def start_arch_modelserver():
|
||||
def start_arch_modelserver(foreground):
|
||||
"""
|
||||
Start the model server. This assumes that the archgw_modelserver package is installed locally
|
||||
|
||||
"""
|
||||
try:
|
||||
log.info("archgw_modelserver restart")
|
||||
subprocess.run(
|
||||
["archgw_modelserver", "restart"], check=True, start_new_session=True
|
||||
)
|
||||
log.info("Successfully ran model_server")
|
||||
if foreground:
|
||||
subprocess.run(
|
||||
["archgw_modelserver", "start", "--foreground"],
|
||||
check=True,
|
||||
)
|
||||
else:
|
||||
subprocess.run(
|
||||
["archgw_modelserver", "start"],
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
sys.exit(1)
|
||||
|
|
@ -212,7 +217,6 @@ def stop_arch_modelserver():
|
|||
["archgw_modelserver", "stop"],
|
||||
check=True,
|
||||
)
|
||||
log.info("Successfully stopped the archgw model_server")
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
sys.exit(1)
|
||||
|
|
|
|||
|
|
@ -16,10 +16,9 @@ from cli.core import (
|
|||
stop_arch_modelserver,
|
||||
start_arch,
|
||||
stop_arch,
|
||||
stream_gateway_logs,
|
||||
stream_model_server_logs,
|
||||
stream_access_logs,
|
||||
download_models_from_hf,
|
||||
stream_access_logs,
|
||||
stream_gateway_logs,
|
||||
)
|
||||
from cli.consts import (
|
||||
KATANEMO_DOCKERHUB_REPO,
|
||||
|
|
@ -138,16 +137,27 @@ def build(service):
|
|||
default=SERVICE_ALL,
|
||||
help="Service to start. Options are model_server, archgw.",
|
||||
)
|
||||
def up(file, path, service):
|
||||
@click.option(
|
||||
"--foreground",
|
||||
default=False,
|
||||
help="Run Arch in the foreground. Default is False",
|
||||
is_flag=True,
|
||||
)
|
||||
def up(file, path, service, foreground):
|
||||
"""Starts Arch."""
|
||||
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
|
||||
log.info(f"Error: Invalid service {service}. Exiting")
|
||||
sys.exit(1)
|
||||
|
||||
if service == SERVICE_ALL and foreground:
|
||||
# foreground can only be specified when starting individual services
|
||||
log.info("foreground flag is only supported for individual services. Exiting.")
|
||||
sys.exit(1)
|
||||
|
||||
if service == SERVICE_NAME_MODEL_SERVER:
|
||||
log.info("Download archgw models from HuggingFace...")
|
||||
download_models_from_hf()
|
||||
start_arch_modelserver()
|
||||
start_arch_modelserver(foreground)
|
||||
return
|
||||
|
||||
if file:
|
||||
|
|
@ -214,12 +224,11 @@ def up(file, path, service):
|
|||
env.update(env_stage)
|
||||
|
||||
if service == SERVICE_NAME_ARCHGW:
|
||||
start_arch(arch_config_file, env)
|
||||
start_arch(arch_config_file, env, foreground=foreground)
|
||||
else:
|
||||
# this will used the cached versions of the models, so its safe to use everytime.
|
||||
download_models_from_hf()
|
||||
start_arch_modelserver()
|
||||
start_arch(arch_config_file, env)
|
||||
start_arch_modelserver(foreground)
|
||||
start_arch(arch_config_file, env, foreground=foreground)
|
||||
|
||||
|
||||
@click.command()
|
||||
|
|
@ -267,65 +276,37 @@ def generate_prompt_targets(file):
|
|||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--service",
|
||||
default=SERVICE_ALL,
|
||||
help="Service to monitor. By default it will monitor both core gateway and model_server logs.",
|
||||
)
|
||||
@click.option(
|
||||
"--debug",
|
||||
help="For detailed debug logs to trace calls from archgw <> model_server <> api_server, etc",
|
||||
is_flag=True,
|
||||
)
|
||||
@click.option("--follow", help="Follow the logs", is_flag=True)
|
||||
def logs(service, debug, follow):
|
||||
def logs(debug, follow):
|
||||
"""Stream logs from access logs services."""
|
||||
|
||||
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
|
||||
print(f"Error: Invalid service {service}. Exiting")
|
||||
sys.exit(1)
|
||||
|
||||
if debug:
|
||||
try:
|
||||
archgw_process = None
|
||||
if service == SERVICE_NAME_ARCHGW or service == SERVICE_ALL:
|
||||
archgw_process = multiprocessing.Process(
|
||||
target=stream_gateway_logs, args=(follow,)
|
||||
)
|
||||
archgw_process.start()
|
||||
|
||||
model_server_process = None
|
||||
if service == SERVICE_NAME_MODEL_SERVER or service == SERVICE_ALL:
|
||||
model_server_process = multiprocessing.Process(
|
||||
target=stream_model_server_logs, args=(follow,)
|
||||
)
|
||||
model_server_process.start()
|
||||
|
||||
if archgw_process:
|
||||
archgw_process.join()
|
||||
if model_server_process:
|
||||
model_server_process.join()
|
||||
except KeyboardInterrupt:
|
||||
log.info("KeyboardInterrupt detected. Exiting.")
|
||||
if archgw_process and archgw_process.is_alive():
|
||||
archgw_process.terminate()
|
||||
|
||||
if model_server_process and model_server_process.is_alive():
|
||||
model_server_process.terminate()
|
||||
else:
|
||||
try:
|
||||
archgw_access_logs_process = None
|
||||
archgw_access_logs_process = multiprocessing.Process(
|
||||
target=stream_access_logs, args=(follow,)
|
||||
archgw_process = None
|
||||
try:
|
||||
if debug:
|
||||
archgw_process = multiprocessing.Process(
|
||||
target=stream_gateway_logs, args=(follow,)
|
||||
)
|
||||
archgw_access_logs_process.start()
|
||||
archgw_process.start()
|
||||
|
||||
if archgw_access_logs_process:
|
||||
archgw_access_logs_process.join()
|
||||
except KeyboardInterrupt:
|
||||
log.info("KeyboardInterrupt detected. Exiting.")
|
||||
if archgw_access_logs_process.is_alive():
|
||||
archgw_access_logs_process.terminate()
|
||||
archgw_access_logs_process = multiprocessing.Process(
|
||||
target=stream_access_logs, args=(follow,)
|
||||
)
|
||||
archgw_access_logs_process.start()
|
||||
archgw_access_logs_process.join()
|
||||
|
||||
if archgw_process:
|
||||
archgw_process.join()
|
||||
except KeyboardInterrupt:
|
||||
log.info("KeyboardInterrupt detected. Exiting.")
|
||||
if archgw_access_logs_process.is_alive():
|
||||
archgw_access_logs_process.terminate()
|
||||
if archgw_process and archgw_process.is_alive():
|
||||
archgw_process.terminate()
|
||||
|
||||
|
||||
main.add_command(up)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue