mirror of
https://github.com/katanemo/plano.git
synced 2026-05-10 08:12:48 +02:00
use overrides for custom routing and orchestration model
This commit is contained in:
parent
98038690b0
commit
6143b7ad54
9 changed files with 93 additions and 114 deletions
|
|
@ -90,16 +90,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let routing_model_name: String = plano_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.model.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_MODEL_NAME.to_string());
|
||||
let overrides = plano_config.overrides.clone().unwrap_or_default();
|
||||
|
||||
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
|
||||
let routing_model_name: String = overrides
|
||||
.router_model
|
||||
.as_deref()
|
||||
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
||||
.unwrap_or(DEFAULT_ROUTING_MODEL_NAME)
|
||||
.to_string();
|
||||
|
||||
let routing_llm_provider = plano_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.model_provider.clone())
|
||||
.model_providers
|
||||
.iter()
|
||||
.find(|p| p.model.as_deref() == Some(routing_model_name.as_str()))
|
||||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
|
|
@ -109,16 +114,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
routing_llm_provider,
|
||||
));
|
||||
|
||||
let orchestrator_model_name: String = plano_config
|
||||
.orchestration
|
||||
.as_ref()
|
||||
.and_then(|o| o.model.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_MODEL_NAME.to_string());
|
||||
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
|
||||
let orchestrator_model_name: String = overrides
|
||||
.orchestrator_model
|
||||
.as_deref()
|
||||
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
||||
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
|
||||
.to_string();
|
||||
|
||||
let orchestrator_llm_provider: String = plano_config
|
||||
.orchestration
|
||||
.as_ref()
|
||||
.and_then(|o| o.model_provider.clone())
|
||||
.model_providers
|
||||
.iter()
|
||||
.find(|p| p.model.as_deref() == Some(orchestrator_model_name.as_str()))
|
||||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
|
||||
|
||||
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue