mirror of
https://github.com/katanemo/plano.git
synced 2026-05-06 22:32:42 +02:00
Unified overrides for custom router and orchestrator models (#820)
* support configurable orchestrator model via orchestration config section * add self-hosting docs and demo for Plano-Orchestrator * list all Plano-Orchestrator model variants in docs * use overrides for custom routing and orchestration model * update docs * update orchestrator model name * rename arch provider to plano, use llm_routing_model and agent_orchestration_model * regenerate rendered config reference
This commit is contained in:
parent
785bf7e021
commit
bc059aed4d
20 changed files with 312 additions and 103 deletions
|
|
@ -178,6 +178,7 @@ mod tests {
|
|||
Arc::new(OrchestratorService::new(
|
||||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ mod tests {
|
|||
Arc::new(OrchestratorService::new(
|
||||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ use brightstaff::state::StateStorage;
|
|||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, Configuration};
|
||||
use common::consts::{
|
||||
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
||||
};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
|
|
@ -35,6 +33,8 @@ pub mod router;
|
|||
const BIND_ADDRESS: &str = "0.0.0.0:9091";
|
||||
const DEFAULT_ROUTING_LLM_PROVIDER: &str = "arch-router";
|
||||
const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
|
||||
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
|
||||
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
|
||||
|
||||
// Utility function to extract the context from the incoming request headers
|
||||
fn extract_context_from_request(req: &Request<Incoming>) -> Context {
|
||||
|
|
@ -90,16 +90,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
||||
let listener = TcpListener::bind(bind_address).await?;
|
||||
let routing_model_name: String = plano_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.model.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_MODEL_NAME.to_string());
|
||||
let overrides = plano_config.overrides.clone().unwrap_or_default();
|
||||
|
||||
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
|
||||
let routing_model_name: String = overrides
|
||||
.llm_routing_model
|
||||
.as_deref()
|
||||
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
||||
.unwrap_or(DEFAULT_ROUTING_MODEL_NAME)
|
||||
.to_string();
|
||||
|
||||
let routing_llm_provider = plano_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.model_provider.clone())
|
||||
.model_providers
|
||||
.iter()
|
||||
.find(|p| p.model.as_deref() == Some(routing_model_name.as_str()))
|
||||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
|
||||
|
|
@ -109,9 +114,25 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
routing_llm_provider,
|
||||
));
|
||||
|
||||
// Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests
|
||||
let orchestrator_model_name: String = overrides
|
||||
.agent_orchestration_model
|
||||
.as_deref()
|
||||
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
||||
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
|
||||
.to_string();
|
||||
|
||||
let orchestrator_llm_provider: String = plano_config
|
||||
.model_providers
|
||||
.iter()
|
||||
.find(|p| p.model.as_deref() == Some(orchestrator_model_name.as_str()))
|
||||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
|
||||
|
||||
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),
|
||||
orchestrator_model_name,
|
||||
orchestrator_llm_provider,
|
||||
));
|
||||
|
||||
let model_aliases = Arc::new(plano_config.model_aliases.clone());
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc};
|
|||
|
||||
use common::{
|
||||
configuration::{AgentUsagePreference, OrchestrationPreference},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, PLANO_ORCHESTRATOR_MODEL_NAME, REQUEST_ID_HEADER},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER},
|
||||
};
|
||||
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
|
||||
use hyper::header;
|
||||
|
|
@ -19,6 +19,7 @@ pub struct OrchestratorService {
|
|||
orchestrator_url: String,
|
||||
client: reqwest::Client,
|
||||
orchestrator_model: Arc<dyn OrchestratorModel>,
|
||||
orchestrator_provider_name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
@ -36,7 +37,11 @@ pub enum OrchestrationError {
|
|||
pub type Result<T> = std::result::Result<T, OrchestrationError>;
|
||||
|
||||
impl OrchestratorService {
|
||||
pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self {
|
||||
pub fn new(
|
||||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
orchestrator_provider_name: String,
|
||||
) -> Self {
|
||||
// Empty agent orchestrations - will be provided via usage_preferences in requests
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
|
||||
|
|
@ -50,6 +55,7 @@ impl OrchestratorService {
|
|||
orchestrator_url,
|
||||
client: reqwest::Client::new(),
|
||||
orchestrator_model,
|
||||
orchestrator_provider_name,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -75,12 +81,12 @@ impl OrchestratorService {
|
|||
debug!(
|
||||
model = %self.orchestrator_model.get_model_name(),
|
||||
endpoint = %self.orchestrator_url,
|
||||
"sending request to arch-orchestrator"
|
||||
"sending request to plano-orchestrator"
|
||||
);
|
||||
|
||||
debug!(
|
||||
body = %serde_json::to_string(&orchestrator_request).unwrap(),
|
||||
"arch orchestrator request"
|
||||
"plano orchestrator request"
|
||||
);
|
||||
|
||||
let mut orchestration_request_headers = header::HeaderMap::new();
|
||||
|
|
@ -91,7 +97,7 @@ impl OrchestratorService {
|
|||
|
||||
orchestration_request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
header::HeaderValue::from_str(PLANO_ORCHESTRATOR_MODEL_NAME).unwrap(),
|
||||
header::HeaderValue::from_str(&self.orchestrator_provider_name).unwrap(),
|
||||
);
|
||||
|
||||
// Inject OpenTelemetry trace context from current span
|
||||
|
|
@ -110,7 +116,7 @@ impl OrchestratorService {
|
|||
|
||||
orchestration_request_headers.insert(
|
||||
header::HeaderName::from_static("model"),
|
||||
header::HeaderValue::from_static(PLANO_ORCHESTRATOR_MODEL_NAME),
|
||||
header::HeaderValue::from_str(&self.orchestrator_provider_name).unwrap(),
|
||||
);
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue