diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4d07452b..9c3698fc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,6 +46,10 @@ jobs: - name: Install plano tools run: uv sync --extra dev + - name: Sync CLI templates to demos + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + run: uv run python -m planoai.template_sync + - name: Run tests run: uv run pytest @@ -75,13 +79,13 @@ jobs: load: true tags: | ${{ env.PLANO_DOCKER_IMAGE }} - ${{ env.DOCKER_IMAGE }}:0.4.7 + ${{ env.DOCKER_IMAGE }}:0.4.8 ${{ env.DOCKER_IMAGE }}:latest cache-from: type=gha cache-to: type=gha,mode=max - name: Save image as artifact - run: docker save ${{ env.PLANO_DOCKER_IMAGE }} ${{ env.DOCKER_IMAGE }}:0.4.7 ${{ env.DOCKER_IMAGE }}:latest -o /tmp/plano-image.tar + run: docker save ${{ env.PLANO_DOCKER_IMAGE }} ${{ env.DOCKER_IMAGE }}:0.4.8 ${{ env.DOCKER_IMAGE }}:latest -o /tmp/plano-image.tar - name: Upload image artifact uses: actions/upload-artifact@v4 diff --git a/CLAUDE.md b/CLAUDE.md index b8c1c1bd..71c94303 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -137,6 +137,12 @@ To prepare a release (e.g., bumping from `0.4.6` to `0.4.7`), update the version Commit message format: `release X.Y.Z` +## Workflow Preferences + +- **Git commits:** Do NOT add `Co-Authored-By` lines. Keep commit messages short and concise (one line, no verbose descriptions). NEVER commit and push directly to `main`—always use a feature branch and PR. +- **Git branches:** Use the format `/` when creating branches for PRs. Determine the username from `gh api user --jq .login`. +- **GitHub issues:** When a GitHub issue URL is pasted, fetch all requirements and context from the issue first. The end goal is always a PR with all tests passing. + ## Key Conventions - Rust edition 2021, formatted with `cargo fmt`, linted with `cargo clippy -D warnings` diff --git a/apps/www/src/components/Hero.tsx b/apps/www/src/components/Hero.tsx index f98bc4d2..2b490b7b 100644 --- a/apps/www/src/components/Hero.tsx +++ b/apps/www/src/components/Hero.tsx @@ -24,7 +24,7 @@ export function Hero() { >
- v0.4.7 + v0.4.8 — diff --git a/build_filter_image.sh b/build_filter_image.sh index 318fa542..7c79d45c 100644 --- a/build_filter_image.sh +++ b/build_filter_image.sh @@ -1 +1 @@ -docker build -f Dockerfile . -t katanemo/plano -t katanemo/plano:0.4.7 +docker build -f Dockerfile . -t katanemo/plano -t katanemo/plano:0.4.8 diff --git a/cli/README.md b/cli/README.md index 19567824..4bd769bc 100644 --- a/cli/README.md +++ b/cli/README.md @@ -71,6 +71,17 @@ uv run planoai logs --follow uv run planoai [options] ``` +### CI: Keep CLI templates and demos in sync + +The CLI templates in `cli/planoai/templates/` are the source of truth for mapped +demo `config.yaml` files. + +Use the sync utility to write mapped demo configs from templates: + +```bash +uv run python -m planoai.template_sync +``` + ### Optional: Manual Virtual Environment Activation While `uv run` handles the virtual environment automatically, you can activate it manually if needed: @@ -80,4 +91,4 @@ source .venv/bin/activate planoai build # No need for 'uv run' when activated ``` -**Note:** For end-user installation instructions, see the [plano documentation](https://docs.planoai.dev). +**Note:** For end-user installation instructions, see the [Plano documentation](https://docs.planoai.dev). diff --git a/cli/planoai/__init__.py b/cli/planoai/__init__.py index 9e014320..03c28daa 100644 --- a/cli/planoai/__init__.py +++ b/cli/planoai/__init__.py @@ -1,3 +1,3 @@ """Plano CLI - Intelligent Prompt Gateway.""" -__version__ = "0.4.7" +__version__ = "0.4.8" diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index 8354b8dc..522968c9 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -460,6 +460,12 @@ def validate_and_render_schema(): print("agent_orchestrator: ", agent_orchestrator) + overrides = config_yaml.get("overrides", {}) + upstream_connect_timeout = overrides.get("upstream_connect_timeout", "5s") + upstream_tls_ca_path = overrides.get( + "upstream_tls_ca_path", "/etc/ssl/certs/ca-certificates.crt" + ) + data = { "prompt_gateway_listener": prompt_gateway, "llm_gateway_listener": llm_gateway, @@ -471,6 +477,8 @@ def validate_and_render_schema(): "local_llms": llms_with_endpoint, "agent_orchestrator": agent_orchestrator, "listeners": listeners, + "upstream_connect_timeout": upstream_connect_timeout, + "upstream_tls_ca_path": upstream_tls_ca_path, } rendered = template.render(data) diff --git a/cli/planoai/consts.py b/cli/planoai/consts.py index fa94efb6..84b4439f 100644 --- a/cli/planoai/consts.py +++ b/cli/planoai/consts.py @@ -5,5 +5,5 @@ PLANO_COLOR = "#969FF4" SERVICE_NAME_ARCHGW = "plano" PLANO_DOCKER_NAME = "plano" -PLANO_DOCKER_IMAGE = os.getenv("PLANO_DOCKER_IMAGE", "katanemo/plano:0.4.7") +PLANO_DOCKER_IMAGE = os.getenv("PLANO_DOCKER_IMAGE", "katanemo/plano:0.4.8") DEFAULT_OTEL_TRACING_GRPC_ENDPOINT = "http://host.docker.internal:4317" diff --git a/cli/planoai/template_sync.py b/cli/planoai/template_sync.py new file mode 100644 index 00000000..f4f2e44e --- /dev/null +++ b/cli/planoai/template_sync.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import argparse +from dataclasses import dataclass +from pathlib import Path + +import yaml + +from planoai.init_cmd import BUILTIN_TEMPLATES + + +@dataclass(frozen=True) +class SyncEntry: + template_id: str + template_file: str + demo_configs: tuple[str, ...] + transform: str = "none" + + +REPO_ROOT = Path(__file__).resolve().parents[2] +TEMPLATES_DIR = REPO_ROOT / "cli" / "planoai" / "templates" +SYNC_MAP_PATH = TEMPLATES_DIR / "template_sync_map.yaml" + + +def _load_sync_entries() -> list[SyncEntry]: + payload = yaml.safe_load(SYNC_MAP_PATH.read_text(encoding="utf-8")) or {} + rows = payload.get("templates", []) + entries: list[SyncEntry] = [] + for row in rows: + entries.append( + SyncEntry( + template_id=row["template_id"], + template_file=row["template_file"], + demo_configs=tuple(row.get("demo_configs", [])), + transform=row.get("transform", "none"), + ) + ) + return entries + + +def _render_for_demo(template_text: str, transform: str) -> str: + if transform == "none": + rendered = template_text + else: + raise ValueError(f"Unknown transform profile: {transform}") + + return rendered if rendered.endswith("\n") else f"{rendered}\n" + + +def _validate_manifest(entries: list[SyncEntry]) -> list[str]: + errors: list[str] = [] + builtin_ids = {t.id for t in BUILTIN_TEMPLATES} + manifest_ids = {entry.template_id for entry in entries} + + missing = sorted(builtin_ids - manifest_ids) + extra = sorted(manifest_ids - builtin_ids) + if missing: + errors.append(f"Missing template IDs in sync map: {', '.join(missing)}") + if extra: + errors.append(f"Unknown template IDs in sync map: {', '.join(extra)}") + + for entry in entries: + template_path = TEMPLATES_DIR / entry.template_file + if not template_path.exists(): + errors.append( + f"template_file does not exist for '{entry.template_id}': {template_path}" + ) + for demo_rel_path in entry.demo_configs: + demo_path = REPO_ROOT / demo_rel_path + if not demo_path.exists(): + errors.append( + f"demo config does not exist for '{entry.template_id}': {demo_path}" + ) + + return errors + + +def write_mapped_demo_configs(*, verbose: bool = False) -> int: + entries = _load_sync_entries() + manifest_errors = _validate_manifest(entries) + if manifest_errors: + for error in manifest_errors: + print(f"[manifest] {error}") + return 2 + + write_count = 0 + for entry in entries: + template_text = (TEMPLATES_DIR / entry.template_file).read_text( + encoding="utf-8" + ) + expected_text = _render_for_demo(template_text, entry.transform) + + for demo_rel_path in entry.demo_configs: + demo_path = REPO_ROOT / demo_rel_path + # Keep this as a write-only sync step so CI behavior is deterministic. + demo_path.write_text(expected_text, encoding="utf-8") + write_count += 1 + if verbose: + print( + f"[wrote] {demo_rel_path} <- {entry.template_id} ({entry.template_file})" + ) + + print(f"Wrote {write_count} mapped demo config(s) from CLI templates.") + return 0 + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Sync CLI templates to mapped demo config.yaml files (write-only)." + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print each file written during sync.", + ) + args = parser.parse_args() + + return write_mapped_demo_configs(verbose=bool(args.verbose)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/cli/planoai/templates/template_sync_map.yaml b/cli/planoai/templates/template_sync_map.yaml new file mode 100644 index 00000000..4d601f6c --- /dev/null +++ b/cli/planoai/templates/template_sync_map.yaml @@ -0,0 +1,29 @@ +templates: + - template_id: sub_agent_orchestration + template_file: sub_agent_orchestration.yaml + demo_configs: + - demos/agent_orchestration/multi_agent_crewai_langchain/config.yaml + transform: none + + - template_id: coding_agent_routing + template_file: coding_agent_routing.yaml + demo_configs: + - demos/llm_routing/claude_code_router/config.yaml + transform: none + + - template_id: preference_aware_routing + template_file: preference_aware_routing.yaml + demo_configs: + - demos/llm_routing/preference_based_routing/config.yaml + transform: none + + - template_id: filter_chain_guardrails + template_file: filter_chain_guardrails.yaml + demo_configs: + - demos/filter_chains/http_filter/config.yaml + transform: none + + - template_id: conversational_state_v1_responses + template_file: conversational_state_v1_responses.yaml + demo_configs: [] + transform: none diff --git a/cli/pyproject.toml b/cli/pyproject.toml index 673e821b..44b3a553 100644 --- a/cli/pyproject.toml +++ b/cli/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "planoai" -version = "0.4.7" +version = "0.4.8" description = "Python-based CLI tool to manage Plano." authors = [{name = "Katanemo Labs, Inc."}] readme = "README.md" diff --git a/config/envoy.template.yaml b/config/envoy.template.yaml index f514e728..a780c3f1 100644 --- a/config/envoy.template.yaml +++ b/config/envoy.template.yaml @@ -595,7 +595,7 @@ static_resources: clusters: - name: arch - connect_timeout: 5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -618,9 +618,12 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} - name: anthropic - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -643,9 +646,12 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} - name: deepseek - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -668,9 +674,12 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} - name: xai - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -693,9 +702,12 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} - name: moonshotai - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -718,9 +730,12 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} - name: zhipu - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -743,9 +758,12 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} - name: together_ai - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -768,9 +786,12 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} - name: gemini - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -793,9 +814,12 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} - name: groq - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -818,9 +842,12 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} - name: mistral - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -839,9 +866,16 @@ static_resources: typed_config: "@type": type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext sni: api.mistral.ai + common_tls_context: + tls_params: + tls_minimum_protocol_version: TLSv1_2 + tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} - name: openai - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -864,6 +898,9 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} - name: mistral_7b_instruct connect_timeout: 0.5s type: STRICT_DNS @@ -884,7 +921,7 @@ static_resources: {% if cluster.connect_timeout -%} connect_timeout: {{ cluster.connect_timeout }} {% else -%} - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} {% endif -%} type: LOGICAL_DNS dns_lookup_family: V4_ONLY @@ -913,12 +950,15 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} {% endif %} {% endfor %} {% for local_llm_provider in local_llms %} - name: {{ local_llm_provider.cluster_name }} - connect_timeout: 0.5s + connect_timeout: {{ upstream_connect_timeout | default('5s') }} type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN @@ -946,6 +986,9 @@ static_resources: tls_params: tls_minimum_protocol_version: TLSv1_2 tls_maximum_protocol_version: TLSv1_3 + validation_context: + trusted_ca: + filename: {{ upstream_tls_ca_path | default('/etc/ssl/certs/ca-certificates.crt') }} {% endif %} {% endfor %} diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index 0f3cefb7..cd736eb6 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -265,6 +265,12 @@ properties: type: boolean use_agent_orchestrator: type: boolean + upstream_connect_timeout: + type: string + description: "Connect timeout for upstream provider clusters (e.g., '5s', '10s'). Default is '5s'." + upstream_tls_ca_path: + type: string + description: "Path to the trusted CA bundle for upstream TLS verification. Default is '/etc/ssl/certs/ca-certificates.crt'." system_prompt: type: string prompt_targets: diff --git a/config/validate_plano_config.sh b/config/validate_plano_config.sh index 8eafd344..5291341d 100644 --- a/config/validate_plano_config.sh +++ b/config/validate_plano_config.sh @@ -5,7 +5,7 @@ failed_files=() for file in $(find . -name config.yaml -o -name plano_config_full_reference.yaml); do echo "Validating ${file}..." touch $(pwd)/${file}_rendered - if ! docker run --rm -v "$(pwd)/${file}:/app/plano_config.yaml:ro" -v "$(pwd)/${file}_rendered:/app/plano_config_rendered.yaml:rw" --entrypoint /bin/sh ${PLANO_DOCKER_IMAGE:-katanemo/plano:0.4.7} -c "python -m planoai.config_generator" 2>&1 > /dev/null ; then + if ! docker run --rm -v "$(pwd)/${file}:/app/plano_config.yaml:ro" -v "$(pwd)/${file}_rendered:/app/plano_config_rendered.yaml:rw" --entrypoint /bin/sh ${PLANO_DOCKER_IMAGE:-katanemo/plano:0.4.8} -c "python -m planoai.config_generator" 2>&1 > /dev/null ; then echo "Validation failed for $file" failed_files+=("$file") fi diff --git a/crates/Cargo.lock b/crates/Cargo.lock index ebe5b881..fbf817e7 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -436,11 +436,14 @@ name = "common" version = "0.1.0" dependencies = [ "axum", + "bytes", "derivative", "duration-string", "governor", "hermesllm", "hex", + "http-body-util", + "hyper 1.6.0", "log", "pretty_assertions", "proxy-wasm", diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index adfdce02..dea736e3 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -2,15 +2,18 @@ use std::sync::Arc; use std::time::Instant; use bytes::Bytes; +use common::errors::BrightStaffError; +use common::llm_providers::LlmProviders; use hermesllm::apis::OpenAIMessage; use hermesllm::clients::SupportedAPIsFromClient; use hermesllm::providers::request::ProviderRequest; use hermesllm::ProviderRequestType; use http_body_util::combinators::BoxBody; use http_body_util::BodyExt; -use hyper::{Request, Response}; +use hyper::{Request, Response, StatusCode}; use opentelemetry::trace::get_active_span; use serde::ser::Error as SerError; +use tokio::sync::RwLock; use tracing::{debug, info, info_span, warn, Instrument}; use super::agent_selector::{AgentSelectionError, AgentSelector}; @@ -22,12 +25,12 @@ use crate::tracing::{operation_component, set_service_name}; /// Main errors for agent chat completions #[derive(Debug, thiserror::Error)] pub enum AgentFilterChainError { + #[error("Forwarded error: {0}")] + Brightstaff(#[from] BrightStaffError), #[error("Agent selection error: {0}")] Selection(#[from] AgentSelectionError), #[error("Pipeline processing error: {0}")] Pipeline(#[from] PipelineError), - #[error("Response handling error: {0}")] - Response(#[from] super::response_handler::ResponseError), #[error("Request parsing error: {0}")] RequestParsing(#[from] serde_json::Error), #[error("HTTP error: {0}")] @@ -40,6 +43,7 @@ pub async fn agent_chat( _: String, agents_list: Arc>>>, listeners: Arc>>, + llm_providers: Arc>, ) -> Result>, hyper::Error> { // Extract request_id from headers or generate a new one let request_id: String = match request @@ -71,6 +75,7 @@ pub async fn agent_chat( orchestrator_service, agents_list, listeners, + llm_providers, request_id, ) .await @@ -99,16 +104,15 @@ pub async fn agent_chat( "agent_response": body }); + let status_code = hyper::StatusCode::from_u16(*status) + .unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR); + let json_string = error_json.to_string(); - let mut response = - Response::new(ResponseHandler::create_full_body(json_string)); - *response.status_mut() = hyper::StatusCode::from_u16(*status) - .unwrap_or(hyper::StatusCode::BAD_REQUEST); - response.headers_mut().insert( - hyper::header::CONTENT_TYPE, - "application/json".parse().unwrap(), - ); - return Ok(response); + return Ok(BrightStaffError::ForwardedError { + status_code, + message: json_string, + } + .into_response()); } // Print detailed error information with full error chain for other errors @@ -141,8 +145,11 @@ pub async fn agent_chat( // Log the error for debugging info!(error = %error_json, "structured error info"); - // Return JSON error response - Ok(ResponseHandler::create_json_error_response(&error_json)) + Ok(BrightStaffError::ForwardedError { + status_code: StatusCode::BAD_REQUEST, + message: error_json.to_string(), + } + .into_response()) } } } @@ -155,6 +162,7 @@ async fn handle_agent_chat_inner( orchestrator_service: Arc, agents_list: Arc>>>, listeners: Arc>>, + llm_providers: Arc>, request_id: String, ) -> Result>, AgentFilterChainError> { // Initialize services @@ -221,16 +229,33 @@ async fn handle_agent_chat_inner( AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg)) })?; - let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) { - Ok(request) => request, - Err(err) => { - warn!("failed to parse request as ProviderRequestType: {}", err); - let err_msg = format!("Failed to parse request: {}", err); - return Err(AgentFilterChainError::RequestParsing( - serde_json::Error::custom(err_msg), - )); + let mut client_request = + match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) { + Ok(request) => request, + Err(err) => { + warn!("failed to parse request as ProviderRequestType: {}", err); + let err_msg = format!("Failed to parse request: {}", err); + return Err(AgentFilterChainError::RequestParsing( + serde_json::Error::custom(err_msg), + )); + } + }; + + // If model is not specified in the request, resolve from default provider + if client_request.model().is_empty() { + match llm_providers.read().await.default() { + Some(default_provider) => { + let default_model = default_provider.name.clone(); + info!(default_model = %default_model, "no model specified in request, using default provider"); + client_request.set_model(default_model); + } + None => { + let err_msg = "No model specified in request and no default provider configured"; + warn!("{}", err_msg); + return Ok(BrightStaffError::NoModelSpecified.into_response()); + } } - }; + } let message: Vec = client_request.get_messages(); diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index 70eaacd7..70b2999d 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -5,9 +5,10 @@ use hyper::header::HeaderMap; use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector}; use crate::handlers::pipeline_processor::PipelineProcessor; -use crate::handlers::response_handler::ResponseHandler; use crate::router::plano_orchestrator::OrchestratorService; - +use common::errors::BrightStaffError; +use http_body_util::BodyExt; +use hyper::StatusCode; /// Integration test that demonstrates the modular agent chat flow /// This test shows how the three main components work together: /// 1. AgentSelector - selects the appropriate agents based on orchestration @@ -128,8 +129,24 @@ mod tests { } // Test 4: Error Response Creation - let error_response = ResponseHandler::create_bad_request("Test error"); - assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST); + let err = BrightStaffError::ModelNotFound("gpt-5-secret".to_string()); + let response = err.into_response(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + // Helper to extract body as JSON + let body_bytes = response.into_body().collect().await.unwrap().to_bytes(); + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + + assert_eq!(body["error"]["code"], "ModelNotFound"); + assert_eq!( + body["error"]["details"]["rejected_model_id"], + "gpt-5-secret" + ); + assert!(body["error"]["message"] + .as_str() + .unwrap() + .contains("gpt-5-secret")); println!("✅ All modular components working correctly!"); } @@ -148,12 +165,21 @@ mod tests { AgentSelectionError::ListenerNotFound(_) )); - // Test error response creation - let error_response = ResponseHandler::create_internal_error("Pipeline failed"); - assert_eq!( - error_response.status(), - hyper::StatusCode::INTERNAL_SERVER_ERROR - ); + let technical_reason = "Database connection timed out"; + let err = BrightStaffError::InternalServerError(technical_reason.to_string()); + + let response = err.into_response(); + + // --- 1. EXTRACT BYTES --- + let body_bytes = response.into_body().collect().await.unwrap().to_bytes(); + + // --- 2. DECLARE body_json HERE --- + let body_json: serde_json::Value = + serde_json::from_slice(&body_bytes).expect("Failed to parse JSON body"); + + // --- 3. USE body_json --- + assert_eq!(body_json["error"]["code"], "InternalServerError"); + assert_eq!(body_json["error"]["details"]["reason"], technical_reason); println!("✅ Error handling working correctly!"); } diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index 10a68c1a..8e8f9661 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -8,9 +8,9 @@ use hermesllm::apis::openai_responses::InputParam; use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use hermesllm::{ProviderRequest, ProviderRequestType}; use http_body_util::combinators::BoxBody; -use http_body_util::{BodyExt, Full}; +use http_body_util::BodyExt; use hyper::header::{self}; -use hyper::{Request, Response, StatusCode}; +use hyper::{Request, Response}; use opentelemetry::global; use opentelemetry::trace::get_active_span; use opentelemetry_http::HeaderInjector; @@ -30,11 +30,7 @@ use crate::state::{ }; use crate::tracing::{llm as tracing_llm, operation_component, set_service_name}; -fn full>(chunk: T) -> BoxBody { - Full::new(chunk.into()) - .map_err(|never| match never {}) - .boxed() -} +use common::errors::BrightStaffError; pub async fn llm_chat( request: Request, @@ -135,10 +131,11 @@ async fn llm_chat_inner( error = %err, "failed to parse request as ProviderRequestType" ); - let err_msg = format!("Failed to parse request: {}", err); - let mut bad_request = Response::new(full(err_msg)); - *bad_request.status_mut() = StatusCode::BAD_REQUEST; - return Ok(bad_request); + return Ok(BrightStaffError::InvalidRequest(format!( + "Failed to parse request: {}", + err + )) + .into_response()); } }; @@ -150,9 +147,28 @@ async fn llm_chat_inner( Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)) ); + // If model is not specified in the request, resolve from default provider + let model_from_request = client_request.model().to_string(); + let model_from_request = if model_from_request.is_empty() { + match llm_providers.read().await.default() { + Some(default_provider) => { + let default_model = default_provider.name.clone(); + info!(default_model = %default_model, "no model specified in request, using default provider"); + client_request.set_model(default_model.clone()); + default_model + } + None => { + let err_msg = "No model specified in request and no default provider configured"; + warn!("{}", err_msg); + return Ok(BrightStaffError::NoModelSpecified.into_response()); + } + } + } else { + model_from_request + }; + // Model alias resolution: update model field in client_request immediately // This ensures all downstream objects use the resolved model - let model_from_request = client_request.model().to_string(); let temperature = client_request.get_temperature(); let is_streaming_request = client_request.is_streaming(); let alias_resolved_model = resolve_model_alias(&model_from_request, &model_aliases); @@ -165,14 +181,8 @@ async fn llm_chat_inner( .get(&alias_resolved_model) .is_none() { - let err_msg = format!( - "Model '{}' not found in configured providers", - alias_resolved_model - ); warn!(model = %alias_resolved_model, "model not found in configured providers"); - let mut bad_request = Response::new(full(err_msg)); - *bad_request.status_mut() = StatusCode::BAD_REQUEST; - return Ok(bad_request); + return Ok(BrightStaffError::ModelNotFound(alias_resolved_model).into_response()); } // Handle provider/model slug format (e.g., "openai/gpt-4") @@ -267,13 +277,10 @@ async fn llm_chat_inner( Err(StateStorageError::NotFound(_)) => { // Return 409 Conflict when previous_response_id not found warn!(previous_response_id = %prev_resp_id, "previous response_id not found"); - let err_msg = format!( - "Conversation state not found for previous_response_id: {}", - prev_resp_id - ); - let mut conflict_response = Response::new(full(err_msg)); - *conflict_response.status_mut() = StatusCode::CONFLICT; - return Ok(conflict_response); + return Ok(BrightStaffError::ConversationStateNotFound( + prev_resp_id.to_string(), + ) + .into_response()); } Err(e) => { // Log warning but continue on other storage errors @@ -324,9 +331,11 @@ async fn llm_chat_inner( { Ok(result) => result, Err(err) => { - let mut internal_error = Response::new(full(err.message)); - *internal_error.status_mut() = err.status_code; - return Ok(internal_error); + return Ok(BrightStaffError::ForwardedError { + status_code: err.status_code, + message: err.message, + } + .into_response()); } }; @@ -392,10 +401,11 @@ async fn llm_chat_inner( { Ok(res) => res, Err(err) => { - let err_msg = format!("Failed to send request: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); + return Ok(BrightStaffError::InternalServerError(format!( + "Failed to send request: {}", + err + )) + .into_response()); } }; @@ -452,12 +462,11 @@ async fn llm_chat_inner( match response.body(streaming_response.body) { Ok(response) => Ok(response), - Err(err) => { - let err_msg = format!("Failed to create response: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - Ok(internal_error) - } + Err(err) => Ok(BrightStaffError::InternalServerError(format!( + "Failed to create response: {}", + err + )) + .into_response()), } } diff --git a/crates/brightstaff/src/handlers/response_handler.rs b/crates/brightstaff/src/handlers/response_handler.rs index e2561a8f..7331ab4c 100644 --- a/crates/brightstaff/src/handlers/response_handler.rs +++ b/crates/brightstaff/src/handlers/response_handler.rs @@ -1,25 +1,17 @@ use bytes::Bytes; +use common::errors::BrightStaffError; use hermesllm::apis::OpenAIApi; use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use hermesllm::SseEvent; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Frame; -use hyper::{Response, StatusCode}; +use hyper::Response; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; use tracing::{info, warn, Instrument}; -/// Errors that can occur during response handling -#[derive(Debug, thiserror::Error)] -pub enum ResponseError { - #[error("Failed to create response: {0}")] - ResponseCreationFailed(#[from] hyper::http::Error), - #[error("Stream error: {0}")] - StreamError(String), -} - /// Service for handling HTTP responses and streaming pub struct ResponseHandler; @@ -35,40 +27,6 @@ impl ResponseHandler { .boxed() } - /// Create an error response with a given status code and message - pub fn create_error_response( - status: StatusCode, - message: &str, - ) -> Response> { - let mut response = Response::new(Self::create_full_body(message.to_string())); - *response.status_mut() = status; - response - } - - /// Create a bad request response - pub fn create_bad_request(message: &str) -> Response> { - Self::create_error_response(StatusCode::BAD_REQUEST, message) - } - - /// Create an internal server error response - pub fn create_internal_error(message: &str) -> Response> { - Self::create_error_response(StatusCode::INTERNAL_SERVER_ERROR, message) - } - - /// Create a JSON error response - pub fn create_json_error_response( - error_json: &serde_json::Value, - ) -> Response> { - let json_string = error_json.to_string(); - let mut response = Response::new(Self::create_full_body(json_string)); - *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - response.headers_mut().insert( - hyper::header::CONTENT_TYPE, - "application/json".parse().unwrap(), - ); - response - } - /// Create a streaming response from a reqwest response. /// The spawned streaming task is instrumented with both `agent_span` and `orchestrator_span` /// so their durations reflect the actual time spent streaming to the client. @@ -77,13 +35,13 @@ impl ResponseHandler { llm_response: reqwest::Response, agent_span: tracing::Span, orchestrator_span: tracing::Span, - ) -> Result>, ResponseError> { + ) -> Result>, BrightStaffError> { // Copy headers from the original response let response_headers = llm_response.headers(); let mut response_builder = Response::builder(); let headers = response_builder.headers_mut().ok_or_else(|| { - ResponseError::StreamError("Failed to get mutable headers".to_string()) + BrightStaffError::StreamError("Failed to get mutable headers".to_string()) })?; for (header_name, header_value) in response_headers.iter() { @@ -123,7 +81,7 @@ impl ResponseHandler { response_builder .body(stream_body) - .map_err(ResponseError::from) + .map_err(BrightStaffError::from) } /// Collect the full response body as a string @@ -136,7 +94,7 @@ impl ResponseHandler { pub async fn collect_full_response( &self, llm_response: reqwest::Response, - ) -> Result { + ) -> Result { use hermesllm::apis::streaming_shapes::sse::SseStreamIter; let response_headers = llm_response.headers(); @@ -144,10 +102,9 @@ impl ResponseHandler { .get(hyper::header::CONTENT_TYPE) .is_some_and(|v| v.to_str().unwrap_or("").contains("text/event-stream")); - let response_bytes = llm_response - .bytes() - .await - .map_err(|e| ResponseError::StreamError(format!("Failed to read response: {}", e)))?; + let response_bytes = llm_response.bytes().await.map_err(|e| { + BrightStaffError::StreamError(format!("Failed to read response: {}", e)) + })?; if is_sse_streaming { let client_api = @@ -185,7 +142,7 @@ impl ResponseHandler { } else { // If not SSE, treat as regular text response let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| { - ResponseError::StreamError(format!("Failed to decode response: {}", e)) + BrightStaffError::StreamError(format!("Failed to decode response: {}", e)) })?; Ok(response_text) @@ -204,42 +161,6 @@ mod tests { use super::*; use hyper::StatusCode; - #[test] - fn test_create_bad_request() { - let response = ResponseHandler::create_bad_request("Invalid request"); - assert_eq!(response.status(), StatusCode::BAD_REQUEST); - } - - #[test] - fn test_create_internal_error() { - let response = ResponseHandler::create_internal_error("Server error"); - assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); - } - - #[test] - fn test_create_error_response() { - let response = - ResponseHandler::create_error_response(StatusCode::NOT_FOUND, "Resource not found"); - assert_eq!(response.status(), StatusCode::NOT_FOUND); - } - - #[test] - fn test_create_json_error_response() { - let error_json = serde_json::json!({ - "error": { - "type": "TestError", - "message": "Test error message" - } - }); - - let response = ResponseHandler::create_json_error_response(&error_json); - assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); - assert_eq!( - response.headers().get("content-type").unwrap(), - "application/json" - ); - } - #[tokio::test] async fn test_create_streaming_response_with_mock() { use mockito::Server; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index fff69b00..87deda6a 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -202,6 +202,7 @@ async fn main() -> Result<(), Box> { fully_qualified_url, agents_list, listeners, + llm_providers, ) .with_context(parent_cx) .await; diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index cb471bd6..dd2cba15 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -20,6 +20,9 @@ urlencoding = "2.1.3" url = "2.5.4" hermesllm = { version = "0.1.0", path = "../hermesllm" } serde_with = "3.13.0" +hyper = "1.0" +bytes = "1.0" +http-body-util = "0.1" [features] default = [] @@ -30,3 +33,6 @@ serde_json = "1.0.64" serial_test = "3.2" axum = "0.7" tokio = { version = "1.44", features = ["sync", "time", "macros", "rt"] } +hyper = { version = "1.0", features = ["full"] } +bytes = "1.0" +http-body-util = "0.1" diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs index 21af3c94..b2f57199 100644 --- a/crates/common/src/errors.rs +++ b/crates/common/src/errors.rs @@ -1,9 +1,13 @@ -use proxy_wasm::types::Status; - use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit}; +use bytes::Bytes; use hermesllm::apis::openai::OpenAIError; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use hyper::{Error as HyperError, Response, StatusCode}; +use proxy_wasm::types::Status; +use serde_json::json; +use thiserror::Error; -#[derive(thiserror::Error, Debug)] +#[derive(Error, Debug)] pub enum ClientError { #[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")] DispatchError { @@ -13,7 +17,7 @@ pub enum ClientError { }, } -#[derive(thiserror::Error, Debug)] +#[derive(Error, Debug)] pub enum ServerError { #[error(transparent)] HttpDispatch(ClientError), @@ -43,3 +47,174 @@ pub enum ServerError { #[error("error parsing openai message: {0}")] OpenAIPError(#[from] OpenAIError), } +// ----------------------------------------------------------------------------- +// BrightStaff Errors (Standardized) +// ----------------------------------------------------------------------------- +#[derive(Debug, Error)] +pub enum BrightStaffError { + #[error("The requested model '{0}' does not exist")] + ModelNotFound(String), + + #[error("No model specified in request and no default provider configured")] + NoModelSpecified, + + #[error("Conversation state not found for previous_response_id: {0}")] + ConversationStateNotFound(String), + + #[error("Internal server error")] + InternalServerError(String), + + #[error("Invalid request")] + InvalidRequest(String), + + #[error("{message}")] + ForwardedError { + status_code: StatusCode, + message: String, + }, + + #[error("Stream error: {0}")] + StreamError(String), + + #[error("Failed to create response: {0}")] + ResponseCreationFailed(#[from] hyper::http::Error), +} + +impl BrightStaffError { + pub fn into_response(self) -> Response> { + let (status, code, details) = match &self { + BrightStaffError::ModelNotFound(model_name) => ( + StatusCode::NOT_FOUND, + "ModelNotFound", + json!({ "rejected_model_id": model_name }), + ), + + BrightStaffError::NoModelSpecified => { + (StatusCode::BAD_REQUEST, "NoModelSpecified", json!({})) + } + + BrightStaffError::ConversationStateNotFound(prev_resp_id) => ( + StatusCode::CONFLICT, + "ConversationStateNotFound", + json!({ "previous_response_id": prev_resp_id }), + ), + + BrightStaffError::InternalServerError(reason) => ( + StatusCode::INTERNAL_SERVER_ERROR, + "InternalServerError", + // Passing the reason into details for easier debugging + json!({ "reason": reason }), + ), + + BrightStaffError::InvalidRequest(reason) => ( + StatusCode::BAD_REQUEST, + "InvalidRequest", + json!({ "reason": reason }), + ), + + BrightStaffError::ForwardedError { + status_code, + message, + } => (*status_code, "ForwardedError", json!({ "reason": message })), + + BrightStaffError::StreamError(reason) => ( + StatusCode::BAD_REQUEST, + "StreamError", + json!({ "reason": reason }), + ), + + BrightStaffError::ResponseCreationFailed(reason) => ( + StatusCode::BAD_REQUEST, + "ResponseCreationFailed", + json!({ "reason": reason.to_string() }), + ), + }; + + let body_json = json!({ + "error": { + "code": code, + "message": self.to_string(), + "details": details + } + }); + + // 1. Create the concrete body + let full_body = Full::new(Bytes::from(body_json.to_string())); + + // 2. Convert it to BoxBody + // We map_err because Full never fails, but BoxBody expects a HyperError + let boxed_body = full_body + .map_err(|never| match never {}) // This handles the "Infallible" error type + .boxed(); + + Response::builder() + .status(status) + .header("content-type", "application/json") + .body(boxed_body) + .unwrap_or_else(|_| { + Response::new( + Full::new(Bytes::from("Internal Error")) + .map_err(|never| match never {}) + .boxed(), + ) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use http_body_util::BodyExt; // For .collect().await + + #[tokio::test] + async fn test_model_not_found_format() { + let err = BrightStaffError::ModelNotFound("gpt-5-secret".to_string()); + let response = err.into_response(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + + // Helper to extract body as JSON + let body_bytes = response.into_body().collect().await.unwrap().to_bytes(); + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + + assert_eq!(body["error"]["code"], "ModelNotFound"); + assert_eq!( + body["error"]["details"]["rejected_model_id"], + "gpt-5-secret" + ); + assert!(body["error"]["message"] + .as_str() + .unwrap() + .contains("gpt-5-secret")); + } + + #[tokio::test] + async fn test_forwarded_error_preserves_status() { + let err = BrightStaffError::ForwardedError { + status_code: StatusCode::TOO_MANY_REQUESTS, + message: "Rate limit exceeded on agent side".to_string(), + }; + + let response = err.into_response(); + assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS); + + let body_bytes = response.into_body().collect().await.unwrap().to_bytes(); + let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + + assert_eq!(body["error"]["code"], "ForwardedError"); + } + + #[tokio::test] + async fn test_hyper_error_wrapping() { + // Manually trigger a hyper error by creating an invalid URI/Header + let hyper_err = hyper::http::Response::builder() + .status(1000) // Invalid status + .body(()) + .unwrap_err(); + + let err = BrightStaffError::ResponseCreationFailed(hyper_err); + let response = err.into_response(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + } +} diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index 6e53e6db..3cb06828 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -102,6 +102,7 @@ pub struct McpServer { #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct MessagesRequest { + #[serde(default)] pub model: String, pub messages: Vec, pub max_tokens: u32, diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index cd4e7d0b..53eee442 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -74,6 +74,7 @@ impl ApiDefinition for OpenAIApi { #[derive(Serialize, Deserialize, Debug, Clone, Default)] pub struct ChatCompletionsRequest { pub messages: Vec, + #[serde(default)] pub model: String, // pub audio: Option