From 8edf686665536b017092167e06d09c47a92a85e1 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Wed, 11 Mar 2026 15:43:10 -0700 Subject: [PATCH] support configurable orchestrator model via orchestration config section --- cli/planoai/config_generator.py | 18 ++++++++++++++-- config/plano_config_schema.yaml | 8 +++++++ .../src/handlers/agent_selector.rs | 1 + .../src/handlers/integration_tests.rs | 1 + crates/brightstaff/src/main.rs | 21 +++++++++++++++---- .../src/router/plano_orchestrator.rs | 14 +++++++++---- crates/common/src/configuration.rs | 7 +++++++ crates/common/src/consts.rs | 1 - 8 files changed, 60 insertions(+), 11 deletions(-) diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index 522968c9..4bc9c4b7 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -403,12 +403,26 @@ def validate_and_render_schema(): } ) - if "plano-orchestrator" not in model_provider_name_set: + orchestration_config = config_yaml.get("orchestration", {}) + orchestration_model_provider = orchestration_config.get("llm_provider", None) + + if ( + orchestration_model_provider + and orchestration_model_provider not in model_provider_name_set + ): + raise Exception( + f"Orchestration llm_provider {orchestration_model_provider} is not defined in model_providers" + ) + + if ( + orchestration_model_provider is None + and "plano-orchestrator" not in model_provider_name_set + ): updated_model_providers.append( { "name": "plano-orchestrator", "provider_interface": "arch", - "model": "Plano-Orchestrator", + "model": orchestration_config.get("model", "Plano-Orchestrator"), "internal": True, } ) diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index b63cb824..d2293650 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -416,6 +416,14 @@ properties: model: type: string additionalProperties: false + orchestration: + type: object + properties: + llm_provider: + type: string + model: + type: string + additionalProperties: false state_storage: type: object properties: diff --git a/crates/brightstaff/src/handlers/agent_selector.rs b/crates/brightstaff/src/handlers/agent_selector.rs index faa734ee..a1b38b2c 100644 --- a/crates/brightstaff/src/handlers/agent_selector.rs +++ b/crates/brightstaff/src/handlers/agent_selector.rs @@ -178,6 +178,7 @@ mod tests { Arc::new(OrchestratorService::new( "http://localhost:8080".to_string(), "test-model".to_string(), + "plano-orchestrator".to_string(), )) } diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index 70b2999d..b440e198 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -23,6 +23,7 @@ mod tests { Arc::new(OrchestratorService::new( "http://localhost:8080".to_string(), "test-model".to_string(), + "plano-orchestrator".to_string(), )) } diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 51c9127f..025ff545 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -11,9 +11,7 @@ use brightstaff::state::StateStorage; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; use common::configuration::{Agent, Configuration}; -use common::consts::{ - CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME, -}; +use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH}; use common::llm_providers::LlmProviders; use http_body_util::{combinators::BoxBody, BodyExt, Empty}; use hyper::body::Incoming; @@ -35,6 +33,8 @@ pub mod router; const BIND_ADDRESS: &str = "0.0.0.0:9091"; const DEFAULT_ROUTING_LLM_PROVIDER: &str = "arch-router"; const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router"; +const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator"; +const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator"; // Utility function to extract the context from the incoming request headers fn extract_context_from_request(req: &Request) -> Context { @@ -109,9 +109,22 @@ async fn main() -> Result<(), Box> { routing_llm_provider, )); + let orchestrator_model_name: String = plano_config + .orchestration + .as_ref() + .and_then(|o| o.model.clone()) + .unwrap_or_else(|| DEFAULT_ORCHESTRATOR_MODEL_NAME.to_string()); + + let orchestrator_llm_provider: String = plano_config + .orchestration + .as_ref() + .and_then(|o| o.model_provider.clone()) + .unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string()); + let orchestrator_service: Arc = Arc::new(OrchestratorService::new( format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"), - PLANO_ORCHESTRATOR_MODEL_NAME.to_string(), + orchestrator_model_name, + orchestrator_llm_provider, )); let model_aliases = Arc::new(plano_config.model_aliases.clone()); diff --git a/crates/brightstaff/src/router/plano_orchestrator.rs b/crates/brightstaff/src/router/plano_orchestrator.rs index cf2688b9..ed303187 100644 --- a/crates/brightstaff/src/router/plano_orchestrator.rs +++ b/crates/brightstaff/src/router/plano_orchestrator.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use common::{ configuration::{AgentUsagePreference, OrchestrationPreference}, - consts::{ARCH_PROVIDER_HINT_HEADER, PLANO_ORCHESTRATOR_MODEL_NAME, REQUEST_ID_HEADER}, + consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER}, }; use hermesllm::apis::openai::{ChatCompletionsResponse, Message}; use hyper::header; @@ -19,6 +19,7 @@ pub struct OrchestratorService { orchestrator_url: String, client: reqwest::Client, orchestrator_model: Arc, + orchestrator_provider_name: String, } #[derive(Debug, Error)] @@ -36,7 +37,11 @@ pub enum OrchestrationError { pub type Result = std::result::Result; impl OrchestratorService { - pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self { + pub fn new( + orchestrator_url: String, + orchestration_model_name: String, + orchestrator_provider_name: String, + ) -> Self { // Empty agent orchestrations - will be provided via usage_preferences in requests let agent_orchestrations: HashMap> = HashMap::new(); @@ -50,6 +55,7 @@ impl OrchestratorService { orchestrator_url, client: reqwest::Client::new(), orchestrator_model, + orchestrator_provider_name, } } @@ -91,7 +97,7 @@ impl OrchestratorService { orchestration_request_headers.insert( header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), - header::HeaderValue::from_str(PLANO_ORCHESTRATOR_MODEL_NAME).unwrap(), + header::HeaderValue::from_str(&self.orchestrator_provider_name).unwrap(), ); // Inject OpenTelemetry trace context from current span @@ -110,7 +116,7 @@ impl OrchestratorService { orchestration_request_headers.insert( header::HeaderName::from_static("model"), - header::HeaderValue::from_static(PLANO_ORCHESTRATOR_MODEL_NAME), + header::HeaderValue::from_str(&self.orchestrator_provider_name).unwrap(), ); let start_time = std::time::Instant::now(); diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index f4e2b7b4..948526b2 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -13,6 +13,12 @@ pub struct Routing { pub model: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Orchestration { + pub model_provider: Option, + pub model: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelAlias { pub target: String, @@ -73,6 +79,7 @@ pub struct Configuration { pub tracing: Option, pub mode: Option, pub routing: Option, + pub orchestration: Option, pub agents: Option>, pub filters: Option>, pub listeners: Vec, diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index cafc8e80..11a15028 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -33,5 +33,4 @@ pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http"; pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route"; pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries"; pub const BRIGHT_STAFF_SERVICE_NAME: &str = "brightstaff"; -pub const PLANO_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator"; pub const ARCH_FC_CLUSTER: &str = "arch";