mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
feat: make model pricing source configurable (models.dev + DigitalOcean) (#971)
This commit is contained in:
parent
5cc4c4ee77
commit
558df0307c
9 changed files with 687 additions and 48 deletions
|
|
@ -9,6 +9,7 @@ use tokio::sync::RwLock;
|
|||
use tracing::{debug, info, warn};
|
||||
|
||||
const DO_PRICING_URL: &str = "https://api.digitalocean.com/v2/gen-ai/models/catalog";
|
||||
const MODELS_DEV_URL: &str = "https://models.dev/api.json";
|
||||
|
||||
pub struct ModelMetricsService {
|
||||
cost: Arc<RwLock<HashMap<String, f64>>>,
|
||||
|
|
@ -22,28 +23,35 @@ impl ModelMetricsService {
|
|||
|
||||
for source in sources {
|
||||
match source {
|
||||
MetricsSource::Cost(cfg) => match cfg.provider {
|
||||
CostProvider::Digitalocean => {
|
||||
let aliases = cfg.model_aliases.clone().unwrap_or_default();
|
||||
let data = fetch_do_pricing(&client, &aliases).await;
|
||||
info!(models = data.len(), "fetched digitalocean pricing");
|
||||
*cost_data.write().await = data;
|
||||
MetricsSource::Cost(cfg) => {
|
||||
let provider = cfg.provider.clone();
|
||||
let url = cfg
|
||||
.url
|
||||
.clone()
|
||||
.unwrap_or_else(|| default_cost_url(&provider).to_string());
|
||||
let aliases = cfg.model_aliases.clone().unwrap_or_default();
|
||||
let provider_name = cost_provider_name(&provider);
|
||||
|
||||
if let Some(interval_secs) = cfg.refresh_interval {
|
||||
let cost_clone = Arc::clone(&cost_data);
|
||||
let client_clone = client.clone();
|
||||
let interval = Duration::from_secs(interval_secs);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
let data = fetch_do_pricing(&client_clone, &aliases).await;
|
||||
info!(models = data.len(), "refreshed digitalocean pricing");
|
||||
*cost_clone.write().await = data;
|
||||
}
|
||||
});
|
||||
}
|
||||
let data = fetch_cost_pricing(&provider, &url, &client, &aliases).await;
|
||||
info!(models = data.len(), provider = provider_name, url = %url, "fetched cost pricing");
|
||||
*cost_data.write().await = data;
|
||||
|
||||
if let Some(interval_secs) = cfg.refresh_interval {
|
||||
let cost_clone = Arc::clone(&cost_data);
|
||||
let client_clone = client.clone();
|
||||
let interval = Duration::from_secs(interval_secs);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
let data =
|
||||
fetch_cost_pricing(&provider, &url, &client_clone, &aliases)
|
||||
.await;
|
||||
info!(models = data.len(), provider = provider_name, url = %url, "refreshed cost pricing");
|
||||
*cost_clone.write().await = data;
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
}
|
||||
MetricsSource::Latency(cfg) => match cfg.provider {
|
||||
LatencyProvider::Prometheus => {
|
||||
let data = fetch_prometheus_metrics(&cfg.url, &cfg.query, &client).await;
|
||||
|
|
@ -165,11 +173,55 @@ struct DoPricing {
|
|||
output_price_per_million: Option<f64>,
|
||||
}
|
||||
|
||||
async fn fetch_do_pricing(
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ModelsDevProvider {
|
||||
#[serde(default)]
|
||||
models: HashMap<String, ModelsDevModel>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ModelsDevModel {
|
||||
cost: Option<ModelsDevCost>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ModelsDevCost {
|
||||
input: Option<f64>,
|
||||
output: Option<f64>,
|
||||
}
|
||||
|
||||
fn default_cost_url(provider: &CostProvider) -> &'static str {
|
||||
match provider {
|
||||
CostProvider::Digitalocean => DO_PRICING_URL,
|
||||
CostProvider::ModelsDev => MODELS_DEV_URL,
|
||||
}
|
||||
}
|
||||
|
||||
fn cost_provider_name(provider: &CostProvider) -> &'static str {
|
||||
match provider {
|
||||
CostProvider::Digitalocean => "digitalocean",
|
||||
CostProvider::ModelsDev => "models.dev",
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_cost_pricing(
|
||||
provider: &CostProvider,
|
||||
url: &str,
|
||||
client: &reqwest::Client,
|
||||
aliases: &HashMap<String, String>,
|
||||
) -> HashMap<String, f64> {
|
||||
match client.get(DO_PRICING_URL).send().await {
|
||||
match provider {
|
||||
CostProvider::Digitalocean => fetch_do_pricing(url, client, aliases).await,
|
||||
CostProvider::ModelsDev => fetch_models_dev_pricing(url, client, aliases).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_do_pricing(
|
||||
url: &str,
|
||||
client: &reqwest::Client,
|
||||
aliases: &HashMap<String, String>,
|
||||
) -> HashMap<String, f64> {
|
||||
match client.get(url).send().await {
|
||||
Ok(resp) => match resp.json::<DoModelList>().await {
|
||||
Ok(list) => list
|
||||
.data
|
||||
|
|
@ -184,17 +236,66 @@ async fn fetch_do_pricing(
|
|||
})
|
||||
.collect(),
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = DO_PRICING_URL, "failed to parse digitalocean pricing response");
|
||||
warn!(error = %err, url = %url, "failed to parse digitalocean pricing response");
|
||||
HashMap::new()
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = DO_PRICING_URL, "failed to fetch digitalocean pricing");
|
||||
warn!(error = %err, url = %url, "failed to fetch digitalocean pricing");
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// models.dev publishes a top-level object keyed by provider id; each provider
|
||||
/// carries a `models` map whose keys are `creator/model` ids and whose `cost`
|
||||
/// block holds per-million USD rates. We sum input + output (mirroring the DO
|
||||
/// ranking metric) and key the result by `creator/model_id` so it lines up with
|
||||
/// Plano's `provider/model` routing names.
|
||||
async fn fetch_models_dev_pricing(
|
||||
url: &str,
|
||||
client: &reqwest::Client,
|
||||
aliases: &HashMap<String, String>,
|
||||
) -> HashMap<String, f64> {
|
||||
match client.get(url).send().await {
|
||||
Ok(resp) => match resp.json::<HashMap<String, ModelsDevProvider>>().await {
|
||||
Ok(providers) => parse_models_dev_pricing(providers, aliases),
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = %url, "failed to parse models.dev pricing response");
|
||||
HashMap::new()
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = %url, "failed to fetch models.dev pricing");
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_models_dev_pricing(
|
||||
providers: HashMap<String, ModelsDevProvider>,
|
||||
aliases: &HashMap<String, String>,
|
||||
) -> HashMap<String, f64> {
|
||||
let mut out = HashMap::new();
|
||||
for (provider_id, provider) in providers {
|
||||
for (model_key, model) in provider.models {
|
||||
let Some(cost) = model.cost else { continue };
|
||||
let (Some(input), Some(output)) = (cost.input, cost.output) else {
|
||||
continue;
|
||||
};
|
||||
// First-party providers use bare model keys (`claude-opus-4-5`),
|
||||
// so compose `provider/model` to line up with Plano routing names.
|
||||
let raw_key = format!("{provider_id}/{model_key}");
|
||||
let total = input + output;
|
||||
let key = aliases.get(&raw_key).cloned().unwrap_or(raw_key);
|
||||
out.insert(key, total);
|
||||
// Also register the bare model id as a fallback lookup.
|
||||
out.entry(model_key).or_insert(total);
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct PrometheusResponse {
|
||||
data: PrometheusData,
|
||||
|
|
@ -368,6 +469,50 @@ mod tests {
|
|||
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_models_dev_pricing_composes_provider_keys() {
|
||||
let json = r#"{
|
||||
"anthropic": {
|
||||
"models": {
|
||||
"claude-opus-4-5": {"cost": {"input": 5.0, "output": 25.0}}
|
||||
}
|
||||
},
|
||||
"groq": {
|
||||
"models": {
|
||||
"llama-3.3-70b-versatile": {"cost": {"input": 0.59, "output": 0.79}},
|
||||
"whisper-large-v3-turbo": {"cost": null}
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
let providers: HashMap<String, ModelsDevProvider> = serde_json::from_str(json).unwrap();
|
||||
let aliases = HashMap::new();
|
||||
let prices = parse_models_dev_pricing(providers, &aliases);
|
||||
|
||||
assert_eq!(prices.get("anthropic/claude-opus-4-5"), Some(&30.0));
|
||||
assert_eq!(prices.get("groq/llama-3.3-70b-versatile"), Some(&1.38));
|
||||
// bare fallback also registered
|
||||
assert_eq!(prices.get("claude-opus-4-5"), Some(&30.0));
|
||||
// models with no cost block are skipped
|
||||
assert!(!prices.contains_key("groq/whisper-large-v3-turbo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_models_dev_pricing_applies_aliases() {
|
||||
let json = r#"{
|
||||
"openai": {"models": {"gpt-oss-120b": {"cost": {"input": 1.0, "output": 2.0}}}}
|
||||
}"#;
|
||||
let providers: HashMap<String, ModelsDevProvider> = serde_json::from_str(json).unwrap();
|
||||
let mut aliases = HashMap::new();
|
||||
aliases.insert(
|
||||
"openai/gpt-oss-120b".to_string(),
|
||||
"openai/gpt-4o".to_string(),
|
||||
);
|
||||
let prices = parse_models_dev_pricing(providers, &aliases);
|
||||
|
||||
assert_eq!(prices.get("openai/gpt-4o"), Some(&3.0));
|
||||
assert!(!prices.contains_key("openai/gpt-oss-120b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rank_by_ascending_metric_nan_treated_as_missing() {
|
||||
let models = vec![
|
||||
|
|
|
|||
|
|
@ -177,8 +177,13 @@ pub enum MetricsSource {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CostMetricsConfig {
|
||||
pub provider: CostProvider,
|
||||
/// Optional override for the pricing catalog endpoint. When omitted, a
|
||||
/// sensible default is used per provider.
|
||||
pub url: Option<String>,
|
||||
pub refresh_interval: Option<u64>,
|
||||
/// Map DO catalog keys (`lowercase(creator)/model_id`) to Plano model names.
|
||||
/// Map catalog keys to Plano model names used in `routing_preferences`.
|
||||
/// DigitalOcean keys look like `lowercase(creator)/model_id`; models.dev
|
||||
/// keys look like `creator/model_id`.
|
||||
/// Example: `openai/openai-gpt-oss-120b: openai/gpt-4o`
|
||||
pub model_aliases: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
|
@ -187,6 +192,8 @@ pub struct CostMetricsConfig {
|
|||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CostProvider {
|
||||
Digitalocean,
|
||||
#[serde(rename = "models.dev")]
|
||||
ModelsDev,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -741,6 +748,51 @@ mod test {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_models_dev_cost_source() {
|
||||
let yaml = r#"
|
||||
- type: cost
|
||||
provider: models.dev
|
||||
url: https://models.dev/api.json
|
||||
refresh_interval: 3600
|
||||
model_aliases:
|
||||
openai/gpt-oss-120b: openai/gpt-4o
|
||||
"#;
|
||||
let sources: Vec<super::MetricsSource> = serde_yaml::from_str(yaml).unwrap();
|
||||
assert_eq!(sources.len(), 1);
|
||||
match &sources[0] {
|
||||
super::MetricsSource::Cost(cfg) => {
|
||||
assert!(matches!(cfg.provider, super::CostProvider::ModelsDev));
|
||||
assert_eq!(cfg.url.as_deref(), Some("https://models.dev/api.json"));
|
||||
assert_eq!(cfg.refresh_interval, Some(3600));
|
||||
assert_eq!(
|
||||
cfg.model_aliases
|
||||
.as_ref()
|
||||
.and_then(|m| m.get("openai/gpt-oss-120b"))
|
||||
.map(String::as_str),
|
||||
Some("openai/gpt-4o")
|
||||
);
|
||||
}
|
||||
other => panic!("expected cost source, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_digitalocean_cost_source_without_url() {
|
||||
let yaml = r#"
|
||||
- type: cost
|
||||
provider: digitalocean
|
||||
"#;
|
||||
let sources: Vec<super::MetricsSource> = serde_yaml::from_str(yaml).unwrap();
|
||||
match &sources[0] {
|
||||
super::MetricsSource::Cost(cfg) => {
|
||||
assert!(matches!(cfg.provider, super::CostProvider::Digitalocean));
|
||||
assert_eq!(cfg.url, None);
|
||||
}
|
||||
other => panic!("expected cost source, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_into_models_filters_internal_providers() {
|
||||
let providers = vec![
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue