add top-level routing_preferences with selection_policy and model metrics fetch

This commit is contained in:
Adil Hafeez 2026-03-26 17:35:39 -07:00
parent 406fa92802
commit 2ef938ac5f
9 changed files with 568 additions and 49 deletions

View file

@ -6,6 +6,7 @@ use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::routing_service::routing_decision;
use brightstaff::router::llm::RouterService;
use brightstaff::router::model_metrics::ModelMetricsService;
use brightstaff::router::orchestrator::OrchestratorService;
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
@ -40,6 +41,17 @@ const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
/// Parse a version string like `v0.4.0`, `v0.3.0`, `0.2.0` into a `(major, minor, patch)` tuple.
/// Missing parts default to 0. Non-numeric parts are treated as 0.
fn parse_semver(version: &str) -> (u32, u32, u32) {
let v = version.trim_start_matches('v');
let mut parts = v.splitn(3, '.').map(|p| p.parse::<u32>().unwrap_or(0));
let major = parts.next().unwrap_or(0);
let minor = parts.next().unwrap_or(0);
let patch = parts.next().unwrap_or(0);
(major, minor, patch)
}
/// CORS pre-flight response for the models endpoint.
fn cors_preflight() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let mut response = Response::new(empty());
@ -162,8 +174,69 @@ async fn init_app_state(
.map(|p| p.name.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
// Validate version-gated routing_preferences rules.
let config_version = parse_semver(&config.version);
let is_v040_plus = config_version >= (0, 4, 0);
if is_v040_plus {
// v0.4.0+: per-provider routing_preferences are forbidden.
let providers_with_per_provider_prefs: Vec<&str> = config
.model_providers
.iter()
.filter(|p| p.routing_preferences.is_some())
.filter_map(|p| p.model.as_deref())
.collect();
if !providers_with_per_provider_prefs.is_empty() {
return Err(format!(
"routing_preferences inside model_providers is not allowed in v0.4.0+. \
Use the top-level routing_preferences instead. \
Offending models: {}",
providers_with_per_provider_prefs.join(", ")
)
.into());
}
} else if config.routing_preferences.is_some() {
return Err(
"top-level routing_preferences requires version v0.4.0 or above. \
Update the version field or remove routing_preferences."
.into(),
);
}
// Validate that all models referenced in top-level routing_preferences exist in model_providers.
if let Some(ref route_prefs) = config.routing_preferences {
let provider_model_names: std::collections::HashSet<&str> = config
.model_providers
.iter()
.flat_map(|p| p.model.as_deref())
.collect();
for pref in route_prefs {
for model in &pref.models {
if !provider_model_names.contains(model.as_str()) {
return Err(format!(
"routing_preferences route '{}' references model '{}' \
which is not declared in model_providers",
pref.name, model
)
.into());
}
}
}
}
// Initialize ModelMetricsService if model_metrics_sources is configured.
let metrics_service: Option<Arc<ModelMetricsService>> =
if let Some(ref sources) = config.model_metrics_sources {
let svc = ModelMetricsService::new(sources, reqwest::Client::new()).await;
Some(Arc::new(svc))
} else {
None
};
let router_service = Arc::new(RouterService::new(
config.model_providers.clone(),
config.routing_preferences.clone(),
metrics_service,
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
routing_model_name,
routing_llm_provider,