mirror of
https://github.com/katanemo/plano.git
synced 2026-05-24 14:05:14 +02:00
model routing: cost/latency ranking with ranked fallback list (#849)
This commit is contained in:
parent
3a531ce22a
commit
e5751d6b13
23 changed files with 1524 additions and 317 deletions
|
|
@ -6,6 +6,7 @@ use brightstaff::handlers::llm::llm_chat;
|
|||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::handlers::routing_service::routing_decision;
|
||||
use brightstaff::router::llm::RouterService;
|
||||
use brightstaff::router::model_metrics::ModelMetricsService;
|
||||
use brightstaff::router::orchestrator::OrchestratorService;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
|
|
@ -40,6 +41,17 @@ const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
|
|||
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
|
||||
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
|
||||
|
||||
/// Parse a version string like `v0.4.0`, `v0.3.0`, `0.2.0` into a `(major, minor, patch)` tuple.
|
||||
/// Missing parts default to 0. Non-numeric parts are treated as 0.
|
||||
fn parse_semver(version: &str) -> (u32, u32, u32) {
|
||||
let v = version.trim_start_matches('v');
|
||||
let mut parts = v.splitn(3, '.').map(|p| p.parse::<u32>().unwrap_or(0));
|
||||
let major = parts.next().unwrap_or(0);
|
||||
let minor = parts.next().unwrap_or(0);
|
||||
let patch = parts.next().unwrap_or(0);
|
||||
(major, minor, patch)
|
||||
}
|
||||
|
||||
/// CORS pre-flight response for the models endpoint.
|
||||
fn cors_preflight() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let mut response = Response::new(empty());
|
||||
|
|
@ -162,8 +174,150 @@ async fn init_app_state(
|
|||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||
|
||||
// Validate that top-level routing_preferences requires v0.4.0+.
|
||||
let config_version = parse_semver(&config.version);
|
||||
let is_v040_plus = config_version >= (0, 4, 0);
|
||||
|
||||
if !is_v040_plus && config.routing_preferences.is_some() {
|
||||
return Err(
|
||||
"top-level routing_preferences requires version v0.4.0 or above. \
|
||||
Update the version field or remove routing_preferences."
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
// Validate that all models referenced in top-level routing_preferences exist in model_providers.
|
||||
// The CLI renders model_providers with `name` = "openai/gpt-4o" and `model` = "gpt-4o",
|
||||
// so we accept a match against either field.
|
||||
if let Some(ref route_prefs) = config.routing_preferences {
|
||||
let provider_model_names: std::collections::HashSet<&str> = config
|
||||
.model_providers
|
||||
.iter()
|
||||
.flat_map(|p| std::iter::once(p.name.as_str()).chain(p.model.as_deref()))
|
||||
.collect();
|
||||
for pref in route_prefs {
|
||||
for model in &pref.models {
|
||||
if !provider_model_names.contains(model.as_str()) {
|
||||
return Err(format!(
|
||||
"routing_preferences route '{}' references model '{}' \
|
||||
which is not declared in model_providers",
|
||||
pref.name, model
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate and initialize ModelMetricsService if model_metrics_sources is configured.
|
||||
let metrics_service: Option<Arc<ModelMetricsService>> = if let Some(ref sources) =
|
||||
config.model_metrics_sources
|
||||
{
|
||||
use common::configuration::MetricsSource;
|
||||
let cost_count = sources
|
||||
.iter()
|
||||
.filter(|s| matches!(s, MetricsSource::CostMetrics { .. }))
|
||||
.count();
|
||||
let prom_count = sources
|
||||
.iter()
|
||||
.filter(|s| matches!(s, MetricsSource::PrometheusMetrics { .. }))
|
||||
.count();
|
||||
let do_count = sources
|
||||
.iter()
|
||||
.filter(|s| matches!(s, MetricsSource::DigitalOceanPricing { .. }))
|
||||
.count();
|
||||
if cost_count > 1 {
|
||||
return Err("model_metrics_sources: only one cost_metrics source is allowed".into());
|
||||
}
|
||||
if prom_count > 1 {
|
||||
return Err(
|
||||
"model_metrics_sources: only one prometheus_metrics source is allowed".into(),
|
||||
);
|
||||
}
|
||||
if do_count > 1 {
|
||||
return Err(
|
||||
"model_metrics_sources: only one digitalocean_pricing source is allowed".into(),
|
||||
);
|
||||
}
|
||||
if cost_count > 0 && do_count > 0 {
|
||||
return Err(
|
||||
"model_metrics_sources: cost_metrics and digitalocean_pricing cannot both be configured — use one or the other".into(),
|
||||
);
|
||||
}
|
||||
let svc = ModelMetricsService::new(sources, reqwest::Client::new()).await;
|
||||
Some(Arc::new(svc))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Validate that selection_policy.prefer is compatible with the configured metric sources.
|
||||
if let Some(ref prefs) = config.routing_preferences {
|
||||
use common::configuration::{MetricsSource, SelectionPreference};
|
||||
|
||||
let has_cost_source = config
|
||||
.model_metrics_sources
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.any(|s| {
|
||||
matches!(
|
||||
s,
|
||||
MetricsSource::CostMetrics { .. } | MetricsSource::DigitalOceanPricing { .. }
|
||||
)
|
||||
});
|
||||
let has_prometheus = config
|
||||
.model_metrics_sources
|
||||
.as_deref()
|
||||
.unwrap_or_default()
|
||||
.iter()
|
||||
.any(|s| matches!(s, MetricsSource::PrometheusMetrics { .. }));
|
||||
|
||||
for pref in prefs {
|
||||
if pref.selection_policy.prefer == SelectionPreference::Cheapest && !has_cost_source {
|
||||
return Err(format!(
|
||||
"routing_preferences route '{}' uses prefer: cheapest but no cost data source is configured — \
|
||||
add cost_metrics or digitalocean_pricing to model_metrics_sources",
|
||||
pref.name
|
||||
)
|
||||
.into());
|
||||
}
|
||||
if pref.selection_policy.prefer == SelectionPreference::Fastest && !has_prometheus {
|
||||
return Err(format!(
|
||||
"routing_preferences route '{}' uses prefer: fastest but no prometheus_metrics source is configured — \
|
||||
add prometheus_metrics to model_metrics_sources",
|
||||
pref.name
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Warn about models in routing_preferences that have no matching pricing/latency data.
|
||||
if let (Some(ref prefs), Some(ref svc)) = (&config.routing_preferences, &metrics_service) {
|
||||
let cost_data = svc.cost_snapshot().await;
|
||||
let latency_data = svc.latency_snapshot().await;
|
||||
for pref in prefs {
|
||||
use common::configuration::SelectionPreference;
|
||||
for model in &pref.models {
|
||||
let missing = match pref.selection_policy.prefer {
|
||||
SelectionPreference::Cheapest => !cost_data.contains_key(model.as_str()),
|
||||
SelectionPreference::Fastest => !latency_data.contains_key(model.as_str()),
|
||||
_ => false,
|
||||
};
|
||||
if missing {
|
||||
warn!(
|
||||
model = %model,
|
||||
route = %pref.name,
|
||||
"model has no metric data — will be ranked last"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let router_service = Arc::new(RouterService::new(
|
||||
config.model_providers.clone(),
|
||||
config.routing_preferences.clone(),
|
||||
metrics_service,
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue