add top-level routing_preferences with selection_policy and model metrics fetch

This commit is contained in:
Adil Hafeez 2026-03-26 17:35:39 -07:00
parent 406fa92802
commit 2ef938ac5f
9 changed files with 568 additions and 49 deletions

View file

@ -9,6 +9,7 @@ properties:
- 0.1-beta
- 0.2.0
- v0.3.0
- v0.4.0
agents:
type: array
@ -470,6 +471,51 @@ properties:
additionalProperties: false
required:
- jailbreak
routing_preferences:
type: array
items:
type: object
properties:
name:
type: string
description:
type: string
models:
type: array
items:
type: string
minItems: 1
selection_policy:
type: object
properties:
prefer:
type: string
enum:
- cheapest
- fastest
- random
additionalProperties: false
required:
- prefer
additionalProperties: false
required:
- name
- description
- models
- selection_policy
model_metrics_sources:
type: object
properties:
url:
type: string
refresh_interval:
type: integer
minimum: 1
additionalProperties: false
required:
- url
additionalProperties: false
required:
- version

View file

@ -120,6 +120,7 @@ async fn llm_chat_inner(
tool_names,
user_message_preview,
inline_routing_policy,
inline_routing_preferences,
client_api,
provider_id,
} = parsed;
@ -262,6 +263,7 @@ async fn llm_chat_inner(
&request_path,
&request_id,
inline_routing_policy,
inline_routing_preferences,
)
.await
}
@ -324,6 +326,7 @@ struct PreparedRequest {
tool_names: Option<Vec<String>>,
user_message_preview: Option<String>,
inline_routing_policy: Option<Vec<common::configuration::ModelUsagePreference>>,
inline_routing_preferences: Option<Vec<common::configuration::TopLevelRoutingPreference>>,
client_api: Option<SupportedAPIsFromClient>,
provider_id: hermesllm::ProviderId,
}
@ -352,8 +355,8 @@ async fn parse_and_validate_request(
"request body received"
);
// Extract routing_policy from request body if present
let (chat_request_bytes, inline_routing_policy) =
// Extract routing_policy and routing_preferences from request body if present
let (chat_request_bytes, inline_routing_policy, inline_routing_preferences) =
crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false).map_err(
|err| {
warn!(error = %err, "failed to parse request JSON");
@ -440,6 +443,7 @@ async fn parse_and_validate_request(
tool_names,
user_message_preview,
inline_routing_policy,
inline_routing_preferences,
client_api,
provider_id,
})

View file

@ -1,4 +1,4 @@
use common::configuration::ModelUsagePreference;
use common::configuration::{ModelUsagePreference, TopLevelRoutingPreference};
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hyper::StatusCode;
@ -40,6 +40,7 @@ pub async fn router_chat_get_upstream_model(
request_path: &str,
request_id: &str,
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
) -> Result<RoutingResult, RoutingError> {
// Clone metadata for routing before converting (which consumes client_request)
let routing_metadata = client_request.metadata().clone();
@ -122,6 +123,7 @@ pub async fn router_chat_get_upstream_model(
&chat_request.messages,
traceparent,
usage_preferences,
inline_routing_preferences,
request_id,
)
.await;

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::configuration::{ModelUsagePreference, SpanAttributes};
use common::configuration::{ModelUsagePreference, SpanAttributes, TopLevelRoutingPreference};
use common::consts::REQUEST_ID_HEADER;
use common::errors::BrightStaffError;
use hermesllm::clients::SupportedAPIsFromClient;
@ -17,21 +17,31 @@ use crate::tracing::{collect_custom_trace_attributes, operation_component, set_s
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes
/// and parsed preferences. The `routing_policy` field is removed from the JSON
/// before re-serializing so downstream parsers don't see the non-standard field.
type ExtractedRoutingPolicies = (
Bytes,
Option<Vec<ModelUsagePreference>>,
Option<Vec<TopLevelRoutingPreference>>,
);
/// Extracts `routing_policy` and `routing_preferences` from a JSON body, returning
/// the cleaned body bytes and both sets of parsed preferences. Both fields are removed
/// from the JSON before re-serializing so downstream parsers don't see them.
///
/// - `routing_policy` — legacy per-provider format (`Vec<ModelUsagePreference>`)
/// - `routing_preferences` — v0.4.0+ format (`Vec<TopLevelRoutingPreference>`)
///
/// If `warn_on_size` is true, logs a warning when the serialized policy exceeds 5KB.
pub fn extract_routing_policy(
raw_bytes: &[u8],
warn_on_size: bool,
) -> Result<(Bytes, Option<Vec<ModelUsagePreference>>), String> {
) -> Result<ExtractedRoutingPolicies, String> {
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
let preferences = json_body
// Extract legacy routing_policy
let legacy_preferences = json_body
.as_object_mut()
.and_then(|obj| obj.remove("routing_policy"))
.and_then(|o| o.remove("routing_policy"))
.and_then(|policy_value| {
if warn_on_size {
let policy_str = serde_json::to_string(&policy_value).unwrap_or_default();
@ -58,8 +68,28 @@ pub fn extract_routing_policy(
}
});
// Extract new v0.4.0 routing_preferences
let top_level_preferences = json_body
.as_object_mut()
.and_then(|o| o.remove("routing_preferences"))
.and_then(|value| {
match serde_json::from_value::<Vec<TopLevelRoutingPreference>>(value) {
Ok(prefs) => {
info!(
num_routes = prefs.len(),
"using inline routing_preferences from request body"
);
Some(prefs)
}
Err(err) => {
warn!(error = %err, "failed to parse routing_preferences");
None
}
}
});
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
Ok((bytes, preferences))
Ok((bytes, legacy_preferences, top_level_preferences))
}
#[derive(serde::Serialize)]
@ -136,18 +166,19 @@ async fn routing_decision_inner(
"routing decision request body received"
);
// Extract routing_policy from request body before parsing as ProviderRequestType
let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) {
Ok(result) => result,
Err(err) => {
warn!(error = %err, "failed to parse request JSON");
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request JSON: {}",
err
))
.into_response());
}
};
// Extract routing_policy and routing_preferences from body before parsing as ProviderRequestType
let (chat_request_bytes, inline_preferences, inline_routing_preferences) =
match extract_routing_policy(&raw_bytes, true) {
Ok(result) => result,
Err(err) => {
warn!(error = %err, "failed to parse request JSON");
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request JSON: {}",
err
))
.into_response());
}
};
let client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
@ -172,6 +203,7 @@ async fn routing_decision_inner(
&request_path,
&request_id,
inline_preferences,
inline_routing_preferences,
)
.await;
@ -227,9 +259,10 @@ mod tests {
#[test]
fn extract_routing_policy_no_policy() {
let body = make_chat_body("");
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs, top_prefs) = extract_routing_policy(&body, false).unwrap();
assert!(prefs.is_none());
assert!(top_prefs.is_none());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
assert!(cleaned_json.get("routing_policy").is_none());
@ -252,7 +285,7 @@ mod tests {
}
]"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs, top_prefs) = extract_routing_policy(&body, false).unwrap();
let prefs = prefs.expect("should have parsed preferences");
assert_eq!(prefs.len(), 2);
@ -260,6 +293,7 @@ mod tests {
assert_eq!(prefs[0].routing_preferences[0].name, "coding");
assert_eq!(prefs[1].model, "openai/gpt-4o-mini");
assert_eq!(prefs[1].routing_preferences[0].name, "general");
assert!(top_prefs.is_none());
// routing_policy should be stripped from cleaned body
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
@ -272,7 +306,7 @@ mod tests {
// routing_policy is present but has wrong shape
let policy = r#""routing_policy": "not-an-array""#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs, _) = extract_routing_policy(&body, false).unwrap();
// Invalid policy should be ignored (returns None), not error
assert!(prefs.is_none());
@ -293,7 +327,7 @@ mod tests {
fn extract_routing_policy_empty_array() {
let policy = r#""routing_policy": []"#;
let body = make_chat_body(policy);
let (_, prefs) = extract_routing_policy(&body, false).unwrap();
let (_, prefs, _) = extract_routing_policy(&body, false).unwrap();
let prefs = prefs.expect("empty array is valid");
assert_eq!(prefs.len(), 0);
@ -303,7 +337,7 @@ mod tests {
fn extract_routing_policy_preserves_other_fields() {
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs, _) = extract_routing_policy(&body, false).unwrap();
assert!(prefs.is_some());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
@ -312,6 +346,29 @@ mod tests {
assert!(cleaned_json.get("routing_policy").is_none());
}
#[test]
fn extract_routing_policy_top_level_routing_preferences() {
let policy = r#""routing_preferences": [
{
"name": "code generation",
"description": "generate new code",
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
"selection_policy": {"prefer": "fastest"}
}
]"#;
let body = make_chat_body(policy);
let (cleaned, legacy_prefs, top_prefs) = extract_routing_policy(&body, false).unwrap();
assert!(legacy_prefs.is_none());
let top_prefs = top_prefs.expect("should have parsed top-level routing_preferences");
assert_eq!(top_prefs.len(), 1);
assert_eq!(top_prefs[0].name, "code generation");
assert_eq!(top_prefs[0].models, vec!["openai/gpt-4o", "openai/gpt-4o-mini"]);
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_preferences").is_none());
}
#[test]
fn routing_decision_response_serialization() {
let response = RoutingDecisionResponse {

View file

@ -6,6 +6,7 @@ use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::routing_service::routing_decision;
use brightstaff::router::llm::RouterService;
use brightstaff::router::model_metrics::ModelMetricsService;
use brightstaff::router::orchestrator::OrchestratorService;
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
@ -40,6 +41,17 @@ const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
/// Parse a version string like `v0.4.0`, `v0.3.0`, `0.2.0` into a `(major, minor, patch)` tuple.
/// Missing parts default to 0. Non-numeric parts are treated as 0.
fn parse_semver(version: &str) -> (u32, u32, u32) {
let v = version.trim_start_matches('v');
let mut parts = v.splitn(3, '.').map(|p| p.parse::<u32>().unwrap_or(0));
let major = parts.next().unwrap_or(0);
let minor = parts.next().unwrap_or(0);
let patch = parts.next().unwrap_or(0);
(major, minor, patch)
}
/// CORS pre-flight response for the models endpoint.
fn cors_preflight() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let mut response = Response::new(empty());
@ -162,8 +174,69 @@ async fn init_app_state(
.map(|p| p.name.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
// Validate version-gated routing_preferences rules.
let config_version = parse_semver(&config.version);
let is_v040_plus = config_version >= (0, 4, 0);
if is_v040_plus {
// v0.4.0+: per-provider routing_preferences are forbidden.
let providers_with_per_provider_prefs: Vec<&str> = config
.model_providers
.iter()
.filter(|p| p.routing_preferences.is_some())
.filter_map(|p| p.model.as_deref())
.collect();
if !providers_with_per_provider_prefs.is_empty() {
return Err(format!(
"routing_preferences inside model_providers is not allowed in v0.4.0+. \
Use the top-level routing_preferences instead. \
Offending models: {}",
providers_with_per_provider_prefs.join(", ")
)
.into());
}
} else if config.routing_preferences.is_some() {
return Err(
"top-level routing_preferences requires version v0.4.0 or above. \
Update the version field or remove routing_preferences."
.into(),
);
}
// Validate that all models referenced in top-level routing_preferences exist in model_providers.
if let Some(ref route_prefs) = config.routing_preferences {
let provider_model_names: std::collections::HashSet<&str> = config
.model_providers
.iter()
.flat_map(|p| p.model.as_deref())
.collect();
for pref in route_prefs {
for model in &pref.models {
if !provider_model_names.contains(model.as_str()) {
return Err(format!(
"routing_preferences route '{}' references model '{}' \
which is not declared in model_providers",
pref.name, model
)
.into());
}
}
}
}
// Initialize ModelMetricsService if model_metrics_sources is configured.
let metrics_service: Option<Arc<ModelMetricsService>> =
if let Some(ref sources) = config.model_metrics_sources {
let svc = ModelMetricsService::new(sources, reqwest::Client::new()).await;
Some(Arc::new(svc))
} else {
None
};
let router_service = Arc::new(RouterService::new(
config.model_providers.clone(),
config.routing_preferences.clone(),
metrics_service,
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
routing_model_name,
routing_llm_provider,

View file

@ -1,7 +1,9 @@
use std::{collections::HashMap, sync::Arc};
use common::{
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
configuration::{
LlmProvider, ModelUsagePreference, RoutingPreference, TopLevelRoutingPreference,
},
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER},
};
use hermesllm::apis::openai::Message;
@ -10,6 +12,7 @@ use thiserror::Error;
use tracing::{debug, info};
use super::http::{self, post_and_extract_content};
use super::model_metrics::ModelMetricsService;
use super::router_model::RouterModel;
use crate::router::router_model_v1;
@ -20,6 +23,8 @@ pub struct RouterService {
router_model: Arc<dyn RouterModel>,
routing_provider_name: String,
llm_usage_defined: bool,
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
metrics_service: Option<Arc<ModelMetricsService>>,
}
#[derive(Debug, Error)]
@ -36,25 +41,58 @@ pub type Result<T> = std::result::Result<T, RoutingError>;
impl RouterService {
pub fn new(
providers: Vec<LlmProvider>,
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
metrics_service: Option<Arc<ModelMetricsService>>,
router_url: String,
routing_model_name: String,
routing_provider_name: String,
) -> Self {
let providers_with_usage = providers
.iter()
.filter(|provider| provider.routing_preferences.is_some())
.cloned()
.collect::<Vec<LlmProvider>>();
// Build top-level preference map and sentinel llm_routes when v0.4.0 format is used.
let (top_level_preferences, llm_routes, llm_usage_defined) =
if let Some(top_prefs) = top_level_prefs {
let top_level_map: HashMap<String, TopLevelRoutingPreference> = top_prefs
.into_iter()
.map(|p| (p.name.clone(), p))
.collect();
// Build sentinel routes: route_name → first model (RouterModelV1 needs a model
// mapping, but RouterService overrides the selection via metrics_service).
let sentinel_routes: HashMap<String, Vec<RoutingPreference>> = top_level_map
.iter()
.filter_map(|(name, pref)| {
pref.models.first().map(|first_model| {
(
first_model.clone(),
vec![RoutingPreference {
name: name.clone(),
description: pref.description.clone(),
}],
)
})
})
.collect();
let defined = !top_level_map.is_empty();
(top_level_map, sentinel_routes, defined)
} else {
// Legacy per-provider format.
let providers_with_usage = providers
.iter()
.filter(|provider| provider.routing_preferences.is_some())
.cloned()
.collect::<Vec<LlmProvider>>();
let llm_routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
.iter()
.filter_map(|provider| {
provider
.routing_preferences
.as_ref()
.map(|prefs| (provider.name.clone(), prefs.clone()))
})
.collect();
let routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
.iter()
.filter_map(|provider| {
provider
.routing_preferences
.as_ref()
.map(|prefs| (provider.name.clone(), prefs.clone()))
})
.collect();
let defined = !providers_with_usage.is_empty();
(HashMap::new(), routes, defined)
};
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
llm_routes,
@ -67,7 +105,9 @@ impl RouterService {
client: reqwest::Client::new(),
router_model,
routing_provider_name,
llm_usage_defined: !providers_with_usage.is_empty(),
llm_usage_defined,
top_level_preferences,
metrics_service,
}
}
@ -76,23 +116,58 @@ impl RouterService {
messages: &[Message],
traceparent: &str,
usage_preferences: Option<Vec<ModelUsagePreference>>,
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
request_id: &str,
) -> Result<Option<(String, String)>> {
if messages.is_empty() {
return Ok(None);
}
// Build inline top-level map from request if present (inline overrides config).
let inline_top_map: Option<HashMap<String, TopLevelRoutingPreference>> =
inline_routing_preferences.map(|prefs| {
prefs.into_iter().map(|p| (p.name.clone(), p)).collect()
});
// Determine whether any routing is defined.
let has_top_level = inline_top_map.is_some() || !self.top_level_preferences.is_empty();
if usage_preferences
.as_ref()
.is_none_or(|prefs| prefs.len() < 2)
&& !self.llm_usage_defined
&& !has_top_level
{
return Ok(None);
}
// For top-level format, build a synthetic ModelUsagePreference list so RouterModelV1
// generates the correct prompt (route name + description pairs).
let effective_usage_preferences: Option<Vec<ModelUsagePreference>> =
if let Some(ref inline_map) = inline_top_map {
Some(
inline_map
.values()
.map(|p| ModelUsagePreference {
model: p.models.first().cloned().unwrap_or_default(),
routing_preferences: vec![RoutingPreference {
name: p.name.clone(),
description: p.description.clone(),
}],
})
.collect(),
)
} else if !self.top_level_preferences.is_empty() {
// Config top-level prefs: already encoded as sentinel routes in RouterModelV1,
// pass None so it uses the pre-built llm_route_json_str.
None
} else {
usage_preferences.clone()
};
let router_request = self
.router_model
.generate_request(messages, &usage_preferences);
.generate_request(messages, &effective_usage_preferences);
debug!(
model = %self.router_model.get_model_name(),
@ -132,17 +207,40 @@ impl RouterService {
return Ok(None);
};
// Parse the route name from the router response.
let parsed = self
.router_model
.parse_response(&content, &usage_preferences)?;
.parse_response(&content, &effective_usage_preferences)?;
let result = if let Some((route_name, _sentinel_model)) = parsed {
// Check if this route belongs to the top-level preference format.
let top_pref = inline_top_map
.as_ref()
.and_then(|m| m.get(&route_name))
.or_else(|| self.top_level_preferences.get(&route_name));
if let Some(pref) = top_pref {
let selected_model = match &self.metrics_service {
Some(svc) => {
svc.select_model(&pref.models, &pref.selection_policy).await
}
None => pref.models.first().cloned().unwrap_or_default(),
};
Some((route_name, selected_model))
} else {
Some((route_name, _sentinel_model))
}
} else {
None
};
info!(
content = %content.replace("\n", "\\n"),
selected_model = ?parsed,
selected_model = ?result,
response_time_ms = elapsed.as_millis(),
"arch-router determined route"
);
Ok(parsed)
Ok(result)
}
}

View file

@ -1,5 +1,6 @@
pub(crate) mod http;
pub mod llm;
pub mod model_metrics;
pub mod orchestrator;
pub mod orchestrator_model;
pub mod orchestrator_model_v1;

View file

@ -0,0 +1,209 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use common::configuration::{ModelMetricsSources, SelectionPolicy, SelectionPreference};
use serde::Deserialize;
use tokio::sync::RwLock;
use tracing::{info, warn};
#[derive(Deserialize)]
struct MetricsResponse {
#[serde(default)]
cost: HashMap<String, f64>,
#[serde(default)]
latency: HashMap<String, f64>,
}
pub struct ModelMetricsService {
cost: Arc<RwLock<HashMap<String, f64>>>,
latency: Arc<RwLock<HashMap<String, f64>>>,
}
impl ModelMetricsService {
pub async fn new(sources: &ModelMetricsSources, client: reqwest::Client) -> Self {
let cost_data = Arc::new(RwLock::new(HashMap::new()));
let latency_data = Arc::new(RwLock::new(HashMap::new()));
let metrics = fetch_metrics(&sources.url, &client).await;
info!(
cost_models = metrics.cost.len(),
latency_models = metrics.latency.len(),
url = %sources.url,
"fetched model metrics"
);
*cost_data.write().await = metrics.cost;
*latency_data.write().await = metrics.latency;
if let Some(interval_secs) = sources.refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let latency_clone = Arc::clone(&latency_data);
let client_clone = client.clone();
let url = sources.url.clone();
tokio::spawn(async move {
let interval = Duration::from_secs(interval_secs);
loop {
tokio::time::sleep(interval).await;
let metrics = fetch_metrics(&url, &client_clone).await;
info!(
cost_models = metrics.cost.len(),
latency_models = metrics.latency.len(),
url = %url,
"refreshed model metrics"
);
*cost_clone.write().await = metrics.cost;
*latency_clone.write().await = metrics.latency;
}
});
}
ModelMetricsService {
cost: cost_data,
latency: latency_data,
}
}
/// Select the best model from `models` according to `policy`.
/// Falls back to `models[0]` if metric data is unavailable for all candidates.
pub async fn select_model(&self, models: &[String], policy: &SelectionPolicy) -> String {
match policy.prefer {
SelectionPreference::Cheapest => {
let data = self.cost.read().await;
select_by_ascending_metric(models, &data)
}
SelectionPreference::Fastest => {
let data = self.latency.read().await;
select_by_ascending_metric(models, &data)
}
SelectionPreference::Random => {
let idx = rand_index(models.len());
models[idx].clone()
}
}
}
}
fn select_by_ascending_metric(models: &[String], data: &HashMap<String, f64>) -> String {
models
.iter()
.filter_map(|m| data.get(m.as_str()).map(|v| (m, *v)))
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(m, _)| m.clone())
.unwrap_or_else(|| models[0].clone())
}
/// Simple non-crypto random index using system time nanoseconds.
fn rand_index(len: usize) -> usize {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.subsec_nanos() as usize)
.unwrap_or(0);
nanos % len
}
async fn fetch_metrics(url: &str, client: &reqwest::Client) -> MetricsResponse {
match client.get(url).send().await {
Ok(resp) => match resp.json::<MetricsResponse>().await {
Ok(data) => data,
Err(err) => {
warn!(error = %err, url = %url, "failed to parse metrics response");
MetricsResponse {
cost: HashMap::new(),
latency: HashMap::new(),
}
}
},
Err(err) => {
warn!(error = %err, url = %url, "failed to fetch metrics");
MetricsResponse {
cost: HashMap::new(),
latency: HashMap::new(),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::SelectionPreference;
fn make_policy(prefer: SelectionPreference) -> SelectionPolicy {
SelectionPolicy { prefer }
}
#[test]
fn test_select_by_ascending_metric_picks_lowest() {
let models = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let mut data = HashMap::new();
data.insert("a".to_string(), 0.01);
data.insert("b".to_string(), 0.005);
data.insert("c".to_string(), 0.02);
assert_eq!(select_by_ascending_metric(&models, &data), "b");
}
#[test]
fn test_select_by_ascending_metric_fallback_to_first() {
let models = vec!["x".to_string(), "y".to_string()];
let data = HashMap::new();
assert_eq!(select_by_ascending_metric(&models, &data), "x");
}
#[test]
fn test_select_by_ascending_metric_partial_data() {
let models = vec!["a".to_string(), "b".to_string()];
let mut data = HashMap::new();
data.insert("b".to_string(), 100.0);
assert_eq!(select_by_ascending_metric(&models, &data), "b");
}
#[tokio::test]
async fn test_select_model_cheapest() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o".to_string(), 0.005);
m.insert("gpt-4o-mini".to_string(), 0.0001);
m
})),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()];
let result = service
.select_model(&models, &make_policy(SelectionPreference::Cheapest))
.await;
assert_eq!(result, "gpt-4o-mini");
}
#[tokio::test]
async fn test_select_model_fastest() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new(HashMap::new())),
latency: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o".to_string(), 200.0);
m.insert("claude-sonnet".to_string(), 120.0);
m
})),
};
let models = vec!["gpt-4o".to_string(), "claude-sonnet".to_string()];
let result = service
.select_model(&models, &make_policy(SelectionPreference::Fastest))
.await;
assert_eq!(result, "claude-sonnet");
}
#[tokio::test]
async fn test_select_model_fallback_no_metrics() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new(HashMap::new())),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["model-a".to_string(), "model-b".to_string()];
let result = service
.select_model(&models, &make_policy(SelectionPreference::Cheapest))
.await;
assert_eq!(result, "model-a");
}
}

View file

@ -104,6 +104,33 @@ pub enum StateStorageType {
Postgres,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum SelectionPreference {
Cheapest,
Fastest,
Random,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SelectionPolicy {
pub prefer: SelectionPreference,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopLevelRoutingPreference {
pub name: String,
pub description: String,
pub models: Vec<String>,
pub selection_policy: SelectionPolicy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetricsSources {
pub url: String,
pub refresh_interval: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
@ -122,6 +149,8 @@ pub struct Configuration {
pub filters: Option<Vec<Agent>>,
pub listeners: Vec<Listener>,
pub state_storage: Option<StateStorageConfig>,
pub routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
pub model_metrics_sources: Option<ModelMetricsSources>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
@ -317,7 +346,7 @@ impl LlmProviderType {
}
}
#[derive(Serialize, Deserialize, Debug)]
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct ModelUsagePreference {
pub model: String,
pub routing_preferences: Vec<RoutingPreference>,