Update arch_config and add tests for arch config file (#407)

This commit is contained in:
Adil Hafeez 2025-02-14 19:28:10 -08:00 committed by GitHub
parent d0a783cca8
commit e40b13be05
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 379 additions and 212 deletions

View file

@ -104,7 +104,27 @@ def validate_and_render_schema():
arch_config_string = yaml.dump(config_yaml)
arch_llm_config_string = yaml.dump(config_yaml)
prompt_gateway_listener = config_yaml.get("listeners", {}).get(
"ingress_traffic", {}
)
if prompt_gateway_listener.get("port") == None:
prompt_gateway_listener["port"] = 10000 # default port for prompt gateway
if prompt_gateway_listener.get("address") == None:
prompt_gateway_listener["address"] = "127.0.0.1"
if prompt_gateway_listener.get("timeout") == None:
prompt_gateway_listener["timeout"] = "10s"
llm_gateway_listener = config_yaml.get("listeners", {}).get("egress_traffic", {})
if llm_gateway_listener.get("port") == None:
llm_gateway_listener["port"] = 12000 # default port for llm gateway
if llm_gateway_listener.get("address") == None:
llm_gateway_listener["address"] = "127.0.0.1"
if llm_gateway_listener.get("timeout") == None:
llm_gateway_listener["timeout"] = "10s"
data = {
"prompt_gateway_listener": prompt_gateway_listener,
"llm_gateway_listener": llm_gateway_listener,
"arch_config": arch_config_string,
"arch_llm_config": arch_llm_config_string,
"arch_clusters": inferred_clusters,

View file

@ -2,6 +2,8 @@ import subprocess
import os
import time
import sys
import yaml
from cli.utils import getLogger
from cli.consts import (
ARCHGW_DOCKER_NAME,
@ -22,6 +24,29 @@ from cli.docker_cli import (
log = getLogger(__name__)
def _get_gateway_ports(arch_config_file: str) -> tuple:
PROMPT_GATEWAY_DEFAULT_PORT = 10000
LLM_GATEWAY_DEFAULT_PORT = 12000
# parse arch_config_file yaml file and get prompt_gateway_port
arch_config_dict = {}
with open(arch_config_file) as f:
arch_config_dict = yaml.safe_load(f)
prompt_gateway_port = (
arch_config_dict.get("listeners", {})
.get("ingress_traffic", {})
.get("port", PROMPT_GATEWAY_DEFAULT_PORT)
)
llm_gateway_port = (
arch_config_dict.get("listeners", {})
.get("egress_traffic", {})
.get("port", LLM_GATEWAY_DEFAULT_PORT)
)
return prompt_gateway_port, llm_gateway_port
def start_arch(arch_config_file, env, log_timeout=120, foreground=False):
"""
Start Docker Compose in detached mode and stream logs until services are healthy.
@ -39,8 +64,14 @@ def start_arch(arch_config_file, env, log_timeout=120, foreground=False):
docker_stop_container(ARCHGW_DOCKER_NAME)
docker_remove_container(ARCHGW_DOCKER_NAME)
prompt_gateway_port, llm_gateway_port = _get_gateway_ports(arch_config_file)
return_code, _, archgw_stderr = docker_start_archgw_detached(
arch_config_file, os.path.expanduser("~/archgw_logs"), env
arch_config_file,
os.path.expanduser("~/archgw_logs"),
env,
prompt_gateway_port,
llm_gateway_port,
)
if return_code != 0:
log.info("Failed to start arch gateway: " + str(return_code))
@ -50,7 +81,7 @@ def start_arch(arch_config_file, env, log_timeout=120, foreground=False):
start_time = time.time()
while True:
health_check_status = health_check_endpoint(
"http://localhost:10000/healthz"
f"http://localhost:{prompt_gateway_port}/healthz"
)
archgw_status = docker_container_status(ARCHGW_DOCKER_NAME)
current_time = time.time()

View file

@ -1,7 +1,7 @@
import subprocess
import json
import sys
import requests # Add this import
import requests
from cli.consts import ARCHGW_DOCKER_IMAGE, ARCHGW_DOCKER_NAME
from cli.utils import getLogger
@ -33,11 +33,19 @@ def docker_remove_container(container: str) -> str:
def docker_start_archgw_detached(
arch_config_file: str, logs_path_abs: str, env: dict
arch_config_file: str,
logs_path_abs: str,
env: dict,
prompt_gateway_port,
llm_gateway_port,
) -> str:
env_args = [item for key, value in env.items() for item in ["-e", f"{key}={value}"]]
port_mappings = ["10000:10000", "12000:12000", "9901:19901"]
port_mappings = [
f"{prompt_gateway_port}:{prompt_gateway_port}",
f"{llm_gateway_port}:{llm_gateway_port}",
"9901:19901",
]
port_mappings_args = [item for port in port_mappings for item in ("-p", port)]
volume_mappings = [