use overrides for custom routing and orchestration model

This commit is contained in:
Adil Hafeez 2026-03-11 16:38:00 -07:00
parent 98038690b0
commit 6143b7ad54
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
9 changed files with 93 additions and 114 deletions

View file

@ -90,16 +90,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
let listener = TcpListener::bind(bind_address).await?;
let routing_model_name: String = plano_config
.routing
.as_ref()
.and_then(|r| r.model.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_MODEL_NAME.to_string());
let overrides = plano_config.overrides.clone().unwrap_or_default();
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
let routing_model_name: String = overrides
.router_model
.as_deref()
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
.unwrap_or(DEFAULT_ROUTING_MODEL_NAME)
.to_string();
let routing_llm_provider = plano_config
.routing
.as_ref()
.and_then(|r| r.model_provider.clone())
.model_providers
.iter()
.find(|p| p.model.as_deref() == Some(routing_model_name.as_str()))
.map(|p| p.name.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
@ -109,16 +114,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
routing_llm_provider,
));
let orchestrator_model_name: String = plano_config
.orchestration
.as_ref()
.and_then(|o| o.model.clone())
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_MODEL_NAME.to_string());
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
let orchestrator_model_name: String = overrides
.orchestrator_model
.as_deref()
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
.to_string();
let orchestrator_llm_provider: String = plano_config
.orchestration
.as_ref()
.and_then(|o| o.model_provider.clone())
.model_providers
.iter()
.find(|p| p.model.as_deref() == Some(orchestrator_model_name.as_str()))
.map(|p| p.name.clone())
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(