restructure model_metrics_sources to use type + provider pattern

This commit is contained in:
Adil Hafeez 2026-03-30 15:18:04 -07:00
parent e5751d6b13
commit ba701264be
7 changed files with 142 additions and 299 deletions

View file

@ -216,33 +216,17 @@ async fn init_app_state(
use common::configuration::MetricsSource;
let cost_count = sources
.iter()
.filter(|s| matches!(s, MetricsSource::CostMetrics { .. }))
.filter(|s| matches!(s, MetricsSource::Cost(_)))
.count();
let prom_count = sources
let latency_count = sources
.iter()
.filter(|s| matches!(s, MetricsSource::PrometheusMetrics { .. }))
.count();
let do_count = sources
.iter()
.filter(|s| matches!(s, MetricsSource::DigitalOceanPricing { .. }))
.filter(|s| matches!(s, MetricsSource::Latency(_)))
.count();
if cost_count > 1 {
return Err("model_metrics_sources: only one cost_metrics source is allowed".into());
return Err("model_metrics_sources: only one cost metrics source is allowed".into());
}
if prom_count > 1 {
return Err(
"model_metrics_sources: only one prometheus_metrics source is allowed".into(),
);
}
if do_count > 1 {
return Err(
"model_metrics_sources: only one digitalocean_pricing source is allowed".into(),
);
}
if cost_count > 0 && do_count > 0 {
return Err(
"model_metrics_sources: cost_metrics and digitalocean_pricing cannot both be configured — use one or the other".into(),
);
if latency_count > 1 {
return Err("model_metrics_sources: only one latency metrics source is allowed".into());
}
let svc = ModelMetricsService::new(sources, reqwest::Client::new()).await;
Some(Arc::new(svc))
@ -259,32 +243,27 @@ async fn init_app_state(
.as_deref()
.unwrap_or_default()
.iter()
.any(|s| {
matches!(
s,
MetricsSource::CostMetrics { .. } | MetricsSource::DigitalOceanPricing { .. }
)
});
let has_prometheus = config
.any(|s| matches!(s, MetricsSource::Cost(_)));
let has_latency_source = config
.model_metrics_sources
.as_deref()
.unwrap_or_default()
.iter()
.any(|s| matches!(s, MetricsSource::PrometheusMetrics { .. }));
.any(|s| matches!(s, MetricsSource::Latency(_)));
for pref in prefs {
if pref.selection_policy.prefer == SelectionPreference::Cheapest && !has_cost_source {
return Err(format!(
"routing_preferences route '{}' uses prefer: cheapest but no cost data source is configured — \
add cost_metrics or digitalocean_pricing to model_metrics_sources",
"routing_preferences route '{}' uses prefer: cheapest but no cost metrics source is configured — \
add a cost metrics source to model_metrics_sources",
pref.name
)
.into());
}
if pref.selection_policy.prefer == SelectionPreference::Fastest && !has_prometheus {
if pref.selection_policy.prefer == SelectionPreference::Fastest && !has_latency_source {
return Err(format!(
"routing_preferences route '{}' uses prefer: fastest but no prometheus_metrics source is configured — \
add prometheus_metrics to model_metrics_sources",
"routing_preferences route '{}' uses prefer: fastest but no latency metrics source is configured — \
add a latency metrics source to model_metrics_sources",
pref.name
)
.into());

View file

@ -2,7 +2,9 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use common::configuration::{MetricsSource, SelectionPolicy, SelectionPreference};
use common::configuration::{
CostProvider, LatencyProvider, MetricsSource, SelectionPolicy, SelectionPreference,
};
use tokio::sync::RwLock;
use tracing::{info, warn};
@ -20,81 +22,52 @@ impl ModelMetricsService {
for source in sources {
match source {
MetricsSource::CostMetrics {
url,
refresh_interval,
auth,
} => {
let data = fetch_cost_metrics(url, auth.as_ref(), &client).await;
info!(models = data.len(), url = %url, "fetched cost metrics");
*cost_data.write().await = data;
MetricsSource::Cost(cfg) => match cfg.provider {
CostProvider::Digitalocean => {
let aliases = cfg.model_aliases.clone().unwrap_or_default();
let data = fetch_do_pricing(&client, &aliases).await;
info!(models = data.len(), "fetched digitalocean pricing");
*cost_data.write().await = data;
if let Some(interval_secs) = refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let url = url.clone();
let auth = auth.clone();
let interval = Duration::from_secs(*interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data =
fetch_cost_metrics(&url, auth.as_ref(), &client_clone).await;
info!(models = data.len(), url = %url, "refreshed cost metrics");
*cost_clone.write().await = data;
}
});
if let Some(interval_secs) = cfg.refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let interval = Duration::from_secs(interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data = fetch_do_pricing(&client_clone, &aliases).await;
info!(models = data.len(), "refreshed digitalocean pricing");
*cost_clone.write().await = data;
}
});
}
}
}
MetricsSource::PrometheusMetrics {
url,
query,
refresh_interval,
} => {
let data = fetch_prometheus_metrics(url, query, &client).await;
info!(models = data.len(), url = %url, "fetched prometheus latency metrics");
*latency_data.write().await = data;
},
MetricsSource::Latency(cfg) => match cfg.provider {
LatencyProvider::Prometheus => {
let data = fetch_prometheus_metrics(&cfg.url, &cfg.query, &client).await;
info!(models = data.len(), url = %cfg.url, "fetched latency metrics");
*latency_data.write().await = data;
if let Some(interval_secs) = refresh_interval {
let latency_clone = Arc::clone(&latency_data);
let client_clone = client.clone();
let url = url.clone();
let query = query.clone();
let interval = Duration::from_secs(*interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data =
fetch_prometheus_metrics(&url, &query, &client_clone).await;
info!(models = data.len(), url = %url, "refreshed prometheus latency metrics");
*latency_clone.write().await = data;
}
});
if let Some(interval_secs) = cfg.refresh_interval {
let latency_clone = Arc::clone(&latency_data);
let client_clone = client.clone();
let url = cfg.url.clone();
let query = cfg.query.clone();
let interval = Duration::from_secs(interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data =
fetch_prometheus_metrics(&url, &query, &client_clone).await;
info!(models = data.len(), url = %url, "refreshed latency metrics");
*latency_clone.write().await = data;
}
});
}
}
}
MetricsSource::DigitalOceanPricing {
refresh_interval,
model_aliases,
} => {
let aliases = model_aliases.clone().unwrap_or_default();
let data = fetch_do_pricing(&client, &aliases).await;
info!(models = data.len(), "fetched digitalocean pricing");
*cost_data.write().await = data;
if let Some(interval_secs) = refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let interval = Duration::from_secs(*interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data = fetch_do_pricing(&client_clone, &aliases).await;
info!(models = data.len(), "refreshed digitalocean pricing");
*cost_clone.write().await = data;
}
});
}
}
},
}
}
@ -160,43 +133,6 @@ fn rank_by_ascending_metric(models: &[String], data: &HashMap<String, f64>) -> V
.collect()
}
#[derive(serde::Deserialize)]
struct CostEntry {
input_per_million: f64,
output_per_million: f64,
}
async fn fetch_cost_metrics(
url: &str,
auth: Option<&common::configuration::MetricsAuth>,
client: &reqwest::Client,
) -> HashMap<String, f64> {
let mut req = client.get(url);
if let Some(auth) = auth {
if auth.auth_type == "bearer" {
req = req.header("Authorization", format!("Bearer {}", auth.token));
} else {
warn!(auth_type = %auth.auth_type, "unsupported auth type for cost_metrics, skipping auth");
}
}
match req.send().await {
Ok(resp) => match resp.json::<HashMap<String, CostEntry>>().await {
Ok(data) => data
.into_iter()
.map(|(k, v)| (k, v.input_per_million + v.output_per_million))
.collect(),
Err(err) => {
warn!(error = %err, url = %url, "failed to parse cost metrics response");
HashMap::new()
}
},
Err(err) => {
warn!(error = %err, url = %url, "failed to fetch cost metrics");
HashMap::new()
}
}
}
#[derive(serde::Deserialize)]
struct DoModelList {
data: Vec<DoModel>,

View file

@ -127,32 +127,39 @@ pub struct TopLevelRoutingPreference {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsAuth {
#[serde(rename = "type")]
pub auth_type: String, // only "bearer" supported
pub token: String,
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MetricsSource {
Cost(CostMetricsConfig),
Latency(LatencyMetricsConfig),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MetricsSource {
CostMetrics {
url: String,
refresh_interval: Option<u64>,
auth: Option<MetricsAuth>,
},
PrometheusMetrics {
url: String,
query: String,
refresh_interval: Option<u64>,
},
#[serde(rename = "digitalocean_pricing")]
DigitalOceanPricing {
refresh_interval: Option<u64>,
/// Map DO catalog keys (`lowercase(creator)/model_id`) to Plano model names.
/// Example: `openai/openai-gpt-oss-120b: openai/gpt-4o`
model_aliases: Option<HashMap<String, String>>,
},
pub struct CostMetricsConfig {
pub provider: CostProvider,
pub refresh_interval: Option<u64>,
/// Map DO catalog keys (`lowercase(creator)/model_id`) to Plano model names.
/// Example: `openai/openai-gpt-oss-120b: openai/gpt-4o`
pub model_aliases: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CostProvider {
Digitalocean,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LatencyMetricsConfig {
pub provider: LatencyProvider,
pub url: String,
pub query: String,
pub refresh_interval: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LatencyProvider {
Prometheus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]