adding support for wildcard model providers

This commit is contained in:
Salman Paracha 2026-01-16 14:53:02 -08:00
parent 86cf8ccdaa
commit 34711c6f9d
14 changed files with 1027 additions and 1823 deletions

View file

@ -255,7 +255,8 @@ impl LlmProviderType {
/// Get the ProviderId for this LlmProviderType
/// Used with the new function-based hermesllm API
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
hermesllm::ProviderId::from(self.to_string().as_str())
hermesllm::ProviderId::try_from(self.to_string().as_str())
.expect("LlmProviderType should always map to a valid ProviderId")
}
}

View file

@ -1,4 +1,5 @@
use crate::configuration::LlmProvider;
use hermesllm::providers::ProviderId;
use std::collections::HashMap;
use std::rc::Rc;
@ -6,6 +7,9 @@ use std::rc::Rc;
pub struct LlmProviders {
providers: HashMap<String, Rc<LlmProvider>>,
default: Option<Rc<LlmProvider>>,
/// Wildcard providers: maps provider prefix to base provider config
/// e.g., "openai" -> LlmProvider for "openai/*"
wildcard_providers: HashMap<String, Rc<LlmProvider>>,
}
impl LlmProviders {
@ -18,7 +22,36 @@ impl LlmProviders {
}
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
self.providers.get(name).cloned()
// First try exact match
if let Some(provider) = self.providers.get(name).cloned() {
return Some(provider);
}
// If name contains '/', it could be:
// 1. A full model ID like "openai/gpt-4" that we need to lookup
// 2. A provider/model slug that should match a wildcard provider
if let Some((provider_prefix, model_name)) = name.split_once('/') {
// Try to find the expanded model entry (e.g., "openai/gpt-4")
let full_model_id = format!("{}/{}", provider_prefix, model_name);
if let Some(provider) = self.providers.get(&full_model_id).cloned() {
return Some(provider);
}
// Try to find just the model name (for expanded wildcard entries)
if let Some(provider) = self.providers.get(model_name).cloned() {
return Some(provider);
}
// Fall back to wildcard match (e.g., "openai/*")
if let Some(wildcard_provider) = self.wildcard_providers.get(provider_prefix) {
// Create a new provider with the specific model from the slug
let mut specific_provider = (**wildcard_provider).clone();
specific_provider.model = Some(model_name.to_string());
return Some(Rc::new(specific_provider));
}
}
None
}
}
@ -43,10 +76,12 @@ impl TryFrom<Vec<LlmProvider>> for LlmProviders {
let mut llm_providers = LlmProviders {
providers: HashMap::new(),
default: None,
wildcard_providers: HashMap::new(),
};
for llm_provider in llm_providers_config {
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
if llm_provider.default.unwrap_or_default() {
match llm_providers.default {
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
@ -54,27 +89,167 @@ impl TryFrom<Vec<LlmProvider>> for LlmProviders {
}
}
// Insert and check that there is no other provider with the same name.
let name = llm_provider.name.clone();
if llm_providers
.providers
.insert(name.clone(), Rc::clone(&llm_provider))
.is_some()
{
return Err(LlmProvidersNewError::DuplicateName(name));
}
// also add model_id as key for provider lookup
if let Some(model) = llm_provider.model.clone() {
// Check if this is a wildcard provider (model is "*" or ends with "/*")
let is_wildcard = llm_provider
.model
.as_ref()
.map(|m| m == "*" || m.ends_with("/*"))
.unwrap_or(false);
if is_wildcard {
// Extract provider prefix from name
// e.g., "openai/*" -> "openai"
let provider_prefix = name.trim_end_matches("/*").trim_end_matches('*');
// For wildcard providers, we:
// 1. Store the base config in wildcard_providers for runtime matching
// 2. Optionally expand to all known models if available
llm_providers
.wildcard_providers
.insert(provider_prefix.to_string(), Rc::clone(&llm_provider));
// Try to expand wildcard using ProviderId models
if let Ok(provider_id) = ProviderId::try_from(provider_prefix) {
let models = provider_id.models();
if !models.is_empty() {
log::info!(
"Expanding wildcard provider '{}' to {} models",
provider_prefix,
models.len()
);
// Create a provider entry for each model
for model_name in models {
let full_model_id = format!("{}/{}", provider_prefix, model_name);
// Create a new provider with the specific model
let mut expanded_provider = (*llm_provider).clone();
expanded_provider.model = Some(model_name.clone());
expanded_provider.name = full_model_id.clone();
let expanded_rc = Rc::new(expanded_provider);
// Insert with full model ID as key
llm_providers
.providers
.insert(full_model_id.clone(), Rc::clone(&expanded_rc));
// Also insert with just model name for backward compatibility
llm_providers.providers.insert(model_name, expanded_rc);
}
}
} else {
log::warn!(
"Wildcard provider '{}' specified but no models found in registry. \
Will match dynamically at runtime.",
provider_prefix
);
}
} else {
// Non-wildcard provider - original behavior
if llm_providers
.providers
.insert(model, llm_provider)
.insert(name.clone(), Rc::clone(&llm_provider))
.is_some()
{
return Err(LlmProvidersNewError::DuplicateName(name));
}
// also add model_id as key for provider lookup
if let Some(model) = llm_provider.model.clone() {
if llm_providers
.providers
.insert(model, llm_provider)
.is_some()
{
return Err(LlmProvidersNewError::DuplicateName(name));
}
}
}
}
Ok(llm_providers)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::configuration::LlmProviderType;
fn create_test_provider(name: &str, model: Option<String>) -> LlmProvider {
LlmProvider {
name: name.to_string(),
model,
access_key: None,
endpoint: None,
cluster_name: None,
provider_interface: LlmProviderType::OpenAI,
default: None,
base_url_path_prefix: None,
port: None,
rate_limits: None,
usage: None,
routing_preferences: None,
internal: None,
stream: None,
}
}
#[test]
fn test_static_provider_lookup() {
// Test 1: Statically defined provider - should be findable by model or provider name
let providers = vec![create_test_provider("my-openai", Some("gpt-4".to_string()))];
let llm_providers = LlmProviders::try_from(providers).unwrap();
// Should find by model name
let result = llm_providers.get("gpt-4");
assert!(result.is_some());
assert_eq!(result.unwrap().name, "my-openai");
// Should also find by provider name
let result = llm_providers.get("my-openai");
assert!(result.is_some());
assert_eq!(result.unwrap().name, "my-openai");
}
#[test]
fn test_wildcard_provider_with_known_model() {
// Test 2: Wildcard provider that expands to OpenAI models
let providers = vec![create_test_provider("openai/*", Some("*".to_string()))];
let llm_providers = LlmProviders::try_from(providers).unwrap();
// Should find via expanded wildcard entry
let result = llm_providers.get("openai/gpt-4");
let provider = result.unwrap();
assert_eq!(provider.name, "openai/gpt-4");
assert_eq!(provider.model.as_ref().unwrap(), "gpt-4");
// Should also be able to find by just model name (from expansion)
let result = llm_providers.get("gpt-4");
assert_eq!(result.unwrap().model.as_ref().unwrap(), "gpt-4");
}
#[test]
fn test_custom_wildcard_provider_with_full_slug() {
// Test 3: Custom wildcard provider with full slug offered
let providers = vec![create_test_provider(
"custom-provider/*",
Some("*".to_string()),
)];
let llm_providers = LlmProviders::try_from(providers).unwrap();
// Should match via wildcard fallback and extract model name from slug
let result = llm_providers.get("custom-provider/custom-model");
let provider = result.unwrap();
assert_eq!(provider.model.as_ref().unwrap(), "custom-model");
// Wildcard should be stored
assert!(llm_providers
.wildcard_providers
.contains_key("custom-provider"));
}
}

View file

@ -2,9 +2,8 @@ use std::rc::Rc;
use crate::{configuration, llm_providers::LlmProviders};
use configuration::LlmProvider;
use rand::{seq::IteratorRandom, thread_rng};
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum ProviderHint {
Default,
Name(String),
@ -22,33 +21,19 @@ impl From<String> for ProviderHint {
pub fn get_llm_provider(
llm_providers: &LlmProviders,
provider_hint: Option<ProviderHint>,
) -> Rc<LlmProvider> {
let maybe_provider = provider_hint.and_then(|hint| match hint {
ProviderHint::Default => llm_providers.default(),
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
ProviderHint::Name(name) => llm_providers.get(&name),
});
if let Some(provider) = maybe_provider {
return provider;
) -> Result<Rc<LlmProvider>, String> {
match provider_hint {
Some(ProviderHint::Default) => llm_providers
.default()
.ok_or_else(|| "No default provider configured".to_string()),
Some(ProviderHint::Name(name)) => llm_providers
.get(&name)
.ok_or_else(|| format!("Model '{}' not found in configured providers", name)),
None => {
// No hint provided - must have a default configured
llm_providers
.default()
.ok_or_else(|| "No model specified and no default provider configured".to_string())
}
}
if llm_providers.default().is_some() {
return llm_providers.default().unwrap();
}
let mut rng = thread_rng();
llm_providers
.iter()
.filter(|(_, provider)| {
provider
.model
.as_ref()
.map(|m| !m.starts_with("Arch"))
.unwrap_or(true)
})
.choose(&mut rng)
.expect("There should always be at least one non-Arch llm provider")
.1
.clone()
}