model routing: cost/latency ranking with ranked fallback list (#849)

This commit is contained in:
Adil Hafeez 2026-03-30 13:46:52 -07:00 committed by GitHub
parent 3a531ce22a
commit e5751d6b13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1524 additions and 317 deletions

View file

@ -1,15 +1,18 @@
use std::{collections::HashMap, sync::Arc};
use common::{
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
configuration::TopLevelRoutingPreference,
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER},
};
use super::router_model::{ModelUsagePreference, RoutingPreference};
use hermesllm::apis::openai::Message;
use hyper::header;
use thiserror::Error;
use tracing::{debug, info};
use super::http::{self, post_and_extract_content};
use super::model_metrics::ModelMetricsService;
use super::router_model::RouterModel;
use crate::router::router_model_v1;
@ -19,7 +22,8 @@ pub struct RouterService {
client: reqwest::Client,
router_model: Arc<dyn RouterModel>,
routing_provider_name: String,
llm_usage_defined: bool,
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
metrics_service: Option<Arc<ModelMetricsService>>,
}
#[derive(Debug, Error)]
@ -35,29 +39,37 @@ pub type Result<T> = std::result::Result<T, RoutingError>;
impl RouterService {
pub fn new(
providers: Vec<LlmProvider>,
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
metrics_service: Option<Arc<ModelMetricsService>>,
router_url: String,
routing_model_name: String,
routing_provider_name: String,
) -> Self {
let providers_with_usage = providers
.iter()
.filter(|provider| provider.routing_preferences.is_some())
.cloned()
.collect::<Vec<LlmProvider>>();
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
.map_or_else(HashMap::new, |prefs| {
prefs.into_iter().map(|p| (p.name.clone(), p)).collect()
});
let llm_routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
// Build sentinel routes for RouterModelV1: route_name → first model.
// RouterModelV1 uses this to build its prompt; RouterService overrides
// the model selection via rank_models() after the route is determined.
let sentinel_routes: HashMap<String, Vec<RoutingPreference>> = top_level_preferences
.iter()
.filter_map(|provider| {
provider
.routing_preferences
.as_ref()
.map(|prefs| (provider.name.clone(), prefs.clone()))
.filter_map(|(name, pref)| {
pref.models.first().map(|first_model| {
(
first_model.clone(),
vec![RoutingPreference {
name: name.clone(),
description: pref.description.clone(),
}],
)
})
})
.collect();
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
llm_routes,
sentinel_routes,
routing_model_name,
router_model_v1::MAX_TOKEN_LEN,
));
@ -67,7 +79,8 @@ impl RouterService {
client: reqwest::Client::new(),
router_model,
routing_provider_name,
llm_usage_defined: !providers_with_usage.is_empty(),
top_level_preferences,
metrics_service,
}
}
@ -75,24 +88,43 @@ impl RouterService {
&self,
messages: &[Message],
traceparent: &str,
usage_preferences: Option<Vec<ModelUsagePreference>>,
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
request_id: &str,
) -> Result<Option<(String, String)>> {
) -> Result<Option<(String, Vec<String>)>> {
if messages.is_empty() {
return Ok(None);
}
if usage_preferences
.as_ref()
.is_none_or(|prefs| prefs.len() < 2)
&& !self.llm_usage_defined
{
// Build inline top-level map from request if present (inline overrides config).
let inline_top_map: Option<HashMap<String, TopLevelRoutingPreference>> =
inline_routing_preferences
.map(|prefs| prefs.into_iter().map(|p| (p.name.clone(), p)).collect());
// No routing defined — skip the router call entirely.
if inline_top_map.is_none() && self.top_level_preferences.is_empty() {
return Ok(None);
}
// For inline overrides, build synthetic ModelUsagePreference list so RouterModelV1
// generates the correct prompt (route name + description pairs).
// For config-level prefs the sentinel routes are already baked into RouterModelV1.
let effective_usage_preferences: Option<Vec<ModelUsagePreference>> =
inline_top_map.as_ref().map(|inline_map| {
inline_map
.values()
.map(|p| ModelUsagePreference {
model: p.models.first().cloned().unwrap_or_default(),
routing_preferences: vec![RoutingPreference {
name: p.name.clone(),
description: p.description.clone(),
}],
})
.collect()
});
let router_request = self
.router_model
.generate_request(messages, &usage_preferences);
.generate_request(messages, &effective_usage_preferences);
debug!(
model = %self.router_model.get_model_name(),
@ -132,17 +164,37 @@ impl RouterService {
return Ok(None);
};
// Parse the route name from the router response.
let parsed = self
.router_model
.parse_response(&content, &usage_preferences)?;
.parse_response(&content, &effective_usage_preferences)?;
let result = if let Some((route_name, _sentinel)) = parsed {
let top_pref = inline_top_map
.as_ref()
.and_then(|m| m.get(&route_name))
.or_else(|| self.top_level_preferences.get(&route_name));
if let Some(pref) = top_pref {
let ranked = match &self.metrics_service {
Some(svc) => svc.rank_models(&pref.models, &pref.selection_policy).await,
None => pref.models.clone(),
};
Some((route_name, ranked))
} else {
None
}
} else {
None
};
info!(
content = %content.replace("\n", "\\n"),
selected_model = ?parsed,
selected_model = ?result,
response_time_ms = elapsed.as_millis(),
"arch-router determined route"
);
Ok(parsed)
Ok(result)
}
}

View file

@ -1,5 +1,6 @@
pub(crate) mod http;
pub mod llm;
pub mod model_metrics;
pub mod orchestrator;
pub mod orchestrator_model;
pub mod orchestrator_model_v1;

View file

@ -0,0 +1,419 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use common::configuration::{MetricsSource, SelectionPolicy, SelectionPreference};
use tokio::sync::RwLock;
use tracing::{info, warn};
const DO_PRICING_URL: &str = "https://api.digitalocean.com/v2/gen-ai/models/catalog";
pub struct ModelMetricsService {
cost: Arc<RwLock<HashMap<String, f64>>>,
latency: Arc<RwLock<HashMap<String, f64>>>,
}
impl ModelMetricsService {
pub async fn new(sources: &[MetricsSource], client: reqwest::Client) -> Self {
let cost_data = Arc::new(RwLock::new(HashMap::new()));
let latency_data = Arc::new(RwLock::new(HashMap::new()));
for source in sources {
match source {
MetricsSource::CostMetrics {
url,
refresh_interval,
auth,
} => {
let data = fetch_cost_metrics(url, auth.as_ref(), &client).await;
info!(models = data.len(), url = %url, "fetched cost metrics");
*cost_data.write().await = data;
if let Some(interval_secs) = refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let url = url.clone();
let auth = auth.clone();
let interval = Duration::from_secs(*interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data =
fetch_cost_metrics(&url, auth.as_ref(), &client_clone).await;
info!(models = data.len(), url = %url, "refreshed cost metrics");
*cost_clone.write().await = data;
}
});
}
}
MetricsSource::PrometheusMetrics {
url,
query,
refresh_interval,
} => {
let data = fetch_prometheus_metrics(url, query, &client).await;
info!(models = data.len(), url = %url, "fetched prometheus latency metrics");
*latency_data.write().await = data;
if let Some(interval_secs) = refresh_interval {
let latency_clone = Arc::clone(&latency_data);
let client_clone = client.clone();
let url = url.clone();
let query = query.clone();
let interval = Duration::from_secs(*interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data =
fetch_prometheus_metrics(&url, &query, &client_clone).await;
info!(models = data.len(), url = %url, "refreshed prometheus latency metrics");
*latency_clone.write().await = data;
}
});
}
}
MetricsSource::DigitalOceanPricing {
refresh_interval,
model_aliases,
} => {
let aliases = model_aliases.clone().unwrap_or_default();
let data = fetch_do_pricing(&client, &aliases).await;
info!(models = data.len(), "fetched digitalocean pricing");
*cost_data.write().await = data;
if let Some(interval_secs) = refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let interval = Duration::from_secs(*interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data = fetch_do_pricing(&client_clone, &aliases).await;
info!(models = data.len(), "refreshed digitalocean pricing");
*cost_clone.write().await = data;
}
});
}
}
}
}
ModelMetricsService {
cost: cost_data,
latency: latency_data,
}
}
/// Rank `models` by `policy`, returning them in preference order.
/// Models with no metric data are appended at the end in their original order.
pub async fn rank_models(&self, models: &[String], policy: &SelectionPolicy) -> Vec<String> {
match policy.prefer {
SelectionPreference::Cheapest => {
let data = self.cost.read().await;
for m in models {
if !data.contains_key(m.as_str()) {
warn!(model = %m, "no cost data for model — ranking last (prefer: cheapest)");
}
}
rank_by_ascending_metric(models, &data)
}
SelectionPreference::Fastest => {
let data = self.latency.read().await;
for m in models {
if !data.contains_key(m.as_str()) {
warn!(model = %m, "no latency data for model — ranking last (prefer: fastest)");
}
}
rank_by_ascending_metric(models, &data)
}
SelectionPreference::None => models.to_vec(),
}
}
/// Returns a snapshot of the current cost data. Used at startup to warn about unmatched models.
pub async fn cost_snapshot(&self) -> HashMap<String, f64> {
self.cost.read().await.clone()
}
/// Returns a snapshot of the current latency data. Used at startup to warn about unmatched models.
pub async fn latency_snapshot(&self) -> HashMap<String, f64> {
self.latency.read().await.clone()
}
}
fn rank_by_ascending_metric(models: &[String], data: &HashMap<String, f64>) -> Vec<String> {
let mut with_data: Vec<(&String, f64)> = models
.iter()
.filter_map(|m| data.get(m.as_str()).map(|v| (m, *v)))
.collect();
with_data.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let without_data: Vec<&String> = models
.iter()
.filter(|m| !data.contains_key(m.as_str()))
.collect();
with_data
.iter()
.map(|(m, _)| (*m).clone())
.chain(without_data.iter().map(|m| (*m).clone()))
.collect()
}
#[derive(serde::Deserialize)]
struct CostEntry {
input_per_million: f64,
output_per_million: f64,
}
async fn fetch_cost_metrics(
url: &str,
auth: Option<&common::configuration::MetricsAuth>,
client: &reqwest::Client,
) -> HashMap<String, f64> {
let mut req = client.get(url);
if let Some(auth) = auth {
if auth.auth_type == "bearer" {
req = req.header("Authorization", format!("Bearer {}", auth.token));
} else {
warn!(auth_type = %auth.auth_type, "unsupported auth type for cost_metrics, skipping auth");
}
}
match req.send().await {
Ok(resp) => match resp.json::<HashMap<String, CostEntry>>().await {
Ok(data) => data
.into_iter()
.map(|(k, v)| (k, v.input_per_million + v.output_per_million))
.collect(),
Err(err) => {
warn!(error = %err, url = %url, "failed to parse cost metrics response");
HashMap::new()
}
},
Err(err) => {
warn!(error = %err, url = %url, "failed to fetch cost metrics");
HashMap::new()
}
}
}
#[derive(serde::Deserialize)]
struct DoModelList {
data: Vec<DoModel>,
}
#[derive(serde::Deserialize)]
struct DoModel {
model_id: String,
pricing: Option<DoPricing>,
}
#[derive(serde::Deserialize)]
struct DoPricing {
input_price_per_million: Option<f64>,
output_price_per_million: Option<f64>,
}
async fn fetch_do_pricing(
client: &reqwest::Client,
aliases: &HashMap<String, String>,
) -> HashMap<String, f64> {
match client.get(DO_PRICING_URL).send().await {
Ok(resp) => match resp.json::<DoModelList>().await {
Ok(list) => list
.data
.into_iter()
.filter_map(|m| {
let pricing = m.pricing?;
let raw_key = m.model_id.clone();
let key = aliases.get(&raw_key).cloned().unwrap_or(raw_key);
let cost = pricing.input_price_per_million.unwrap_or(0.0)
+ pricing.output_price_per_million.unwrap_or(0.0);
Some((key, cost))
})
.collect(),
Err(err) => {
warn!(error = %err, url = DO_PRICING_URL, "failed to parse digitalocean pricing response");
HashMap::new()
}
},
Err(err) => {
warn!(error = %err, url = DO_PRICING_URL, "failed to fetch digitalocean pricing");
HashMap::new()
}
}
}
#[derive(serde::Deserialize)]
struct PrometheusResponse {
data: PrometheusData,
}
#[derive(serde::Deserialize)]
struct PrometheusData {
result: Vec<PrometheusResult>,
}
#[derive(serde::Deserialize)]
struct PrometheusResult {
metric: HashMap<String, String>,
value: (f64, String), // (timestamp, value_str)
}
async fn fetch_prometheus_metrics(
url: &str,
query: &str,
client: &reqwest::Client,
) -> HashMap<String, f64> {
let query_url = format!("{}/api/v1/query", url.trim_end_matches('/'));
match client
.get(&query_url)
.query(&[("query", query)])
.send()
.await
{
Ok(resp) => match resp.json::<PrometheusResponse>().await {
Ok(prom) => prom
.data
.result
.into_iter()
.filter_map(|r| {
let model_name = r.metric.get("model_name")?.clone();
let value: f64 = r.value.1.parse().ok()?;
Some((model_name, value))
})
.collect(),
Err(err) => {
warn!(error = %err, url = %query_url, "failed to parse prometheus response");
HashMap::new()
}
},
Err(err) => {
warn!(error = %err, url = %query_url, "failed to fetch prometheus metrics");
HashMap::new()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::SelectionPreference;
fn make_policy(prefer: SelectionPreference) -> SelectionPolicy {
SelectionPolicy { prefer }
}
#[test]
fn test_rank_by_ascending_metric_picks_lowest_first() {
let models = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let mut data = HashMap::new();
data.insert("a".to_string(), 0.01);
data.insert("b".to_string(), 0.005);
data.insert("c".to_string(), 0.02);
assert_eq!(
rank_by_ascending_metric(&models, &data),
vec!["b", "a", "c"]
);
}
#[test]
fn test_rank_by_ascending_metric_no_data_preserves_order() {
let models = vec!["x".to_string(), "y".to_string()];
let data = HashMap::new();
assert_eq!(rank_by_ascending_metric(&models, &data), vec!["x", "y"]);
}
#[test]
fn test_rank_by_ascending_metric_partial_data() {
let models = vec!["a".to_string(), "b".to_string()];
let mut data = HashMap::new();
data.insert("b".to_string(), 100.0);
assert_eq!(rank_by_ascending_metric(&models, &data), vec!["b", "a"]);
}
#[tokio::test]
async fn test_rank_models_cheapest() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o".to_string(), 0.005);
m.insert("gpt-4o-mini".to_string(), 0.0001);
m
})),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
.await;
assert_eq!(result, vec!["gpt-4o-mini", "gpt-4o"]);
}
#[tokio::test]
async fn test_rank_models_fastest() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new(HashMap::new())),
latency: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o".to_string(), 200.0);
m.insert("claude-sonnet".to_string(), 120.0);
m
})),
};
let models = vec!["gpt-4o".to_string(), "claude-sonnet".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Fastest))
.await;
assert_eq!(result, vec!["claude-sonnet", "gpt-4o"]);
}
#[tokio::test]
async fn test_rank_models_fallback_no_metrics() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new(HashMap::new())),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["model-a".to_string(), "model-b".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
.await;
assert_eq!(result, vec!["model-a", "model-b"]);
}
#[tokio::test]
async fn test_rank_models_partial_data_appended_last() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o".to_string(), 0.005);
m
})),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["gpt-4o-mini".to_string(), "gpt-4o".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
.await;
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
}
#[tokio::test]
async fn test_rank_models_none_preserves_order() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o-mini".to_string(), 0.0001);
m.insert("gpt-4o".to_string(), 0.005);
m
})),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::None))
.await;
// none → original order, despite gpt-4o-mini being cheaper
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
}
}

View file

@ -1,5 +1,5 @@
use common::configuration::ModelUsagePreference;
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
@ -10,6 +10,20 @@ pub enum RoutingModelError {
pub type Result<T> = std::result::Result<T, RoutingModelError>;
/// Internal route descriptor passed to the router model to build its prompt.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingPreference {
pub name: String,
pub description: String,
}
/// Groups a model with its routing preferences (used internally by RouterModelV1).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelUsagePreference {
pub model: String,
pub routing_preferences: Vec<RoutingPreference>,
}
pub trait RouterModel: Send + Sync {
fn generate_request(
&self,

View file

@ -1,6 +1,6 @@
use std::collections::HashMap;
use common::configuration::{ModelUsagePreference, RoutingPreference};
use super::router_model::{ModelUsagePreference, RoutingPreference};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use hermesllm::transforms::lib::ExtractText;
use serde::{Deserialize, Serialize};