From 0e2f53426a12db062b868c45dfb4aad47cd3eba0 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Wed, 14 May 2025 17:15:42 -0700 Subject: [PATCH] fix rust tests --- arch/tools/cli/consts.py | 5 ++ arch/tools/cli/core.py | 50 +++++++++++++++--- arch/tools/cli/docker_cli.py | 53 ++++++++++++++++++-- arch/tools/cli/main.py | 64 ++++++++++++++++++++---- crates/brightstaff/src/main.rs | 14 ++++-- crates/llm_gateway/src/stream_context.rs | 4 +- crates/llm_gateway/tests/integration.rs | 1 - 7 files changed, 165 insertions(+), 26 deletions(-) diff --git a/arch/tools/cli/consts.py b/arch/tools/cli/consts.py index 23694fe4..47d638c2 100644 --- a/arch/tools/cli/consts.py +++ b/arch/tools/cli/consts.py @@ -6,9 +6,14 @@ KATANEMO_LOCAL_MODEL_LIST = [ "katanemo/Arch-Guard", ] SERVICE_NAME_ARCHGW = "archgw" +SERVICE_NAME_BRIGHTSTAFF = "brightstaff" SERVICE_NAME_MODEL_SERVER = "model_server" SERVICE_ALL = "all" MODEL_SERVER_LOG_FILE = "~/archgw_logs/modelserver.log" ACCESS_LOG_FILES = "~/archgw_logs/access*" ARCHGW_DOCKER_NAME = "archgw" +BRIGHTSTAFF_DOCKER_NAME = "brightstaff" ARCHGW_DOCKER_IMAGE = os.getenv("ARCHGW_DOCKER_IMAGE", "katanemo/archgw:0.2.8") +BRIGHTSTAFF_DOCKER_IMAGE = os.getenv( + "BRIGHTSTAFF_DOCKER_IMAGE", "katanemo/archgw:brightstaff_0.2.8" +) diff --git a/arch/tools/cli/core.py b/arch/tools/cli/core.py index b0a6e58c..c32d16c2 100644 --- a/arch/tools/cli/core.py +++ b/arch/tools/cli/core.py @@ -7,6 +7,7 @@ import yaml from cli.utils import getLogger from cli.consts import ( ARCHGW_DOCKER_NAME, + BRIGHTSTAFF_DOCKER_NAME, KATANEMO_LOCAL_MODEL_LIST, ) from huggingface_hub import snapshot_download @@ -15,6 +16,7 @@ from cli.docker_cli import ( docker_container_status, docker_remove_container, docker_start_archgw_detached, + docker_start_brightstaff_detached, docker_stop_container, health_check_endpoint, stream_gateway_logs, @@ -109,27 +111,63 @@ def start_arch(arch_config_file, env, log_timeout=120, foreground=False): except KeyboardInterrupt: log.info("Keyboard interrupt received, stopping arch gateway service.") - stop_arch() + stop_docker_container() -def stop_arch(): +def start_brightstaff(arch_config_file, env, log_timeout=120, foreground=False): + """ + Start Docker Compose in detached mode and stream logs until services are healthy. + + Args: + path (str): The path where the prompt_config.yml file is located. + log_timeout (int): Time in seconds to show logs before checking for healthy state. + """ + log.info("Starting brightstaff") + + try: + brightstaff_container_status = docker_container_status(BRIGHTSTAFF_DOCKER_NAME) + if brightstaff_container_status != "not found": + log.info( + f"brightstaff found in docker, stopping and removing it: status: {brightstaff_container_status}" + ) + docker_stop_container(BRIGHTSTAFF_DOCKER_NAME) + docker_remove_container(BRIGHTSTAFF_DOCKER_NAME) + + return_code, _, brightstaff_stderr = docker_start_brightstaff_detached( + arch_config_file, + env, + ) + if return_code != 0: + log.info("Failed to start brightstaff: " + str(return_code)) + log.info("stderr: " + brightstaff_stderr) + sys.exit(1) + + if foreground: + stream_gateway_logs(follow=True, service="brightstaff") + + except KeyboardInterrupt: + log.info("Keyboard interrupt received, stopping arch gateway service.") + stop_docker_container(service=BRIGHTSTAFF_DOCKER_NAME) + + +def stop_docker_container(service=ARCHGW_DOCKER_NAME): """ Shutdown all Docker Compose services by running `docker-compose down`. Args: path (str): The path where the docker-compose.yml file is located. """ - log.info("Shutting down arch gateway service.") + log.info(f"Shutting down {service} service.") try: subprocess.run( - ["docker", "stop", ARCHGW_DOCKER_NAME], + ["docker", "stop", service], ) subprocess.run( - ["docker", "rm", ARCHGW_DOCKER_NAME], + ["docker", "rm", service], ) - log.info("Successfully shut down arch gateway service.") + log.info(f"Successfully shut down {service} service.") except subprocess.CalledProcessError as e: log.info(f"Failed to shut down services: {str(e)}") diff --git a/arch/tools/cli/docker_cli.py b/arch/tools/cli/docker_cli.py index d12354eb..b17c498c 100644 --- a/arch/tools/cli/docker_cli.py +++ b/arch/tools/cli/docker_cli.py @@ -3,7 +3,12 @@ import json import sys import requests -from cli.consts import ARCHGW_DOCKER_IMAGE, ARCHGW_DOCKER_NAME +from cli.consts import ( + ARCHGW_DOCKER_IMAGE, + ARCHGW_DOCKER_NAME, + BRIGHTSTAFF_DOCKER_IMAGE, + BRIGHTSTAFF_DOCKER_NAME, +) from cli.utils import getLogger log = getLogger(__name__) @@ -81,6 +86,48 @@ def docker_start_archgw_detached( return result.returncode, result.stdout, result.stderr +def docker_start_brightstaff_detached( + arch_config_file: str, + env: dict, +) -> str: + env_args = [item for key, value in env.items() for item in ["-e", f"{key}={value}"]] + + port_mappings = [ + "9091:9091", + ] + port_mappings_args = [item for port in port_mappings for item in ("-p", port)] + + volume_mappings = [ + f"{arch_config_file}:/app/arch_config.yaml:ro", + ] + volume_mappings_args = [ + item for volume in volume_mappings for item in ("-v", volume) + ] + + llm_provider_endpoint = env.get( + "LLM_PROVIDER_ENDPOINT", "http://host.docker.internal:12000/v1/chat/completions" + ) + + options = [ + "docker", + "run", + "-d", + "--name", + BRIGHTSTAFF_DOCKER_NAME, + *port_mappings_args, + *volume_mappings_args, + *env_args, + "-e", + f"LLM_PROVIDER_ENDPOINT={llm_provider_endpoint}", + "--add-host", + "host.docker.internal:host-gateway", + BRIGHTSTAFF_DOCKER_IMAGE, + ] + + result = subprocess.run(options, capture_output=True, text=True, check=False) + return result.returncode, result.stdout, result.stderr + + def health_check_endpoint(endpoint: str) -> bool: try: response = requests.get(endpoint) @@ -91,7 +138,7 @@ def health_check_endpoint(endpoint: str) -> bool: return False -def stream_gateway_logs(follow): +def stream_gateway_logs(follow, service="archgw"): """ Stream logs from the arch gateway service. """ @@ -100,7 +147,7 @@ def stream_gateway_logs(follow): options = ["docker", "logs"] if follow: options.append("-f") - options.append(ARCHGW_DOCKER_NAME) + options.append(service) try: # Run `docker-compose logs` to stream logs from the gateway service subprocess.run( diff --git a/arch/tools/cli/main.py b/arch/tools/cli/main.py index 6541b51a..b86e1d03 100644 --- a/arch/tools/cli/main.py +++ b/arch/tools/cli/main.py @@ -14,15 +14,18 @@ from cli.utils import ( ) from cli.core import ( start_arch_modelserver, + start_brightstaff, stop_arch_modelserver, start_arch, - stop_arch, + stop_docker_container, download_models_from_hf, ) from cli.consts import ( ARCHGW_DOCKER_IMAGE, + BRIGHTSTAFF_DOCKER_IMAGE, KATANEMO_DOCKERHUB_REPO, SERVICE_NAME_ARCHGW, + SERVICE_NAME_BRIGHTSTAFF, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL, ) @@ -40,6 +43,7 @@ logo = r""" # Command to build archgw and model_server Docker images ARCHGW_DOCKERFILE = "./arch/Dockerfile" +BRIGHTSTAFF_DOCKERFILE = "./arch/Dockerfile.brightstaff" MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml" @@ -51,6 +55,19 @@ def get_version(): return "version not found" +def verify_service_name(service): + """Verify if the service name is valid.""" + if service not in [ + SERVICE_NAME_ARCHGW, + SERVICE_NAME_MODEL_SERVER, + SERVICE_NAME_BRIGHTSTAFF, + SERVICE_ALL, + ]: + print(f"Error: Invalid service {service}. Exiting") + sys.exit(1) + return True + + @click.group(invoke_without_command=True) @click.option("--version", is_flag=True, help="Show the archgw cli version and exit.") @click.pass_context @@ -75,9 +92,8 @@ def main(ctx, version): ) def build(service): """Build Arch from source. Must be in root of cloned repo.""" - if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]: - print(f"Error: Invalid service {service}. Exiting") - sys.exit(1) + verify_service_name(service) + # Check if /arch/Dockerfile exists if service == SERVICE_NAME_ARCHGW or service == SERVICE_ALL: if os.path.exists(ARCHGW_DOCKERFILE): @@ -108,6 +124,35 @@ def build(service): click.echo("archgw image built successfully.") + if service == SERVICE_NAME_BRIGHTSTAFF or service == SERVICE_ALL: + if os.path.exists(BRIGHTSTAFF_DOCKERFILE): + click.echo("Building brightstaff image...") + try: + subprocess.run( + [ + "docker", + "build", + "-f", + BRIGHTSTAFF_DOCKERFILE, + "-t", + f"{KATANEMO_DOCKERHUB_REPO}:brightstaff_latest", + "-t", + f"{BRIGHTSTAFF_DOCKER_IMAGE}", + ".", + "--add-host=host.docker.internal:host-gateway", + ], + check=True, + ) + click.echo("brightstaff image built successfully.") + except subprocess.CalledProcessError as e: + click.echo(f"Error building brightstaff image: {e}") + sys.exit(1) + else: + click.echo("Error: Dockerfile not found in /arch") + sys.exit(1) + + click.echo("brightstaff image built successfully.") + """Install the model server dependencies using Poetry.""" if service == SERVICE_NAME_MODEL_SERVER or service == SERVICE_ALL: # Check if pyproject.toml exists @@ -146,9 +191,7 @@ def build(service): ) def up(file, path, service, foreground): """Starts Arch.""" - if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]: - log.info(f"Error: Invalid service {service}. Exiting") - sys.exit(1) + verify_service_name(service) if service == SERVICE_ALL and foreground: # foreground can only be specified when starting individual services @@ -233,10 +276,13 @@ def up(file, path, service, foreground): if service == SERVICE_NAME_ARCHGW: start_arch(arch_config_file, env, foreground=foreground) + if service == SERVICE_NAME_BRIGHTSTAFF: + start_brightstaff(arch_config_file, env, foreground=foreground) else: download_models_from_hf() start_arch_modelserver(foreground) start_arch(arch_config_file, env, foreground=foreground) + start_brightstaff(arch_config_file, env, foreground=foreground) @click.command() @@ -255,10 +301,10 @@ def down(service): if service == SERVICE_NAME_MODEL_SERVER: stop_arch_modelserver() elif service == SERVICE_NAME_ARCHGW: - stop_arch() + stop_docker_container() else: stop_arch_modelserver() - stop_arch() + stop_docker_container() @click.command() diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 9bd51b93..02ed0909 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -72,7 +72,7 @@ async fn main() -> Result<(), Box> { //loading arch_config.yaml file let arch_config_path = - env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "arch_config.yaml".to_string()); + env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "./arch_config.yaml".to_string()); info!("Loading arch_config.yaml from {}", arch_config_path); let config_contents = @@ -88,14 +88,16 @@ async fn main() -> Result<(), Box> { shorten_string(&serde_json::to_string(arch_config.as_ref()).unwrap()) ); + let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT") + .unwrap_or_else(|_| "http://localhost:12000/v1/chat/completions".to_string()); + + info!("llm provider endpoint: {}", llm_provider_endpoint); info!("Listening on http://{}", bind_address); let listener = TcpListener::bind(bind_address).await?; - let llm_provider_endpoint = "http://localhost:12000/v1/chat/completions"; - let router_service: Arc = Arc::new(RouterService::new( arch_config.llm_providers.clone(), - llm_provider_endpoint.to_string(), + llm_provider_endpoint.clone(), arch_config.routing.as_ref().unwrap().model.clone(), )); @@ -105,6 +107,7 @@ async fn main() -> Result<(), Box> { let io = TokioIo::new(stream); let router_service = Arc::clone(&router_service); + let llm_provider_endpoint = llm_provider_endpoint.clone(); let service = service_fn(move |req| { let router_service = Arc::clone(&router_service); @@ -115,11 +118,12 @@ async fn main() -> Result<(), Box> { .span_builder("router_service") .with_kind(SpanKind::Server) .start_with_context(tracer, &parent_cx); + let llm_provider_endpoint = llm_provider_endpoint.clone(); async move { match (req.method(), req.uri().path()) { (&Method::POST, "/v1/chat/completions") => { - chat_completion(req, router_service, llm_provider_endpoint.to_string()) + chat_completion(req, router_service, llm_provider_endpoint) .with_context(parent_cx) .await } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 10192984..aa77dd34 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -327,10 +327,10 @@ impl HttpContext for StreamContext { let model_requested = deserialized_body.model.clone(); info!( - "on_http_request_body: provider: {}, model requested: {}, model selected: {:?}", + "on_http_request_body: provider: {}, model requested: {}, model selected: {}", self.llm_provider().name, model_requested, - self.llm_provider().model, + self.llm_provider().model.as_ref().unwrap_or(&String::new()) ); deserialized_body.model = self.llm_provider().model.clone().unwrap(); diff --git a/crates/llm_gateway/tests/integration.rs b/crates/llm_gateway/tests/integration.rs index 0f94d2d8..ccd4bb4c 100644 --- a/crates/llm_gateway/tests/integration.rs +++ b/crates/llm_gateway/tests/integration.rs @@ -489,7 +489,6 @@ fn llm_gateway_override_model_name() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_metric_record("input_sequence_length", 29) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)