mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fix rust tests
This commit is contained in:
parent
f60cac27f4
commit
0e2f53426a
7 changed files with 165 additions and 26 deletions
|
|
@ -6,9 +6,14 @@ KATANEMO_LOCAL_MODEL_LIST = [
|
|||
"katanemo/Arch-Guard",
|
||||
]
|
||||
SERVICE_NAME_ARCHGW = "archgw"
|
||||
SERVICE_NAME_BRIGHTSTAFF = "brightstaff"
|
||||
SERVICE_NAME_MODEL_SERVER = "model_server"
|
||||
SERVICE_ALL = "all"
|
||||
MODEL_SERVER_LOG_FILE = "~/archgw_logs/modelserver.log"
|
||||
ACCESS_LOG_FILES = "~/archgw_logs/access*"
|
||||
ARCHGW_DOCKER_NAME = "archgw"
|
||||
BRIGHTSTAFF_DOCKER_NAME = "brightstaff"
|
||||
ARCHGW_DOCKER_IMAGE = os.getenv("ARCHGW_DOCKER_IMAGE", "katanemo/archgw:0.2.8")
|
||||
BRIGHTSTAFF_DOCKER_IMAGE = os.getenv(
|
||||
"BRIGHTSTAFF_DOCKER_IMAGE", "katanemo/archgw:brightstaff_0.2.8"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import yaml
|
|||
from cli.utils import getLogger
|
||||
from cli.consts import (
|
||||
ARCHGW_DOCKER_NAME,
|
||||
BRIGHTSTAFF_DOCKER_NAME,
|
||||
KATANEMO_LOCAL_MODEL_LIST,
|
||||
)
|
||||
from huggingface_hub import snapshot_download
|
||||
|
|
@ -15,6 +16,7 @@ from cli.docker_cli import (
|
|||
docker_container_status,
|
||||
docker_remove_container,
|
||||
docker_start_archgw_detached,
|
||||
docker_start_brightstaff_detached,
|
||||
docker_stop_container,
|
||||
health_check_endpoint,
|
||||
stream_gateway_logs,
|
||||
|
|
@ -109,27 +111,63 @@ def start_arch(arch_config_file, env, log_timeout=120, foreground=False):
|
|||
|
||||
except KeyboardInterrupt:
|
||||
log.info("Keyboard interrupt received, stopping arch gateway service.")
|
||||
stop_arch()
|
||||
stop_docker_container()
|
||||
|
||||
|
||||
def stop_arch():
|
||||
def start_brightstaff(arch_config_file, env, log_timeout=120, foreground=False):
|
||||
"""
|
||||
Start Docker Compose in detached mode and stream logs until services are healthy.
|
||||
|
||||
Args:
|
||||
path (str): The path where the prompt_config.yml file is located.
|
||||
log_timeout (int): Time in seconds to show logs before checking for healthy state.
|
||||
"""
|
||||
log.info("Starting brightstaff")
|
||||
|
||||
try:
|
||||
brightstaff_container_status = docker_container_status(BRIGHTSTAFF_DOCKER_NAME)
|
||||
if brightstaff_container_status != "not found":
|
||||
log.info(
|
||||
f"brightstaff found in docker, stopping and removing it: status: {brightstaff_container_status}"
|
||||
)
|
||||
docker_stop_container(BRIGHTSTAFF_DOCKER_NAME)
|
||||
docker_remove_container(BRIGHTSTAFF_DOCKER_NAME)
|
||||
|
||||
return_code, _, brightstaff_stderr = docker_start_brightstaff_detached(
|
||||
arch_config_file,
|
||||
env,
|
||||
)
|
||||
if return_code != 0:
|
||||
log.info("Failed to start brightstaff: " + str(return_code))
|
||||
log.info("stderr: " + brightstaff_stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if foreground:
|
||||
stream_gateway_logs(follow=True, service="brightstaff")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
log.info("Keyboard interrupt received, stopping arch gateway service.")
|
||||
stop_docker_container(service=BRIGHTSTAFF_DOCKER_NAME)
|
||||
|
||||
|
||||
def stop_docker_container(service=ARCHGW_DOCKER_NAME):
|
||||
"""
|
||||
Shutdown all Docker Compose services by running `docker-compose down`.
|
||||
|
||||
Args:
|
||||
path (str): The path where the docker-compose.yml file is located.
|
||||
"""
|
||||
log.info("Shutting down arch gateway service.")
|
||||
log.info(f"Shutting down {service} service.")
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "stop", ARCHGW_DOCKER_NAME],
|
||||
["docker", "stop", service],
|
||||
)
|
||||
subprocess.run(
|
||||
["docker", "rm", ARCHGW_DOCKER_NAME],
|
||||
["docker", "rm", service],
|
||||
)
|
||||
|
||||
log.info("Successfully shut down arch gateway service.")
|
||||
log.info(f"Successfully shut down {service} service.")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to shut down services: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -3,7 +3,12 @@ import json
|
|||
import sys
|
||||
import requests
|
||||
|
||||
from cli.consts import ARCHGW_DOCKER_IMAGE, ARCHGW_DOCKER_NAME
|
||||
from cli.consts import (
|
||||
ARCHGW_DOCKER_IMAGE,
|
||||
ARCHGW_DOCKER_NAME,
|
||||
BRIGHTSTAFF_DOCKER_IMAGE,
|
||||
BRIGHTSTAFF_DOCKER_NAME,
|
||||
)
|
||||
from cli.utils import getLogger
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
|
@ -81,6 +86,48 @@ def docker_start_archgw_detached(
|
|||
return result.returncode, result.stdout, result.stderr
|
||||
|
||||
|
||||
def docker_start_brightstaff_detached(
|
||||
arch_config_file: str,
|
||||
env: dict,
|
||||
) -> str:
|
||||
env_args = [item for key, value in env.items() for item in ["-e", f"{key}={value}"]]
|
||||
|
||||
port_mappings = [
|
||||
"9091:9091",
|
||||
]
|
||||
port_mappings_args = [item for port in port_mappings for item in ("-p", port)]
|
||||
|
||||
volume_mappings = [
|
||||
f"{arch_config_file}:/app/arch_config.yaml:ro",
|
||||
]
|
||||
volume_mappings_args = [
|
||||
item for volume in volume_mappings for item in ("-v", volume)
|
||||
]
|
||||
|
||||
llm_provider_endpoint = env.get(
|
||||
"LLM_PROVIDER_ENDPOINT", "http://host.docker.internal:12000/v1/chat/completions"
|
||||
)
|
||||
|
||||
options = [
|
||||
"docker",
|
||||
"run",
|
||||
"-d",
|
||||
"--name",
|
||||
BRIGHTSTAFF_DOCKER_NAME,
|
||||
*port_mappings_args,
|
||||
*volume_mappings_args,
|
||||
*env_args,
|
||||
"-e",
|
||||
f"LLM_PROVIDER_ENDPOINT={llm_provider_endpoint}",
|
||||
"--add-host",
|
||||
"host.docker.internal:host-gateway",
|
||||
BRIGHTSTAFF_DOCKER_IMAGE,
|
||||
]
|
||||
|
||||
result = subprocess.run(options, capture_output=True, text=True, check=False)
|
||||
return result.returncode, result.stdout, result.stderr
|
||||
|
||||
|
||||
def health_check_endpoint(endpoint: str) -> bool:
|
||||
try:
|
||||
response = requests.get(endpoint)
|
||||
|
|
@ -91,7 +138,7 @@ def health_check_endpoint(endpoint: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def stream_gateway_logs(follow):
|
||||
def stream_gateway_logs(follow, service="archgw"):
|
||||
"""
|
||||
Stream logs from the arch gateway service.
|
||||
"""
|
||||
|
|
@ -100,7 +147,7 @@ def stream_gateway_logs(follow):
|
|||
options = ["docker", "logs"]
|
||||
if follow:
|
||||
options.append("-f")
|
||||
options.append(ARCHGW_DOCKER_NAME)
|
||||
options.append(service)
|
||||
try:
|
||||
# Run `docker-compose logs` to stream logs from the gateway service
|
||||
subprocess.run(
|
||||
|
|
|
|||
|
|
@ -14,15 +14,18 @@ from cli.utils import (
|
|||
)
|
||||
from cli.core import (
|
||||
start_arch_modelserver,
|
||||
start_brightstaff,
|
||||
stop_arch_modelserver,
|
||||
start_arch,
|
||||
stop_arch,
|
||||
stop_docker_container,
|
||||
download_models_from_hf,
|
||||
)
|
||||
from cli.consts import (
|
||||
ARCHGW_DOCKER_IMAGE,
|
||||
BRIGHTSTAFF_DOCKER_IMAGE,
|
||||
KATANEMO_DOCKERHUB_REPO,
|
||||
SERVICE_NAME_ARCHGW,
|
||||
SERVICE_NAME_BRIGHTSTAFF,
|
||||
SERVICE_NAME_MODEL_SERVER,
|
||||
SERVICE_ALL,
|
||||
)
|
||||
|
|
@ -40,6 +43,7 @@ logo = r"""
|
|||
|
||||
# Command to build archgw and model_server Docker images
|
||||
ARCHGW_DOCKERFILE = "./arch/Dockerfile"
|
||||
BRIGHTSTAFF_DOCKERFILE = "./arch/Dockerfile.brightstaff"
|
||||
MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml"
|
||||
|
||||
|
||||
|
|
@ -51,6 +55,19 @@ def get_version():
|
|||
return "version not found"
|
||||
|
||||
|
||||
def verify_service_name(service):
|
||||
"""Verify if the service name is valid."""
|
||||
if service not in [
|
||||
SERVICE_NAME_ARCHGW,
|
||||
SERVICE_NAME_MODEL_SERVER,
|
||||
SERVICE_NAME_BRIGHTSTAFF,
|
||||
SERVICE_ALL,
|
||||
]:
|
||||
print(f"Error: Invalid service {service}. Exiting")
|
||||
sys.exit(1)
|
||||
return True
|
||||
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.option("--version", is_flag=True, help="Show the archgw cli version and exit.")
|
||||
@click.pass_context
|
||||
|
|
@ -75,9 +92,8 @@ def main(ctx, version):
|
|||
)
|
||||
def build(service):
|
||||
"""Build Arch from source. Must be in root of cloned repo."""
|
||||
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
|
||||
print(f"Error: Invalid service {service}. Exiting")
|
||||
sys.exit(1)
|
||||
verify_service_name(service)
|
||||
|
||||
# Check if /arch/Dockerfile exists
|
||||
if service == SERVICE_NAME_ARCHGW or service == SERVICE_ALL:
|
||||
if os.path.exists(ARCHGW_DOCKERFILE):
|
||||
|
|
@ -108,6 +124,35 @@ def build(service):
|
|||
|
||||
click.echo("archgw image built successfully.")
|
||||
|
||||
if service == SERVICE_NAME_BRIGHTSTAFF or service == SERVICE_ALL:
|
||||
if os.path.exists(BRIGHTSTAFF_DOCKERFILE):
|
||||
click.echo("Building brightstaff image...")
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"build",
|
||||
"-f",
|
||||
BRIGHTSTAFF_DOCKERFILE,
|
||||
"-t",
|
||||
f"{KATANEMO_DOCKERHUB_REPO}:brightstaff_latest",
|
||||
"-t",
|
||||
f"{BRIGHTSTAFF_DOCKER_IMAGE}",
|
||||
".",
|
||||
"--add-host=host.docker.internal:host-gateway",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
click.echo("brightstaff image built successfully.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"Error building brightstaff image: {e}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
click.echo("Error: Dockerfile not found in /arch")
|
||||
sys.exit(1)
|
||||
|
||||
click.echo("brightstaff image built successfully.")
|
||||
|
||||
"""Install the model server dependencies using Poetry."""
|
||||
if service == SERVICE_NAME_MODEL_SERVER or service == SERVICE_ALL:
|
||||
# Check if pyproject.toml exists
|
||||
|
|
@ -146,9 +191,7 @@ def build(service):
|
|||
)
|
||||
def up(file, path, service, foreground):
|
||||
"""Starts Arch."""
|
||||
if service not in [SERVICE_NAME_ARCHGW, SERVICE_NAME_MODEL_SERVER, SERVICE_ALL]:
|
||||
log.info(f"Error: Invalid service {service}. Exiting")
|
||||
sys.exit(1)
|
||||
verify_service_name(service)
|
||||
|
||||
if service == SERVICE_ALL and foreground:
|
||||
# foreground can only be specified when starting individual services
|
||||
|
|
@ -233,10 +276,13 @@ def up(file, path, service, foreground):
|
|||
|
||||
if service == SERVICE_NAME_ARCHGW:
|
||||
start_arch(arch_config_file, env, foreground=foreground)
|
||||
if service == SERVICE_NAME_BRIGHTSTAFF:
|
||||
start_brightstaff(arch_config_file, env, foreground=foreground)
|
||||
else:
|
||||
download_models_from_hf()
|
||||
start_arch_modelserver(foreground)
|
||||
start_arch(arch_config_file, env, foreground=foreground)
|
||||
start_brightstaff(arch_config_file, env, foreground=foreground)
|
||||
|
||||
|
||||
@click.command()
|
||||
|
|
@ -255,10 +301,10 @@ def down(service):
|
|||
if service == SERVICE_NAME_MODEL_SERVER:
|
||||
stop_arch_modelserver()
|
||||
elif service == SERVICE_NAME_ARCHGW:
|
||||
stop_arch()
|
||||
stop_docker_container()
|
||||
else:
|
||||
stop_arch_modelserver()
|
||||
stop_arch()
|
||||
stop_docker_container()
|
||||
|
||||
|
||||
@click.command()
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
//loading arch_config.yaml file
|
||||
let arch_config_path =
|
||||
env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "arch_config.yaml".to_string());
|
||||
env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "./arch_config.yaml".to_string());
|
||||
info!("Loading arch_config.yaml from {}", arch_config_path);
|
||||
|
||||
let config_contents =
|
||||
|
|
@ -88,14 +88,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
shorten_string(&serde_json::to_string(arch_config.as_ref()).unwrap())
|
||||
);
|
||||
|
||||
let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT")
|
||||
.unwrap_or_else(|_| "http://localhost:12000/v1/chat/completions".to_string());
|
||||
|
||||
info!("llm provider endpoint: {}", llm_provider_endpoint);
|
||||
info!("Listening on http://{}", bind_address);
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
|
||||
let llm_provider_endpoint = "http://localhost:12000/v1/chat/completions";
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
arch_config.llm_providers.clone(),
|
||||
llm_provider_endpoint.to_string(),
|
||||
llm_provider_endpoint.clone(),
|
||||
arch_config.routing.as_ref().unwrap().model.clone(),
|
||||
));
|
||||
|
||||
|
|
@ -105,6 +107,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
|
||||
let service = service_fn(move |req| {
|
||||
let router_service = Arc::clone(&router_service);
|
||||
|
|
@ -115,11 +118,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.span_builder("router_service")
|
||||
.with_kind(SpanKind::Server)
|
||||
.start_with_context(tracer, &parent_cx);
|
||||
let llm_provider_endpoint = llm_provider_endpoint.clone();
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, "/v1/chat/completions") => {
|
||||
chat_completion(req, router_service, llm_provider_endpoint.to_string())
|
||||
chat_completion(req, router_service, llm_provider_endpoint)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -327,10 +327,10 @@ impl HttpContext for StreamContext {
|
|||
let model_requested = deserialized_body.model.clone();
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested: {}, model selected: {:?}",
|
||||
"on_http_request_body: provider: {}, model requested: {}, model selected: {}",
|
||||
self.llm_provider().name,
|
||||
model_requested,
|
||||
self.llm_provider().model,
|
||||
self.llm_provider().model.as_ref().unwrap_or(&String::new())
|
||||
);
|
||||
|
||||
deserialized_body.model = self.llm_provider().model.clone().unwrap();
|
||||
|
|
|
|||
|
|
@ -489,7 +489,6 @@ fn llm_gateway_override_model_name() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_record("input_sequence_length", 29)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue