diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index 5190fecf..4563c290 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -9,6 +9,7 @@ properties: - 0.1-beta - 0.2.0 - v0.3.0 + - v0.4.0 agents: type: array @@ -470,6 +471,51 @@ properties: additionalProperties: false required: - jailbreak + routing_preferences: + type: array + items: + type: object + properties: + name: + type: string + description: + type: string + models: + type: array + items: + type: string + minItems: 1 + selection_policy: + type: object + properties: + prefer: + type: string + enum: + - cheapest + - fastest + - random + additionalProperties: false + required: + - prefer + additionalProperties: false + required: + - name + - description + - models + - selection_policy + + model_metrics_sources: + type: object + properties: + url: + type: string + refresh_interval: + type: integer + minimum: 1 + additionalProperties: false + required: + - url + additionalProperties: false required: - version diff --git a/crates/brightstaff/src/handlers/llm/mod.rs b/crates/brightstaff/src/handlers/llm/mod.rs index 9d4a2dfb..2133a231 100644 --- a/crates/brightstaff/src/handlers/llm/mod.rs +++ b/crates/brightstaff/src/handlers/llm/mod.rs @@ -120,6 +120,7 @@ async fn llm_chat_inner( tool_names, user_message_preview, inline_routing_policy, + inline_routing_preferences, client_api, provider_id, } = parsed; @@ -262,6 +263,7 @@ async fn llm_chat_inner( &request_path, &request_id, inline_routing_policy, + inline_routing_preferences, ) .await } @@ -324,6 +326,7 @@ struct PreparedRequest { tool_names: Option>, user_message_preview: Option, inline_routing_policy: Option>, + inline_routing_preferences: Option>, client_api: Option, provider_id: hermesllm::ProviderId, } @@ -352,8 +355,8 @@ async fn parse_and_validate_request( "request body received" ); - // Extract routing_policy from request body if present - let (chat_request_bytes, inline_routing_policy) = + // Extract routing_policy and routing_preferences from request body if present + let (chat_request_bytes, inline_routing_policy, inline_routing_preferences) = crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false).map_err( |err| { warn!(error = %err, "failed to parse request JSON"); @@ -440,6 +443,7 @@ async fn parse_and_validate_request( tool_names, user_message_preview, inline_routing_policy, + inline_routing_preferences, client_api, provider_id, }) diff --git a/crates/brightstaff/src/handlers/llm/model_selection.rs b/crates/brightstaff/src/handlers/llm/model_selection.rs index 455b7c0e..b6d6f325 100644 --- a/crates/brightstaff/src/handlers/llm/model_selection.rs +++ b/crates/brightstaff/src/handlers/llm/model_selection.rs @@ -1,4 +1,4 @@ -use common::configuration::ModelUsagePreference; +use common::configuration::{ModelUsagePreference, TopLevelRoutingPreference}; use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use hermesllm::{ProviderRequest, ProviderRequestType}; use hyper::StatusCode; @@ -40,6 +40,7 @@ pub async fn router_chat_get_upstream_model( request_path: &str, request_id: &str, inline_usage_preferences: Option>, + inline_routing_preferences: Option>, ) -> Result { // Clone metadata for routing before converting (which consumes client_request) let routing_metadata = client_request.metadata().clone(); @@ -122,6 +123,7 @@ pub async fn router_chat_get_upstream_model( &chat_request.messages, traceparent, usage_preferences, + inline_routing_preferences, request_id, ) .await; diff --git a/crates/brightstaff/src/handlers/routing_service.rs b/crates/brightstaff/src/handlers/routing_service.rs index ec09f06f..4e471ec2 100644 --- a/crates/brightstaff/src/handlers/routing_service.rs +++ b/crates/brightstaff/src/handlers/routing_service.rs @@ -1,5 +1,5 @@ use bytes::Bytes; -use common::configuration::{ModelUsagePreference, SpanAttributes}; +use common::configuration::{ModelUsagePreference, SpanAttributes, TopLevelRoutingPreference}; use common::consts::REQUEST_ID_HEADER; use common::errors::BrightStaffError; use hermesllm::clients::SupportedAPIsFromClient; @@ -17,21 +17,31 @@ use crate::tracing::{collect_custom_trace_attributes, operation_component, set_s const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120; -/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes -/// and parsed preferences. The `routing_policy` field is removed from the JSON -/// before re-serializing so downstream parsers don't see the non-standard field. +type ExtractedRoutingPolicies = ( + Bytes, + Option>, + Option>, +); + +/// Extracts `routing_policy` and `routing_preferences` from a JSON body, returning +/// the cleaned body bytes and both sets of parsed preferences. Both fields are removed +/// from the JSON before re-serializing so downstream parsers don't see them. +/// +/// - `routing_policy` — legacy per-provider format (`Vec`) +/// - `routing_preferences` — v0.4.0+ format (`Vec`) /// /// If `warn_on_size` is true, logs a warning when the serialized policy exceeds 5KB. pub fn extract_routing_policy( raw_bytes: &[u8], warn_on_size: bool, -) -> Result<(Bytes, Option>), String> { +) -> Result { let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes) .map_err(|err| format!("Failed to parse JSON: {}", err))?; - let preferences = json_body + // Extract legacy routing_policy + let legacy_preferences = json_body .as_object_mut() - .and_then(|obj| obj.remove("routing_policy")) + .and_then(|o| o.remove("routing_policy")) .and_then(|policy_value| { if warn_on_size { let policy_str = serde_json::to_string(&policy_value).unwrap_or_default(); @@ -58,8 +68,28 @@ pub fn extract_routing_policy( } }); + // Extract new v0.4.0 routing_preferences + let top_level_preferences = json_body + .as_object_mut() + .and_then(|o| o.remove("routing_preferences")) + .and_then(|value| { + match serde_json::from_value::>(value) { + Ok(prefs) => { + info!( + num_routes = prefs.len(), + "using inline routing_preferences from request body" + ); + Some(prefs) + } + Err(err) => { + warn!(error = %err, "failed to parse routing_preferences"); + None + } + } + }); + let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap()); - Ok((bytes, preferences)) + Ok((bytes, legacy_preferences, top_level_preferences)) } #[derive(serde::Serialize)] @@ -136,18 +166,19 @@ async fn routing_decision_inner( "routing decision request body received" ); - // Extract routing_policy from request body before parsing as ProviderRequestType - let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) { - Ok(result) => result, - Err(err) => { - warn!(error = %err, "failed to parse request JSON"); - return Ok(BrightStaffError::InvalidRequest(format!( - "Failed to parse request JSON: {}", - err - )) - .into_response()); - } - }; + // Extract routing_policy and routing_preferences from body before parsing as ProviderRequestType + let (chat_request_bytes, inline_preferences, inline_routing_preferences) = + match extract_routing_policy(&raw_bytes, true) { + Ok(result) => result, + Err(err) => { + warn!(error = %err, "failed to parse request JSON"); + return Ok(BrightStaffError::InvalidRequest(format!( + "Failed to parse request JSON: {}", + err + )) + .into_response()); + } + }; let client_request = match ProviderRequestType::try_from(( &chat_request_bytes[..], @@ -172,6 +203,7 @@ async fn routing_decision_inner( &request_path, &request_id, inline_preferences, + inline_routing_preferences, ) .await; @@ -227,9 +259,10 @@ mod tests { #[test] fn extract_routing_policy_no_policy() { let body = make_chat_body(""); - let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap(); + let (cleaned, prefs, top_prefs) = extract_routing_policy(&body, false).unwrap(); assert!(prefs.is_none()); + assert!(top_prefs.is_none()); let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap(); assert_eq!(cleaned_json["model"], "gpt-4o-mini"); assert!(cleaned_json.get("routing_policy").is_none()); @@ -252,7 +285,7 @@ mod tests { } ]"#; let body = make_chat_body(policy); - let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap(); + let (cleaned, prefs, top_prefs) = extract_routing_policy(&body, false).unwrap(); let prefs = prefs.expect("should have parsed preferences"); assert_eq!(prefs.len(), 2); @@ -260,6 +293,7 @@ mod tests { assert_eq!(prefs[0].routing_preferences[0].name, "coding"); assert_eq!(prefs[1].model, "openai/gpt-4o-mini"); assert_eq!(prefs[1].routing_preferences[0].name, "general"); + assert!(top_prefs.is_none()); // routing_policy should be stripped from cleaned body let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap(); @@ -272,7 +306,7 @@ mod tests { // routing_policy is present but has wrong shape let policy = r#""routing_policy": "not-an-array""#; let body = make_chat_body(policy); - let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap(); + let (cleaned, prefs, _) = extract_routing_policy(&body, false).unwrap(); // Invalid policy should be ignored (returns None), not error assert!(prefs.is_none()); @@ -293,7 +327,7 @@ mod tests { fn extract_routing_policy_empty_array() { let policy = r#""routing_policy": []"#; let body = make_chat_body(policy); - let (_, prefs) = extract_routing_policy(&body, false).unwrap(); + let (_, prefs, _) = extract_routing_policy(&body, false).unwrap(); let prefs = prefs.expect("empty array is valid"); assert_eq!(prefs.len(), 0); @@ -303,7 +337,7 @@ mod tests { fn extract_routing_policy_preserves_other_fields() { let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#; let body = make_chat_body(policy); - let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap(); + let (cleaned, prefs, _) = extract_routing_policy(&body, false).unwrap(); assert!(prefs.is_some()); let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap(); @@ -312,6 +346,29 @@ mod tests { assert!(cleaned_json.get("routing_policy").is_none()); } + #[test] + fn extract_routing_policy_top_level_routing_preferences() { + let policy = r#""routing_preferences": [ + { + "name": "code generation", + "description": "generate new code", + "models": ["openai/gpt-4o", "openai/gpt-4o-mini"], + "selection_policy": {"prefer": "fastest"} + } + ]"#; + let body = make_chat_body(policy); + let (cleaned, legacy_prefs, top_prefs) = extract_routing_policy(&body, false).unwrap(); + + assert!(legacy_prefs.is_none()); + let top_prefs = top_prefs.expect("should have parsed top-level routing_preferences"); + assert_eq!(top_prefs.len(), 1); + assert_eq!(top_prefs[0].name, "code generation"); + assert_eq!(top_prefs[0].models, vec!["openai/gpt-4o", "openai/gpt-4o-mini"]); + + let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap(); + assert!(cleaned_json.get("routing_preferences").is_none()); + } + #[test] fn routing_decision_response_serialization() { let response = RoutingDecisionResponse { diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 60a69bca..45d0e9ef 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -6,6 +6,7 @@ use brightstaff::handlers::llm::llm_chat; use brightstaff::handlers::models::list_models; use brightstaff::handlers::routing_service::routing_decision; use brightstaff::router::llm::RouterService; +use brightstaff::router::model_metrics::ModelMetricsService; use brightstaff::router::orchestrator::OrchestratorService; use brightstaff::state::memory::MemoryConversationalStorage; use brightstaff::state::postgresql::PostgreSQLConversationStorage; @@ -40,6 +41,17 @@ const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router"; const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator"; const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator"; +/// Parse a version string like `v0.4.0`, `v0.3.0`, `0.2.0` into a `(major, minor, patch)` tuple. +/// Missing parts default to 0. Non-numeric parts are treated as 0. +fn parse_semver(version: &str) -> (u32, u32, u32) { + let v = version.trim_start_matches('v'); + let mut parts = v.splitn(3, '.').map(|p| p.parse::().unwrap_or(0)); + let major = parts.next().unwrap_or(0); + let minor = parts.next().unwrap_or(0); + let patch = parts.next().unwrap_or(0); + (major, minor, patch) +} + /// CORS pre-flight response for the models endpoint. fn cors_preflight() -> Result>, hyper::Error> { let mut response = Response::new(empty()); @@ -162,8 +174,69 @@ async fn init_app_state( .map(|p| p.name.clone()) .unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string()); + // Validate version-gated routing_preferences rules. + let config_version = parse_semver(&config.version); + let is_v040_plus = config_version >= (0, 4, 0); + + if is_v040_plus { + // v0.4.0+: per-provider routing_preferences are forbidden. + let providers_with_per_provider_prefs: Vec<&str> = config + .model_providers + .iter() + .filter(|p| p.routing_preferences.is_some()) + .filter_map(|p| p.model.as_deref()) + .collect(); + if !providers_with_per_provider_prefs.is_empty() { + return Err(format!( + "routing_preferences inside model_providers is not allowed in v0.4.0+. \ + Use the top-level routing_preferences instead. \ + Offending models: {}", + providers_with_per_provider_prefs.join(", ") + ) + .into()); + } + } else if config.routing_preferences.is_some() { + return Err( + "top-level routing_preferences requires version v0.4.0 or above. \ + Update the version field or remove routing_preferences." + .into(), + ); + } + + // Validate that all models referenced in top-level routing_preferences exist in model_providers. + if let Some(ref route_prefs) = config.routing_preferences { + let provider_model_names: std::collections::HashSet<&str> = config + .model_providers + .iter() + .flat_map(|p| p.model.as_deref()) + .collect(); + for pref in route_prefs { + for model in &pref.models { + if !provider_model_names.contains(model.as_str()) { + return Err(format!( + "routing_preferences route '{}' references model '{}' \ + which is not declared in model_providers", + pref.name, model + ) + .into()); + } + } + } + } + + // Initialize ModelMetricsService if model_metrics_sources is configured. + let metrics_service: Option> = + if let Some(ref sources) = config.model_metrics_sources { + let svc = ModelMetricsService::new(sources, reqwest::Client::new()).await; + Some(Arc::new(svc)) + } else { + None + }; + let router_service = Arc::new(RouterService::new( config.model_providers.clone(), + config.routing_preferences.clone(), + metrics_service, format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"), routing_model_name, routing_llm_provider, diff --git a/crates/brightstaff/src/router/llm.rs b/crates/brightstaff/src/router/llm.rs index 7d27e80a..9abceb98 100644 --- a/crates/brightstaff/src/router/llm.rs +++ b/crates/brightstaff/src/router/llm.rs @@ -1,7 +1,9 @@ use std::{collections::HashMap, sync::Arc}; use common::{ - configuration::{LlmProvider, ModelUsagePreference, RoutingPreference}, + configuration::{ + LlmProvider, ModelUsagePreference, RoutingPreference, TopLevelRoutingPreference, + }, consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER}, }; use hermesllm::apis::openai::Message; @@ -10,6 +12,7 @@ use thiserror::Error; use tracing::{debug, info}; use super::http::{self, post_and_extract_content}; +use super::model_metrics::ModelMetricsService; use super::router_model::RouterModel; use crate::router::router_model_v1; @@ -20,6 +23,8 @@ pub struct RouterService { router_model: Arc, routing_provider_name: String, llm_usage_defined: bool, + top_level_preferences: HashMap, + metrics_service: Option>, } #[derive(Debug, Error)] @@ -36,25 +41,58 @@ pub type Result = std::result::Result; impl RouterService { pub fn new( providers: Vec, + top_level_prefs: Option>, + metrics_service: Option>, router_url: String, routing_model_name: String, routing_provider_name: String, ) -> Self { - let providers_with_usage = providers - .iter() - .filter(|provider| provider.routing_preferences.is_some()) - .cloned() - .collect::>(); + // Build top-level preference map and sentinel llm_routes when v0.4.0 format is used. + let (top_level_preferences, llm_routes, llm_usage_defined) = + if let Some(top_prefs) = top_level_prefs { + let top_level_map: HashMap = top_prefs + .into_iter() + .map(|p| (p.name.clone(), p)) + .collect(); + // Build sentinel routes: route_name → first model (RouterModelV1 needs a model + // mapping, but RouterService overrides the selection via metrics_service). + let sentinel_routes: HashMap> = top_level_map + .iter() + .filter_map(|(name, pref)| { + pref.models.first().map(|first_model| { + ( + first_model.clone(), + vec![RoutingPreference { + name: name.clone(), + description: pref.description.clone(), + }], + ) + }) + }) + .collect(); + let defined = !top_level_map.is_empty(); + (top_level_map, sentinel_routes, defined) + } else { + // Legacy per-provider format. + let providers_with_usage = providers + .iter() + .filter(|provider| provider.routing_preferences.is_some()) + .cloned() + .collect::>(); - let llm_routes: HashMap> = providers_with_usage - .iter() - .filter_map(|provider| { - provider - .routing_preferences - .as_ref() - .map(|prefs| (provider.name.clone(), prefs.clone())) - }) - .collect(); + let routes: HashMap> = providers_with_usage + .iter() + .filter_map(|provider| { + provider + .routing_preferences + .as_ref() + .map(|prefs| (provider.name.clone(), prefs.clone())) + }) + .collect(); + + let defined = !providers_with_usage.is_empty(); + (HashMap::new(), routes, defined) + }; let router_model = Arc::new(router_model_v1::RouterModelV1::new( llm_routes, @@ -67,7 +105,9 @@ impl RouterService { client: reqwest::Client::new(), router_model, routing_provider_name, - llm_usage_defined: !providers_with_usage.is_empty(), + llm_usage_defined, + top_level_preferences, + metrics_service, } } @@ -76,23 +116,58 @@ impl RouterService { messages: &[Message], traceparent: &str, usage_preferences: Option>, + inline_routing_preferences: Option>, request_id: &str, ) -> Result> { if messages.is_empty() { return Ok(None); } + // Build inline top-level map from request if present (inline overrides config). + let inline_top_map: Option> = + inline_routing_preferences.map(|prefs| { + prefs.into_iter().map(|p| (p.name.clone(), p)).collect() + }); + + // Determine whether any routing is defined. + let has_top_level = inline_top_map.is_some() || !self.top_level_preferences.is_empty(); + if usage_preferences .as_ref() .is_none_or(|prefs| prefs.len() < 2) && !self.llm_usage_defined + && !has_top_level { return Ok(None); } + // For top-level format, build a synthetic ModelUsagePreference list so RouterModelV1 + // generates the correct prompt (route name + description pairs). + let effective_usage_preferences: Option> = + if let Some(ref inline_map) = inline_top_map { + Some( + inline_map + .values() + .map(|p| ModelUsagePreference { + model: p.models.first().cloned().unwrap_or_default(), + routing_preferences: vec![RoutingPreference { + name: p.name.clone(), + description: p.description.clone(), + }], + }) + .collect(), + ) + } else if !self.top_level_preferences.is_empty() { + // Config top-level prefs: already encoded as sentinel routes in RouterModelV1, + // pass None so it uses the pre-built llm_route_json_str. + None + } else { + usage_preferences.clone() + }; + let router_request = self .router_model - .generate_request(messages, &usage_preferences); + .generate_request(messages, &effective_usage_preferences); debug!( model = %self.router_model.get_model_name(), @@ -132,17 +207,40 @@ impl RouterService { return Ok(None); }; + // Parse the route name from the router response. let parsed = self .router_model - .parse_response(&content, &usage_preferences)?; + .parse_response(&content, &effective_usage_preferences)?; + + let result = if let Some((route_name, _sentinel_model)) = parsed { + // Check if this route belongs to the top-level preference format. + let top_pref = inline_top_map + .as_ref() + .and_then(|m| m.get(&route_name)) + .or_else(|| self.top_level_preferences.get(&route_name)); + + if let Some(pref) = top_pref { + let selected_model = match &self.metrics_service { + Some(svc) => { + svc.select_model(&pref.models, &pref.selection_policy).await + } + None => pref.models.first().cloned().unwrap_or_default(), + }; + Some((route_name, selected_model)) + } else { + Some((route_name, _sentinel_model)) + } + } else { + None + }; info!( content = %content.replace("\n", "\\n"), - selected_model = ?parsed, + selected_model = ?result, response_time_ms = elapsed.as_millis(), "arch-router determined route" ); - Ok(parsed) + Ok(result) } } diff --git a/crates/brightstaff/src/router/mod.rs b/crates/brightstaff/src/router/mod.rs index b010d80c..2d9d00a7 100644 --- a/crates/brightstaff/src/router/mod.rs +++ b/crates/brightstaff/src/router/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod http; pub mod llm; +pub mod model_metrics; pub mod orchestrator; pub mod orchestrator_model; pub mod orchestrator_model_v1; diff --git a/crates/brightstaff/src/router/model_metrics.rs b/crates/brightstaff/src/router/model_metrics.rs new file mode 100644 index 00000000..07544ef6 --- /dev/null +++ b/crates/brightstaff/src/router/model_metrics.rs @@ -0,0 +1,209 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use common::configuration::{ModelMetricsSources, SelectionPolicy, SelectionPreference}; +use serde::Deserialize; +use tokio::sync::RwLock; +use tracing::{info, warn}; + +#[derive(Deserialize)] +struct MetricsResponse { + #[serde(default)] + cost: HashMap, + #[serde(default)] + latency: HashMap, +} + +pub struct ModelMetricsService { + cost: Arc>>, + latency: Arc>>, +} + +impl ModelMetricsService { + pub async fn new(sources: &ModelMetricsSources, client: reqwest::Client) -> Self { + let cost_data = Arc::new(RwLock::new(HashMap::new())); + let latency_data = Arc::new(RwLock::new(HashMap::new())); + + let metrics = fetch_metrics(&sources.url, &client).await; + info!( + cost_models = metrics.cost.len(), + latency_models = metrics.latency.len(), + url = %sources.url, + "fetched model metrics" + ); + *cost_data.write().await = metrics.cost; + *latency_data.write().await = metrics.latency; + + if let Some(interval_secs) = sources.refresh_interval { + let cost_clone = Arc::clone(&cost_data); + let latency_clone = Arc::clone(&latency_data); + let client_clone = client.clone(); + let url = sources.url.clone(); + tokio::spawn(async move { + let interval = Duration::from_secs(interval_secs); + loop { + tokio::time::sleep(interval).await; + let metrics = fetch_metrics(&url, &client_clone).await; + info!( + cost_models = metrics.cost.len(), + latency_models = metrics.latency.len(), + url = %url, + "refreshed model metrics" + ); + *cost_clone.write().await = metrics.cost; + *latency_clone.write().await = metrics.latency; + } + }); + } + + ModelMetricsService { + cost: cost_data, + latency: latency_data, + } + } + + /// Select the best model from `models` according to `policy`. + /// Falls back to `models[0]` if metric data is unavailable for all candidates. + pub async fn select_model(&self, models: &[String], policy: &SelectionPolicy) -> String { + match policy.prefer { + SelectionPreference::Cheapest => { + let data = self.cost.read().await; + select_by_ascending_metric(models, &data) + } + SelectionPreference::Fastest => { + let data = self.latency.read().await; + select_by_ascending_metric(models, &data) + } + SelectionPreference::Random => { + let idx = rand_index(models.len()); + models[idx].clone() + } + } + } +} + +fn select_by_ascending_metric(models: &[String], data: &HashMap) -> String { + models + .iter() + .filter_map(|m| data.get(m.as_str()).map(|v| (m, *v))) + .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(m, _)| m.clone()) + .unwrap_or_else(|| models[0].clone()) +} + +/// Simple non-crypto random index using system time nanoseconds. +fn rand_index(len: usize) -> usize { + use std::time::{SystemTime, UNIX_EPOCH}; + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.subsec_nanos() as usize) + .unwrap_or(0); + nanos % len +} + +async fn fetch_metrics(url: &str, client: &reqwest::Client) -> MetricsResponse { + match client.get(url).send().await { + Ok(resp) => match resp.json::().await { + Ok(data) => data, + Err(err) => { + warn!(error = %err, url = %url, "failed to parse metrics response"); + MetricsResponse { + cost: HashMap::new(), + latency: HashMap::new(), + } + } + }, + Err(err) => { + warn!(error = %err, url = %url, "failed to fetch metrics"); + MetricsResponse { + cost: HashMap::new(), + latency: HashMap::new(), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use common::configuration::SelectionPreference; + + fn make_policy(prefer: SelectionPreference) -> SelectionPolicy { + SelectionPolicy { prefer } + } + + #[test] + fn test_select_by_ascending_metric_picks_lowest() { + let models = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + let mut data = HashMap::new(); + data.insert("a".to_string(), 0.01); + data.insert("b".to_string(), 0.005); + data.insert("c".to_string(), 0.02); + assert_eq!(select_by_ascending_metric(&models, &data), "b"); + } + + #[test] + fn test_select_by_ascending_metric_fallback_to_first() { + let models = vec!["x".to_string(), "y".to_string()]; + let data = HashMap::new(); + assert_eq!(select_by_ascending_metric(&models, &data), "x"); + } + + #[test] + fn test_select_by_ascending_metric_partial_data() { + let models = vec!["a".to_string(), "b".to_string()]; + let mut data = HashMap::new(); + data.insert("b".to_string(), 100.0); + assert_eq!(select_by_ascending_metric(&models, &data), "b"); + } + + #[tokio::test] + async fn test_select_model_cheapest() { + let service = ModelMetricsService { + cost: Arc::new(RwLock::new({ + let mut m = HashMap::new(); + m.insert("gpt-4o".to_string(), 0.005); + m.insert("gpt-4o-mini".to_string(), 0.0001); + m + })), + latency: Arc::new(RwLock::new(HashMap::new())), + }; + let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()]; + let result = service + .select_model(&models, &make_policy(SelectionPreference::Cheapest)) + .await; + assert_eq!(result, "gpt-4o-mini"); + } + + #[tokio::test] + async fn test_select_model_fastest() { + let service = ModelMetricsService { + cost: Arc::new(RwLock::new(HashMap::new())), + latency: Arc::new(RwLock::new({ + let mut m = HashMap::new(); + m.insert("gpt-4o".to_string(), 200.0); + m.insert("claude-sonnet".to_string(), 120.0); + m + })), + }; + let models = vec!["gpt-4o".to_string(), "claude-sonnet".to_string()]; + let result = service + .select_model(&models, &make_policy(SelectionPreference::Fastest)) + .await; + assert_eq!(result, "claude-sonnet"); + } + + #[tokio::test] + async fn test_select_model_fallback_no_metrics() { + let service = ModelMetricsService { + cost: Arc::new(RwLock::new(HashMap::new())), + latency: Arc::new(RwLock::new(HashMap::new())), + }; + let models = vec!["model-a".to_string(), "model-b".to_string()]; + let result = service + .select_model(&models, &make_policy(SelectionPreference::Cheapest)) + .await; + assert_eq!(result, "model-a"); + } +} diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index df179059..9581b0a7 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -104,6 +104,33 @@ pub enum StateStorageType { Postgres, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum SelectionPreference { + Cheapest, + Fastest, + Random, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SelectionPolicy { + pub prefer: SelectionPreference, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TopLevelRoutingPreference { + pub name: String, + pub description: String, + pub models: Vec, + pub selection_policy: SelectionPolicy, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelMetricsSources { + pub url: String, + pub refresh_interval: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Configuration { pub version: String, @@ -122,6 +149,8 @@ pub struct Configuration { pub filters: Option>, pub listeners: Vec, pub state_storage: Option, + pub routing_preferences: Option>, + pub model_metrics_sources: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] @@ -317,7 +346,7 @@ impl LlmProviderType { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Clone, Serialize, Deserialize, Debug)] pub struct ModelUsagePreference { pub model: String, pub routing_preferences: Vec,