restructure model_metrics_sources to type + provider (#855)

This commit is contained in:
Adil Hafeez 2026-03-30 17:12:20 -07:00 committed by GitHub
parent e5751d6b13
commit af98c11a6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 171 additions and 455 deletions

View file

@ -2,9 +2,11 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use common::configuration::{MetricsSource, SelectionPolicy, SelectionPreference};
use common::configuration::{
CostProvider, LatencyProvider, MetricsSource, SelectionPolicy, SelectionPreference,
};
use tokio::sync::RwLock;
use tracing::{info, warn};
use tracing::{debug, info, warn};
const DO_PRICING_URL: &str = "https://api.digitalocean.com/v2/gen-ai/models/catalog";
@ -20,81 +22,52 @@ impl ModelMetricsService {
for source in sources {
match source {
MetricsSource::CostMetrics {
url,
refresh_interval,
auth,
} => {
let data = fetch_cost_metrics(url, auth.as_ref(), &client).await;
info!(models = data.len(), url = %url, "fetched cost metrics");
*cost_data.write().await = data;
MetricsSource::Cost(cfg) => match cfg.provider {
CostProvider::Digitalocean => {
let aliases = cfg.model_aliases.clone().unwrap_or_default();
let data = fetch_do_pricing(&client, &aliases).await;
info!(models = data.len(), "fetched digitalocean pricing");
*cost_data.write().await = data;
if let Some(interval_secs) = refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let url = url.clone();
let auth = auth.clone();
let interval = Duration::from_secs(*interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data =
fetch_cost_metrics(&url, auth.as_ref(), &client_clone).await;
info!(models = data.len(), url = %url, "refreshed cost metrics");
*cost_clone.write().await = data;
}
});
if let Some(interval_secs) = cfg.refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let interval = Duration::from_secs(interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data = fetch_do_pricing(&client_clone, &aliases).await;
info!(models = data.len(), "refreshed digitalocean pricing");
*cost_clone.write().await = data;
}
});
}
}
}
MetricsSource::PrometheusMetrics {
url,
query,
refresh_interval,
} => {
let data = fetch_prometheus_metrics(url, query, &client).await;
info!(models = data.len(), url = %url, "fetched prometheus latency metrics");
*latency_data.write().await = data;
},
MetricsSource::Latency(cfg) => match cfg.provider {
LatencyProvider::Prometheus => {
let data = fetch_prometheus_metrics(&cfg.url, &cfg.query, &client).await;
info!(models = data.len(), url = %cfg.url, "fetched latency metrics");
*latency_data.write().await = data;
if let Some(interval_secs) = refresh_interval {
let latency_clone = Arc::clone(&latency_data);
let client_clone = client.clone();
let url = url.clone();
let query = query.clone();
let interval = Duration::from_secs(*interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data =
fetch_prometheus_metrics(&url, &query, &client_clone).await;
info!(models = data.len(), url = %url, "refreshed prometheus latency metrics");
*latency_clone.write().await = data;
}
});
if let Some(interval_secs) = cfg.refresh_interval {
let latency_clone = Arc::clone(&latency_data);
let client_clone = client.clone();
let url = cfg.url.clone();
let query = cfg.query.clone();
let interval = Duration::from_secs(interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data =
fetch_prometheus_metrics(&url, &query, &client_clone).await;
info!(models = data.len(), url = %url, "refreshed latency metrics");
*latency_clone.write().await = data;
}
});
}
}
}
MetricsSource::DigitalOceanPricing {
refresh_interval,
model_aliases,
} => {
let aliases = model_aliases.clone().unwrap_or_default();
let data = fetch_do_pricing(&client, &aliases).await;
info!(models = data.len(), "fetched digitalocean pricing");
*cost_data.write().await = data;
if let Some(interval_secs) = refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let interval = Duration::from_secs(*interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data = fetch_do_pricing(&client_clone, &aliases).await;
info!(models = data.len(), "refreshed digitalocean pricing");
*cost_clone.write().await = data;
}
});
}
}
},
}
}
@ -107,24 +80,32 @@ impl ModelMetricsService {
/// Rank `models` by `policy`, returning them in preference order.
/// Models with no metric data are appended at the end in their original order.
pub async fn rank_models(&self, models: &[String], policy: &SelectionPolicy) -> Vec<String> {
let cost_data = self.cost.read().await;
let latency_data = self.latency.read().await;
debug!(
input_models = ?models,
cost_data = ?cost_data.iter().collect::<Vec<_>>(),
latency_data = ?latency_data.iter().collect::<Vec<_>>(),
prefer = ?policy.prefer,
"rank_models called"
);
match policy.prefer {
SelectionPreference::Cheapest => {
let data = self.cost.read().await;
for m in models {
if !data.contains_key(m.as_str()) {
if !cost_data.contains_key(m.as_str()) {
warn!(model = %m, "no cost data for model — ranking last (prefer: cheapest)");
}
}
rank_by_ascending_metric(models, &data)
rank_by_ascending_metric(models, &cost_data)
}
SelectionPreference::Fastest => {
let data = self.latency.read().await;
for m in models {
if !data.contains_key(m.as_str()) {
if !latency_data.contains_key(m.as_str()) {
warn!(model = %m, "no latency data for model — ranking last (prefer: fastest)");
}
}
rank_by_ascending_metric(models, &data)
rank_by_ascending_metric(models, &latency_data)
}
SelectionPreference::None => models.to_vec(),
}
@ -144,13 +125,20 @@ impl ModelMetricsService {
fn rank_by_ascending_metric(models: &[String], data: &HashMap<String, f64>) -> Vec<String> {
let mut with_data: Vec<(&String, f64)> = models
.iter()
.filter_map(|m| data.get(m.as_str()).map(|v| (m, *v)))
.filter_map(|m| {
let v = *data.get(m.as_str())?;
if v.is_nan() {
None
} else {
Some((m, v))
}
})
.collect();
with_data.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let without_data: Vec<&String> = models
.iter()
.filter(|m| !data.contains_key(m.as_str()))
.filter(|m| data.get(m.as_str()).is_none_or(|v| v.is_nan()))
.collect();
with_data
@ -160,43 +148,6 @@ fn rank_by_ascending_metric(models: &[String], data: &HashMap<String, f64>) -> V
.collect()
}
#[derive(serde::Deserialize)]
struct CostEntry {
input_per_million: f64,
output_per_million: f64,
}
async fn fetch_cost_metrics(
url: &str,
auth: Option<&common::configuration::MetricsAuth>,
client: &reqwest::Client,
) -> HashMap<String, f64> {
let mut req = client.get(url);
if let Some(auth) = auth {
if auth.auth_type == "bearer" {
req = req.header("Authorization", format!("Bearer {}", auth.token));
} else {
warn!(auth_type = %auth.auth_type, "unsupported auth type for cost_metrics, skipping auth");
}
}
match req.send().await {
Ok(resp) => match resp.json::<HashMap<String, CostEntry>>().await {
Ok(data) => data
.into_iter()
.map(|(k, v)| (k, v.input_per_million + v.output_per_million))
.collect(),
Err(err) => {
warn!(error = %err, url = %url, "failed to parse cost metrics response");
HashMap::new()
}
},
Err(err) => {
warn!(error = %err, url = %url, "failed to fetch cost metrics");
HashMap::new()
}
}
}
#[derive(serde::Deserialize)]
struct DoModelList {
data: Vec<DoModel>,
@ -416,4 +367,22 @@ mod tests {
// none → original order, despite gpt-4o-mini being cheaper
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
}
#[test]
fn test_rank_by_ascending_metric_nan_treated_as_missing() {
let models = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
];
let mut data = HashMap::new();
data.insert("a".to_string(), f64::NAN);
data.insert("b".to_string(), 0.5);
data.insert("c".to_string(), 0.1);
// "d" has no entry at all
let result = rank_by_ascending_metric(&models, &data);
// c (0.1) < b (0.5), then NaN "a" and missing "d" appended in original order
assert_eq!(result, vec!["c", "b", "a", "d"]);
}
}