add model_aliases to digitalocean_pricing, use model_id as key, warn on missing data at request time

This commit is contained in:
Adil Hafeez 2026-03-27 17:32:15 -07:00
parent bd335cd8bd
commit a7903d9271
6 changed files with 59 additions and 20 deletions

View file

@ -193,9 +193,7 @@ async fn init_app_state(
let provider_model_names: std::collections::HashSet<&str> = config
.model_providers
.iter()
.flat_map(|p| {
std::iter::once(p.name.as_str()).chain(p.model.as_deref())
})
.flat_map(|p| std::iter::once(p.name.as_str()).chain(p.model.as_deref()))
.collect();
for pref in route_prefs {
for model in &pref.models {

View file

@ -72,8 +72,12 @@ impl ModelMetricsService {
});
}
}
MetricsSource::DigitalOceanPricing { refresh_interval } => {
let data = fetch_do_pricing(&client).await;
MetricsSource::DigitalOceanPricing {
refresh_interval,
model_aliases,
} => {
let aliases = model_aliases.clone().unwrap_or_default();
let data = fetch_do_pricing(&client, &aliases).await;
info!(models = data.len(), "fetched digitalocean pricing");
*cost_data.write().await = data;
@ -84,7 +88,7 @@ impl ModelMetricsService {
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data = fetch_do_pricing(&client_clone).await;
let data = fetch_do_pricing(&client_clone, &aliases).await;
info!(models = data.len(), "refreshed digitalocean pricing");
*cost_clone.write().await = data;
}
@ -106,10 +110,20 @@ impl ModelMetricsService {
match policy.prefer {
SelectionPreference::Cheapest => {
let data = self.cost.read().await;
for m in models {
if !data.contains_key(m.as_str()) {
warn!(model = %m, "no cost data for model — ranking last (prefer: cheapest)");
}
}
rank_by_ascending_metric(models, &data)
}
SelectionPreference::Fastest => {
let data = self.latency.read().await;
for m in models {
if !data.contains_key(m.as_str()) {
warn!(model = %m, "no latency data for model — ranking last (prefer: fastest)");
}
}
rank_by_ascending_metric(models, &data)
}
SelectionPreference::Random => shuffle(models),
@ -210,27 +224,31 @@ struct DoModelList {
#[derive(serde::Deserialize)]
struct DoModel {
model_id: String,
creator: String,
pricing: DoPricing,
pricing: Option<DoPricing>,
}
#[derive(serde::Deserialize)]
struct DoPricing {
input_price_per_million: f64,
output_price_per_million: f64,
input_price_per_million: Option<f64>,
output_price_per_million: Option<f64>,
}
async fn fetch_do_pricing(client: &reqwest::Client) -> HashMap<String, f64> {
async fn fetch_do_pricing(
client: &reqwest::Client,
aliases: &HashMap<String, String>,
) -> HashMap<String, f64> {
match client.get(DO_PRICING_URL).send().await {
Ok(resp) => match resp.json::<DoModelList>().await {
Ok(list) => list
.data
.into_iter()
.map(|m| {
let key = format!("{}/{}", m.creator.to_lowercase(), m.model_id);
let cost =
m.pricing.input_price_per_million + m.pricing.output_price_per_million;
(key, cost)
.filter_map(|m| {
let pricing = m.pricing?;
let raw_key = m.model_id.clone();
let key = aliases.get(&raw_key).cloned().unwrap_or(raw_key);
let cost = pricing.input_price_per_million.unwrap_or(0.0)
+ pricing.output_price_per_million.unwrap_or(0.0);
Some((key, cost))
})
.collect(),
Err(err) => {

View file

@ -150,6 +150,9 @@ pub enum MetricsSource {
#[serde(rename = "digitalocean_pricing")]
DigitalOceanPricing {
refresh_interval: Option<u64>,
/// Map DO catalog keys (`lowercase(creator)/model_id`) to Plano model names.
/// Example: `openai/openai-gpt-oss-120b: openai/gpt-4o`
model_aliases: Option<HashMap<String, String>>,
},
}