feat: make model pricing source configurable (models.dev + DigitalOcean) (#971)

This commit is contained in:
Musa 2026-06-24 10:14:12 -07:00 committed by GitHub
parent 5cc4c4ee77
commit 558df0307c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 687 additions and 48 deletions

View file

@ -9,6 +9,7 @@ use tokio::sync::RwLock;
use tracing::{debug, info, warn};
const DO_PRICING_URL: &str = "https://api.digitalocean.com/v2/gen-ai/models/catalog";
const MODELS_DEV_URL: &str = "https://models.dev/api.json";
pub struct ModelMetricsService {
cost: Arc<RwLock<HashMap<String, f64>>>,
@ -22,28 +23,35 @@ impl ModelMetricsService {
for source in sources {
match source {
MetricsSource::Cost(cfg) => match cfg.provider {
CostProvider::Digitalocean => {
let aliases = cfg.model_aliases.clone().unwrap_or_default();
let data = fetch_do_pricing(&client, &aliases).await;
info!(models = data.len(), "fetched digitalocean pricing");
*cost_data.write().await = data;
MetricsSource::Cost(cfg) => {
let provider = cfg.provider.clone();
let url = cfg
.url
.clone()
.unwrap_or_else(|| default_cost_url(&provider).to_string());
let aliases = cfg.model_aliases.clone().unwrap_or_default();
let provider_name = cost_provider_name(&provider);
if let Some(interval_secs) = cfg.refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let interval = Duration::from_secs(interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data = fetch_do_pricing(&client_clone, &aliases).await;
info!(models = data.len(), "refreshed digitalocean pricing");
*cost_clone.write().await = data;
}
});
}
let data = fetch_cost_pricing(&provider, &url, &client, &aliases).await;
info!(models = data.len(), provider = provider_name, url = %url, "fetched cost pricing");
*cost_data.write().await = data;
if let Some(interval_secs) = cfg.refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let interval = Duration::from_secs(interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data =
fetch_cost_pricing(&provider, &url, &client_clone, &aliases)
.await;
info!(models = data.len(), provider = provider_name, url = %url, "refreshed cost pricing");
*cost_clone.write().await = data;
}
});
}
},
}
MetricsSource::Latency(cfg) => match cfg.provider {
LatencyProvider::Prometheus => {
let data = fetch_prometheus_metrics(&cfg.url, &cfg.query, &client).await;
@ -165,11 +173,55 @@ struct DoPricing {
output_price_per_million: Option<f64>,
}
async fn fetch_do_pricing(
#[derive(serde::Deserialize)]
struct ModelsDevProvider {
#[serde(default)]
models: HashMap<String, ModelsDevModel>,
}
#[derive(serde::Deserialize)]
struct ModelsDevModel {
cost: Option<ModelsDevCost>,
}
#[derive(serde::Deserialize)]
struct ModelsDevCost {
input: Option<f64>,
output: Option<f64>,
}
fn default_cost_url(provider: &CostProvider) -> &'static str {
match provider {
CostProvider::Digitalocean => DO_PRICING_URL,
CostProvider::ModelsDev => MODELS_DEV_URL,
}
}
fn cost_provider_name(provider: &CostProvider) -> &'static str {
match provider {
CostProvider::Digitalocean => "digitalocean",
CostProvider::ModelsDev => "models.dev",
}
}
async fn fetch_cost_pricing(
provider: &CostProvider,
url: &str,
client: &reqwest::Client,
aliases: &HashMap<String, String>,
) -> HashMap<String, f64> {
match client.get(DO_PRICING_URL).send().await {
match provider {
CostProvider::Digitalocean => fetch_do_pricing(url, client, aliases).await,
CostProvider::ModelsDev => fetch_models_dev_pricing(url, client, aliases).await,
}
}
async fn fetch_do_pricing(
url: &str,
client: &reqwest::Client,
aliases: &HashMap<String, String>,
) -> HashMap<String, f64> {
match client.get(url).send().await {
Ok(resp) => match resp.json::<DoModelList>().await {
Ok(list) => list
.data
@ -184,17 +236,66 @@ async fn fetch_do_pricing(
})
.collect(),
Err(err) => {
warn!(error = %err, url = DO_PRICING_URL, "failed to parse digitalocean pricing response");
warn!(error = %err, url = %url, "failed to parse digitalocean pricing response");
HashMap::new()
}
},
Err(err) => {
warn!(error = %err, url = DO_PRICING_URL, "failed to fetch digitalocean pricing");
warn!(error = %err, url = %url, "failed to fetch digitalocean pricing");
HashMap::new()
}
}
}
/// models.dev publishes a top-level object keyed by provider id; each provider
/// carries a `models` map whose keys are `creator/model` ids and whose `cost`
/// block holds per-million USD rates. We sum input + output (mirroring the DO
/// ranking metric) and key the result by `creator/model_id` so it lines up with
/// Plano's `provider/model` routing names.
async fn fetch_models_dev_pricing(
url: &str,
client: &reqwest::Client,
aliases: &HashMap<String, String>,
) -> HashMap<String, f64> {
match client.get(url).send().await {
Ok(resp) => match resp.json::<HashMap<String, ModelsDevProvider>>().await {
Ok(providers) => parse_models_dev_pricing(providers, aliases),
Err(err) => {
warn!(error = %err, url = %url, "failed to parse models.dev pricing response");
HashMap::new()
}
},
Err(err) => {
warn!(error = %err, url = %url, "failed to fetch models.dev pricing");
HashMap::new()
}
}
}
fn parse_models_dev_pricing(
providers: HashMap<String, ModelsDevProvider>,
aliases: &HashMap<String, String>,
) -> HashMap<String, f64> {
let mut out = HashMap::new();
for (provider_id, provider) in providers {
for (model_key, model) in provider.models {
let Some(cost) = model.cost else { continue };
let (Some(input), Some(output)) = (cost.input, cost.output) else {
continue;
};
// First-party providers use bare model keys (`claude-opus-4-5`),
// so compose `provider/model` to line up with Plano routing names.
let raw_key = format!("{provider_id}/{model_key}");
let total = input + output;
let key = aliases.get(&raw_key).cloned().unwrap_or(raw_key);
out.insert(key, total);
// Also register the bare model id as a fallback lookup.
out.entry(model_key).or_insert(total);
}
}
out
}
#[derive(serde::Deserialize)]
struct PrometheusResponse {
data: PrometheusData,
@ -368,6 +469,50 @@ mod tests {
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
}
#[test]
fn test_parse_models_dev_pricing_composes_provider_keys() {
let json = r#"{
"anthropic": {
"models": {
"claude-opus-4-5": {"cost": {"input": 5.0, "output": 25.0}}
}
},
"groq": {
"models": {
"llama-3.3-70b-versatile": {"cost": {"input": 0.59, "output": 0.79}},
"whisper-large-v3-turbo": {"cost": null}
}
}
}"#;
let providers: HashMap<String, ModelsDevProvider> = serde_json::from_str(json).unwrap();
let aliases = HashMap::new();
let prices = parse_models_dev_pricing(providers, &aliases);
assert_eq!(prices.get("anthropic/claude-opus-4-5"), Some(&30.0));
assert_eq!(prices.get("groq/llama-3.3-70b-versatile"), Some(&1.38));
// bare fallback also registered
assert_eq!(prices.get("claude-opus-4-5"), Some(&30.0));
// models with no cost block are skipped
assert!(!prices.contains_key("groq/whisper-large-v3-turbo"));
}
#[test]
fn test_parse_models_dev_pricing_applies_aliases() {
let json = r#"{
"openai": {"models": {"gpt-oss-120b": {"cost": {"input": 1.0, "output": 2.0}}}}
}"#;
let providers: HashMap<String, ModelsDevProvider> = serde_json::from_str(json).unwrap();
let mut aliases = HashMap::new();
aliases.insert(
"openai/gpt-oss-120b".to_string(),
"openai/gpt-4o".to_string(),
);
let prices = parse_models_dev_pricing(providers, &aliases);
assert_eq!(prices.get("openai/gpt-4o"), Some(&3.0));
assert!(!prices.contains_key("openai/gpt-oss-120b"));
}
#[test]
fn test_rank_by_ascending_metric_nan_treated_as_missing() {
let models = vec![

View file

@ -177,8 +177,13 @@ pub enum MetricsSource {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostMetricsConfig {
pub provider: CostProvider,
/// Optional override for the pricing catalog endpoint. When omitted, a
/// sensible default is used per provider.
pub url: Option<String>,
pub refresh_interval: Option<u64>,
/// Map DO catalog keys (`lowercase(creator)/model_id`) to Plano model names.
/// Map catalog keys to Plano model names used in `routing_preferences`.
/// DigitalOcean keys look like `lowercase(creator)/model_id`; models.dev
/// keys look like `creator/model_id`.
/// Example: `openai/openai-gpt-oss-120b: openai/gpt-4o`
pub model_aliases: Option<HashMap<String, String>>,
}
@ -187,6 +192,8 @@ pub struct CostMetricsConfig {
#[serde(rename_all = "snake_case")]
pub enum CostProvider {
Digitalocean,
#[serde(rename = "models.dev")]
ModelsDev,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -741,6 +748,51 @@ mod test {
}
}
#[test]
fn test_deserialize_models_dev_cost_source() {
let yaml = r#"
- type: cost
provider: models.dev
url: https://models.dev/api.json
refresh_interval: 3600
model_aliases:
openai/gpt-oss-120b: openai/gpt-4o
"#;
let sources: Vec<super::MetricsSource> = serde_yaml::from_str(yaml).unwrap();
assert_eq!(sources.len(), 1);
match &sources[0] {
super::MetricsSource::Cost(cfg) => {
assert!(matches!(cfg.provider, super::CostProvider::ModelsDev));
assert_eq!(cfg.url.as_deref(), Some("https://models.dev/api.json"));
assert_eq!(cfg.refresh_interval, Some(3600));
assert_eq!(
cfg.model_aliases
.as_ref()
.and_then(|m| m.get("openai/gpt-oss-120b"))
.map(String::as_str),
Some("openai/gpt-4o")
);
}
other => panic!("expected cost source, got {other:?}"),
}
}
#[test]
fn test_deserialize_digitalocean_cost_source_without_url() {
let yaml = r#"
- type: cost
provider: digitalocean
"#;
let sources: Vec<super::MetricsSource> = serde_yaml::from_str(yaml).unwrap();
match &sources[0] {
super::MetricsSource::Cost(cfg) => {
assert!(matches!(cfg.provider, super::CostProvider::Digitalocean));
assert_eq!(cfg.url, None);
}
other => panic!("expected cost source, got {other:?}"),
}
}
#[test]
fn test_into_models_filters_internal_providers() {
let providers = vec![