mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
adding support for wildcard model providers
This commit is contained in:
parent
86cf8ccdaa
commit
34711c6f9d
14 changed files with 1027 additions and 1823 deletions
|
|
@ -187,11 +187,21 @@ def validate_and_render_schema():
|
|||
|
||||
model_name = model_provider.get("model")
|
||||
print("Processing model_provider: ", model_provider)
|
||||
if model_name in model_name_keys:
|
||||
|
||||
# Check if this is a wildcard model (provider/*)
|
||||
is_wildcard = False
|
||||
if "/" in model_name:
|
||||
model_name_tokens = model_name.split("/")
|
||||
if len(model_name_tokens) >= 2 and model_name_tokens[-1] == "*":
|
||||
is_wildcard = True
|
||||
|
||||
if model_name in model_name_keys and not is_wildcard:
|
||||
raise Exception(
|
||||
f"Duplicate model name {model_name}, please provide unique model name for each model_provider"
|
||||
)
|
||||
model_name_keys.add(model_name)
|
||||
|
||||
if not is_wildcard:
|
||||
model_name_keys.add(model_name)
|
||||
if model_provider.get("name") is None:
|
||||
model_provider["name"] = model_name
|
||||
|
||||
|
|
@ -202,7 +212,21 @@ def validate_and_render_schema():
|
|||
raise Exception(
|
||||
f"Invalid model name {model_name}. Please provide model name in the format <provider>/<model_id>."
|
||||
)
|
||||
provider = model_name_tokens[0]
|
||||
provider = model_name_tokens[0].strip()
|
||||
|
||||
# Check if this is a wildcard (provider/*)
|
||||
is_wildcard = model_name_tokens[-1].strip() == "*"
|
||||
|
||||
# Validate wildcard constraints
|
||||
if is_wildcard:
|
||||
if model_provider.get("default", False):
|
||||
raise Exception(
|
||||
f"Model {model_name} is configured as default but uses wildcard (*). Default models cannot be wildcards."
|
||||
)
|
||||
if model_provider.get("routing_preferences"):
|
||||
raise Exception(
|
||||
f"Model {model_name} has routing_preferences but uses wildcard (*). Models with routing preferences cannot be wildcards."
|
||||
)
|
||||
|
||||
# Validate azure_openai and ollama provider requires base_url
|
||||
if (provider in SUPPORTED_PROVIDERS_WITH_BASE_URL) and model_provider.get(
|
||||
|
|
@ -213,7 +237,9 @@ def validate_and_render_schema():
|
|||
)
|
||||
|
||||
model_id = "/".join(model_name_tokens[1:])
|
||||
if provider not in SUPPORTED_PROVIDERS:
|
||||
|
||||
# For wildcard providers, allow any provider name
|
||||
if not is_wildcard and provider not in SUPPORTED_PROVIDERS:
|
||||
if (
|
||||
model_provider.get("base_url", None) is None
|
||||
or model_provider.get("provider_interface", None) is None
|
||||
|
|
@ -227,11 +253,13 @@ def validate_and_render_schema():
|
|||
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
|
||||
)
|
||||
|
||||
if model_id in model_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider"
|
||||
)
|
||||
model_name_keys.add(model_id)
|
||||
# For wildcard models, don't add model_id to the keys since it's "*"
|
||||
if not is_wildcard:
|
||||
if model_id in model_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider"
|
||||
)
|
||||
model_name_keys.add(model_id)
|
||||
|
||||
for routing_preference in model_provider.get("routing_preferences", []):
|
||||
if routing_preference.get("name") in model_usage_name_keys:
|
||||
|
|
|
|||
94
crates/Cargo.lock
generated
94
crates/Cargo.lock
generated
|
|
@ -459,6 +459,35 @@ dependencies = [
|
|||
"urlencoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cookie"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
"time",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cookie_store"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fc4bff745c9b4c7fb1e97b25d13153da2bc7796260141df62378998d070207f"
|
||||
dependencies = [
|
||||
"cookie",
|
||||
"document-features",
|
||||
"idna",
|
||||
"indexmap 2.9.0",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"time",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.4"
|
||||
|
|
@ -628,6 +657,15 @@ dependencies = [
|
|||
"syn 2.0.101",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "document-features"
|
||||
version = "0.2.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61"
|
||||
dependencies = [
|
||||
"litrs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "duration-string"
|
||||
version = "0.3.0"
|
||||
|
|
@ -999,11 +1037,13 @@ version = "0.1.0"
|
|||
dependencies = [
|
||||
"aws-smithy-eventstream",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"thiserror 2.0.12",
|
||||
"ureq",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
|
|
@ -1479,6 +1519,12 @@ version = "0.8.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956"
|
||||
|
||||
[[package]]
|
||||
name = "litrs"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092"
|
||||
|
||||
[[package]]
|
||||
name = "llm_gateway"
|
||||
version = "0.1.0"
|
||||
|
|
@ -2417,6 +2463,7 @@ version = "0.23.27"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321"
|
||||
dependencies = [
|
||||
"log",
|
||||
"once_cell",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
|
|
@ -3385,6 +3432,38 @@ version = "0.9.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "3.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"cookie_store",
|
||||
"flate2",
|
||||
"log",
|
||||
"percent-encoding",
|
||||
"rustls 0.23.27",
|
||||
"rustls-pki-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"ureq-proto",
|
||||
"utf-8",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ureq-proto"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"http 1.3.1",
|
||||
"httparse",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.5.4"
|
||||
|
|
@ -3402,6 +3481,12 @@ version = "2.1.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utf8_iter"
|
||||
version = "1.0.4"
|
||||
|
|
@ -3578,6 +3663,15 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "12bed680863276c63889429bfd6cab3b99943659923822de1c8a39c49e4d722c"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.6.1"
|
||||
|
|
|
|||
|
|
@ -123,6 +123,14 @@ pub async fn llm_chat(
|
|||
let is_streaming_request = client_request.is_streaming();
|
||||
let resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
|
||||
|
||||
// Handle provider/model slug format (e.g., "openai/gpt-4")
|
||||
// Extract just the model name for upstream (providers don't understand the slug)
|
||||
let model_name_only = if let Some((_, model)) = resolved_model.split_once('/') {
|
||||
model.to_string()
|
||||
} else {
|
||||
resolved_model.clone()
|
||||
};
|
||||
|
||||
// Extract tool names and user message preview for span attributes
|
||||
let tool_names = client_request.get_tool_names();
|
||||
let user_message_preview = client_request
|
||||
|
|
@ -132,7 +140,9 @@ pub async fn llm_chat(
|
|||
// Extract messages for signal analysis (clone before moving client_request)
|
||||
let messages_for_signals = client_request.get_messages();
|
||||
|
||||
client_request.set_model(resolved_model.clone());
|
||||
// Set the model to just the model name (without provider prefix)
|
||||
// This ensures upstream receives "gpt-4" not "openai/gpt-4"
|
||||
client_request.set_model(model_name_only.clone());
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata",
|
||||
|
|
@ -240,16 +250,22 @@ pub async fn llm_chat(
|
|||
}
|
||||
};
|
||||
|
||||
// Use the resolved model (could be "gpt-4" or "openai/gpt-4") as the provider hint
|
||||
// The routing layer will use llm_providers.get() which handles both formats:
|
||||
// - "gpt-4" → looks up by model name
|
||||
// - "openai/gpt-4" → looks up by provider/model slug
|
||||
// If router doesn't find anything, it will use routing_result.model_name
|
||||
let provider_hint_value = resolved_model.clone();
|
||||
let model_name = routing_result.model_name;
|
||||
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Resolved Model: {}",
|
||||
request_id, full_qualified_llm_provider_url, model_name
|
||||
"[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Provider Hint: {}, Model for upstream: {}",
|
||||
request_id, full_qualified_llm_provider_url, provider_hint_value, model_name_only
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(&model_name).unwrap(),
|
||||
header::HeaderValue::from_str(&provider_hint_value).unwrap(),
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
|
|
|
|||
|
|
@ -255,7 +255,8 @@ impl LlmProviderType {
|
|||
/// Get the ProviderId for this LlmProviderType
|
||||
/// Used with the new function-based hermesllm API
|
||||
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
|
||||
hermesllm::ProviderId::from(self.to_string().as_str())
|
||||
hermesllm::ProviderId::try_from(self.to_string().as_str())
|
||||
.expect("LlmProviderType should always map to a valid ProviderId")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use crate::configuration::LlmProvider;
|
||||
use hermesllm::providers::ProviderId;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
||||
|
|
@ -6,6 +7,9 @@ use std::rc::Rc;
|
|||
pub struct LlmProviders {
|
||||
providers: HashMap<String, Rc<LlmProvider>>,
|
||||
default: Option<Rc<LlmProvider>>,
|
||||
/// Wildcard providers: maps provider prefix to base provider config
|
||||
/// e.g., "openai" -> LlmProvider for "openai/*"
|
||||
wildcard_providers: HashMap<String, Rc<LlmProvider>>,
|
||||
}
|
||||
|
||||
impl LlmProviders {
|
||||
|
|
@ -18,7 +22,36 @@ impl LlmProviders {
|
|||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
|
||||
self.providers.get(name).cloned()
|
||||
// First try exact match
|
||||
if let Some(provider) = self.providers.get(name).cloned() {
|
||||
return Some(provider);
|
||||
}
|
||||
|
||||
// If name contains '/', it could be:
|
||||
// 1. A full model ID like "openai/gpt-4" that we need to lookup
|
||||
// 2. A provider/model slug that should match a wildcard provider
|
||||
if let Some((provider_prefix, model_name)) = name.split_once('/') {
|
||||
// Try to find the expanded model entry (e.g., "openai/gpt-4")
|
||||
let full_model_id = format!("{}/{}", provider_prefix, model_name);
|
||||
if let Some(provider) = self.providers.get(&full_model_id).cloned() {
|
||||
return Some(provider);
|
||||
}
|
||||
|
||||
// Try to find just the model name (for expanded wildcard entries)
|
||||
if let Some(provider) = self.providers.get(model_name).cloned() {
|
||||
return Some(provider);
|
||||
}
|
||||
|
||||
// Fall back to wildcard match (e.g., "openai/*")
|
||||
if let Some(wildcard_provider) = self.wildcard_providers.get(provider_prefix) {
|
||||
// Create a new provider with the specific model from the slug
|
||||
let mut specific_provider = (**wildcard_provider).clone();
|
||||
specific_provider.model = Some(model_name.to_string());
|
||||
return Some(Rc::new(specific_provider));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -43,10 +76,12 @@ impl TryFrom<Vec<LlmProvider>> for LlmProviders {
|
|||
let mut llm_providers = LlmProviders {
|
||||
providers: HashMap::new(),
|
||||
default: None,
|
||||
wildcard_providers: HashMap::new(),
|
||||
};
|
||||
|
||||
for llm_provider in llm_providers_config {
|
||||
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
|
||||
|
||||
if llm_provider.default.unwrap_or_default() {
|
||||
match llm_providers.default {
|
||||
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
|
||||
|
|
@ -54,27 +89,167 @@ impl TryFrom<Vec<LlmProvider>> for LlmProviders {
|
|||
}
|
||||
}
|
||||
|
||||
// Insert and check that there is no other provider with the same name.
|
||||
let name = llm_provider.name.clone();
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(name.clone(), Rc::clone(&llm_provider))
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
|
||||
// also add model_id as key for provider lookup
|
||||
if let Some(model) = llm_provider.model.clone() {
|
||||
// Check if this is a wildcard provider (model is "*" or ends with "/*")
|
||||
let is_wildcard = llm_provider
|
||||
.model
|
||||
.as_ref()
|
||||
.map(|m| m == "*" || m.ends_with("/*"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_wildcard {
|
||||
// Extract provider prefix from name
|
||||
// e.g., "openai/*" -> "openai"
|
||||
let provider_prefix = name.trim_end_matches("/*").trim_end_matches('*');
|
||||
|
||||
// For wildcard providers, we:
|
||||
// 1. Store the base config in wildcard_providers for runtime matching
|
||||
// 2. Optionally expand to all known models if available
|
||||
|
||||
llm_providers
|
||||
.wildcard_providers
|
||||
.insert(provider_prefix.to_string(), Rc::clone(&llm_provider));
|
||||
|
||||
// Try to expand wildcard using ProviderId models
|
||||
if let Ok(provider_id) = ProviderId::try_from(provider_prefix) {
|
||||
let models = provider_id.models();
|
||||
if !models.is_empty() {
|
||||
log::info!(
|
||||
"Expanding wildcard provider '{}' to {} models",
|
||||
provider_prefix,
|
||||
models.len()
|
||||
);
|
||||
|
||||
// Create a provider entry for each model
|
||||
for model_name in models {
|
||||
let full_model_id = format!("{}/{}", provider_prefix, model_name);
|
||||
|
||||
// Create a new provider with the specific model
|
||||
let mut expanded_provider = (*llm_provider).clone();
|
||||
expanded_provider.model = Some(model_name.clone());
|
||||
expanded_provider.name = full_model_id.clone();
|
||||
|
||||
let expanded_rc = Rc::new(expanded_provider);
|
||||
|
||||
// Insert with full model ID as key
|
||||
llm_providers
|
||||
.providers
|
||||
.insert(full_model_id.clone(), Rc::clone(&expanded_rc));
|
||||
|
||||
// Also insert with just model name for backward compatibility
|
||||
llm_providers.providers.insert(model_name, expanded_rc);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::warn!(
|
||||
"Wildcard provider '{}' specified but no models found in registry. \
|
||||
Will match dynamically at runtime.",
|
||||
provider_prefix
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Non-wildcard provider - original behavior
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(model, llm_provider)
|
||||
.insert(name.clone(), Rc::clone(&llm_provider))
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
|
||||
// also add model_id as key for provider lookup
|
||||
if let Some(model) = llm_provider.model.clone() {
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(model, llm_provider)
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(llm_providers)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::configuration::LlmProviderType;
|
||||
|
||||
fn create_test_provider(name: &str, model: Option<String>) -> LlmProvider {
|
||||
LlmProvider {
|
||||
name: name.to_string(),
|
||||
model,
|
||||
access_key: None,
|
||||
endpoint: None,
|
||||
cluster_name: None,
|
||||
provider_interface: LlmProviderType::OpenAI,
|
||||
default: None,
|
||||
base_url_path_prefix: None,
|
||||
port: None,
|
||||
rate_limits: None,
|
||||
usage: None,
|
||||
routing_preferences: None,
|
||||
internal: None,
|
||||
stream: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_static_provider_lookup() {
|
||||
// Test 1: Statically defined provider - should be findable by model or provider name
|
||||
let providers = vec![create_test_provider("my-openai", Some("gpt-4".to_string()))];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
// Should find by model name
|
||||
let result = llm_providers.get("gpt-4");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().name, "my-openai");
|
||||
|
||||
// Should also find by provider name
|
||||
let result = llm_providers.get("my-openai");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().name, "my-openai");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_provider_with_known_model() {
|
||||
// Test 2: Wildcard provider that expands to OpenAI models
|
||||
let providers = vec![create_test_provider("openai/*", Some("*".to_string()))];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
// Should find via expanded wildcard entry
|
||||
let result = llm_providers.get("openai/gpt-4");
|
||||
let provider = result.unwrap();
|
||||
assert_eq!(provider.name, "openai/gpt-4");
|
||||
assert_eq!(provider.model.as_ref().unwrap(), "gpt-4");
|
||||
|
||||
// Should also be able to find by just model name (from expansion)
|
||||
let result = llm_providers.get("gpt-4");
|
||||
assert_eq!(result.unwrap().model.as_ref().unwrap(), "gpt-4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_wildcard_provider_with_full_slug() {
|
||||
// Test 3: Custom wildcard provider with full slug offered
|
||||
let providers = vec![create_test_provider(
|
||||
"custom-provider/*",
|
||||
Some("*".to_string()),
|
||||
)];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
// Should match via wildcard fallback and extract model name from slug
|
||||
let result = llm_providers.get("custom-provider/custom-model");
|
||||
let provider = result.unwrap();
|
||||
assert_eq!(provider.model.as_ref().unwrap(), "custom-model");
|
||||
|
||||
// Wildcard should be stored
|
||||
assert!(llm_providers
|
||||
.wildcard_providers
|
||||
.contains_key("custom-provider"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,9 +2,8 @@ use std::rc::Rc;
|
|||
|
||||
use crate::{configuration, llm_providers::LlmProviders};
|
||||
use configuration::LlmProvider;
|
||||
use rand::{seq::IteratorRandom, thread_rng};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProviderHint {
|
||||
Default,
|
||||
Name(String),
|
||||
|
|
@ -22,33 +21,19 @@ impl From<String> for ProviderHint {
|
|||
pub fn get_llm_provider(
|
||||
llm_providers: &LlmProviders,
|
||||
provider_hint: Option<ProviderHint>,
|
||||
) -> Rc<LlmProvider> {
|
||||
let maybe_provider = provider_hint.and_then(|hint| match hint {
|
||||
ProviderHint::Default => llm_providers.default(),
|
||||
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
|
||||
ProviderHint::Name(name) => llm_providers.get(&name),
|
||||
});
|
||||
|
||||
if let Some(provider) = maybe_provider {
|
||||
return provider;
|
||||
) -> Result<Rc<LlmProvider>, String> {
|
||||
match provider_hint {
|
||||
Some(ProviderHint::Default) => llm_providers
|
||||
.default()
|
||||
.ok_or_else(|| "No default provider configured".to_string()),
|
||||
Some(ProviderHint::Name(name)) => llm_providers
|
||||
.get(&name)
|
||||
.ok_or_else(|| format!("Model '{}' not found in configured providers", name)),
|
||||
None => {
|
||||
// No hint provided - must have a default configured
|
||||
llm_providers
|
||||
.default()
|
||||
.ok_or_else(|| "No model specified and no default provider configured".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
if llm_providers.default().is_some() {
|
||||
return llm_providers.default().unwrap();
|
||||
}
|
||||
|
||||
let mut rng = thread_rng();
|
||||
llm_providers
|
||||
.iter()
|
||||
.filter(|(_, provider)| {
|
||||
provider
|
||||
.model
|
||||
.as_ref()
|
||||
.map(|m| !m.starts_with("Arch"))
|
||||
.unwrap_or(true)
|
||||
})
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one non-Arch llm provider")
|
||||
.1
|
||||
.clone()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,11 @@ name = "hermesllm"
|
|||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "fetch_models"
|
||||
path = "src/bin/fetch_models.rs"
|
||||
required-features = ["model-fetch"]
|
||||
|
||||
[dependencies]
|
||||
serde = {version = "1.0.219", features = ["derive"]}
|
||||
serde_json = "1.0.140"
|
||||
|
|
@ -12,3 +17,9 @@ aws-smithy-eventstream = "0.60"
|
|||
bytes = "1.10"
|
||||
uuid = { version = "1.11", features = ["v4"] }
|
||||
log = "0.4"
|
||||
chrono = { version = "0.4", optional = true }
|
||||
ureq = { version = "3.1", features = ["json"], optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
model-fetch = ["ureq", "chrono"]
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
167
crates/hermesllm/src/bin/fetch_models.rs
Normal file
167
crates/hermesllm/src/bin/fetch_models.rs
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
// Fetch latest provider models from OpenRouter and update provider_models.json
|
||||
// Usage: OPENROUTER_API_KEY=xxx cargo run --bin fetch_models
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn main() {
|
||||
// Default to writing in the same directory as this source file
|
||||
let default_path = std::path::Path::new(file!())
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("provider_models.json");
|
||||
|
||||
let output_path = std::env::args()
|
||||
.nth(1)
|
||||
.unwrap_or_else(|| default_path.to_string_lossy().to_string());
|
||||
|
||||
println!("Fetching latest models from OpenRouter...");
|
||||
|
||||
match fetch_openrouter_models() {
|
||||
Ok(models) => {
|
||||
let json = serde_json::to_string_pretty(&models).expect("Failed to serialize models");
|
||||
|
||||
std::fs::write(&output_path, json).expect("Failed to write provider_models.json");
|
||||
|
||||
println!(
|
||||
"✓ Successfully updated {} providers ({} models) to {}",
|
||||
models.metadata.total_providers, models.metadata.total_models, output_path
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error fetching models: {}", e);
|
||||
eprintln!("\nMake sure OPENROUTER_API_KEY is set:");
|
||||
eprintln!(" export OPENROUTER_API_KEY=your-key-here");
|
||||
eprintln!(" cargo run --bin fetch_models");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenRouterModel {
|
||||
id: String,
|
||||
architecture: Option<Architecture>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Architecture {
|
||||
modality: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenRouterResponse {
|
||||
data: Vec<OpenRouterModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ProviderModels {
|
||||
version: String,
|
||||
source: String,
|
||||
providers: HashMap<String, Vec<String>>,
|
||||
metadata: Metadata,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Metadata {
|
||||
total_providers: usize,
|
||||
total_models: usize,
|
||||
last_updated: String,
|
||||
}
|
||||
|
||||
fn fetch_openrouter_models() -> Result<ProviderModels, Box<dyn std::error::Error>> {
|
||||
let api_key = std::env::var("OPENROUTER_API_KEY")
|
||||
.map_err(|_| "OPENROUTER_API_KEY environment variable not set")?;
|
||||
|
||||
let response_body = ureq::get("https://openrouter.ai/api/v1/models")
|
||||
.header("Authorization", &format!("Bearer {}", api_key))
|
||||
.call()?
|
||||
.body_mut()
|
||||
.read_to_string()?;
|
||||
|
||||
let openrouter_response: OpenRouterResponse = serde_json::from_str(&response_body)?;
|
||||
|
||||
// Supported providers to include
|
||||
let supported_providers = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"mistralai",
|
||||
"deepseek",
|
||||
"google",
|
||||
"x-ai",
|
||||
"moonshotai",
|
||||
"qwen",
|
||||
"amazon",
|
||||
"z-ai",
|
||||
];
|
||||
|
||||
let mut providers: HashMap<String, Vec<String>> = HashMap::new();
|
||||
let mut total_models = 0;
|
||||
let mut filtered_modality: Vec<(String, String)> = Vec::new();
|
||||
let mut filtered_provider: Vec<(String, Option<String>)> = Vec::new();
|
||||
|
||||
for model in openrouter_response.data {
|
||||
let modality = model
|
||||
.architecture
|
||||
.as_ref()
|
||||
.and_then(|arch| arch.modality.clone());
|
||||
|
||||
// Only include text->text and text+image->text models
|
||||
if let Some(ref mod_str) = modality {
|
||||
if mod_str != "text->text" && mod_str != "text" && mod_str != "text+image->text" {
|
||||
filtered_modality.push((model.id.clone(), mod_str.clone()));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Extract provider from model ID (e.g., "openai/gpt-4" -> "openai")
|
||||
if let Some(provider_name) = model.id.split('/').next() {
|
||||
if supported_providers.contains(&provider_name) {
|
||||
providers
|
||||
.entry(provider_name.to_string())
|
||||
.or_default()
|
||||
.push(model.id.clone());
|
||||
total_models += 1;
|
||||
} else {
|
||||
filtered_provider.push((model.id.clone(), modality));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("✅ Loaded models from {} providers:", providers.len());
|
||||
let mut sorted_providers: Vec<_> = providers.iter().collect();
|
||||
sorted_providers.sort_by_key(|(name, _)| *name);
|
||||
for (provider, models) in sorted_providers {
|
||||
println!(" • {}: {} models", provider, models.len());
|
||||
}
|
||||
|
||||
// Group filtered providers to get counts
|
||||
let mut filtered_by_provider: HashMap<String, usize> = HashMap::new();
|
||||
for (model_id, _modality) in &filtered_provider {
|
||||
if let Some(provider_name) = model_id.split('/').next() {
|
||||
*filtered_by_provider
|
||||
.entry(provider_name.to_string())
|
||||
.or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"\n⏭️ Skipped {} providers ({} models total)",
|
||||
filtered_by_provider.len(),
|
||||
filtered_provider.len()
|
||||
);
|
||||
println!();
|
||||
|
||||
let total_providers = providers.len();
|
||||
|
||||
Ok(ProviderModels {
|
||||
version: "1.0".to_string(),
|
||||
source: "openrouter".to_string(),
|
||||
providers,
|
||||
metadata: Metadata {
|
||||
total_providers,
|
||||
total_models,
|
||||
last_updated: chrono::Utc::now().to_rfc3339(),
|
||||
},
|
||||
})
|
||||
}
|
||||
236
crates/hermesllm/src/bin/provider_models.json
Normal file
236
crates/hermesllm/src/bin/provider_models.json
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
{
|
||||
"version": "1.0",
|
||||
"source": "openrouter",
|
||||
"providers": {
|
||||
"openai": [
|
||||
"openai/gpt-5.2-codex",
|
||||
"openai/gpt-5.2-chat",
|
||||
"openai/gpt-5.2-pro",
|
||||
"openai/gpt-5.2",
|
||||
"openai/gpt-5.1-codex-max",
|
||||
"openai/gpt-5.1",
|
||||
"openai/gpt-5.1-chat",
|
||||
"openai/gpt-5.1-codex",
|
||||
"openai/gpt-5.1-codex-mini",
|
||||
"openai/gpt-oss-safeguard-20b",
|
||||
"openai/o3-deep-research",
|
||||
"openai/o4-mini-deep-research",
|
||||
"openai/gpt-5-pro",
|
||||
"openai/gpt-5-codex",
|
||||
"openai/gpt-4o-audio-preview",
|
||||
"openai/gpt-5-chat",
|
||||
"openai/gpt-5",
|
||||
"openai/gpt-5-mini",
|
||||
"openai/gpt-5-nano",
|
||||
"openai/gpt-oss-120b:free",
|
||||
"openai/gpt-oss-120b",
|
||||
"openai/gpt-oss-120b:exacto",
|
||||
"openai/gpt-oss-20b:free",
|
||||
"openai/gpt-oss-20b",
|
||||
"openai/o3-pro",
|
||||
"openai/o4-mini-high",
|
||||
"openai/o3",
|
||||
"openai/o4-mini",
|
||||
"openai/gpt-4.1",
|
||||
"openai/gpt-4.1-mini",
|
||||
"openai/gpt-4.1-nano",
|
||||
"openai/o1-pro",
|
||||
"openai/gpt-4o-mini-search-preview",
|
||||
"openai/gpt-4o-search-preview",
|
||||
"openai/o3-mini-high",
|
||||
"openai/o3-mini",
|
||||
"openai/o1",
|
||||
"openai/gpt-4o-2024-11-20",
|
||||
"openai/chatgpt-4o-latest",
|
||||
"openai/gpt-4o-2024-08-06",
|
||||
"openai/gpt-4o-mini-2024-07-18",
|
||||
"openai/gpt-4o-mini",
|
||||
"openai/gpt-4o-2024-05-13",
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o:extended",
|
||||
"openai/gpt-4-turbo",
|
||||
"openai/gpt-3.5-turbo-0613",
|
||||
"openai/gpt-4-turbo-preview",
|
||||
"openai/gpt-4-1106-preview",
|
||||
"openai/gpt-3.5-turbo-instruct",
|
||||
"openai/gpt-3.5-turbo-16k",
|
||||
"openai/gpt-4-0314",
|
||||
"openai/gpt-4",
|
||||
"openai/gpt-3.5-turbo"
|
||||
],
|
||||
"mistralai": [
|
||||
"mistralai/mistral-small-creative",
|
||||
"mistralai/devstral-2512:free",
|
||||
"mistralai/devstral-2512",
|
||||
"mistralai/ministral-14b-2512",
|
||||
"mistralai/ministral-8b-2512",
|
||||
"mistralai/ministral-3b-2512",
|
||||
"mistralai/mistral-large-2512",
|
||||
"mistralai/voxtral-small-24b-2507",
|
||||
"mistralai/mistral-medium-3.1",
|
||||
"mistralai/codestral-2508",
|
||||
"mistralai/devstral-medium",
|
||||
"mistralai/devstral-small",
|
||||
"mistralai/mistral-small-3.2-24b-instruct",
|
||||
"mistralai/mistral-medium-3",
|
||||
"mistralai/mistral-small-3.1-24b-instruct:free",
|
||||
"mistralai/mistral-small-3.1-24b-instruct",
|
||||
"mistralai/mistral-saba",
|
||||
"mistralai/mistral-small-24b-instruct-2501",
|
||||
"mistralai/mistral-large-2411",
|
||||
"mistralai/mistral-large-2407",
|
||||
"mistralai/pixtral-large-2411",
|
||||
"mistralai/ministral-8b",
|
||||
"mistralai/ministral-3b",
|
||||
"mistralai/pixtral-12b",
|
||||
"mistralai/mistral-nemo",
|
||||
"mistralai/mistral-7b-instruct",
|
||||
"mistralai/mistral-7b-instruct-v0.3",
|
||||
"mistralai/mixtral-8x22b-instruct",
|
||||
"mistralai/mistral-large",
|
||||
"mistralai/mistral-tiny",
|
||||
"mistralai/mistral-7b-instruct-v0.2",
|
||||
"mistralai/mixtral-8x7b-instruct",
|
||||
"mistralai/mistral-7b-instruct-v0.1"
|
||||
],
|
||||
"qwen": [
|
||||
"qwen/qwen3-vl-32b-instruct",
|
||||
"qwen/qwen3-vl-8b-thinking",
|
||||
"qwen/qwen3-vl-8b-instruct",
|
||||
"qwen/qwen3-vl-30b-a3b-thinking",
|
||||
"qwen/qwen3-vl-30b-a3b-instruct",
|
||||
"qwen/qwen3-vl-235b-a22b-thinking",
|
||||
"qwen/qwen3-vl-235b-a22b-instruct",
|
||||
"qwen/qwen3-max",
|
||||
"qwen/qwen3-coder-plus",
|
||||
"qwen/qwen3-coder-flash",
|
||||
"qwen/qwen3-next-80b-a3b-thinking",
|
||||
"qwen/qwen3-next-80b-a3b-instruct:free",
|
||||
"qwen/qwen3-next-80b-a3b-instruct",
|
||||
"qwen/qwen-plus-2025-07-28",
|
||||
"qwen/qwen-plus-2025-07-28:thinking",
|
||||
"qwen/qwen3-30b-a3b-thinking-2507",
|
||||
"qwen/qwen3-coder-30b-a3b-instruct",
|
||||
"qwen/qwen3-30b-a3b-instruct-2507",
|
||||
"qwen/qwen3-235b-a22b-thinking-2507",
|
||||
"qwen/qwen3-coder:free",
|
||||
"qwen/qwen3-coder",
|
||||
"qwen/qwen3-coder:exacto",
|
||||
"qwen/qwen3-235b-a22b-2507",
|
||||
"qwen/qwen3-4b:free",
|
||||
"qwen/qwen3-30b-a3b",
|
||||
"qwen/qwen3-8b",
|
||||
"qwen/qwen3-14b",
|
||||
"qwen/qwen3-32b",
|
||||
"qwen/qwen3-235b-a22b",
|
||||
"qwen/qwen2.5-coder-7b-instruct",
|
||||
"qwen/qwen2.5-vl-32b-instruct",
|
||||
"qwen/qwq-32b",
|
||||
"qwen/qwen-vl-plus",
|
||||
"qwen/qwen-vl-max",
|
||||
"qwen/qwen-turbo",
|
||||
"qwen/qwen2.5-vl-72b-instruct",
|
||||
"qwen/qwen-plus",
|
||||
"qwen/qwen-max",
|
||||
"qwen/qwen-2.5-coder-32b-instruct",
|
||||
"qwen/qwen-2.5-7b-instruct",
|
||||
"qwen/qwen-2.5-72b-instruct",
|
||||
"qwen/qwen-2.5-vl-7b-instruct:free",
|
||||
"qwen/qwen-2.5-vl-7b-instruct"
|
||||
],
|
||||
"z-ai": [
|
||||
"z-ai/glm-4.7",
|
||||
"z-ai/glm-4.6v",
|
||||
"z-ai/glm-4.6",
|
||||
"z-ai/glm-4.6:exacto",
|
||||
"z-ai/glm-4.5v",
|
||||
"z-ai/glm-4.5",
|
||||
"z-ai/glm-4.5-air:free",
|
||||
"z-ai/glm-4.5-air",
|
||||
"z-ai/glm-4-32b"
|
||||
],
|
||||
"moonshotai": [
|
||||
"moonshotai/kimi-k2-thinking",
|
||||
"moonshotai/kimi-k2-0905",
|
||||
"moonshotai/kimi-k2-0905:exacto",
|
||||
"moonshotai/kimi-k2:free",
|
||||
"moonshotai/kimi-k2",
|
||||
"moonshotai/kimi-dev-72b"
|
||||
],
|
||||
"anthropic": [
|
||||
"anthropic/claude-opus-4.5",
|
||||
"anthropic/claude-haiku-4.5",
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"anthropic/claude-opus-4.1",
|
||||
"anthropic/claude-opus-4",
|
||||
"anthropic/claude-sonnet-4",
|
||||
"anthropic/claude-3.7-sonnet:thinking",
|
||||
"anthropic/claude-3.7-sonnet",
|
||||
"anthropic/claude-3.5-haiku",
|
||||
"anthropic/claude-3.5-sonnet",
|
||||
"anthropic/claude-3-haiku"
|
||||
],
|
||||
"google": [
|
||||
"google/gemini-3-flash-preview",
|
||||
"google/gemini-3-pro-preview",
|
||||
"google/gemini-2.5-flash-preview-09-2025",
|
||||
"google/gemini-2.5-flash-lite-preview-09-2025",
|
||||
"google/gemini-2.5-flash-lite",
|
||||
"google/gemma-3n-e2b-it:free",
|
||||
"google/gemini-2.5-flash",
|
||||
"google/gemini-2.5-pro",
|
||||
"google/gemini-2.5-pro-preview",
|
||||
"google/gemma-3n-e4b-it:free",
|
||||
"google/gemma-3n-e4b-it",
|
||||
"google/gemini-2.5-pro-preview-05-06",
|
||||
"google/gemma-3-4b-it:free",
|
||||
"google/gemma-3-4b-it",
|
||||
"google/gemma-3-12b-it:free",
|
||||
"google/gemma-3-12b-it",
|
||||
"google/gemma-3-27b-it:free",
|
||||
"google/gemma-3-27b-it",
|
||||
"google/gemini-2.0-flash-lite-001",
|
||||
"google/gemini-2.0-flash-001",
|
||||
"google/gemini-2.0-flash-exp:free",
|
||||
"google/gemma-2-27b-it",
|
||||
"google/gemma-2-9b-it"
|
||||
],
|
||||
"amazon": [
|
||||
"amazon/nova-2-lite-v1",
|
||||
"amazon/nova-premier-v1",
|
||||
"amazon/nova-lite-v1",
|
||||
"amazon/nova-micro-v1",
|
||||
"amazon/nova-pro-v1"
|
||||
],
|
||||
"deepseek": [
|
||||
"deepseek/deepseek-v3.2-speciale",
|
||||
"deepseek/deepseek-v3.2",
|
||||
"deepseek/deepseek-v3.2-exp",
|
||||
"deepseek/deepseek-v3.1-terminus:exacto",
|
||||
"deepseek/deepseek-v3.1-terminus",
|
||||
"deepseek/deepseek-chat-v3.1",
|
||||
"deepseek/deepseek-r1-0528:free",
|
||||
"deepseek/deepseek-r1-0528",
|
||||
"deepseek/deepseek-chat-v3-0324",
|
||||
"deepseek/deepseek-r1-distill-qwen-32b",
|
||||
"deepseek/deepseek-r1-distill-llama-70b",
|
||||
"deepseek/deepseek-r1",
|
||||
"deepseek/deepseek-chat"
|
||||
],
|
||||
"x-ai": [
|
||||
"x-ai/grok-4.1-fast",
|
||||
"x-ai/grok-4-fast",
|
||||
"x-ai/grok-code-fast-1",
|
||||
"x-ai/grok-4",
|
||||
"x-ai/grok-3-mini",
|
||||
"x-ai/grok-3",
|
||||
"x-ai/grok-3-mini-beta",
|
||||
"x-ai/grok-3-beta"
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"total_providers": 10,
|
||||
"total_models": 205,
|
||||
"last_updated": "2026-01-16T20:30:00.806165+00:00"
|
||||
}
|
||||
}
|
||||
|
|
@ -29,10 +29,27 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_provider_id_conversion() {
|
||||
assert_eq!(ProviderId::from("openai"), ProviderId::OpenAI);
|
||||
assert_eq!(ProviderId::from("mistral"), ProviderId::Mistral);
|
||||
assert_eq!(ProviderId::from("groq"), ProviderId::Groq);
|
||||
assert_eq!(ProviderId::from("arch"), ProviderId::Arch);
|
||||
assert_eq!(ProviderId::try_from("openai").unwrap(), ProviderId::OpenAI);
|
||||
assert_eq!(
|
||||
ProviderId::try_from("mistral").unwrap(),
|
||||
ProviderId::Mistral
|
||||
);
|
||||
assert_eq!(ProviderId::try_from("groq").unwrap(), ProviderId::Groq);
|
||||
assert_eq!(ProviderId::try_from("arch").unwrap(), ProviderId::Arch);
|
||||
|
||||
// Test aliases
|
||||
assert_eq!(ProviderId::try_from("google").unwrap(), ProviderId::Gemini);
|
||||
assert_eq!(
|
||||
ProviderId::try_from("together").unwrap(),
|
||||
ProviderId::TogetherAI
|
||||
);
|
||||
assert_eq!(
|
||||
ProviderId::try_from("amazon").unwrap(),
|
||||
ProviderId::AmazonBedrock
|
||||
);
|
||||
|
||||
// Test error case
|
||||
assert!(ProviderId::try_from("unknown_provider").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,36 @@
|
|||
use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi};
|
||||
use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
static PROVIDER_MODELS_JSON: &str = include_str!("../bin/provider_models.json");
|
||||
|
||||
fn load_provider_models() -> &'static HashMap<String, Vec<String>> {
|
||||
static MODELS: OnceLock<HashMap<String, Vec<String>>> = OnceLock::new();
|
||||
MODELS.get_or_init(|| {
|
||||
let data: serde_json::Value = serde_json::from_str(PROVIDER_MODELS_JSON)
|
||||
.expect("Failed to parse provider_models.json");
|
||||
|
||||
let providers = data
|
||||
.get("providers")
|
||||
.expect("Missing 'providers' key")
|
||||
.as_object()
|
||||
.expect("'providers' must be an object");
|
||||
|
||||
let mut result = HashMap::new();
|
||||
for (provider, models) in providers {
|
||||
let model_list: Vec<String> = models
|
||||
.as_array()
|
||||
.expect("Models must be an array")
|
||||
.iter()
|
||||
.map(|m| m.as_str().expect("Model must be a string").to_string())
|
||||
.collect();
|
||||
result.insert(provider.clone(), model_list);
|
||||
}
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
/// Provider identifier enum - simple enum for identifying providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
|
|
@ -23,31 +53,70 @@ pub enum ProviderId {
|
|||
AmazonBedrock,
|
||||
}
|
||||
|
||||
impl From<&str> for ProviderId {
|
||||
fn from(value: &str) -> Self {
|
||||
impl TryFrom<&str> for ProviderId {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
match value.to_lowercase().as_str() {
|
||||
"openai" => ProviderId::OpenAI,
|
||||
"mistral" => ProviderId::Mistral,
|
||||
"deepseek" => ProviderId::Deepseek,
|
||||
"groq" => ProviderId::Groq,
|
||||
"gemini" => ProviderId::Gemini,
|
||||
"anthropic" => ProviderId::Anthropic,
|
||||
"github" => ProviderId::GitHub,
|
||||
"arch" => ProviderId::Arch,
|
||||
"azure_openai" => ProviderId::AzureOpenAI,
|
||||
"xai" => ProviderId::XAI,
|
||||
"together_ai" => ProviderId::TogetherAI,
|
||||
"ollama" => ProviderId::Ollama,
|
||||
"moonshotai" => ProviderId::Moonshotai,
|
||||
"zhipu" => ProviderId::Zhipu,
|
||||
"qwen" => ProviderId::Qwen, // alias for Qwen
|
||||
"amazon_bedrock" => ProviderId::AmazonBedrock,
|
||||
_ => panic!("Unknown provider: {}", value),
|
||||
"openai" => Ok(ProviderId::OpenAI),
|
||||
"mistral" => Ok(ProviderId::Mistral),
|
||||
"deepseek" => Ok(ProviderId::Deepseek),
|
||||
"groq" => Ok(ProviderId::Groq),
|
||||
"gemini" => Ok(ProviderId::Gemini),
|
||||
"google" => Ok(ProviderId::Gemini), // alias
|
||||
"anthropic" => Ok(ProviderId::Anthropic),
|
||||
"github" => Ok(ProviderId::GitHub),
|
||||
"arch" => Ok(ProviderId::Arch),
|
||||
"azure_openai" => Ok(ProviderId::AzureOpenAI),
|
||||
"xai" => Ok(ProviderId::XAI),
|
||||
"together_ai" => Ok(ProviderId::TogetherAI),
|
||||
"together" => Ok(ProviderId::TogetherAI), // alias
|
||||
"ollama" => Ok(ProviderId::Ollama),
|
||||
"moonshotai" => Ok(ProviderId::Moonshotai),
|
||||
"zhipu" => Ok(ProviderId::Zhipu),
|
||||
"qwen" => Ok(ProviderId::Qwen),
|
||||
"amazon_bedrock" => Ok(ProviderId::AmazonBedrock),
|
||||
"amazon" => Ok(ProviderId::AmazonBedrock), // alias
|
||||
_ => Err(format!("Unknown provider: {}", value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderId {
|
||||
/// Get all available models for this provider
|
||||
/// Returns model names without the provider prefix (e.g., "gpt-4" not "openai/gpt-4")
|
||||
pub fn models(&self) -> Vec<String> {
|
||||
let provider_key = match self {
|
||||
ProviderId::AmazonBedrock => "amazon",
|
||||
ProviderId::AzureOpenAI => "openai",
|
||||
ProviderId::TogetherAI => "together",
|
||||
ProviderId::Gemini => "google",
|
||||
ProviderId::OpenAI => "openai",
|
||||
ProviderId::Anthropic => "anthropic",
|
||||
ProviderId::Mistral => "mistralai",
|
||||
ProviderId::Deepseek => "deepseek",
|
||||
ProviderId::Groq => "groq",
|
||||
ProviderId::XAI => "x-ai",
|
||||
ProviderId::Moonshotai => "moonshotai",
|
||||
ProviderId::Zhipu => "z-ai",
|
||||
ProviderId::Qwen => "qwen",
|
||||
_ => return Vec::new(),
|
||||
};
|
||||
|
||||
load_provider_models()
|
||||
.get(provider_key)
|
||||
.map(|models| {
|
||||
models
|
||||
.iter()
|
||||
.filter_map(|model| {
|
||||
// Strip provider prefix (e.g., "openai/gpt-4" -> "gpt-4")
|
||||
model.split_once('/').map(|(_, name)| name.to_string())
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Given a client API, return the compatible upstream API for this provider
|
||||
pub fn compatible_api_for_client(
|
||||
&self,
|
||||
|
|
@ -169,3 +238,102 @@ impl Display for ProviderId {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_models_loaded_from_json() {
|
||||
// Test that we can load models for each supported provider
|
||||
let openai_models = ProviderId::OpenAI.models();
|
||||
assert!(!openai_models.is_empty(), "OpenAI should have models");
|
||||
|
||||
let anthropic_models = ProviderId::Anthropic.models();
|
||||
assert!(!anthropic_models.is_empty(), "Anthropic should have models");
|
||||
|
||||
let mistral_models = ProviderId::Mistral.models();
|
||||
assert!(!mistral_models.is_empty(), "Mistral should have models");
|
||||
|
||||
let deepseek_models = ProviderId::Deepseek.models();
|
||||
assert!(!deepseek_models.is_empty(), "Deepseek should have models");
|
||||
|
||||
let gemini_models = ProviderId::Gemini.models();
|
||||
assert!(!gemini_models.is_empty(), "Gemini should have models");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_names_without_provider_prefix() {
|
||||
// Test that model names don't include the provider/ prefix
|
||||
let openai_models = ProviderId::OpenAI.models();
|
||||
for model in &openai_models {
|
||||
assert!(
|
||||
!model.contains('/'),
|
||||
"Model name '{}' should not contain provider prefix",
|
||||
model
|
||||
);
|
||||
}
|
||||
|
||||
let anthropic_models = ProviderId::Anthropic.models();
|
||||
for model in &anthropic_models {
|
||||
assert!(
|
||||
!model.contains('/'),
|
||||
"Model name '{}' should not contain provider prefix",
|
||||
model
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_specific_models_exist() {
|
||||
// Test that specific well-known models are present
|
||||
let openai_models = ProviderId::OpenAI.models();
|
||||
let has_gpt4 = openai_models.iter().any(|m| m.contains("gpt-4"));
|
||||
assert!(has_gpt4, "OpenAI models should include GPT-4 variants");
|
||||
|
||||
let anthropic_models = ProviderId::Anthropic.models();
|
||||
let has_claude = anthropic_models.iter().any(|m| m.contains("claude"));
|
||||
assert!(
|
||||
has_claude,
|
||||
"Anthropic models should include Claude variants"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsupported_providers_return_empty() {
|
||||
// Providers without models should return empty vec
|
||||
let github_models = ProviderId::GitHub.models();
|
||||
assert!(
|
||||
github_models.is_empty(),
|
||||
"GitHub should return empty models list"
|
||||
);
|
||||
|
||||
let ollama_models = ProviderId::Ollama.models();
|
||||
assert!(
|
||||
ollama_models.is_empty(),
|
||||
"Ollama should return empty models list"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_name_mapping() {
|
||||
// Test that provider key mappings work correctly
|
||||
let xai_models = ProviderId::XAI.models();
|
||||
assert!(
|
||||
!xai_models.is_empty(),
|
||||
"XAI should have models (mapped to x-ai)"
|
||||
);
|
||||
|
||||
let zhipu_models = ProviderId::Zhipu.models();
|
||||
assert!(
|
||||
!zhipu_models.is_empty(),
|
||||
"Zhipu should have models (mapped to z-ai)"
|
||||
);
|
||||
|
||||
let amazon_models = ProviderId::AmazonBedrock.models();
|
||||
assert!(
|
||||
!amazon_models.is_empty(),
|
||||
"AmazonBedrock should have models (mapped to amazon)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use log::{debug, error, info, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
|
|
@ -128,16 +128,23 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
fn select_llm_provider(&mut self) {
|
||||
fn select_llm_provider(&mut self) -> Result<(), String> {
|
||||
let provider_hint = self
|
||||
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.map(|llm_name| llm_name.into());
|
||||
|
||||
// info!("llm_providers: {:?}", self.llm_providers);
|
||||
self.llm_provider = Some(routing::get_llm_provider(
|
||||
&self.llm_providers,
|
||||
provider_hint,
|
||||
));
|
||||
let provider =
|
||||
routing::get_llm_provider(&self.llm_providers, provider_hint).map_err(|err| {
|
||||
error!(
|
||||
"[PLANO_REQ_ID:{}] PROVIDER_SELECTION_FAILED: Hint='None' Error='{}'",
|
||||
self.request_identifier(),
|
||||
err
|
||||
);
|
||||
err
|
||||
})?;
|
||||
|
||||
self.llm_provider = Some(provider);
|
||||
|
||||
info!(
|
||||
"[PLANO_REQ_ID:{}] PROVIDER_SELECTION: Hint='{}' -> Selected='{}'",
|
||||
|
|
@ -146,6 +153,8 @@ impl StreamContext {
|
|||
.unwrap_or("none".to_string()),
|
||||
self.llm_provider.as_ref().unwrap().name
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
|
||||
|
|
@ -747,7 +756,15 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// let routing_header_value = self.get_http_request_header(ARCH_ROUTING_HEADER);
|
||||
|
||||
self.select_llm_provider();
|
||||
if let Err(err) = self.select_llm_provider() {
|
||||
self.send_http_response(
|
||||
400,
|
||||
vec![],
|
||||
Some(format!(r#"{{"error": "{}"}}"#, err).as_bytes()),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
// Check if this is a supported API endpoint
|
||||
if SupportedAPIsFromClient::from_endpoint(&request_path).is_none() {
|
||||
self.send_http_response(404, vec![], Some(b"Unsupported endpoint"));
|
||||
|
|
|
|||
26
demos/use_cases/wildcard_providers/config.yaml
Normal file
26
demos/use_cases/wildcard_providers/config.yaml
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
version: v0.3.0
|
||||
|
||||
listeners:
|
||||
# Model listener for direct LLM access
|
||||
- type: model
|
||||
name: llms
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
|
||||
model_providers:
|
||||
# OpenAI - support all models via wildcard
|
||||
- model: openai/*
|
||||
access_key: $OPENAI_API_KEY
|
||||
|
||||
# Anthropic - support all Claude models
|
||||
- model: anthropic/*
|
||||
access_key: $ANTHROPIC_API_KEY
|
||||
|
||||
- model: xai/*
|
||||
access_key: $GROK_API_KEY
|
||||
|
||||
|
||||
# Custom internal LLM provider
|
||||
# Note: Requires base_url and provider_interface for unknown providers
|
||||
- model: ollama/*
|
||||
base_url: https://llm.internal.company.com
|
||||
Loading…
Add table
Add a link
Reference in a new issue