Unified overrides for custom router and orchestrator models (#820)

* support configurable orchestrator model via orchestration config section

* add self-hosting docs and demo for Plano-Orchestrator

* list all Plano-Orchestrator model variants in docs

* use overrides for custom routing and orchestration model

* update docs

* update orchestrator model name

* rename arch provider to plano, use llm_routing_model and agent_orchestration_model

* regenerate rendered config reference
This commit is contained in:
Adil Hafeez 2026-03-15 09:36:11 -07:00 committed by GitHub
parent 785bf7e021
commit bc059aed4d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 312 additions and 103 deletions

View file

@ -178,6 +178,7 @@ mod tests {
Arc::new(OrchestratorService::new(
"http://localhost:8080".to_string(),
"test-model".to_string(),
"plano-orchestrator".to_string(),
))
}

View file

@ -23,6 +23,7 @@ mod tests {
Arc::new(OrchestratorService::new(
"http://localhost:8080".to_string(),
"test-model".to_string(),
"plano-orchestrator".to_string(),
))
}

View file

@ -11,9 +11,7 @@ use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::{Agent, Configuration};
use common::consts::{
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
};
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
use common::llm_providers::LlmProviders;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
use hyper::body::Incoming;
@ -35,6 +33,8 @@ pub mod router;
const BIND_ADDRESS: &str = "0.0.0.0:9091";
const DEFAULT_ROUTING_LLM_PROVIDER: &str = "arch-router";
const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
// Utility function to extract the context from the incoming request headers
fn extract_context_from_request(req: &Request<Incoming>) -> Context {
@ -90,16 +90,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
let listener = TcpListener::bind(bind_address).await?;
let routing_model_name: String = plano_config
.routing
.as_ref()
.and_then(|r| r.model.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_MODEL_NAME.to_string());
let overrides = plano_config.overrides.clone().unwrap_or_default();
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
let routing_model_name: String = overrides
.llm_routing_model
.as_deref()
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
.unwrap_or(DEFAULT_ROUTING_MODEL_NAME)
.to_string();
let routing_llm_provider = plano_config
.routing
.as_ref()
.and_then(|r| r.model_provider.clone())
.model_providers
.iter()
.find(|p| p.model.as_deref() == Some(routing_model_name.as_str()))
.map(|p| p.name.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
@ -109,9 +114,25 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
routing_llm_provider,
));
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
let orchestrator_model_name: String = overrides
.agent_orchestration_model
.as_deref()
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
.to_string();
let orchestrator_llm_provider: String = plano_config
.model_providers
.iter()
.find(|p| p.model.as_deref() == Some(orchestrator_model_name.as_str()))
.map(|p| p.name.clone())
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),
orchestrator_model_name,
orchestrator_llm_provider,
));
let model_aliases = Arc::new(plano_config.model_aliases.clone());

View file

@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc};
use common::{
configuration::{AgentUsagePreference, OrchestrationPreference},
consts::{ARCH_PROVIDER_HINT_HEADER, PLANO_ORCHESTRATOR_MODEL_NAME, REQUEST_ID_HEADER},
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER},
};
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
use hyper::header;
@ -19,6 +19,7 @@ pub struct OrchestratorService {
orchestrator_url: String,
client: reqwest::Client,
orchestrator_model: Arc<dyn OrchestratorModel>,
orchestrator_provider_name: String,
}
#[derive(Debug, Error)]
@ -36,7 +37,11 @@ pub enum OrchestrationError {
pub type Result<T> = std::result::Result<T, OrchestrationError>;
impl OrchestratorService {
pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self {
pub fn new(
orchestrator_url: String,
orchestration_model_name: String,
orchestrator_provider_name: String,
) -> Self {
// Empty agent orchestrations - will be provided via usage_preferences in requests
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
@ -50,6 +55,7 @@ impl OrchestratorService {
orchestrator_url,
client: reqwest::Client::new(),
orchestrator_model,
orchestrator_provider_name,
}
}
@ -75,12 +81,12 @@ impl OrchestratorService {
debug!(
model = %self.orchestrator_model.get_model_name(),
endpoint = %self.orchestrator_url,
"sending request to arch-orchestrator"
"sending request to plano-orchestrator"
);
debug!(
body = %serde_json::to_string(&orchestrator_request).unwrap(),
"arch orchestrator request"
"plano orchestrator request"
);
let mut orchestration_request_headers = header::HeaderMap::new();
@ -91,7 +97,7 @@ impl OrchestratorService {
orchestration_request_headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(PLANO_ORCHESTRATOR_MODEL_NAME).unwrap(),
header::HeaderValue::from_str(&self.orchestrator_provider_name).unwrap(),
);
// Inject OpenTelemetry trace context from current span
@ -110,7 +116,7 @@ impl OrchestratorService {
orchestration_request_headers.insert(
header::HeaderName::from_static("model"),
header::HeaderValue::from_static(PLANO_ORCHESTRATOR_MODEL_NAME),
header::HeaderValue::from_str(&self.orchestrator_provider_name).unwrap(),
);
let start_time = std::time::Instant::now();

View file

@ -7,12 +7,6 @@ use crate::api::open_ai::{
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Routing {
pub model_provider: Option<String>,
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelAlias {
pub target: String,
@ -72,7 +66,6 @@ pub struct Configuration {
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
pub routing: Option<Routing>,
pub agents: Option<Vec<Agent>>,
pub filters: Option<Vec<Agent>>,
pub listeners: Vec<Listener>,
@ -84,6 +77,8 @@ pub struct Overrides {
pub prompt_target_intent_matching_threshold: Option<f64>,
pub optimize_context_window: Option<bool>,
pub use_agent_orchestrator: Option<bool>,
pub llm_routing_model: Option<String>,
pub agent_orchestration_model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
@ -207,8 +202,6 @@ pub struct EmbeddingProviver {
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum LlmProviderType {
#[serde(rename = "arch")]
Arch,
#[serde(rename = "anthropic")]
Anthropic,
#[serde(rename = "deepseek")]
@ -237,12 +230,13 @@ pub enum LlmProviderType {
Qwen,
#[serde(rename = "amazon_bedrock")]
AmazonBedrock,
#[serde(rename = "plano")]
Plano,
}
impl Display for LlmProviderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LlmProviderType::Arch => write!(f, "arch"),
LlmProviderType::Anthropic => write!(f, "anthropic"),
LlmProviderType::Deepseek => write!(f, "deepseek"),
LlmProviderType::Groq => write!(f, "groq"),
@ -257,6 +251,7 @@ impl Display for LlmProviderType {
LlmProviderType::Zhipu => write!(f, "zhipu"),
LlmProviderType::Qwen => write!(f, "qwen"),
LlmProviderType::AmazonBedrock => write!(f, "amazon_bedrock"),
LlmProviderType::Plano => write!(f, "plano"),
}
}
}
@ -591,14 +586,14 @@ mod test {
},
LlmProvider {
name: "arch-router".to_string(),
provider_interface: LlmProviderType::Arch,
provider_interface: LlmProviderType::Plano,
model: Some("Arch-Router".to_string()),
internal: Some(true),
..Default::default()
},
LlmProvider {
name: "plano-orchestrator".to_string(),
provider_interface: LlmProviderType::Arch,
provider_interface: LlmProviderType::Plano,
model: Some("Plano-Orchestrator".to_string()),
internal: Some(true),
..Default::default()

View file

@ -33,5 +33,4 @@ pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";
pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route";
pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries";
pub const BRIGHT_STAFF_SERVICE_NAME: &str = "brightstaff";
pub const PLANO_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
pub const ARCH_FC_CLUSTER: &str = "arch";
pub const PLANO_FC_CLUSTER: &str = "plano";

View file

@ -35,7 +35,7 @@ mod tests {
ProviderId::Mistral
);
assert_eq!(ProviderId::try_from("groq").unwrap(), ProviderId::Groq);
assert_eq!(ProviderId::try_from("arch").unwrap(), ProviderId::Arch);
assert_eq!(ProviderId::try_from("plano").unwrap(), ProviderId::Plano);
// Test aliases
assert_eq!(ProviderId::try_from("google").unwrap(), ProviderId::Gemini);

View file

@ -34,7 +34,7 @@ pub enum ProviderId {
Gemini,
Anthropic,
GitHub,
Arch,
Plano,
AzureOpenAI,
XAI,
TogetherAI,
@ -58,7 +58,7 @@ impl TryFrom<&str> for ProviderId {
"google" => Ok(ProviderId::Gemini), // alias
"anthropic" => Ok(ProviderId::Anthropic),
"github" => Ok(ProviderId::GitHub),
"arch" => Ok(ProviderId::Arch),
"plano" => Ok(ProviderId::Plano),
"azure_openai" => Ok(ProviderId::AzureOpenAI),
"xai" => Ok(ProviderId::XAI),
"together_ai" => Ok(ProviderId::TogetherAI),
@ -135,7 +135,7 @@ impl ProviderId {
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Plano
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
@ -153,7 +153,7 @@ impl ProviderId {
| ProviderId::Groq
| ProviderId::Mistral
| ProviderId::Deepseek
| ProviderId::Arch
| ProviderId::Plano
| ProviderId::Gemini
| ProviderId::GitHub
| ProviderId::AzureOpenAI
@ -219,7 +219,7 @@ impl Display for ProviderId {
ProviderId::Gemini => write!(f, "Gemini"),
ProviderId::Anthropic => write!(f, "Anthropic"),
ProviderId::GitHub => write!(f, "GitHub"),
ProviderId::Arch => write!(f, "Arch"),
ProviderId::Plano => write!(f, "Plano"),
ProviderId::AzureOpenAI => write!(f, "azure_openai"),
ProviderId::XAI => write!(f, "xai"),
ProviderId::TogetherAI => write!(f, "together_ai"),

View file

@ -873,7 +873,7 @@ impl HttpContext for StreamContext {
// ensure that the provider has an endpoint if the access key is missing else return a bad request
if self.llm_provider.as_ref().unwrap().endpoint.is_none()
&& self.llm_provider.as_ref().unwrap().provider_interface
!= LlmProviderType::Arch
!= LlmProviderType::Plano
{
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
}