mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Adding support for wildcard models in the model_providers config (#696)
* cleaning up plano cli commands * adding support for wildcard model providers * fixing compile errors * fixing bugs related to default model provider, provider hint and duplicates in the model provider list * fixed cargo fmt issues * updating tests to always include the model id * using default for the prompt_gateway path * fixed the model name, as gpt-5-mini-2025-08-07 wasn't in the config * making sure that all aliases and models match the config * fixed the config generator to allow for base_url providers LLMs to include wildcard models * re-ran the models list utility and added a shell script to run it * updating docs to mention wildcard model providers * updated provider_models.json to yaml, added that file to our docs for reference * updating the build docs to use the new root-based build --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
parent
8428b06e22
commit
2941392ed1
42 changed files with 1748 additions and 202 deletions
95
crates/Cargo.lock
generated
95
crates/Cargo.lock
generated
|
|
@ -459,6 +459,35 @@ dependencies = [
|
|||
"urlencoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cookie"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
"time",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cookie_store"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fc4bff745c9b4c7fb1e97b25d13153da2bc7796260141df62378998d070207f"
|
||||
dependencies = [
|
||||
"cookie",
|
||||
"document-features",
|
||||
"idna",
|
||||
"indexmap 2.9.0",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"time",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
|
|
@ -628,6 +657,15 @@ dependencies = [
|
|||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "document-features"
|
||||
version = "0.2.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61"
|
||||
dependencies = [
|
||||
"litrs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "duration-string"
|
||||
version = "0.3.0"
|
||||
|
|
@ -999,11 +1037,14 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"aws-smithy-eventstream",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"serde_yaml",
|
||||
"thiserror 2.0.12",
|
||||
"ureq",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
|
|
@ -1479,6 +1520,12 @@ version = "0.8.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
|
||||
|
||||
[[package]]
|
||||
name = "litrs"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092"
|
||||
|
||||
[[package]]
|
||||
name = "llm_gateway"
|
||||
version = "0.1.0"
|
||||
|
|
@ -2417,6 +2464,7 @@ version = "0.23.27"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321"
|
||||
dependencies = [
|
||||
"log",
|
||||
"once_cell",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
|
|
@ -3385,6 +3433,38 @@ version = "0.9.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "3.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"cookie_store",
|
||||
"flate2",
|
||||
"log",
|
||||
"percent-encoding",
|
||||
"rustls 0.23.27",
|
||||
"rustls-pki-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"ureq-proto",
|
||||
"utf-8",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ureq-proto"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"http 1.3.1",
|
||||
"httparse",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.5.4"
|
||||
|
|
@ -3402,6 +3482,12 @@ version = "2.1.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utf8_iter"
|
||||
version = "1.0.4"
|
||||
|
|
@ -3578,6 +3664,15 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12bed680863276c63889429bfd6cab3b99943659923822de1c8a39c49e4d722c"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.6.1"
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{LlmProvider, ModelAlias};
|
||||
use common::configuration::ModelAlias;
|
||||
use common::consts::{
|
||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use common::traces::TraceCollector;
|
||||
use hermesllm::apis::openai_responses::InputParam;
|
||||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
|
|
@ -38,7 +39,7 @@ pub async fn llm_chat(
|
|||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<Vec<LlmProvider>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
trace_collector: Arc<TraceCollector>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
|
@ -123,6 +124,27 @@ pub async fn llm_chat(
|
|||
let is_streaming_request = client_request.is_streaming();
|
||||
let resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
|
||||
|
||||
// Validate that the requested model exists in configuration
|
||||
// This matches the validation in llm_gateway routing.rs
|
||||
if llm_providers.read().await.get(&resolved_model).is_none() {
|
||||
let err_msg = format!(
|
||||
"Model '{}' not found in configured providers",
|
||||
resolved_model
|
||||
);
|
||||
warn!("[PLANO_REQ_ID:{}] | FAILURE | {}", request_id, err_msg);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
|
||||
// Handle provider/model slug format (e.g., "openai/gpt-4")
|
||||
// Extract just the model name for upstream (providers don't understand the slug)
|
||||
let model_name_only = if let Some((_, model)) = resolved_model.split_once('/') {
|
||||
model.to_string()
|
||||
} else {
|
||||
resolved_model.clone()
|
||||
};
|
||||
|
||||
// Extract tool names and user message preview for span attributes
|
||||
let tool_names = client_request.get_tool_names();
|
||||
let user_message_preview = client_request
|
||||
|
|
@ -132,7 +154,9 @@ pub async fn llm_chat(
|
|||
// Extract messages for signal analysis (clone before moving client_request)
|
||||
let messages_for_signals = client_request.get_messages();
|
||||
|
||||
client_request.set_model(resolved_model.clone());
|
||||
// Set the model to just the model name (without provider prefix)
|
||||
// This ensures upstream receives "gpt-4" not "openai/gpt-4"
|
||||
client_request.set_model(model_name_only.clone());
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata",
|
||||
|
|
@ -240,11 +264,20 @@ pub async fn llm_chat(
|
|||
}
|
||||
};
|
||||
|
||||
let model_name = routing_result.model_name;
|
||||
// Determine final model to use
|
||||
// Router returns "none" as a sentinel value when it doesn't select a specific model
|
||||
let router_selected_model = routing_result.model_name;
|
||||
let model_name = if router_selected_model != "none" {
|
||||
// Router selected a specific model via routing preferences
|
||||
router_selected_model
|
||||
} else {
|
||||
// Router returned "none" sentinel, use validated resolved_model from request
|
||||
resolved_model.clone()
|
||||
};
|
||||
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Resolved Model: {}",
|
||||
request_id, full_qualified_llm_provider_url, model_name
|
||||
"[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Provider Hint: {}, Model for upstream: {}",
|
||||
request_id, full_qualified_llm_provider_url, model_name, model_name_only
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
|
|
@ -389,7 +422,7 @@ async fn build_llm_span(
|
|||
tool_names: Option<Vec<String>>,
|
||||
user_message_preview: Option<String>,
|
||||
temperature: Option<f32>,
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
) -> common::traces::Span {
|
||||
use crate::tracing::{http, llm, OperationNameBuilder};
|
||||
use common::traces::{parse_traceparent, SpanBuilder, SpanKind};
|
||||
|
|
@ -462,7 +495,7 @@ async fn build_llm_span(
|
|||
/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix,
|
||||
/// then uses target_endpoint_for_provider to calculate the correct upstream path.
|
||||
async fn get_upstream_path(
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
model_name: &str,
|
||||
request_path: &str,
|
||||
resolved_model: &str,
|
||||
|
|
@ -485,25 +518,21 @@ async fn get_upstream_path(
|
|||
|
||||
/// Helper function to get provider info (ProviderId and base_url_path_prefix)
|
||||
async fn get_provider_info(
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
model_name: &str,
|
||||
) -> (hermesllm::ProviderId, Option<String>) {
|
||||
let providers_lock = llm_providers.read().await;
|
||||
|
||||
// First, try to find by model name or provider name
|
||||
let provider = providers_lock.iter().find(|p| {
|
||||
p.model.as_ref().map(|m| m == model_name).unwrap_or(false) || p.name == model_name
|
||||
});
|
||||
|
||||
if let Some(provider) = provider {
|
||||
// Try to find by model name or provider name using LlmProviders::get
|
||||
// This handles both "gpt-4" and "openai/gpt-4" formats
|
||||
if let Some(provider) = providers_lock.get(model_name) {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
return (provider_id, prefix);
|
||||
}
|
||||
|
||||
let default_provider = providers_lock.iter().find(|p| p.default.unwrap_or(false));
|
||||
|
||||
if let Some(provider) = default_provider {
|
||||
// Fall back to default provider
|
||||
if let Some(provider) = providers_lock.default() {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
(provider_id, prefix)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,17 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{IntoModels, LlmProvider};
|
||||
use hermesllm::apis::openai::Models;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::{Response, StatusCode};
|
||||
use serde_json;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn list_models(
|
||||
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
|
||||
llm_providers: Arc<tokio::sync::RwLock<LlmProviders>>,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let prov = llm_providers.read().await;
|
||||
let providers = prov.clone();
|
||||
let openai_models: Models = providers.into_models();
|
||||
let models = prov.to_models();
|
||||
|
||||
match serde_json::to_string(&openai_models) {
|
||||
match serde_json::to_string(&models) {
|
||||
Ok(json) => {
|
||||
let body = Full::new(Bytes::from(json))
|
||||
.map_err(|never| match never {})
|
||||
|
|
|
|||
|
|
@ -151,16 +151,15 @@ pub async fn router_chat_get_upstream_model(
|
|||
Ok(RoutingResult { model_name })
|
||||
}
|
||||
None => {
|
||||
// No route determined, use default model from request
|
||||
// No route determined, return sentinel value "none"
|
||||
// This signals to llm.rs to use the original validated request model
|
||||
info!(
|
||||
"[PLANO_REQ_ID: {}] | ROUTER_REQ | No route determined, using default model from request: {}",
|
||||
request_id,
|
||||
chat_request.model
|
||||
"[PLANO_REQ_ID: {}] | ROUTER_REQ | No route determined, returning sentinel 'none'",
|
||||
request_id
|
||||
);
|
||||
|
||||
let default_model = chat_request.model.clone();
|
||||
let mut attrs = HashMap::new();
|
||||
attrs.insert("route.selected_model".to_string(), default_model.clone());
|
||||
attrs.insert("route.selected_model".to_string(), "none".to_string());
|
||||
record_routing_span(
|
||||
trace_collector,
|
||||
traceparent,
|
||||
|
|
@ -171,7 +170,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
.await;
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name: default_model,
|
||||
model_name: "none".to_string(),
|
||||
})
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ use common::configuration::{Agent, Configuration};
|
|||
use common::consts::{
|
||||
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
||||
};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use common::traces::TraceCollector;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
|
|
@ -76,7 +77,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.cloned()
|
||||
.collect();
|
||||
|
||||
let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone()));
|
||||
// Create expanded provider list for /v1/models endpoint
|
||||
let llm_providers = LlmProviders::try_from(arch_config.model_providers.clone())
|
||||
.expect("Failed to create LlmProviders");
|
||||
let llm_providers = Arc::new(RwLock::new(llm_providers));
|
||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
|
||||
let llm_provider_url =
|
||||
|
|
|
|||
|
|
@ -255,7 +255,8 @@ impl LlmProviderType {
|
|||
/// Get the ProviderId for this LlmProviderType
|
||||
/// Used with the new function-based hermesllm API
|
||||
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
|
||||
hermesllm::ProviderId::from(self.to_string().as_str())
|
||||
hermesllm::ProviderId::try_from(self.to_string().as_str())
|
||||
.expect("LlmProviderType should always map to a valid ProviderId")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,24 +1,84 @@
|
|||
use crate::configuration::LlmProvider;
|
||||
use hermesllm::providers::ProviderId;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LlmProviders {
|
||||
providers: HashMap<String, Rc<LlmProvider>>,
|
||||
default: Option<Rc<LlmProvider>>,
|
||||
providers: HashMap<String, Arc<LlmProvider>>,
|
||||
default: Option<Arc<LlmProvider>>,
|
||||
/// Wildcard providers: maps provider prefix to base provider config
|
||||
/// e.g., "openai" -> LlmProvider for "openai/*"
|
||||
wildcard_providers: HashMap<String, Arc<LlmProvider>>,
|
||||
}
|
||||
|
||||
impl LlmProviders {
|
||||
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
|
||||
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Arc<LlmProvider>> {
|
||||
self.providers.iter()
|
||||
}
|
||||
|
||||
pub fn default(&self) -> Option<Rc<LlmProvider>> {
|
||||
pub fn default(&self) -> Option<Arc<LlmProvider>> {
|
||||
self.default.clone()
|
||||
}
|
||||
/// Convert providers to OpenAI Models format for /v1/models endpoint
|
||||
/// Filters out internal models and duplicate entries (backward compatibility aliases)
|
||||
pub fn to_models(&self) -> hermesllm::apis::openai::Models {
|
||||
use hermesllm::apis::openai::{ModelDetail, ModelObject, Models};
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
|
||||
self.providers.get(name).cloned()
|
||||
let data: Vec<ModelDetail> = self
|
||||
.providers
|
||||
.iter()
|
||||
.filter(|(key, provider)| {
|
||||
// Exclude internal models
|
||||
provider.internal != Some(true)
|
||||
// Only include canonical entries (key matches provider name)
|
||||
// This avoids duplicates from backward compatibility short names
|
||||
&& *key == &provider.name
|
||||
})
|
||||
.map(|(name, provider)| ModelDetail {
|
||||
id: name.clone(),
|
||||
object: Some("model".to_string()),
|
||||
created: 0,
|
||||
owned_by: provider.to_provider_id().to_string(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Models {
|
||||
object: ModelObject::List,
|
||||
data,
|
||||
}
|
||||
}
|
||||
pub fn get(&self, name: &str) -> Option<Arc<LlmProvider>> {
|
||||
// First try exact match
|
||||
if let Some(provider) = self.providers.get(name).cloned() {
|
||||
return Some(provider);
|
||||
}
|
||||
|
||||
// If name contains '/', it could be:
|
||||
// 1. A full model ID like "openai/gpt-4" that we need to lookup
|
||||
// 2. A provider/model slug that should match a wildcard provider
|
||||
if let Some((provider_prefix, model_name)) = name.split_once('/') {
|
||||
// Try to find the expanded model entry (e.g., "openai/gpt-4")
|
||||
let full_model_id = format!("{}/{}", provider_prefix, model_name);
|
||||
if let Some(provider) = self.providers.get(&full_model_id).cloned() {
|
||||
return Some(provider);
|
||||
}
|
||||
|
||||
// Try to find just the model name (for expanded wildcard entries)
|
||||
if let Some(provider) = self.providers.get(model_name).cloned() {
|
||||
return Some(provider);
|
||||
}
|
||||
|
||||
// Fall back to wildcard match (e.g., "openai/*")
|
||||
if let Some(wildcard_provider) = self.wildcard_providers.get(provider_prefix) {
|
||||
// Create a new provider with the specific model from the slug
|
||||
let mut specific_provider = (**wildcard_provider).clone();
|
||||
specific_provider.model = Some(model_name.to_string());
|
||||
return Some(Arc::new(specific_provider));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -43,38 +103,235 @@ impl TryFrom<Vec<LlmProvider>> for LlmProviders {
|
|||
let mut llm_providers = LlmProviders {
|
||||
providers: HashMap::new(),
|
||||
default: None,
|
||||
wildcard_providers: HashMap::new(),
|
||||
};
|
||||
|
||||
for llm_provider in llm_providers_config {
|
||||
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
|
||||
if llm_provider.default.unwrap_or_default() {
|
||||
match llm_providers.default {
|
||||
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
|
||||
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
|
||||
}
|
||||
}
|
||||
// Track specific (non-wildcard) provider names to detect true duplicates
|
||||
let mut specific_provider_names = std::collections::HashSet::new();
|
||||
|
||||
// Insert and check that there is no other provider with the same name.
|
||||
let name = llm_provider.name.clone();
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(name.clone(), Rc::clone(&llm_provider))
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
// Track specific models that should be excluded from wildcard expansion
|
||||
// Maps provider_prefix -> Set of model names (e.g., "anthropic" -> {"claude-sonnet-4-20250514"})
|
||||
let mut specific_models_by_provider: HashMap<String, std::collections::HashSet<String>> =
|
||||
HashMap::new();
|
||||
|
||||
// also add model_id as key for provider lookup
|
||||
if let Some(model) = llm_provider.model.clone() {
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(model, llm_provider)
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
// First pass: collect all specific model configurations
|
||||
for llm_provider in &llm_providers_config {
|
||||
let is_wildcard = llm_provider
|
||||
.model
|
||||
.as_ref()
|
||||
.map(|m| m == "*" || m.ends_with("/*"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_wildcard {
|
||||
// Check if this is a provider/model format
|
||||
if let Some((provider_prefix, model_name)) = llm_provider.name.split_once('/') {
|
||||
specific_models_by_provider
|
||||
.entry(provider_prefix.to_string())
|
||||
.or_default()
|
||||
.insert(model_name.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for llm_provider in llm_providers_config {
|
||||
let llm_provider: Arc<LlmProvider> = Arc::new(llm_provider);
|
||||
|
||||
if llm_provider.default.unwrap_or_default() {
|
||||
match llm_providers.default {
|
||||
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
|
||||
None => llm_providers.default = Some(Arc::clone(&llm_provider)),
|
||||
}
|
||||
}
|
||||
|
||||
let name = llm_provider.name.clone();
|
||||
|
||||
// Check if this is a wildcard provider (model is "*" or ends with "/*")
|
||||
let is_wildcard = llm_provider
|
||||
.model
|
||||
.as_ref()
|
||||
.map(|m| m == "*" || m.ends_with("/*"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_wildcard {
|
||||
// Extract provider prefix from name
|
||||
// e.g., "openai/*" -> "openai"
|
||||
let provider_prefix = name.trim_end_matches("/*").trim_end_matches('*');
|
||||
|
||||
// For wildcard providers, we:
|
||||
// 1. Store the base config in wildcard_providers for runtime matching
|
||||
// 2. Optionally expand to all known models if available
|
||||
|
||||
llm_providers
|
||||
.wildcard_providers
|
||||
.insert(provider_prefix.to_string(), Arc::clone(&llm_provider));
|
||||
|
||||
// Try to expand wildcard using ProviderId models
|
||||
if let Ok(provider_id) = ProviderId::try_from(provider_prefix) {
|
||||
let models = provider_id.models();
|
||||
|
||||
// Get the set of specific models to exclude for this provider
|
||||
let models_to_exclude = specific_models_by_provider
|
||||
.get(provider_prefix)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
if !models.is_empty() {
|
||||
let excluded_count = models_to_exclude.len();
|
||||
let total_models = models.len();
|
||||
|
||||
log::info!(
|
||||
"Expanding wildcard provider '{}' to {} models{}",
|
||||
provider_prefix,
|
||||
total_models - excluded_count,
|
||||
if excluded_count > 0 {
|
||||
format!(" (excluding {} specifically configured)", excluded_count)
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
);
|
||||
|
||||
// Create a provider entry for each model (except those specifically configured)
|
||||
for model_name in models {
|
||||
// Skip this model if it has a specific configuration
|
||||
if models_to_exclude.contains(&model_name) {
|
||||
log::debug!(
|
||||
"Skipping wildcard expansion for '{}/{}' - specific configuration exists",
|
||||
provider_prefix,
|
||||
model_name
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let full_model_id = format!("{}/{}", provider_prefix, model_name);
|
||||
|
||||
// Create a new provider with the specific model
|
||||
let mut expanded_provider = (*llm_provider).clone();
|
||||
expanded_provider.model = Some(model_name.clone());
|
||||
expanded_provider.name = full_model_id.clone();
|
||||
|
||||
let expanded_rc = Arc::new(expanded_provider);
|
||||
|
||||
// Insert with full model ID as key
|
||||
llm_providers
|
||||
.providers
|
||||
.insert(full_model_id.clone(), Arc::clone(&expanded_rc));
|
||||
|
||||
// Also insert with just model name for backward compatibility
|
||||
llm_providers.providers.insert(model_name, expanded_rc);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::warn!(
|
||||
"Wildcard provider '{}' specified but no models found in registry. \
|
||||
Will match dynamically at runtime.",
|
||||
provider_prefix
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Non-wildcard provider - specific configuration
|
||||
// Check for duplicate specific entries (not allowed)
|
||||
if specific_provider_names.contains(&name) {
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
specific_provider_names.insert(name.clone());
|
||||
|
||||
// This specific configuration takes precedence over any wildcard expansion
|
||||
// The wildcard expansion already excluded this model (see first pass above)
|
||||
|
||||
log::debug!("Processing specific provider configuration: {}", name);
|
||||
|
||||
// Insert with the provider name as key
|
||||
llm_providers
|
||||
.providers
|
||||
.insert(name.clone(), Arc::clone(&llm_provider));
|
||||
|
||||
// Also add model_id as key for provider lookup
|
||||
if let Some(model) = llm_provider.model.clone() {
|
||||
llm_providers.providers.insert(model, llm_provider);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(llm_providers)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::configuration::LlmProviderType;
|
||||
|
||||
fn create_test_provider(name: &str, model: Option<String>) -> LlmProvider {
|
||||
LlmProvider {
|
||||
name: name.to_string(),
|
||||
model,
|
||||
access_key: None,
|
||||
endpoint: None,
|
||||
cluster_name: None,
|
||||
provider_interface: LlmProviderType::OpenAI,
|
||||
default: None,
|
||||
base_url_path_prefix: None,
|
||||
port: None,
|
||||
rate_limits: None,
|
||||
usage: None,
|
||||
routing_preferences: None,
|
||||
internal: None,
|
||||
stream: None,
|
||||
passthrough_auth: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_static_provider_lookup() {
|
||||
// Test 1: Statically defined provider - should be findable by model or provider name
|
||||
let providers = vec![create_test_provider("my-openai", Some("gpt-4".to_string()))];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
// Should find by model name
|
||||
let result = llm_providers.get("gpt-4");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().name, "my-openai");
|
||||
|
||||
// Should also find by provider name
|
||||
let result = llm_providers.get("my-openai");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().name, "my-openai");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_provider_with_known_model() {
|
||||
// Test 2: Wildcard provider that expands to OpenAI models
|
||||
let providers = vec![create_test_provider("openai/*", Some("*".to_string()))];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
// Should find via expanded wildcard entry
|
||||
let result = llm_providers.get("openai/gpt-4");
|
||||
let provider = result.unwrap();
|
||||
assert_eq!(provider.name, "openai/gpt-4");
|
||||
assert_eq!(provider.model.as_ref().unwrap(), "gpt-4");
|
||||
|
||||
// Should also be able to find by just model name (from expansion)
|
||||
let result = llm_providers.get("gpt-4");
|
||||
assert_eq!(result.unwrap().model.as_ref().unwrap(), "gpt-4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_wildcard_provider_with_full_slug() {
|
||||
// Test 3: Custom wildcard provider with full slug offered
|
||||
let providers = vec![create_test_provider(
|
||||
"custom-provider/*",
|
||||
Some("*".to_string()),
|
||||
)];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
// Should match via wildcard fallback and extract model name from slug
|
||||
let result = llm_providers.get("custom-provider/custom-model");
|
||||
let provider = result.unwrap();
|
||||
assert_eq!(provider.model.as_ref().unwrap(), "custom-model");
|
||||
|
||||
// Wildcard should be stored
|
||||
assert!(llm_providers
|
||||
.wildcard_providers
|
||||
.contains_key("custom-provider"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{configuration, llm_providers::LlmProviders};
|
||||
use configuration::LlmProvider;
|
||||
use rand::{seq::IteratorRandom, thread_rng};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProviderHint {
|
||||
Default,
|
||||
Name(String),
|
||||
|
|
@ -22,33 +21,14 @@ impl From<String> for ProviderHint {
|
|||
pub fn get_llm_provider(
|
||||
llm_providers: &LlmProviders,
|
||||
provider_hint: Option<ProviderHint>,
|
||||
) -> Rc<LlmProvider> {
|
||||
let maybe_provider = provider_hint.and_then(|hint| match hint {
|
||||
ProviderHint::Default => llm_providers.default(),
|
||||
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
|
||||
ProviderHint::Name(name) => llm_providers.get(&name),
|
||||
});
|
||||
|
||||
if let Some(provider) = maybe_provider {
|
||||
return provider;
|
||||
) -> Result<Arc<LlmProvider>, String> {
|
||||
match provider_hint {
|
||||
Some(ProviderHint::Default) => llm_providers
|
||||
.default()
|
||||
.ok_or_else(|| "No default provider configured".to_string()),
|
||||
Some(ProviderHint::Name(name)) => llm_providers
|
||||
.get(&name)
|
||||
.ok_or_else(|| format!("Model '{}' not found in configured providers", name)),
|
||||
None => Err("No model specified in request".to_string()),
|
||||
}
|
||||
|
||||
if llm_providers.default().is_some() {
|
||||
return llm_providers.default().unwrap();
|
||||
}
|
||||
|
||||
let mut rng = thread_rng();
|
||||
llm_providers
|
||||
.iter()
|
||||
.filter(|(_, provider)| {
|
||||
provider
|
||||
.model
|
||||
.as_ref()
|
||||
.map(|m| !m.starts_with("Arch"))
|
||||
.unwrap_or(true)
|
||||
})
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one non-Arch llm provider")
|
||||
.1
|
||||
.clone()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,12 +3,24 @@ name = "hermesllm"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "fetch_models"
|
||||
path = "src/bin/fetch_models.rs"
|
||||
required-features = ["model-fetch"]
|
||||
|
||||
[dependencies]
|
||||
serde = {version = "1.0.219", features = ["derive"]}
|
||||
serde_json = "1.0.140"
|
||||
serde_yaml = "0.9.34-deprecated"
|
||||
serde_with = {version = "3.12.0", features = ["base64"]}
|
||||
thiserror = "2.0.12"
|
||||
aws-smithy-eventstream = "0.60"
|
||||
bytes = "1.10"
|
||||
uuid = { version = "1.11", features = ["v4"] }
|
||||
log = "0.4"
|
||||
chrono = { version = "0.4", optional = true }
|
||||
ureq = { version = "3.1", features = ["json"], optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
model-fetch = ["ureq", "chrono"]
|
||||
|
|
|
|||
412
crates/hermesllm/src/bin/fetch_models.rs
Normal file
412
crates/hermesllm/src/bin/fetch_models.rs
Normal file
|
|
@ -0,0 +1,412 @@
|
|||
// Fetch latest provider models from canonical provider APIs and update provider_models.yaml
|
||||
// Usage:
|
||||
// Optional: OPENAI_API_KEY, ANTHROPIC_API_KEY, DEEPSEEK_API_KEY, GROK_API_KEY,
|
||||
// DASHSCOPE_API_KEY, MOONSHOT_API_KEY, ZHIPU_API_KEY, GOOGLE_API_KEY
|
||||
// Required: AWS CLI configured for Amazon Bedrock models
|
||||
// cargo run --bin fetch_models
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn main() {
|
||||
// Default to writing in the same directory as this source file
|
||||
let default_path = std::path::Path::new(file!())
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("provider_models.yaml");
|
||||
|
||||
let output_path = std::env::args()
|
||||
.nth(1)
|
||||
.unwrap_or_else(|| default_path.to_string_lossy().to_string());
|
||||
|
||||
println!("Fetching latest models from provider APIs...");
|
||||
|
||||
match fetch_all_models() {
|
||||
Ok(models) => {
|
||||
let yaml = serde_yaml::to_string(&models).expect("Failed to serialize models");
|
||||
|
||||
std::fs::write(&output_path, yaml).expect("Failed to write provider_models.yaml");
|
||||
|
||||
println!(
|
||||
"✓ Successfully updated {} providers ({} models) to {}",
|
||||
models.metadata.total_providers, models.metadata.total_models, output_path
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error fetching models: {}", e);
|
||||
eprintln!("\nMake sure required tools are set up:");
|
||||
eprintln!(" AWS CLI configured for Bedrock (for Amazon models)");
|
||||
eprintln!(" export OPENAI_API_KEY=your-key-here # Optional");
|
||||
eprintln!(" export DEEPSEEK_API_KEY=your-key-here # Optional");
|
||||
eprintln!(" cargo run --bin fetch_models");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI-compatible API response (used by most providers)
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAICompatibleModel {
|
||||
id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAICompatibleResponse {
|
||||
data: Vec<OpenAICompatibleModel>,
|
||||
}
|
||||
|
||||
// Google Gemini API response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GoogleModel {
|
||||
name: String,
|
||||
#[serde(rename = "supportedGenerationMethods")]
|
||||
supported_generation_methods: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GoogleResponse {
|
||||
models: Vec<GoogleModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ProviderModels {
|
||||
version: String,
|
||||
source: String,
|
||||
providers: HashMap<String, Vec<String>>,
|
||||
metadata: Metadata,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Metadata {
|
||||
total_providers: usize,
|
||||
total_models: usize,
|
||||
last_updated: String,
|
||||
}
|
||||
|
||||
fn is_text_model(model_id: &str) -> bool {
|
||||
let id_lower = model_id.to_lowercase();
|
||||
|
||||
// Filter out known non-text models
|
||||
let non_text_patterns = [
|
||||
"embedding", // Embedding models
|
||||
"whisper", // Audio transcription
|
||||
"-tts", // Text-to-speech (with dash to avoid matching in middle of words)
|
||||
"tts-", // Text-to-speech prefix
|
||||
"dall-e", // Image generation
|
||||
"sora", // Video generation
|
||||
"moderation", // Moderation models
|
||||
"babbage", // Legacy completion models
|
||||
"davinci-002", // Legacy completion models
|
||||
"transcribe", // Audio transcription models
|
||||
"realtime", // Realtime audio models
|
||||
"audio", // Audio models (gpt-audio, gpt-audio-mini)
|
||||
"-image-", // Image generation models (grok-2-image-1212)
|
||||
"-ocr-", // OCR models
|
||||
"ocr-", // OCR models prefix
|
||||
"voxtral", // Audio/voice models
|
||||
];
|
||||
|
||||
// Additional pattern: models that are purely for image generation usually have "image" in the name
|
||||
// but we need to be careful not to filter vision models that can process images
|
||||
// Models like "gpt-image-1" or "chatgpt-image-latest" are image generators
|
||||
// Models like "grok-2-vision" or "gemini-vision" are vision models (text+image->text)
|
||||
|
||||
if non_text_patterns
|
||||
.iter()
|
||||
.any(|pattern| id_lower.contains(pattern))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Filter models starting with "gpt-image" (image generators)
|
||||
if id_lower.contains("/gpt-image") || id_lower.contains("/chatgpt-image") {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn fetch_openai_compatible_models(
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
provider_prefix: &str,
|
||||
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
|
||||
let response_body = ureq::get(api_url)
|
||||
.header("Authorization", &format!("Bearer {}", api_key))
|
||||
.call()?
|
||||
.body_mut()
|
||||
.read_to_string()?;
|
||||
|
||||
let response: OpenAICompatibleResponse = serde_json::from_str(&response_body)?;
|
||||
|
||||
Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.filter(|m| is_text_model(&m.id))
|
||||
.map(|m| format!("{}/{}", provider_prefix, m.id))
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn fetch_anthropic_models(api_key: &str) -> Result<Vec<String>, Box<dyn std::error::Error>> {
|
||||
let response_body = ureq::get("https://api.anthropic.com/v1/models")
|
||||
.header("x-api-key", api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.call()?
|
||||
.body_mut()
|
||||
.read_to_string()?;
|
||||
|
||||
let response: OpenAICompatibleResponse = serde_json::from_str(&response_body)?;
|
||||
|
||||
let dated_models: Vec<String> = response
|
||||
.data
|
||||
.into_iter()
|
||||
.filter(|m| is_text_model(&m.id))
|
||||
.map(|m| m.id)
|
||||
.collect();
|
||||
|
||||
let mut models: Vec<String> = Vec::new();
|
||||
|
||||
// Add both dated versions and their aliases (without the -YYYYMMDD suffix)
|
||||
for model_id in dated_models {
|
||||
// Add the full dated model ID
|
||||
models.push(format!("anthropic/{}", model_id));
|
||||
|
||||
// Generate alias by removing trailing -YYYYMMDD pattern
|
||||
// Pattern: ends with -YYYYMMDD where YYYY is year, MM is month, DD is day
|
||||
if let Some(date_pos) = model_id.rfind('-') {
|
||||
let potential_date = &model_id[date_pos + 1..];
|
||||
// Check if it's an 8-digit date (YYYYMMDD)
|
||||
if potential_date.len() == 8 && potential_date.chars().all(|c| c.is_ascii_digit()) {
|
||||
let alias = &model_id[..date_pos];
|
||||
let alias_full = format!("anthropic/{}", alias);
|
||||
// Only add if not already present
|
||||
if !models.contains(&alias_full) {
|
||||
models.push(alias_full);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
fn fetch_google_models(api_key: &str) -> Result<Vec<String>, Box<dyn std::error::Error>> {
|
||||
let api_url = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models?key={}",
|
||||
api_key
|
||||
);
|
||||
|
||||
let response_body = ureq::get(&api_url).call()?.body_mut().read_to_string()?;
|
||||
|
||||
let response: GoogleResponse = serde_json::from_str(&response_body)?;
|
||||
|
||||
// Only include models that support generateContent
|
||||
Ok(response
|
||||
.models
|
||||
.into_iter()
|
||||
.filter(|m| {
|
||||
m.supported_generation_methods
|
||||
.as_ref()
|
||||
.is_some_and(|methods| methods.contains(&"generateContent".to_string()))
|
||||
})
|
||||
.map(|m| {
|
||||
// Convert "models/gemini-pro" to "google/gemini-pro"
|
||||
let model_id = m.name.strip_prefix("models/").unwrap_or(&m.name);
|
||||
format!("google/{}", model_id)
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn fetch_bedrock_amazon_models() -> Result<Vec<String>, Box<dyn std::error::Error>> {
|
||||
// Use AWS CLI to fetch Amazon models from Bedrock
|
||||
let output = std::process::Command::new("aws")
|
||||
.args([
|
||||
"bedrock",
|
||||
"list-foundation-models",
|
||||
"--by-provider",
|
||||
"amazon",
|
||||
"--by-output-modality",
|
||||
"TEXT",
|
||||
"--no-cli-pager",
|
||||
"--output",
|
||||
"json",
|
||||
])
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Err(format!(
|
||||
"AWS CLI command failed: {}",
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
let response_body = String::from_utf8(output.stdout)?;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BedrockModelSummary {
|
||||
#[serde(rename = "modelId")]
|
||||
model_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BedrockResponse {
|
||||
#[serde(rename = "modelSummaries")]
|
||||
model_summaries: Vec<BedrockModelSummary>,
|
||||
}
|
||||
|
||||
let bedrock_response: BedrockResponse = serde_json::from_str(&response_body)?;
|
||||
|
||||
// Filter out embedding, image generation, and rerank models
|
||||
let amazon_models: Vec<String> = bedrock_response
|
||||
.model_summaries
|
||||
.into_iter()
|
||||
.filter(|model| {
|
||||
let id_lower = model.model_id.to_lowercase();
|
||||
!id_lower.contains("embed")
|
||||
&& !id_lower.contains("image")
|
||||
&& !id_lower.contains("rerank")
|
||||
})
|
||||
.map(|m| format!("amazon/{}", m.model_id))
|
||||
.collect();
|
||||
|
||||
Ok(amazon_models)
|
||||
}
|
||||
|
||||
fn fetch_all_models() -> Result<ProviderModels, Box<dyn std::error::Error>> {
|
||||
let mut providers: HashMap<String, Vec<String>> = HashMap::new();
|
||||
let mut errors: Vec<String> = Vec::new();
|
||||
|
||||
// Configuration: provider name, env var, API URL, prefix for model IDs
|
||||
let provider_configs = vec![
|
||||
(
|
||||
"openai",
|
||||
"OPENAI_API_KEY",
|
||||
"https://api.openai.com/v1/models",
|
||||
"openai",
|
||||
),
|
||||
(
|
||||
"mistralai",
|
||||
"MISTRAL_API_KEY",
|
||||
"https://api.mistral.ai/v1/models",
|
||||
"mistralai",
|
||||
),
|
||||
(
|
||||
"deepseek",
|
||||
"DEEPSEEK_API_KEY",
|
||||
"https://api.deepseek.com/v1/models",
|
||||
"deepseek",
|
||||
),
|
||||
("x-ai", "GROK_API_KEY", "https://api.x.ai/v1/models", "x-ai"),
|
||||
(
|
||||
"moonshotai",
|
||||
"MOONSHOT_API_KEY",
|
||||
"https://api.moonshot.ai/v1/models",
|
||||
"moonshotai",
|
||||
),
|
||||
(
|
||||
"qwen",
|
||||
"DASHSCOPE_API_KEY",
|
||||
"https://dashscope-intl.aliyuncs.com/compatible-mode/v1/models",
|
||||
"qwen",
|
||||
),
|
||||
(
|
||||
"z-ai",
|
||||
"ZHIPU_API_KEY",
|
||||
"https://open.bigmodel.cn/api/paas/v4/models",
|
||||
"z-ai",
|
||||
),
|
||||
];
|
||||
|
||||
// Fetch from OpenAI-compatible providers
|
||||
for (provider_name, env_var, api_url, prefix) in provider_configs {
|
||||
if let Ok(api_key) = std::env::var(env_var) {
|
||||
match fetch_openai_compatible_models(api_url, &api_key, prefix) {
|
||||
Ok(models) => {
|
||||
println!(" ✓ {}: {} models", provider_name, models.len());
|
||||
providers.insert(provider_name.to_string(), models);
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!(" ✗ {}: {}", provider_name, e);
|
||||
eprintln!("{}", err_msg);
|
||||
errors.push(err_msg);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!(" ⊘ {}: {} not set (skipped)", provider_name, env_var);
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch Anthropic models (different authentication)
|
||||
if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
|
||||
match fetch_anthropic_models(&api_key) {
|
||||
Ok(models) => {
|
||||
println!(" ✓ anthropic: {} models", models.len());
|
||||
providers.insert("anthropic".to_string(), models);
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!(" ✗ anthropic: {}", e);
|
||||
eprintln!("{}", err_msg);
|
||||
errors.push(err_msg);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!(" ⊘ anthropic: ANTHROPIC_API_KEY not set (skipped)");
|
||||
}
|
||||
|
||||
// Fetch Google models (different API format)
|
||||
if let Ok(api_key) = std::env::var("GOOGLE_API_KEY") {
|
||||
match fetch_google_models(&api_key) {
|
||||
Ok(models) => {
|
||||
println!(" ✓ google: {} models", models.len());
|
||||
providers.insert("google".to_string(), models);
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!(" ✗ google: {}", e);
|
||||
eprintln!("{}", err_msg);
|
||||
errors.push(err_msg);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!(" ⊘ google: GOOGLE_API_KEY not set (skipped)");
|
||||
}
|
||||
|
||||
// Fetch Amazon models from AWS Bedrock
|
||||
match fetch_bedrock_amazon_models() {
|
||||
Ok(models) => {
|
||||
println!(" ✓ amazon: {} models (via AWS Bedrock)", models.len());
|
||||
providers.insert("amazon".to_string(), models);
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = format!(" ✗ amazon: {} (AWS Bedrock required)", e);
|
||||
eprintln!("{}", err_msg);
|
||||
errors.push(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
if providers.is_empty() {
|
||||
return Err("No models fetched from any provider. Check API keys.".into());
|
||||
}
|
||||
|
||||
let total_providers = providers.len();
|
||||
let total_models: usize = providers.values().map(|v| v.len()).sum();
|
||||
|
||||
println!(
|
||||
"\n✅ Successfully fetched models from {} providers",
|
||||
total_providers
|
||||
);
|
||||
if !errors.is_empty() {
|
||||
println!("⚠️ {} providers failed", errors.len());
|
||||
}
|
||||
|
||||
Ok(ProviderModels {
|
||||
version: "1.0".to_string(),
|
||||
source: "canonical-apis".to_string(),
|
||||
providers,
|
||||
metadata: Metadata {
|
||||
total_providers,
|
||||
total_models,
|
||||
last_updated: chrono::Utc::now().to_rfc3339(),
|
||||
},
|
||||
})
|
||||
}
|
||||
315
crates/hermesllm/src/bin/provider_models.yaml
Normal file
315
crates/hermesllm/src/bin/provider_models.yaml
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
version: '1.0'
|
||||
source: canonical-apis
|
||||
providers:
|
||||
qwen:
|
||||
- qwen/qwen3-max-2026-01-23
|
||||
- qwen/qwen-plus-character
|
||||
- qwen/qwen-flash-character
|
||||
- qwen/qwen-flash
|
||||
- qwen/qwen3-vl-plus-2025-12-19
|
||||
- qwen/qwen3-omni-flash-2025-12-01
|
||||
- qwen/qwen3-livetranslate-flash-2025-12-01
|
||||
- qwen/qwen3-livetranslate-flash
|
||||
- qwen/qwen-mt-lite
|
||||
- qwen/qwen-plus-2025-12-01
|
||||
- qwen/qwen-mt-flash
|
||||
- qwen/ccai-pro
|
||||
- qwen/tongyi-tingwu-slp
|
||||
- qwen/qwen3-vl-flash
|
||||
- qwen/qwen3-vl-flash-2025-10-15
|
||||
- qwen/qwen3-omni-flash
|
||||
- qwen/qwen3-omni-flash-2025-09-15
|
||||
- qwen/qwen3-omni-30b-a3b-captioner
|
||||
- qwen/qwen2.5-7b-instruct
|
||||
- qwen/qwen2.5-14b-instruct
|
||||
- qwen/qwen2.5-32b-instruct
|
||||
- qwen/qwen2.5-72b-instruct
|
||||
- qwen/qwen2.5-14b-instruct-1m
|
||||
- qwen/qwen2.5-7b-instruct-1m
|
||||
- qwen/qwen-max-2025-01-25
|
||||
- qwen/qwen-max-latest
|
||||
- qwen/qwen-turbo-2024-11-01
|
||||
- qwen/qwen-turbo-latest
|
||||
- qwen/qwen-plus-latest
|
||||
- qwen/qwen-plus-2025-01-25
|
||||
- qwen/qwq-plus-2025-03-05
|
||||
- qwen/qwen-mt-turbo
|
||||
- qwen/qwen-mt-plus
|
||||
- qwen/qwen-coder-plus
|
||||
- qwen/qwq-plus
|
||||
- qwen/qwen2.5-vl-32b-instruct
|
||||
- qwen/qvq-max
|
||||
- qwen/qwen-omni-turbo
|
||||
- qwen/qwen3-8b
|
||||
- qwen/qwen3-30b-a3b
|
||||
- qwen/qwen3-235b-a22b
|
||||
- qwen/qwen-turbo-2025-04-28
|
||||
- qwen/qwen-plus-2025-04-28
|
||||
- qwen/qwen-vl-max-2025-04-08
|
||||
- qwen/qwen-vl-plus-2025-01-25
|
||||
- qwen/qwen-vl-plus-latest
|
||||
- qwen/qwen-vl-max-latest
|
||||
- qwen/qwen-vl-plus-2025-05-07
|
||||
- qwen/qwen3-coder-plus
|
||||
- qwen/qwen3-coder-480b-a35b-instruct
|
||||
- qwen/qwen3-235b-a22b-instruct-2507
|
||||
- qwen/qwen-plus-2025-07-14
|
||||
- qwen/qwen3-coder-plus-2025-07-22
|
||||
- qwen/qwen3-235b-a22b-thinking-2507
|
||||
- qwen/qwen3-coder-flash
|
||||
- qwen/qwen-vl-max
|
||||
- qwen/qwen-vl-max-2025-08-13
|
||||
- qwen/qwen3-max
|
||||
- qwen/qwen3-max-2025-09-23
|
||||
- qwen/qwen3-vl-plus
|
||||
- qwen/qwen3-vl-235b-a22b-instruct
|
||||
- qwen/qwen3-vl-235b-a22b-thinking
|
||||
- qwen/qwen3-30b-a3b-thinking-2507
|
||||
- qwen/qwen3-30b-a3b-instruct-2507
|
||||
- qwen/qwen3-14b
|
||||
- qwen/qwen3-32b
|
||||
- qwen/qwen3-0.6b
|
||||
- qwen/qwen3-4b
|
||||
- qwen/qwen3-1.7b
|
||||
- qwen/qwen-vl-plus
|
||||
- qwen/qwen3-coder-plus-2025-09-23
|
||||
- qwen/qwen3-vl-plus-2025-09-23
|
||||
- qwen/qwen-plus-2025-09-11
|
||||
- qwen/qwen3-next-80b-a3b-thinking
|
||||
- qwen/qwen3-next-80b-a3b-instruct
|
||||
- qwen/qwen3-max-preview
|
||||
- qwen/qwen2-7b-instruct
|
||||
- qwen/qwen-max
|
||||
- qwen/qwen-plus
|
||||
- qwen/qwen-turbo
|
||||
openai:
|
||||
- openai/gpt-4-0613
|
||||
- openai/gpt-4
|
||||
- openai/gpt-3.5-turbo
|
||||
- openai/gpt-5.2-codex
|
||||
- openai/gpt-3.5-turbo-instruct
|
||||
- openai/gpt-3.5-turbo-instruct-0914
|
||||
- openai/gpt-4-1106-preview
|
||||
- openai/gpt-3.5-turbo-1106
|
||||
- openai/gpt-4-0125-preview
|
||||
- openai/gpt-4-turbo-preview
|
||||
- openai/gpt-3.5-turbo-0125
|
||||
- openai/gpt-4-turbo
|
||||
- openai/gpt-4-turbo-2024-04-09
|
||||
- openai/gpt-4o
|
||||
- openai/gpt-4o-2024-05-13
|
||||
- openai/gpt-4o-mini-2024-07-18
|
||||
- openai/gpt-4o-mini
|
||||
- openai/gpt-4o-2024-08-06
|
||||
- openai/chatgpt-4o-latest
|
||||
- openai/o1-2024-12-17
|
||||
- openai/o1
|
||||
- openai/computer-use-preview
|
||||
- openai/o3-mini
|
||||
- openai/o3-mini-2025-01-31
|
||||
- openai/gpt-4o-2024-11-20
|
||||
- openai/computer-use-preview-2025-03-11
|
||||
- openai/gpt-4o-search-preview-2025-03-11
|
||||
- openai/gpt-4o-search-preview
|
||||
- openai/gpt-4o-mini-search-preview-2025-03-11
|
||||
- openai/gpt-4o-mini-search-preview
|
||||
- openai/o1-pro-2025-03-19
|
||||
- openai/o1-pro
|
||||
- openai/o3-2025-04-16
|
||||
- openai/o4-mini-2025-04-16
|
||||
- openai/o3
|
||||
- openai/o4-mini
|
||||
- openai/gpt-4.1-2025-04-14
|
||||
- openai/gpt-4.1
|
||||
- openai/gpt-4.1-mini-2025-04-14
|
||||
- openai/gpt-4.1-mini
|
||||
- openai/gpt-4.1-nano-2025-04-14
|
||||
- openai/gpt-4.1-nano
|
||||
- openai/codex-mini-latest
|
||||
- openai/o3-pro
|
||||
- openai/o3-pro-2025-06-10
|
||||
- openai/o4-mini-deep-research
|
||||
- openai/o3-deep-research
|
||||
- openai/o3-deep-research-2025-06-26
|
||||
- openai/o4-mini-deep-research-2025-06-26
|
||||
- openai/gpt-5-chat-latest
|
||||
- openai/gpt-5-2025-08-07
|
||||
- openai/gpt-5
|
||||
- openai/gpt-5-mini-2025-08-07
|
||||
- openai/gpt-5-mini
|
||||
- openai/gpt-5-nano-2025-08-07
|
||||
- openai/gpt-5-nano
|
||||
- openai/gpt-5-codex
|
||||
- openai/gpt-5-pro-2025-10-06
|
||||
- openai/gpt-5-pro
|
||||
- openai/gpt-5-search-api
|
||||
- openai/gpt-5-search-api-2025-10-14
|
||||
- openai/gpt-5.1-chat-latest
|
||||
- openai/gpt-5.1-2025-11-13
|
||||
- openai/gpt-5.1
|
||||
- openai/gpt-5.1-codex
|
||||
- openai/gpt-5.1-codex-mini
|
||||
- openai/gpt-5.1-codex-max
|
||||
- openai/gpt-5.2-2025-12-11
|
||||
- openai/gpt-5.2
|
||||
- openai/gpt-5.2-pro-2025-12-11
|
||||
- openai/gpt-5.2-pro
|
||||
- openai/gpt-5.2-chat-latest
|
||||
- openai/gpt-3.5-turbo-16k
|
||||
- openai/ft:gpt-3.5-turbo-0613:katanemo::8CMZbm0P
|
||||
google:
|
||||
- google/gemini-2.5-flash
|
||||
- google/gemini-2.5-pro
|
||||
- google/gemini-2.0-flash-exp
|
||||
- google/gemini-2.0-flash
|
||||
- google/gemini-2.0-flash-001
|
||||
- google/gemini-2.0-flash-exp-image-generation
|
||||
- google/gemini-2.0-flash-lite-001
|
||||
- google/gemini-2.0-flash-lite
|
||||
- google/gemini-2.0-flash-lite-preview-02-05
|
||||
- google/gemini-2.0-flash-lite-preview
|
||||
- google/gemini-exp-1206
|
||||
- google/gemini-2.5-flash-preview-tts
|
||||
- google/gemini-2.5-pro-preview-tts
|
||||
- google/gemma-3-1b-it
|
||||
- google/gemma-3-4b-it
|
||||
- google/gemma-3-12b-it
|
||||
- google/gemma-3-27b-it
|
||||
- google/gemma-3n-e4b-it
|
||||
- google/gemma-3n-e2b-it
|
||||
- google/gemini-flash-latest
|
||||
- google/gemini-flash-lite-latest
|
||||
- google/gemini-pro-latest
|
||||
- google/gemini-2.5-flash-lite
|
||||
- google/gemini-2.5-flash-image
|
||||
- google/gemini-2.5-flash-preview-09-2025
|
||||
- google/gemini-2.5-flash-lite-preview-09-2025
|
||||
- google/gemini-3-pro-preview
|
||||
- google/gemini-3-flash-preview
|
||||
- google/gemini-3-pro-image-preview
|
||||
- google/nano-banana-pro-preview
|
||||
- google/gemini-robotics-er-1.5-preview
|
||||
- google/gemini-2.5-computer-use-preview-10-2025
|
||||
- google/deep-research-pro-preview-12-2025
|
||||
mistralai:
|
||||
- mistralai/mistral-medium-2505
|
||||
- mistralai/mistral-medium-2508
|
||||
- mistralai/mistral-medium-latest
|
||||
- mistralai/mistral-medium
|
||||
- mistralai/open-mistral-nemo
|
||||
- mistralai/open-mistral-nemo-2407
|
||||
- mistralai/mistral-tiny-2407
|
||||
- mistralai/mistral-tiny-latest
|
||||
- mistralai/mistral-large-2411
|
||||
- mistralai/pixtral-large-2411
|
||||
- mistralai/pixtral-large-latest
|
||||
- mistralai/mistral-large-pixtral-2411
|
||||
- mistralai/codestral-2508
|
||||
- mistralai/codestral-latest
|
||||
- mistralai/devstral-small-2507
|
||||
- mistralai/devstral-medium-2507
|
||||
- mistralai/devstral-2512
|
||||
- mistralai/mistral-vibe-cli-latest
|
||||
- mistralai/devstral-medium-latest
|
||||
- mistralai/devstral-latest
|
||||
- mistralai/labs-devstral-small-2512
|
||||
- mistralai/devstral-small-latest
|
||||
- mistralai/mistral-small-2506
|
||||
- mistralai/mistral-small-latest
|
||||
- mistralai/labs-mistral-small-creative
|
||||
- mistralai/magistral-medium-2509
|
||||
- mistralai/magistral-medium-latest
|
||||
- mistralai/magistral-small-2509
|
||||
- mistralai/magistral-small-latest
|
||||
- mistralai/mistral-large-2512
|
||||
- mistralai/mistral-large-latest
|
||||
- mistralai/ministral-3b-2512
|
||||
- mistralai/ministral-3b-latest
|
||||
- mistralai/ministral-8b-2512
|
||||
- mistralai/ministral-8b-latest
|
||||
- mistralai/ministral-14b-2512
|
||||
- mistralai/ministral-14b-latest
|
||||
- mistralai/open-mistral-7b
|
||||
- mistralai/mistral-tiny
|
||||
- mistralai/mistral-tiny-2312
|
||||
- mistralai/pixtral-12b-2409
|
||||
- mistralai/pixtral-12b
|
||||
- mistralai/pixtral-12b-latest
|
||||
- mistralai/ministral-3b-2410
|
||||
- mistralai/ministral-8b-2410
|
||||
- mistralai/codestral-2501
|
||||
- mistralai/codestral-2412
|
||||
- mistralai/codestral-2411-rc5
|
||||
- mistralai/mistral-small-2501
|
||||
- mistralai/mistral-embed-2312
|
||||
- mistralai/mistral-embed
|
||||
- mistralai/codestral-embed
|
||||
- mistralai/codestral-embed-2505
|
||||
z-ai:
|
||||
- z-ai/glm-4.5
|
||||
- z-ai/glm-4.5-air
|
||||
- z-ai/glm-4.6
|
||||
- z-ai/glm-4.7
|
||||
amazon:
|
||||
- amazon/amazon.nova-pro-v1:0
|
||||
- amazon/amazon.nova-2-lite-v1:0
|
||||
- amazon/amazon.nova-2-sonic-v1:0
|
||||
- amazon/amazon.titan-tg1-large
|
||||
- amazon/amazon.nova-premier-v1:0:8k
|
||||
- amazon/amazon.nova-premier-v1:0:20k
|
||||
- amazon/amazon.nova-premier-v1:0:1000k
|
||||
- amazon/amazon.nova-premier-v1:0:mm
|
||||
- amazon/amazon.nova-premier-v1:0
|
||||
- amazon/amazon.nova-lite-v1:0
|
||||
- amazon/amazon.nova-micro-v1:0
|
||||
deepseek:
|
||||
- deepseek/deepseek-chat
|
||||
- deepseek/deepseek-reasoner
|
||||
x-ai:
|
||||
- x-ai/grok-2-vision-1212
|
||||
- x-ai/grok-3
|
||||
- x-ai/grok-3-mini
|
||||
- x-ai/grok-4-0709
|
||||
- x-ai/grok-4-1-fast-non-reasoning
|
||||
- x-ai/grok-4-1-fast-reasoning
|
||||
- x-ai/grok-4-fast-non-reasoning
|
||||
- x-ai/grok-4-fast-reasoning
|
||||
- x-ai/grok-code-fast-1
|
||||
moonshotai:
|
||||
- moonshotai/kimi-latest
|
||||
- moonshotai/kimi-k2.5
|
||||
- moonshotai/moonshot-v1-8k-vision-preview
|
||||
- moonshotai/kimi-k2-thinking
|
||||
- moonshotai/moonshot-v1-auto
|
||||
- moonshotai/kimi-k2-0711-preview
|
||||
- moonshotai/moonshot-v1-32k
|
||||
- moonshotai/kimi-k2-thinking-turbo
|
||||
- moonshotai/kimi-k2-0905-preview
|
||||
- moonshotai/moonshot-v1-128k
|
||||
- moonshotai/moonshot-v1-32k-vision-preview
|
||||
- moonshotai/moonshot-v1-128k-vision-preview
|
||||
- moonshotai/kimi-k2-turbo-preview
|
||||
- moonshotai/moonshot-v1-8k
|
||||
anthropic:
|
||||
- anthropic/claude-opus-4-5-20251101
|
||||
- anthropic/claude-opus-4-5
|
||||
- anthropic/claude-haiku-4-5-20251001
|
||||
- anthropic/claude-haiku-4-5
|
||||
- anthropic/claude-sonnet-4-5-20250929
|
||||
- anthropic/claude-sonnet-4-5
|
||||
- anthropic/claude-opus-4-1-20250805
|
||||
- anthropic/claude-opus-4-1
|
||||
- anthropic/claude-opus-4-20250514
|
||||
- anthropic/claude-opus-4
|
||||
- anthropic/claude-sonnet-4-20250514
|
||||
- anthropic/claude-sonnet-4
|
||||
- anthropic/claude-3-7-sonnet-20250219
|
||||
- anthropic/claude-3-7-sonnet
|
||||
- anthropic/claude-3-5-haiku-20241022
|
||||
- anthropic/claude-3-5-haiku
|
||||
- anthropic/claude-3-haiku-20240307
|
||||
- anthropic/claude-3-haiku
|
||||
metadata:
|
||||
total_providers: 10
|
||||
total_models: 298
|
||||
last_updated: 2026-01-27T22:40:53.653700+00:00
|
||||
15
crates/hermesllm/src/bin/run.sh
Normal file
15
crates/hermesllm/src/bin/run.sh
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Get the directory where this script is located
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
|
||||
# Navigate to crates directory (bin -> src -> hermesllm -> crates)
|
||||
cd "$SCRIPT_DIR/../../.."
|
||||
|
||||
# Load environment variables silently and run fetch_models
|
||||
set -a
|
||||
source hermesllm/src/bin/.env
|
||||
set +a
|
||||
|
||||
cargo run --bin fetch_models --features model-fetch
|
||||
|
|
@ -29,10 +29,27 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_provider_id_conversion() {
|
||||
assert_eq!(ProviderId::from("openai"), ProviderId::OpenAI);
|
||||
assert_eq!(ProviderId::from("mistral"), ProviderId::Mistral);
|
||||
assert_eq!(ProviderId::from("groq"), ProviderId::Groq);
|
||||
assert_eq!(ProviderId::from("arch"), ProviderId::Arch);
|
||||
assert_eq!(ProviderId::try_from("openai").unwrap(), ProviderId::OpenAI);
|
||||
assert_eq!(
|
||||
ProviderId::try_from("mistral").unwrap(),
|
||||
ProviderId::Mistral
|
||||
);
|
||||
assert_eq!(ProviderId::try_from("groq").unwrap(), ProviderId::Groq);
|
||||
assert_eq!(ProviderId::try_from("arch").unwrap(), ProviderId::Arch);
|
||||
|
||||
// Test aliases
|
||||
assert_eq!(ProviderId::try_from("google").unwrap(), ProviderId::Gemini);
|
||||
assert_eq!(
|
||||
ProviderId::try_from("together").unwrap(),
|
||||
ProviderId::TogetherAI
|
||||
);
|
||||
assert_eq!(
|
||||
ProviderId::try_from("amazon").unwrap(),
|
||||
ProviderId::AmazonBedrock
|
||||
);
|
||||
|
||||
// Test error case
|
||||
assert!(ProviderId::try_from("unknown_provider").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,28 @@
|
|||
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
|
||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
static PROVIDER_MODELS_YAML: &str = include_str!(concat!(
|
||||
env!("CARGO_MANIFEST_DIR"),
|
||||
"/src/bin/provider_models.yaml"
|
||||
));
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ProviderModelsFile {
|
||||
providers: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
fn load_provider_models() -> &'static HashMap<String, Vec<String>> {
|
||||
static MODELS: OnceLock<HashMap<String, Vec<String>>> = OnceLock::new();
|
||||
MODELS.get_or_init(|| {
|
||||
let ProviderModelsFile { providers } = serde_yaml::from_str(PROVIDER_MODELS_YAML)
|
||||
.expect("Failed to parse provider_models.yaml");
|
||||
providers
|
||||
})
|
||||
}
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
|
|
@ -23,31 +45,70 @@ pub enum ProviderId {
|
|||
AmazonBedrock,
|
||||
}
|
||||
|
||||
impl From<&str> for ProviderId {
|
||||
fn from(value: &str) -> Self {
|
||||
impl TryFrom<&str> for ProviderId {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
match value.to_lowercase().as_str() {
|
||||
"openai" => ProviderId::OpenAI,
|
||||
"mistral" => ProviderId::Mistral,
|
||||
"deepseek" => ProviderId::Deepseek,
|
||||
"groq" => ProviderId::Groq,
|
||||
"gemini" => ProviderId::Gemini,
|
||||
"anthropic" => ProviderId::Anthropic,
|
||||
"github" => ProviderId::GitHub,
|
||||
"arch" => ProviderId::Arch,
|
||||
"azure_openai" => ProviderId::AzureOpenAI,
|
||||
"xai" => ProviderId::XAI,
|
||||
"together_ai" => ProviderId::TogetherAI,
|
||||
"ollama" => ProviderId::Ollama,
|
||||
"moonshotai" => ProviderId::Moonshotai,
|
||||
"zhipu" => ProviderId::Zhipu,
|
||||
"qwen" => ProviderId::Qwen, // alias for Qwen
|
||||
"amazon_bedrock" => ProviderId::AmazonBedrock,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
"openai" => Ok(ProviderId::OpenAI),
|
||||
"mistral" => Ok(ProviderId::Mistral),
|
||||
"deepseek" => Ok(ProviderId::Deepseek),
|
||||
"groq" => Ok(ProviderId::Groq),
|
||||
"gemini" => Ok(ProviderId::Gemini),
|
||||
"google" => Ok(ProviderId::Gemini), // alias
|
||||
"anthropic" => Ok(ProviderId::Anthropic),
|
||||
"github" => Ok(ProviderId::GitHub),
|
||||
"arch" => Ok(ProviderId::Arch),
|
||||
"azure_openai" => Ok(ProviderId::AzureOpenAI),
|
||||
"xai" => Ok(ProviderId::XAI),
|
||||
"together_ai" => Ok(ProviderId::TogetherAI),
|
||||
"together" => Ok(ProviderId::TogetherAI), // alias
|
||||
"ollama" => Ok(ProviderId::Ollama),
|
||||
"moonshotai" => Ok(ProviderId::Moonshotai),
|
||||
"zhipu" => Ok(ProviderId::Zhipu),
|
||||
"qwen" => Ok(ProviderId::Qwen),
|
||||
"amazon_bedrock" => Ok(ProviderId::AmazonBedrock),
|
||||
"amazon" => Ok(ProviderId::AmazonBedrock), // alias
|
||||
_ => Err(format!("Unknown provider: {}", value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderId {
|
||||
/// Get all available models for this provider
|
||||
/// Returns model names without the provider prefix (e.g., "gpt-4" not "openai/gpt-4")
|
||||
pub fn models(&self) -> Vec<String> {
|
||||
let provider_key = match self {
|
||||
ProviderId::AmazonBedrock => "amazon",
|
||||
ProviderId::AzureOpenAI => "openai",
|
||||
ProviderId::TogetherAI => "together",
|
||||
ProviderId::Gemini => "google",
|
||||
ProviderId::OpenAI => "openai",
|
||||
ProviderId::Anthropic => "anthropic",
|
||||
ProviderId::Mistral => "mistralai",
|
||||
ProviderId::Deepseek => "deepseek",
|
||||
ProviderId::Groq => "groq",
|
||||
ProviderId::XAI => "x-ai",
|
||||
ProviderId::Moonshotai => "moonshotai",
|
||||
ProviderId::Zhipu => "z-ai",
|
||||
ProviderId::Qwen => "qwen",
|
||||
_ => return Vec::new(),
|
||||
};
|
||||
|
||||
load_provider_models()
|
||||
.get(provider_key)
|
||||
.map(|models| {
|
||||
models
|
||||
.iter()
|
||||
.filter_map(|model| {
|
||||
// Strip provider prefix (e.g., "openai/gpt-4" -> "gpt-4")
|
||||
model.split_once('/').map(|(_, name)| name.to_string())
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Given a client API, return the compatible upstream API for this provider
|
||||
pub fn compatible_api_for_client(
|
||||
&self,
|
||||
|
|
@ -169,3 +230,102 @@ impl Display for ProviderId {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_models_loaded_from_yaml() {
|
||||
// Test that we can load models for each supported provider
|
||||
let openai_models = ProviderId::OpenAI.models();
|
||||
assert!(!openai_models.is_empty(), "OpenAI should have models");
|
||||
|
||||
let anthropic_models = ProviderId::Anthropic.models();
|
||||
assert!(!anthropic_models.is_empty(), "Anthropic should have models");
|
||||
|
||||
let mistral_models = ProviderId::Mistral.models();
|
||||
assert!(!mistral_models.is_empty(), "Mistral should have models");
|
||||
|
||||
let deepseek_models = ProviderId::Deepseek.models();
|
||||
assert!(!deepseek_models.is_empty(), "Deepseek should have models");
|
||||
|
||||
let gemini_models = ProviderId::Gemini.models();
|
||||
assert!(!gemini_models.is_empty(), "Gemini should have models");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_names_without_provider_prefix() {
|
||||
// Test that model names don't include the provider/ prefix
|
||||
let openai_models = ProviderId::OpenAI.models();
|
||||
for model in &openai_models {
|
||||
assert!(
|
||||
!model.contains('/'),
|
||||
"Model name '{}' should not contain provider prefix",
|
||||
model
|
||||
);
|
||||
}
|
||||
|
||||
let anthropic_models = ProviderId::Anthropic.models();
|
||||
for model in &anthropic_models {
|
||||
assert!(
|
||||
!model.contains('/'),
|
||||
"Model name '{}' should not contain provider prefix",
|
||||
model
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_specific_models_exist() {
|
||||
// Test that specific well-known models are present
|
||||
let openai_models = ProviderId::OpenAI.models();
|
||||
let has_gpt4 = openai_models.iter().any(|m| m.contains("gpt-4"));
|
||||
assert!(has_gpt4, "OpenAI models should include GPT-4 variants");
|
||||
|
||||
let anthropic_models = ProviderId::Anthropic.models();
|
||||
let has_claude = anthropic_models.iter().any(|m| m.contains("claude"));
|
||||
assert!(
|
||||
has_claude,
|
||||
"Anthropic models should include Claude variants"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsupported_providers_return_empty() {
|
||||
// Providers without models should return empty vec
|
||||
let github_models = ProviderId::GitHub.models();
|
||||
assert!(
|
||||
github_models.is_empty(),
|
||||
"GitHub should return empty models list"
|
||||
);
|
||||
|
||||
let ollama_models = ProviderId::Ollama.models();
|
||||
assert!(
|
||||
ollama_models.is_empty(),
|
||||
"Ollama should return empty models list"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_name_mapping() {
|
||||
// Test that provider key mappings work correctly
|
||||
let xai_models = ProviderId::XAI.models();
|
||||
assert!(
|
||||
!xai_models.is_empty(),
|
||||
"XAI should have models (mapped to x-ai)"
|
||||
);
|
||||
|
||||
let zhipu_models = ProviderId::Zhipu.models();
|
||||
assert!(
|
||||
!zhipu_models.is_empty(),
|
||||
"Zhipu should have models (mapped to z-ai)"
|
||||
);
|
||||
|
||||
let amazon_models = ProviderId::AmazonBedrock.models();
|
||||
assert!(
|
||||
!amazon_models.is_empty(),
|
||||
"AmazonBedrock should have models (mapped to amazon)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use log::{debug, error, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::metrics::Metrics;
|
||||
|
|
@ -40,7 +41,7 @@ pub struct StreamContext {
|
|||
/// The API that should be used for the upstream provider (after compatibility mapping)
|
||||
resolved_api: Option<SupportedUpstreamAPIs>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
llm_provider: Option<Arc<LlmProvider>>,
|
||||
request_id: Option<String>,
|
||||
start_time: SystemTime,
|
||||
ttft_duration: Option<Duration>,
|
||||
|
|
@ -128,16 +129,40 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
fn select_llm_provider(&mut self) -> Result<(), String> {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.map(|llm_name| llm_name.into());
|
||||
|
||||
// info!("llm_providers: {:?}", self.llm_providers);
|
||||
self.llm_provider = Some(routing::get_llm_provider(
|
||||
&self.llm_providers,
|
||||
provider_hint,
|
||||
));
|
||||
// Try to get provider with hint, fallback to default if error
|
||||
// This handles prompt_gateway requests which don't set ARCH_PROVIDER_HINT_HEADER
|
||||
// since prompt_gateway doesn't have access to model configuration.
|
||||
// brightstaff (model proxy) always validates and sets the provider hint.
|
||||
let provider = match routing::get_llm_provider(&self.llm_providers, provider_hint) {
|
||||
Ok(provider) => provider,
|
||||
Err(err) => {
|
||||
// Try default provider as fallback
|
||||
match self.llm_providers.default() {
|
||||
Some(default_provider) => {
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] Provider selection failed, using default provider",
|
||||
self.request_identifier()
|
||||
);
|
||||
default_provider
|
||||
}
|
||||
None => {
|
||||
error!(
|
||||
"[PLANO_REQ_ID:{}] PROVIDER_SELECTION_FAILED: Error='{}' and no default provider configured",
|
||||
self.request_identifier(),
|
||||
err
|
||||
);
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
self.llm_provider = Some(provider);
|
||||
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] PROVIDER_SELECTION: Hint='{}' -> Selected='{}'",
|
||||
|
|
@ -146,6 +171,8 @@ impl StreamContext {
|
|||
.unwrap_or("none".to_string()),
|
||||
self.llm_provider.as_ref().unwrap().name
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
|
||||
|
|
@ -764,7 +791,15 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// let routing_header_value = self.get_http_request_header(ARCH_ROUTING_HEADER);
|
||||
|
||||
self.select_llm_provider();
|
||||
if let Err(err) = self.select_llm_provider() {
|
||||
self.send_http_response(
|
||||
400,
|
||||
vec![],
|
||||
Some(format!(r#"{{"error": "{}"}}"#, err).as_bytes()),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
// Check if this is a supported API endpoint
|
||||
if SupportedAPIsFromClient::from_endpoint(&request_path).is_none() {
|
||||
self.send_http_response(404, vec![], Some(b"Unsupported endpoint"));
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue