use overrides for custom routing and orchestration model

This commit is contained in:
Adil Hafeez 2026-03-11 16:38:00 -07:00
parent 98038690b0
commit 6143b7ad54
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
9 changed files with 93 additions and 114 deletions

View file

@ -90,16 +90,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
let listener = TcpListener::bind(bind_address).await?;
let routing_model_name: String = plano_config
.routing
.as_ref()
.and_then(|r| r.model.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_MODEL_NAME.to_string());
let overrides = plano_config.overrides.clone().unwrap_or_default();
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
let routing_model_name: String = overrides
.router_model
.as_deref()
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
.unwrap_or(DEFAULT_ROUTING_MODEL_NAME)
.to_string();
let routing_llm_provider = plano_config
.routing
.as_ref()
.and_then(|r| r.model_provider.clone())
.model_providers
.iter()
.find(|p| p.model.as_deref() == Some(routing_model_name.as_str()))
.map(|p| p.name.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
@ -109,16 +114,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
routing_llm_provider,
));
let orchestrator_model_name: String = plano_config
.orchestration
.as_ref()
.and_then(|o| o.model.clone())
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_MODEL_NAME.to_string());
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
let orchestrator_model_name: String = overrides
.orchestrator_model
.as_deref()
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
.to_string();
let orchestrator_llm_provider: String = plano_config
.orchestration
.as_ref()
.and_then(|o| o.model_provider.clone())
.model_providers
.iter()
.find(|p| p.model.as_deref() == Some(orchestrator_model_name.as_str()))
.map(|p| p.name.clone())
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(

View file

@ -7,18 +7,6 @@ use crate::api::open_ai::{
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Routing {
pub model_provider: Option<String>,
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Orchestration {
pub model_provider: Option<String>,
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelAlias {
pub target: String,
@ -78,8 +66,6 @@ pub struct Configuration {
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
pub routing: Option<Routing>,
pub orchestration: Option<Orchestration>,
pub agents: Option<Vec<Agent>>,
pub filters: Option<Vec<Agent>>,
pub listeners: Vec<Listener>,
@ -91,6 +77,8 @@ pub struct Overrides {
pub prompt_target_intent_matching_threshold: Option<f64>,
pub optimize_context_window: Option<bool>,
pub use_agent_orchestrator: Option<bool>,
pub router_model: Option<String>,
pub orchestrator_model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
@ -244,6 +232,8 @@ pub enum LlmProviderType {
Qwen,
#[serde(rename = "amazon_bedrock")]
AmazonBedrock,
#[serde(rename = "plano")]
Plano,
}
impl Display for LlmProviderType {
@ -264,6 +254,7 @@ impl Display for LlmProviderType {
LlmProviderType::Zhipu => write!(f, "zhipu"),
LlmProviderType::Qwen => write!(f, "qwen"),
LlmProviderType::AmazonBedrock => write!(f, "amazon_bedrock"),
LlmProviderType::Plano => write!(f, "plano"),
}
}
}
@ -272,7 +263,15 @@ impl LlmProviderType {
/// Get the ProviderId for this LlmProviderType
/// Used with the new function-based hermesllm API
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
hermesllm::ProviderId::try_from(self.to_string().as_str())
// Plano provider uses the same interface as Arch
let provider_str = match self {
LlmProviderType::Plano => "arch",
other => {
return hermesllm::ProviderId::try_from(other.to_string().as_str())
.expect("LlmProviderType should always map to a valid ProviderId")
}
};
hermesllm::ProviderId::try_from(provider_str)
.expect("LlmProviderType should always map to a valid ProviderId")
}
}