From f3e35914e39fac1082f72328e8283ea27cd0316c Mon Sep 17 00:00:00 2001 From: Troy Mitchell Date: Tue, 28 Apr 2026 16:20:50 +0800 Subject: [PATCH] feat: integrate retry orchestrator into LLM handler Wire up the retry module into the brightstaff LLM handler: - Add send_upstream_with_retry() that uses RetryOrchestrator to coordinate retry attempts with backoff and failover - Build forward_fn closure for per-attempt HTTP calls - Support failover to alternative providers on retryable errors - Fall back to single-attempt send_upstream() when no retry policy is configured Signed-off-by: Troy Mitchell --- crates/brightstaff/src/handlers/llm/mod.rs | 403 +++++++++++++++++++-- 1 file changed, 370 insertions(+), 33 deletions(-) diff --git a/crates/brightstaff/src/handlers/llm/mod.rs b/crates/brightstaff/src/handlers/llm/mod.rs index 3336209f..aad0ce07 100644 --- a/crates/brightstaff/src/handlers/llm/mod.rs +++ b/crates/brightstaff/src/handlers/llm/mod.rs @@ -1,19 +1,24 @@ use bytes::Bytes; -use common::configuration::{FilterPipeline, ModelAlias}; -use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, MODEL_AFFINITY_HEADER}; +use common::configuration::{FilterPipeline, LlmProvider, ModelAlias}; +use common::consts::{ + ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, MODEL_AFFINITY_HEADER, +}; use common::llm_providers::LlmProviders; +use common::retry::error_response::build_error_response; +use common::retry::orchestrator::RetryOrchestrator; +use common::retry::{rebuild_request_for_provider, RequestContext, RequestSignature}; use hermesllm::apis::openai::Message; use hermesllm::apis::openai_responses::InputParam; use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use hermesllm::{ProviderRequest, ProviderRequestType}; use http_body_util::combinators::BoxBody; -use http_body_util::BodyExt; +use http_body_util::{BodyExt, Full}; use hyper::header::{self}; use hyper::{Request, Response, StatusCode}; use opentelemetry::global; use opentelemetry::trace::get_active_span; use opentelemetry_http::HeaderInjector; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tokio::sync::RwLock; use tracing::{debug, info, info_span, warn, Instrument}; @@ -282,17 +287,10 @@ async fn llm_chat_inner( Err(response) => return Ok(response), }; - // Serialize request for upstream BEFORE router consumes it - let client_request_bytes_for_upstream: Bytes = - match ProviderRequestType::to_bytes(&client_request) { - Ok(bytes) => bytes.into(), - Err(err) => { - warn!(error = %err, "failed to serialize request for upstream"); - let mut r = Response::new(full(format!("Failed to serialize request: {}", err))); - *r.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(r); - } - }; + // Use the original bytes (from extract_routing_policy) for upstream to preserve + // JSON key order, whitespace, and unknown fields — critical for prompt cache hits. + // Only fall back to re-serialization if input filters modified the request. + let client_request_bytes_for_upstream: Bytes = chat_request_bytes.clone(); // --- Phase 3: Route the request (or use pinned model from session cache) --- let resolved_model = if let Some(cached_model) = pinned_model { @@ -367,24 +365,57 @@ async fn llm_chat_inner( tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str()); // --- Phase 4: Forward to upstream and stream back --- - send_upstream( - &state.http_client, - &full_qualified_llm_provider_url, - &mut request_headers, - client_request_bytes_for_upstream, - &model_from_request, - &alias_resolved_model, - &resolved_model, - &model_name_only, - &request_path, - is_streaming_request, - messages_for_signals, - state_ctx, - state.state_storage.clone(), - request_id, - &state.filter_pipeline, - ) - .await + // Check if the resolved provider has a retry_policy configured. + // If so, use the RetryOrchestrator to wrap the upstream call with retry logic. + let resolved_provider: Option> = + state.llm_providers.read().await.get(&resolved_model); + + let has_retry_policy = resolved_provider + .as_ref() + .and_then(|p| p.retry_policy.as_ref()) + .is_some(); + + if has_retry_policy { + send_upstream_with_retry( + &state.http_client, + &full_qualified_llm_provider_url, + &mut request_headers, + client_request_bytes_for_upstream, + &model_from_request, + &alias_resolved_model, + &resolved_model, + &model_name_only, + &request_path, + is_streaming_request, + messages_for_signals, + state_ctx, + state.state_storage.clone(), + request_id, + &state.filter_pipeline, + &resolved_provider.unwrap(), + &state.llm_providers, + ) + .await + } else { + send_upstream( + &state.http_client, + &full_qualified_llm_provider_url, + &mut request_headers, + client_request_bytes_for_upstream, + &model_from_request, + &alias_resolved_model, + &resolved_model, + &model_name_only, + &request_path, + is_streaming_request, + messages_for_signals, + state_ctx, + state.state_storage.clone(), + request_id, + &state.filter_pipeline, + ) + .await + } } // --------------------------------------------------------------------------- @@ -845,6 +876,312 @@ async fn send_upstream( } } +/// Retry-aware version of send_upstream. Uses the RetryOrchestrator to wrap +/// the upstream HTTP call with automatic retry and provider failover logic. +#[allow(clippy::too_many_arguments)] +async fn send_upstream_with_retry( + http_client: &reqwest::Client, + upstream_url: &str, + request_headers: &mut hyper::HeaderMap, + body: bytes::Bytes, + model_from_request: &str, + alias_resolved_model: &str, + resolved_model: &str, + _model_name_only: &str, + request_path: &str, + is_streaming_request: bool, + messages_for_signals: Option>, + state_ctx: ConversationStateContext, + state_storage: Option>, + request_id: String, + filter_pipeline: &Arc, + primary_provider: &Arc, + llm_providers: &Arc>, +) -> Result>, hyper::Error> { + let retry_policy = primary_provider.retry_policy.as_ref().unwrap(); + + // Collect all providers for the retry orchestrator + let all_providers: Vec = llm_providers + .read() + .await + .iter() + .map(|(_, p)| (*p).as_ref().clone()) + .collect(); + + // Build request signature + let request_signature = RequestSignature::new( + &body, + request_headers, + is_streaming_request, + alias_resolved_model.to_string(), + ); + + // Build request context + let mut request_context = RequestContext { + request_id: request_id.clone(), + attempted_providers: HashSet::new(), + retry_start_time: None, + attempt_number: 0, + request_retry_after_state: HashMap::new(), + request_latency_block_state: HashMap::new(), + request_signature: request_signature.clone(), + errors: vec![], + }; + + let orchestrator = RetryOrchestrator::new_default(); + + debug!( + model = %alias_resolved_model, + fallback_models = ?retry_policy.fallback_models, + default_strategy = ?retry_policy.default_strategy, + default_max_attempts = retry_policy.default_max_attempts, + "Retry orchestrator initialized for request" + ); + + // Capture references for the forward_fn closure + let base_url = upstream_url.to_string(); + let original_headers = request_headers.clone(); + let primary_model = alias_resolved_model.to_string(); + let http_client = http_client.clone(); + + // The forward_fn handles the actual HTTP call to upstream for each attempt. + let forward_fn = |body: &Bytes, target_provider: &LlmProvider| { + let body = body.clone(); + let target_provider = target_provider.clone(); + let base_url = base_url.clone(); + let original_headers = original_headers.clone(); + let primary_model = primary_model.clone(); + let http_client = http_client.clone(); + + async move { + let target_model = target_provider + .model + .as_deref() + .unwrap_or(&target_provider.name); + + let (request_body, mut headers) = if target_model == primary_model { + (body.clone(), original_headers.clone()) + } else { + match rebuild_request_for_provider(&body, &target_provider, &original_headers) { + Ok((new_body, new_headers)) => (new_body, new_headers), + Err(e) => { + warn!(error = %e, "Failed to rebuild request for provider"); + return Err(common::retry::error_detector::TimeoutError { duration_ms: 0 }); + } + } + }; + + // Resolve the upstream URL for the target provider. + // Always route through the Envoy proxy (base_url) and let the + // provider-hint header select the upstream cluster. Building a + // direct URL from the provider endpoint is wrong because the + // endpoint field stores a bare hostname (no scheme), and + // bypassing Envoy loses TLS, load-balancing, and observability. + let upstream_url = base_url.clone(); + + // Set provider hint header so the WASM gateway selects the + // correct provider (and its credentials) for this retry attempt. + // Do NOT set x-arch-llm-provider here — the WASM gateway sets it + // via add_http_request_header after provider selection. If we set + // it too, Envoy sees a duplicate multi-value header that fails + // exact-match routing and falls through to the 400 catch-all. + headers.remove(header::HeaderName::from_static(ARCH_ROUTING_HEADER)); + headers.insert( + ARCH_PROVIDER_HINT_HEADER, + header::HeaderValue::from_str(target_model) + .unwrap_or_else(|_| header::HeaderValue::from_static("unknown")), + ); + headers.remove(header::CONTENT_LENGTH); + + // Send the request + let result = http_client + .post(&upstream_url) + .headers(headers) + .body(request_body.to_vec()) + .send() + .await; + + match result { + Ok(res) => { + let status = res.status().as_u16(); + let resp_headers = res.headers().clone(); + let body_bytes = res.bytes().await.unwrap_or_default(); + + // Debug: log upstream response for retry attempts + if status >= 400 { + let body_preview = String::from_utf8_lossy(&body_bytes); + warn!( + "Retry upstream response: status={}, model={}, body={}", + status, + target_model, + &body_preview[..body_preview.len().min(500)] + ); + } + + let full_body = Full::new(body_bytes) + .map_err(|never| match never {}) + .boxed(); + + let mut builder = Response::builder().status(status); + if let Some(hdrs) = builder.headers_mut() { + for (name, value) in resp_headers.iter() { + if let Ok(hyper_name) = + hyper::header::HeaderName::from_bytes(name.as_str().as_bytes()) + { + if let Ok(hyper_value) = + hyper::header::HeaderValue::from_bytes(value.as_bytes()) + { + hdrs.insert(hyper_name, hyper_value); + } + } + } + } + + Ok(builder.body(full_body).unwrap()) + } + Err(err) => { + warn!(error = %err, "Upstream request failed during retry"); + Err(common::retry::error_detector::TimeoutError { duration_ms: 0 }) + } + } + } + }; + + // Execute the retry orchestrator + let retry_result = orchestrator + .execute( + &body, + &request_signature, + primary_provider.as_ref(), + retry_policy, + &all_providers, + &mut request_context, + forward_fn, + ) + .await; + + match retry_result { + Ok(http_response) => { + // Success (possibly after retries) — stream the response back to client + let upstream_status = http_response.status(); + let response_headers = http_response.headers().clone(); + + let span_name = if model_from_request == resolved_model { + format!("POST {} {}", request_path, resolved_model) + } else { + format!( + "POST {} {} -> {}", + request_path, model_from_request, resolved_model + ) + }; + + let mut response = Response::builder().status(upstream_status); + let headers = response.headers_mut().unwrap(); + for (header_name, header_value) in response_headers.iter() { + headers.insert(header_name, header_value.clone()); + } + + // Collect the body from the HttpResponse + let body_bytes = http_response + .into_body() + .collect() + .await + .map(|collected| collected.to_bytes()) + .unwrap_or_default(); + + let byte_stream = futures::stream::iter(vec![Ok::(body_bytes)]); + + let (metric_provider_raw, metric_model_raw) = + bs_metrics::split_provider_model(resolved_model); + + let base_processor = ObservableStreamProcessor::new( + operation_component::LLM, + span_name, + std::time::Instant::now(), + messages_for_signals, + ) + .with_llm_metrics(LlmMetricsCtx { + provider: metric_provider_raw.to_string(), + model: metric_model_raw.to_string(), + upstream_status: upstream_status.as_u16(), + }); + + let output_filter_request_headers = if filter_pipeline.has_output_filters() { + Some(request_headers.clone()) + } else { + None + }; + + let processor: Box = if let (true, false, Some(state_store)) = ( + state_ctx.should_manage_state, + state_ctx.original_input_items.is_empty(), + &state_storage, + ) { + let content_encoding = response_headers + .get("content-encoding") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + Box::new(ResponsesStateProcessor::new( + base_processor, + state_store.clone(), + state_ctx.original_input_items, + alias_resolved_model.to_string(), + resolved_model.to_string(), + is_streaming_request, + false, + content_encoding, + request_id.clone(), + )) + } else { + Box::new(base_processor) + }; + + let streaming_response = if let (Some(output_chain), Some(filter_headers)) = ( + filter_pipeline.output.as_ref().filter(|c| !c.is_empty()), + output_filter_request_headers, + ) { + create_streaming_response_with_output_filter( + byte_stream, + processor, + output_chain.clone(), + filter_headers, + request_path.to_string(), + ) + } else { + create_streaming_response(byte_stream, processor) + }; + + match response.body(streaming_response.body) { + Ok(response) => Ok(response), + Err(err) => { + let err_msg = format!("Failed to create response: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + Ok(internal_error) + } + } + } + Err(retry_exhausted_error) => { + // All retries exhausted — build structured error response + info!( + request_id = %request_id, + total_attempts = retry_exhausted_error.attempts.len(), + budget_exhausted = retry_exhausted_error.retry_budget_exhausted, + "All retries exhausted" + ); + + let error_resp = build_error_response(&retry_exhausted_error, &request_id); + + // Convert Full body to BoxBody + let (parts, full_body) = error_resp.into_parts(); + let boxed_body = full_body.map_err(|never| match never {}).boxed(); + + Ok(Response::from_parts(parts, boxed_body)) + } + } +} + // --------------------------------------------------------------------------- // Helpers // ---------------------------------------------------------------------------