mirror of
https://github.com/katanemo/plano.git
synced 2026-05-06 22:32:42 +02:00
Adding support for wildcard models in the model_providers config (#696)
* cleaning up plano cli commands * adding support for wildcard model providers * fixing compile errors * fixing bugs related to default model provider, provider hint and duplicates in the model provider list * fixed cargo fmt issues * updating tests to always include the model id * using default for the prompt_gateway path * fixed the model name, as gpt-5-mini-2025-08-07 wasn't in the config * making sure that all aliases and models match the config * fixed the config generator to allow for base_url providers LLMs to include wildcard models * re-ran the models list utility and added a shell script to run it * updating docs to mention wildcard model providers * updated provider_models.json to yaml, added that file to our docs for reference * updating the build docs to use the new root-based build --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
parent
8428b06e22
commit
2941392ed1
42 changed files with 1748 additions and 202 deletions
|
|
@ -1,8 +1,9 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{LlmProvider, ModelAlias};
|
||||
use common::configuration::ModelAlias;
|
||||
use common::consts::{
|
||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use common::traces::TraceCollector;
|
||||
use hermesllm::apis::openai_responses::InputParam;
|
||||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
|
|
@ -38,7 +39,7 @@ pub async fn llm_chat(
|
|||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<Vec<LlmProvider>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
trace_collector: Arc<TraceCollector>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
|
@ -123,6 +124,27 @@ pub async fn llm_chat(
|
|||
let is_streaming_request = client_request.is_streaming();
|
||||
let resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
|
||||
|
||||
// Validate that the requested model exists in configuration
|
||||
// This matches the validation in llm_gateway routing.rs
|
||||
if llm_providers.read().await.get(&resolved_model).is_none() {
|
||||
let err_msg = format!(
|
||||
"Model '{}' not found in configured providers",
|
||||
resolved_model
|
||||
);
|
||||
warn!("[PLANO_REQ_ID:{}] | FAILURE | {}", request_id, err_msg);
|
||||
let mut bad_request = Response::new(full(err_msg));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
|
||||
// Handle provider/model slug format (e.g., "openai/gpt-4")
|
||||
// Extract just the model name for upstream (providers don't understand the slug)
|
||||
let model_name_only = if let Some((_, model)) = resolved_model.split_once('/') {
|
||||
model.to_string()
|
||||
} else {
|
||||
resolved_model.clone()
|
||||
};
|
||||
|
||||
// Extract tool names and user message preview for span attributes
|
||||
let tool_names = client_request.get_tool_names();
|
||||
let user_message_preview = client_request
|
||||
|
|
@ -132,7 +154,9 @@ pub async fn llm_chat(
|
|||
// Extract messages for signal analysis (clone before moving client_request)
|
||||
let messages_for_signals = client_request.get_messages();
|
||||
|
||||
client_request.set_model(resolved_model.clone());
|
||||
// Set the model to just the model name (without provider prefix)
|
||||
// This ensures upstream receives "gpt-4" not "openai/gpt-4"
|
||||
client_request.set_model(model_name_only.clone());
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata",
|
||||
|
|
@ -240,11 +264,20 @@ pub async fn llm_chat(
|
|||
}
|
||||
};
|
||||
|
||||
let model_name = routing_result.model_name;
|
||||
// Determine final model to use
|
||||
// Router returns "none" as a sentinel value when it doesn't select a specific model
|
||||
let router_selected_model = routing_result.model_name;
|
||||
let model_name = if router_selected_model != "none" {
|
||||
// Router selected a specific model via routing preferences
|
||||
router_selected_model
|
||||
} else {
|
||||
// Router returned "none" sentinel, use validated resolved_model from request
|
||||
resolved_model.clone()
|
||||
};
|
||||
|
||||
debug!(
|
||||
"[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Resolved Model: {}",
|
||||
request_id, full_qualified_llm_provider_url, model_name
|
||||
"[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Provider Hint: {}, Model for upstream: {}",
|
||||
request_id, full_qualified_llm_provider_url, model_name, model_name_only
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
|
|
@ -389,7 +422,7 @@ async fn build_llm_span(
|
|||
tool_names: Option<Vec<String>>,
|
||||
user_message_preview: Option<String>,
|
||||
temperature: Option<f32>,
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
) -> common::traces::Span {
|
||||
use crate::tracing::{http, llm, OperationNameBuilder};
|
||||
use common::traces::{parse_traceparent, SpanBuilder, SpanKind};
|
||||
|
|
@ -462,7 +495,7 @@ async fn build_llm_span(
|
|||
/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix,
|
||||
/// then uses target_endpoint_for_provider to calculate the correct upstream path.
|
||||
async fn get_upstream_path(
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
model_name: &str,
|
||||
request_path: &str,
|
||||
resolved_model: &str,
|
||||
|
|
@ -485,25 +518,21 @@ async fn get_upstream_path(
|
|||
|
||||
/// Helper function to get provider info (ProviderId and base_url_path_prefix)
|
||||
async fn get_provider_info(
|
||||
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||
llm_providers: &Arc<RwLock<LlmProviders>>,
|
||||
model_name: &str,
|
||||
) -> (hermesllm::ProviderId, Option<String>) {
|
||||
let providers_lock = llm_providers.read().await;
|
||||
|
||||
// First, try to find by model name or provider name
|
||||
let provider = providers_lock.iter().find(|p| {
|
||||
p.model.as_ref().map(|m| m == model_name).unwrap_or(false) || p.name == model_name
|
||||
});
|
||||
|
||||
if let Some(provider) = provider {
|
||||
// Try to find by model name or provider name using LlmProviders::get
|
||||
// This handles both "gpt-4" and "openai/gpt-4" formats
|
||||
if let Some(provider) = providers_lock.get(model_name) {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
return (provider_id, prefix);
|
||||
}
|
||||
|
||||
let default_provider = providers_lock.iter().find(|p| p.default.unwrap_or(false));
|
||||
|
||||
if let Some(provider) = default_provider {
|
||||
// Fall back to default provider
|
||||
if let Some(provider) = providers_lock.default() {
|
||||
let provider_id = provider.provider_interface.to_provider_id();
|
||||
let prefix = provider.base_url_path_prefix.clone();
|
||||
(provider_id, prefix)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,17 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{IntoModels, LlmProvider};
|
||||
use hermesllm::apis::openai::Models;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::{Response, StatusCode};
|
||||
use serde_json;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn list_models(
|
||||
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
|
||||
llm_providers: Arc<tokio::sync::RwLock<LlmProviders>>,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let prov = llm_providers.read().await;
|
||||
let providers = prov.clone();
|
||||
let openai_models: Models = providers.into_models();
|
||||
let models = prov.to_models();
|
||||
|
||||
match serde_json::to_string(&openai_models) {
|
||||
match serde_json::to_string(&models) {
|
||||
Ok(json) => {
|
||||
let body = Full::new(Bytes::from(json))
|
||||
.map_err(|never| match never {})
|
||||
|
|
|
|||
|
|
@ -151,16 +151,15 @@ pub async fn router_chat_get_upstream_model(
|
|||
Ok(RoutingResult { model_name })
|
||||
}
|
||||
None => {
|
||||
// No route determined, use default model from request
|
||||
// No route determined, return sentinel value "none"
|
||||
// This signals to llm.rs to use the original validated request model
|
||||
info!(
|
||||
"[PLANO_REQ_ID: {}] | ROUTER_REQ | No route determined, using default model from request: {}",
|
||||
request_id,
|
||||
chat_request.model
|
||||
"[PLANO_REQ_ID: {}] | ROUTER_REQ | No route determined, returning sentinel 'none'",
|
||||
request_id
|
||||
);
|
||||
|
||||
let default_model = chat_request.model.clone();
|
||||
let mut attrs = HashMap::new();
|
||||
attrs.insert("route.selected_model".to_string(), default_model.clone());
|
||||
attrs.insert("route.selected_model".to_string(), "none".to_string());
|
||||
record_routing_span(
|
||||
trace_collector,
|
||||
traceparent,
|
||||
|
|
@ -171,7 +170,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
.await;
|
||||
|
||||
Ok(RoutingResult {
|
||||
model_name: default_model,
|
||||
model_name: "none".to_string(),
|
||||
})
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ use common::configuration::{Agent, Configuration};
|
|||
use common::consts::{
|
||||
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
||||
};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use common::traces::TraceCollector;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
use hyper::body::Incoming;
|
||||
|
|
@ -76,7 +77,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.cloned()
|
||||
.collect();
|
||||
|
||||
let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone()));
|
||||
// Create expanded provider list for /v1/models endpoint
|
||||
let llm_providers = LlmProviders::try_from(arch_config.model_providers.clone())
|
||||
.expect("Failed to create LlmProviders");
|
||||
let llm_providers = Arc::new(RwLock::new(llm_providers));
|
||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
|
||||
let llm_provider_url =
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue