mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
feat: add support for retrying LLM requests on 429 ratelimits (#697)
- Added 'retry_on_ratelimit' configuration to LlmProvider. - Implemented a retry loop in the LLM handler to automatically failover to an alternative model when a 429 status is received. - Added comprehensive unit tests for fallback selection and failover logic. - Ensured default behavior is unchanged when the feature is disabled.
This commit is contained in:
parent
46de89590b
commit
851d8a3054
4 changed files with 239 additions and 43 deletions
|
|
@ -90,7 +90,7 @@ async fn llm_chat_inner(
|
|||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
request_id: String,
|
||||
request_path: String,
|
||||
mut request_headers: hyper::HeaderMap,
|
||||
request_headers: hyper::HeaderMap,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// Set service name for LLM operations
|
||||
set_service_name(operation_component::LLM);
|
||||
|
|
@ -274,9 +274,6 @@ async fn llm_chat_inner(
|
|||
}
|
||||
}
|
||||
|
||||
// Serialize request for upstream BEFORE router consumes it
|
||||
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
||||
|
||||
// Determine routing using the dedicated router_chat module
|
||||
// This gets its own span for latency and error tracking
|
||||
let routing_span = info_span!(
|
||||
|
|
@ -293,7 +290,7 @@ async fn llm_chat_inner(
|
|||
set_service_name(operation_component::ROUTING);
|
||||
router_chat_get_upstream_model(
|
||||
router_service,
|
||||
client_request, // Pass the original request - router_chat will convert it
|
||||
client_request.clone(), // Clone here to preserve for retries
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
|
|
@ -334,49 +331,93 @@ async fn llm_chat_inner(
|
|||
span.update_name(span_name.clone());
|
||||
});
|
||||
|
||||
debug!(
|
||||
url = %full_qualified_llm_provider_url,
|
||||
provider_hint = %resolved_model,
|
||||
upstream_model = %model_name_only,
|
||||
"Routing to upstream"
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(&resolved_model).unwrap(),
|
||||
);
|
||||
|
||||
request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
||||
);
|
||||
// remove content-length header if it exists
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
// Inject current LLM span's trace context so upstream spans are children of plano(llm)
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
|
||||
});
|
||||
|
||||
// Capture start time right before sending request to upstream
|
||||
let request_start_time = std::time::Instant::now();
|
||||
let _request_start_system_time = std::time::SystemTime::now();
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(&full_qualified_llm_provider_url)
|
||||
.headers(request_headers)
|
||||
.body(client_request_bytes_for_upstream)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
let mut current_resolved_model = resolved_model.clone();
|
||||
let mut current_client_request = client_request;
|
||||
let mut attempts = 0;
|
||||
let max_attempts = 2; // Original + 1 retry
|
||||
|
||||
let llm_response = loop {
|
||||
attempts += 1;
|
||||
|
||||
// Handle provider/model slug format (e.g., "openai/gpt-4")
|
||||
// Extract just the model name for upstream (providers don't understand the slug)
|
||||
let current_model_name_only = if let Some((_, model)) = current_resolved_model.split_once('/') {
|
||||
model.to_string()
|
||||
} else {
|
||||
current_resolved_model.clone()
|
||||
};
|
||||
|
||||
debug!(
|
||||
url = %full_qualified_llm_provider_url,
|
||||
provider_hint = %current_resolved_model,
|
||||
upstream_model = %current_model_name_only,
|
||||
attempt = attempts,
|
||||
"Routing to upstream"
|
||||
);
|
||||
|
||||
// Set the model to just the model name (without provider prefix)
|
||||
current_client_request.set_model(current_model_name_only.clone());
|
||||
|
||||
// Serialize request for upstream
|
||||
let current_request_bytes = ProviderRequestType::to_bytes(¤t_client_request).unwrap();
|
||||
|
||||
let mut current_request_headers = request_headers.clone();
|
||||
current_request_headers.insert(
|
||||
ARCH_PROVIDER_HINT_HEADER,
|
||||
header::HeaderValue::from_str(¤t_resolved_model).unwrap(),
|
||||
);
|
||||
|
||||
current_request_headers.insert(
|
||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
||||
);
|
||||
// remove content-length header if it exists
|
||||
current_request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
// Inject current LLM span's trace context so upstream spans are children of plano(llm)
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut current_request_headers));
|
||||
});
|
||||
|
||||
let res = match reqwest::Client::new()
|
||||
.post(&full_qualified_llm_provider_url)
|
||||
.headers(current_request_headers)
|
||||
.body(current_request_bytes)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Failed to send request: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
};
|
||||
|
||||
if res.status() == StatusCode::TOO_MANY_REQUESTS && attempts < max_attempts {
|
||||
let providers = llm_providers.read().await;
|
||||
if let Some(provider) = providers.get(¤t_resolved_model) {
|
||||
if provider.retry_on_ratelimit == Some(true) {
|
||||
if let Some(alt_provider) = providers.get_alternative(¤t_resolved_model) {
|
||||
info!(
|
||||
request_id = %request_id,
|
||||
current_model = %current_resolved_model,
|
||||
alt_model = %alt_provider.name,
|
||||
"429 received, retrying with alternative model"
|
||||
);
|
||||
current_resolved_model = alt_provider.name.clone();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break res;
|
||||
};
|
||||
|
||||
// copy over the headers and status code from the original response
|
||||
|
|
@ -391,6 +432,7 @@ async fn llm_chat_inner(
|
|||
// Build LLM span with actual status code using constants
|
||||
let byte_stream = llm_response.bytes_stream();
|
||||
|
||||
|
||||
// Create base processor for metrics and tracing
|
||||
let base_processor = ObservableStreamProcessor::new(
|
||||
operation_component::LLM,
|
||||
|
|
@ -441,6 +483,82 @@ async fn llm_chat_inner(
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use common::configuration::{LlmProvider, LlmProviderType};
|
||||
use common::llm_providers::LlmProviders;
|
||||
|
||||
// We can't easily create Request<Incoming> in tests without a full server setup.
|
||||
// So we'll skip the functional test of llm_chat and rely on unit tests of get_alternative.
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llm_providers_get_alternative() {
|
||||
let primary = LlmProvider {
|
||||
name: "primary".to_string(),
|
||||
provider_interface: LlmProviderType::OpenAI,
|
||||
model: Some("gpt-4".to_string()),
|
||||
default: Some(false),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let secondary = LlmProvider {
|
||||
name: "secondary".to_string(),
|
||||
provider_interface: LlmProviderType::OpenAI,
|
||||
model: Some("gpt-4-alt".to_string()),
|
||||
default: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let providers_vec = vec![primary.clone(), secondary.clone()];
|
||||
let llm_providers = LlmProviders::try_from(providers_vec).unwrap();
|
||||
|
||||
let alt = llm_providers.get_alternative("primary");
|
||||
assert!(alt.is_some());
|
||||
assert_eq!(alt.unwrap().name, "secondary");
|
||||
|
||||
let alt_none = llm_providers.get_alternative("secondary");
|
||||
assert!(alt_none.is_some());
|
||||
assert_eq!(alt_none.unwrap().name, "primary");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llm_providers_get_alternative_internal_skipped() {
|
||||
let primary = LlmProvider {
|
||||
name: "primary".to_string(),
|
||||
provider_interface: LlmProviderType::OpenAI,
|
||||
model: Some("gpt-4".to_string()),
|
||||
default: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let internal = LlmProvider {
|
||||
name: "internal".to_string(),
|
||||
provider_interface: LlmProviderType::Arch,
|
||||
model: Some("router".to_string()),
|
||||
internal: Some(true),
|
||||
default: Some(false),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let secondary = LlmProvider {
|
||||
name: "secondary".to_string(),
|
||||
provider_interface: LlmProviderType::OpenAI,
|
||||
model: Some("gpt-4-alt".to_string()),
|
||||
default: Some(false),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let providers_vec = vec![primary, internal, secondary];
|
||||
let llm_providers = LlmProviders::try_from(providers_vec).unwrap();
|
||||
|
||||
let alt = llm_providers.get_alternative("primary");
|
||||
assert!(alt.is_some());
|
||||
assert_eq!(alt.unwrap().name, "secondary");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Resolves model aliases by looking up the requested model in the model_aliases map.
|
||||
/// Returns the target model if an alias is found, otherwise returns the original model.
|
||||
fn resolve_model_alias(
|
||||
|
|
|
|||
0
crates/build.sh
Normal file → Executable file
0
crates/build.sh
Normal file → Executable file
|
|
@ -328,6 +328,7 @@ pub struct LlmProvider {
|
|||
pub base_url_path_prefix: Option<String>,
|
||||
pub internal: Option<bool>,
|
||||
pub passthrough_auth: Option<bool>,
|
||||
pub retry_on_ratelimit: Option<bool>,
|
||||
}
|
||||
|
||||
pub trait IntoModels {
|
||||
|
|
@ -372,6 +373,7 @@ impl Default for LlmProvider {
|
|||
base_url_path_prefix: None,
|
||||
internal: None,
|
||||
passthrough_auth: None,
|
||||
retry_on_ratelimit: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,6 +80,29 @@ impl LlmProviders {
|
|||
|
||||
None
|
||||
}
|
||||
|
||||
/// Get an alternative provider that is not the one specified by current_name.
|
||||
/// Prefers the default provider if it's different, otherwise picks the first non-internal provider.
|
||||
pub fn get_alternative(&self, current_name: &str) -> Option<Arc<LlmProvider>> {
|
||||
// Try to find a default provider that is not the current one
|
||||
if let Some(default_provider) = &self.default {
|
||||
if default_provider.name != current_name {
|
||||
return Some(Arc::clone(default_provider));
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise just pick the first canonical non-internal provider that is not the current one
|
||||
self.providers.iter().find_map(|(key, provider)| {
|
||||
if provider.internal != Some(true)
|
||||
&& provider.name != current_name
|
||||
&& key == &provider.name
|
||||
{
|
||||
Some(Arc::clone(provider))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
|
|
@ -278,6 +301,7 @@ mod tests {
|
|||
internal: None,
|
||||
stream: None,
|
||||
passthrough_auth: None,
|
||||
retry_on_ratelimit: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -334,4 +358,56 @@ mod tests {
|
|||
.wildcard_providers
|
||||
.contains_key("custom-provider"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_alternative_prefers_default() {
|
||||
let primary = create_test_provider("primary", Some("gpt-4".to_string()));
|
||||
let mut secondary = create_test_provider("secondary", Some("gpt-4-alt".to_string()));
|
||||
secondary.default = Some(true);
|
||||
let tertiary = create_test_provider("tertiary", Some("gpt-4-other".to_string()));
|
||||
|
||||
let providers = vec![primary, secondary, tertiary];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
// If we are at primary, should return secondary (default)
|
||||
let alt = llm_providers.get_alternative("primary");
|
||||
assert_eq!(alt.unwrap().name, "secondary");
|
||||
|
||||
// If we are at tertiary, should return secondary (default)
|
||||
let alt = llm_providers.get_alternative("tertiary");
|
||||
assert_eq!(alt.unwrap().name, "secondary");
|
||||
|
||||
// If we are at secondary (the default), should return something else (primary or tertiary)
|
||||
let alt = llm_providers.get_alternative("secondary");
|
||||
let alt_name = alt.unwrap().name.clone();
|
||||
assert!(alt_name == "primary" || alt_name == "tertiary");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_alternative_skips_internal() {
|
||||
let primary = create_test_provider("primary", Some("gpt-4".to_string()));
|
||||
let mut internal = create_test_provider("internal", Some("router".to_string()));
|
||||
internal.internal = Some(true);
|
||||
let secondary = create_test_provider("secondary", Some("gpt-4-alt".to_string()));
|
||||
|
||||
let providers = vec![primary, internal, secondary];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
// Should return secondary, NOT internal
|
||||
let alt = llm_providers.get_alternative("primary");
|
||||
assert_eq!(alt.unwrap().name, "secondary");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_alternative_returns_none_if_no_other_available() {
|
||||
let primary = create_test_provider("primary", Some("gpt-4".to_string()));
|
||||
let mut internal = create_test_provider("internal", Some("router".to_string()));
|
||||
internal.internal = Some(true);
|
||||
|
||||
let providers = vec![primary, internal];
|
||||
let llm_providers = LlmProviders::try_from(providers).unwrap();
|
||||
|
||||
let alt = llm_providers.get_alternative("primary");
|
||||
assert!(alt.is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue