mirror of
https://github.com/katanemo/plano.git
synced 2026-04-30 19:36:34 +02:00
Adding support for wildcard models in the model_providers config (#696)
* cleaning up plano cli commands * adding support for wildcard model providers * fixing compile errors * fixing bugs related to default model provider, provider hint and duplicates in the model provider list * fixed cargo fmt issues * updating tests to always include the model id * using default for the prompt_gateway path * fixed the model name, as gpt-5-mini-2025-08-07 wasn't in the config * making sure that all aliases and models match the config * fixed the config generator to allow for base_url providers LLMs to include wildcard models * re-ran the models list utility and added a shell script to run it * updating docs to mention wildcard model providers * updated provider_models.json to yaml, added that file to our docs for reference * updating the build docs to use the new root-based build --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
parent
8428b06e22
commit
2941392ed1
42 changed files with 1748 additions and 202 deletions
|
|
@ -255,7 +255,8 @@ impl LlmProviderType {
|
|||
/// Get the ProviderId for this LlmProviderType
|
||||
/// Used with the new function-based hermesllm API
|
||||
pub fn to_provider_id(&self) -> hermesllm::ProviderId {
|
||||
hermesllm::ProviderId::from(self.to_string().as_str())
|
||||
hermesllm::ProviderId::try_from(self.to_string().as_str())
|
||||
.expect("LlmProviderType should always map to a valid ProviderId")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,24 +1,84 @@
|
|||
use crate::configuration::LlmProvider;
|
||||
use hermesllm::providers::ProviderId;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LlmProviders {
|
||||
providers: HashMap<String, Rc<LlmProvider>>,
|
||||
default: Option<Rc<LlmProvider>>,
|
||||
providers: HashMap<String, Arc<LlmProvider>>,
|
||||
default: Option<Arc<LlmProvider>>,
|
||||
/// Wildcard providers: maps provider prefix to base provider config
|
||||
/// e.g., "openai" -> LlmProvider for "openai/*"
|
||||
wildcard_providers: HashMap<String, Arc<LlmProvider>>,
|
||||
}
|
||||
|
||||
impl LlmProviders {
|
||||
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Rc<LlmProvider>> {
|
||||
pub fn iter(&self) -> std::collections::hash_map::Iter<'_, String, Arc<LlmProvider>> {
|
||||
self.providers.iter()
|
||||
}
|
||||
|
||||
pub fn default(&self) -> Option<Rc<LlmProvider>> {
|
||||
pub fn default(&self) -> Option<Arc<LlmProvider>> {
|
||||
self.default.clone()
|
||||
}
|
||||
/// Convert providers to OpenAI Models format for /v1/models endpoint
|
||||
/// Filters out internal models and duplicate entries (backward compatibility aliases)
|
||||
pub fn to_models(&self) -> hermesllm::apis::openai::Models {
|
||||
use hermesllm::apis::openai::{ModelDetail, ModelObject, Models};
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Rc<LlmProvider>> {
|
||||
self.providers.get(name).cloned()
|
||||
let data: Vec<ModelDetail> = self
|
||||
.providers
|
||||
.iter()
|
||||
.filter(|(key, provider)| {
|
||||
// Exclude internal models
|
||||
provider.internal != Some(true)
|
||||
// Only include canonical entries (key matches provider name)
|
||||
// This avoids duplicates from backward compatibility short names
|
||||
&& *key == &provider.name
|
||||
})
|
||||
.map(|(name, provider)| ModelDetail {
|
||||
id: name.clone(),
|
||||
object: Some("model".to_string()),
|
||||
created: 0,
|
||||
owned_by: provider.to_provider_id().to_string(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Models {
|
||||
object: ModelObject::List,
|
||||
data,
|
||||
}
|
||||
}
|
||||
pub fn get(&self, name: &str) -> Option<Arc<LlmProvider>> {
|
||||
// First try exact match
|
||||
if let Some(provider) = self.providers.get(name).cloned() {
|
||||
return Some(provider);
|
||||
}
|
||||
|
||||
// If name contains '/', it could be:
|
||||
// 1. A full model ID like "openai/gpt-4" that we need to lookup
|
||||
// 2. A provider/model slug that should match a wildcard provider
|
||||
if let Some((provider_prefix, model_name)) = name.split_once('/') {
|
||||
// Try to find the expanded model entry (e.g., "openai/gpt-4")
|
||||
let full_model_id = format!("{}/{}", provider_prefix, model_name);
|
||||
if let Some(provider) = self.providers.get(&full_model_id).cloned() {
|
||||
return Some(provider);
|
||||
}
|
||||
|
||||
// Try to find just the model name (for expanded wildcard entries)
|
||||
if let Some(provider) = self.providers.get(model_name).cloned() {
|
||||
return Some(provider);
|
||||
}
|
||||
|
||||
// Fall back to wildcard match (e.g., "openai/*")
|
||||
if let Some(wildcard_provider) = self.wildcard_providers.get(provider_prefix) {
|
||||
// Create a new provider with the specific model from the slug
|
||||
let mut specific_provider = (**wildcard_provider).clone();
|
||||
specific_provider.model = Some(model_name.to_string());
|
||||
return Some(Arc::new(specific_provider));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -43,38 +103,235 @@ impl TryFrom<Vec<LlmProvider>> for LlmProviders {
|
|||
let mut llm_providers = LlmProviders {
|
||||
providers: HashMap::new(),
|
||||
default: None,
|
||||
wildcard_providers: HashMap::new(),
|
||||
};
|
||||
|
||||
for llm_provider in llm_providers_config {
|
||||
let llm_provider: Rc<LlmProvider> = Rc::new(llm_provider);
|
||||
if llm_provider.default.unwrap_or_default() {
|
||||
match llm_providers.default {
|
||||
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
|
||||
None => llm_providers.default = Some(Rc::clone(&llm_provider)),
|
||||
}
|
||||
}
|
||||
// Track specific (non-wildcard) provider names to detect true duplicates
|
||||
let mut specific_provider_names = std::collections::HashSet::new();
|
||||
|
||||
// Insert and check that there is no other provider with the same name.
|
||||
let name = llm_provider.name.clone();
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(name.clone(), Rc::clone(&llm_provider))
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
// Track specific models that should be excluded from wildcard expansion
|
||||
// Maps provider_prefix -> Set of model names (e.g., "anthropic" -> {"claude-sonnet-4-20250514"})
|
||||
let mut specific_models_by_provider: HashMap<String, std::collections::HashSet<String>> =
|
||||
HashMap::new();
|
||||
|
||||
// also add model_id as key for provider lookup
|
||||
if let Some(model) = llm_provider.model.clone() {
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(model, llm_provider)
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
// First pass: collect all specific model configurations
|
||||
for llm_provider in &llm_providers_config {
|
||||
let is_wildcard = llm_provider
|
||||
.model
|
||||
.as_ref()
|
||||
.map(|m| m == "*" || m.ends_with("/*"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_wildcard {
|
||||
// Check if this is a provider/model format
|
||||
if let Some((provider_prefix, model_name)) = llm_provider.name.split_once('/') {
|
||||
specific_models_by_provider
|
||||
.entry(provider_prefix.to_string())
|
||||
.or_default()
|
||||
.insert(model_name.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for llm_provider in llm_providers_config {
|
||||
let llm_provider: Arc<LlmProvider> = Arc::new(llm_provider);
|
||||
|
||||
if llm_provider.default.unwrap_or_default() {
|
||||
match llm_providers.default {
|
||||
Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault),
|
||||
None => llm_providers.default = Some(Arc::clone(&llm_provider)),
|
||||
}
|
||||
}
|
||||
|
||||
let name = llm_provider.name.clone();
|
||||
|
||||
// Check if this is a wildcard provider (model is "*" or ends with "/*")
|
||||
let is_wildcard = llm_provider
|
||||
.model
|
||||
.as_ref()
|
||||
.map(|m| m == "*" || m.ends_with("/*"))
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_wildcard {
|
||||
// Extract provider prefix from name
|
||||
// e.g., "openai/*" -> "openai"
|
||||
let provider_prefix = name.trim_end_matches("/*").trim_end_matches('*');
|
||||
|
||||
// For wildcard providers, we:
|
||||
// 1. Store the base config in wildcard_providers for runtime matching
|
||||
// 2. Optionally expand to all known models if available
|
||||
|
||||
llm_providers
|
||||
.wildcard_providers
|
||||
.insert(provider_prefix.to_string(), Arc::clone(&llm_provider));
|
||||
|
||||
// Try to expand wildcard using ProviderId models
|
||||
if let Ok(provider_id) = ProviderId::try_from(provider_prefix) {
|
||||
let models = provider_id.models();
|
||||
|
||||
// Get the set of specific models to exclude for this provider
|
||||
let models_to_exclude = specific_models_by_provider
|
||||
.get(provider_prefix)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
if !models.is_empty() {
|
||||
let excluded_count = models_to_exclude.len();
|
||||
let total_models = models.len();
|
||||
|
||||
log::info!(
|
||||
"Expanding wildcard provider '{}' to {} models{}",
|
||||
provider_prefix,
|
||||
total_models - excluded_count,
|
||||
if excluded_count > 0 {
|
||||
format!(" (excluding {} specifically configured)", excluded_count)
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
);
|
||||
|
||||
// Create a provider entry for each model (except those specifically configured)
|
||||
for model_name in models {
|
||||
// Skip this model if it has a specific configuration
|
||||
if models_to_exclude.contains(&model_name) {
|
||||
log::debug!(
|
||||
"Skipping wildcard expansion for '{}/{}' - specific configuration exists",
|
||||
provider_prefix,
|
||||
model_name
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let full_model_id = format!("{}/{}", provider_prefix, model_name);
|
||||
|
||||
// Create a new provider with the specific model
|
||||
let mut expanded_provider = (*llm_provider).clone();
|
||||
expanded_provider.model = Some(model_name.clone());
|
||||
expanded_provider.name = full_model_id.clone();
|
||||
|
||||
let expanded_rc = Arc::new(expanded_provider);
|
||||
|
||||
// Insert with full model ID as key
|
||||
llm_providers
|
||||
.providers
|
||||
.insert(full_model_id.clone(), Arc::clone(&expanded_rc));
|
||||
|
||||
// Also insert with just model name for backward compatibility
|
||||
llm_providers.providers.insert(model_name, expanded_rc);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::warn!(
|
||||
"Wildcard provider '{}' specified but no models found in registry. \
|
||||
Will match dynamically at runtime.",
|
||||
provider_prefix
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Non-wildcard provider - specific configuration
|
||||
// Check for duplicate specific entries (not allowed)
|
||||
if specific_provider_names.contains(&name) {
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
specific_provider_names.insert(name.clone());
|
||||
|
||||
// This specific configuration takes precedence over any wildcard expansion
|
||||
// The wildcard expansion already excluded this model (see first pass above)
|
||||
|
||||
log::debug!("Processing specific provider configuration: {}", name);
|
||||
|
||||
// Insert with the provider name as key
|
||||
llm_providers
|
||||
.providers
|
||||
.insert(name.clone(), Arc::clone(&llm_provider));
|
||||
|
||||
// Also add model_id as key for provider lookup
|
||||
if let Some(model) = llm_provider.model.clone() {
|
||||
llm_providers.providers.insert(model, llm_provider);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(llm_providers)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::configuration::LlmProviderType;
|
||||
|
||||
fn create_test_provider(name: &str, model: Option<String>) -> LlmProvider {
|
||||
LlmProvider {
|
||||
name: name.to_string(),
|
||||
model,
|
||||
access_key: None,
|
||||
endpoint: None,
|
||||
cluster_name: None,
|
||||
provider_interface: LlmProviderType::OpenAI,
|
||||
default: None,
|
||||
base_url_path_prefix: None,
|
||||
port: None,
|
||||
rate_limits: None,
|
||||
usage: None,
|
||||
routing_preferences: None,
|
||||
internal: None,
|
||||
stream: None,
|
||||
passthrough_auth: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_static_provider_lookup() {
|
||||
// Test 1: Statically defined provider - should be findable by model or provider name
|
||||
let providers = vec![create_test_provider("my-openai", Some("gpt-4".to_string()))];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
// Should find by model name
|
||||
let result = llm_providers.get("gpt-4");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().name, "my-openai");
|
||||
|
||||
// Should also find by provider name
|
||||
let result = llm_providers.get("my-openai");
|
||||
assert!(result.is_some());
|
||||
assert_eq!(result.unwrap().name, "my-openai");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_provider_with_known_model() {
|
||||
// Test 2: Wildcard provider that expands to OpenAI models
|
||||
let providers = vec![create_test_provider("openai/*", Some("*".to_string()))];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
// Should find via expanded wildcard entry
|
||||
let result = llm_providers.get("openai/gpt-4");
|
||||
let provider = result.unwrap();
|
||||
assert_eq!(provider.name, "openai/gpt-4");
|
||||
assert_eq!(provider.model.as_ref().unwrap(), "gpt-4");
|
||||
|
||||
// Should also be able to find by just model name (from expansion)
|
||||
let result = llm_providers.get("gpt-4");
|
||||
assert_eq!(result.unwrap().model.as_ref().unwrap(), "gpt-4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_wildcard_provider_with_full_slug() {
|
||||
// Test 3: Custom wildcard provider with full slug offered
|
||||
let providers = vec![create_test_provider(
|
||||
"custom-provider/*",
|
||||
Some("*".to_string()),
|
||||
)];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
// Should match via wildcard fallback and extract model name from slug
|
||||
let result = llm_providers.get("custom-provider/custom-model");
|
||||
let provider = result.unwrap();
|
||||
assert_eq!(provider.model.as_ref().unwrap(), "custom-model");
|
||||
|
||||
// Wildcard should be stored
|
||||
assert!(llm_providers
|
||||
.wildcard_providers
|
||||
.contains_key("custom-provider"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{configuration, llm_providers::LlmProviders};
|
||||
use configuration::LlmProvider;
|
||||
use rand::{seq::IteratorRandom, thread_rng};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProviderHint {
|
||||
Default,
|
||||
Name(String),
|
||||
|
|
@ -22,33 +21,14 @@ impl From<String> for ProviderHint {
|
|||
pub fn get_llm_provider(
|
||||
llm_providers: &LlmProviders,
|
||||
provider_hint: Option<ProviderHint>,
|
||||
) -> Rc<LlmProvider> {
|
||||
let maybe_provider = provider_hint.and_then(|hint| match hint {
|
||||
ProviderHint::Default => llm_providers.default(),
|
||||
// FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST?
|
||||
ProviderHint::Name(name) => llm_providers.get(&name),
|
||||
});
|
||||
|
||||
if let Some(provider) = maybe_provider {
|
||||
return provider;
|
||||
) -> Result<Arc<LlmProvider>, String> {
|
||||
match provider_hint {
|
||||
Some(ProviderHint::Default) => llm_providers
|
||||
.default()
|
||||
.ok_or_else(|| "No default provider configured".to_string()),
|
||||
Some(ProviderHint::Name(name)) => llm_providers
|
||||
.get(&name)
|
||||
.ok_or_else(|| format!("Model '{}' not found in configured providers", name)),
|
||||
None => Err("No model specified in request".to_string()),
|
||||
}
|
||||
|
||||
if llm_providers.default().is_some() {
|
||||
return llm_providers.default().unwrap();
|
||||
}
|
||||
|
||||
let mut rng = thread_rng();
|
||||
llm_providers
|
||||
.iter()
|
||||
.filter(|(_, provider)| {
|
||||
provider
|
||||
.model
|
||||
.as_ref()
|
||||
.map(|m| !m.starts_with("Arch"))
|
||||
.unwrap_or(true)
|
||||
})
|
||||
.choose(&mut rng)
|
||||
.expect("There should always be at least one non-Arch llm provider")
|
||||
.1
|
||||
.clone()
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue