feat: integrate retry orchestrator into LLM handler

Wire up the retry module into the brightstaff LLM handler:
- Add send_upstream_with_retry() that uses RetryOrchestrator
  to coordinate retry attempts with backoff and failover
- Build forward_fn closure for per-attempt HTTP calls
- Support failover to alternative providers on retryable errors
- Fall back to single-attempt send_upstream() when no retry
  policy is configured

Signed-off-by: Troy Mitchell <i@troy-y.org>
This commit is contained in:
Troy Mitchell 2026-04-28 16:20:50 +08:00
parent c34ff5b5fd
commit f3e35914e3

View file

@ -1,19 +1,24 @@
use bytes::Bytes;
use common::configuration::{FilterPipeline, ModelAlias};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, MODEL_AFFINITY_HEADER};
use common::configuration::{FilterPipeline, LlmProvider, ModelAlias};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, MODEL_AFFINITY_HEADER,
};
use common::llm_providers::LlmProviders;
use common::retry::error_response::build_error_response;
use common::retry::orchestrator::RetryOrchestrator;
use common::retry::{rebuild_request_for_provider, RequestContext, RequestSignature};
use hermesllm::apis::openai::Message;
use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::{BodyExt, Full};
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use opentelemetry::global;
use opentelemetry::trace::get_active_span;
use opentelemetry_http::HeaderInjector;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
@ -282,17 +287,10 @@ async fn llm_chat_inner(
Err(response) => return Ok(response),
};
// Serialize request for upstream BEFORE router consumes it
let client_request_bytes_for_upstream: Bytes =
match ProviderRequestType::to_bytes(&client_request) {
Ok(bytes) => bytes.into(),
Err(err) => {
warn!(error = %err, "failed to serialize request for upstream");
let mut r = Response::new(full(format!("Failed to serialize request: {}", err)));
*r.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(r);
}
};
// Use the original bytes (from extract_routing_policy) for upstream to preserve
// JSON key order, whitespace, and unknown fields — critical for prompt cache hits.
// Only fall back to re-serialization if input filters modified the request.
let client_request_bytes_for_upstream: Bytes = chat_request_bytes.clone();
// --- Phase 3: Route the request (or use pinned model from session cache) ---
let resolved_model = if let Some(cached_model) = pinned_model {
@ -367,24 +365,57 @@ async fn llm_chat_inner(
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
// --- Phase 4: Forward to upstream and stream back ---
send_upstream(
&state.http_client,
&full_qualified_llm_provider_url,
&mut request_headers,
client_request_bytes_for_upstream,
&model_from_request,
&alias_resolved_model,
&resolved_model,
&model_name_only,
&request_path,
is_streaming_request,
messages_for_signals,
state_ctx,
state.state_storage.clone(),
request_id,
&state.filter_pipeline,
)
.await
// Check if the resolved provider has a retry_policy configured.
// If so, use the RetryOrchestrator to wrap the upstream call with retry logic.
let resolved_provider: Option<Arc<LlmProvider>> =
state.llm_providers.read().await.get(&resolved_model);
let has_retry_policy = resolved_provider
.as_ref()
.and_then(|p| p.retry_policy.as_ref())
.is_some();
if has_retry_policy {
send_upstream_with_retry(
&state.http_client,
&full_qualified_llm_provider_url,
&mut request_headers,
client_request_bytes_for_upstream,
&model_from_request,
&alias_resolved_model,
&resolved_model,
&model_name_only,
&request_path,
is_streaming_request,
messages_for_signals,
state_ctx,
state.state_storage.clone(),
request_id,
&state.filter_pipeline,
&resolved_provider.unwrap(),
&state.llm_providers,
)
.await
} else {
send_upstream(
&state.http_client,
&full_qualified_llm_provider_url,
&mut request_headers,
client_request_bytes_for_upstream,
&model_from_request,
&alias_resolved_model,
&resolved_model,
&model_name_only,
&request_path,
is_streaming_request,
messages_for_signals,
state_ctx,
state.state_storage.clone(),
request_id,
&state.filter_pipeline,
)
.await
}
}
// ---------------------------------------------------------------------------
@ -845,6 +876,312 @@ async fn send_upstream(
}
}
/// Retry-aware version of send_upstream. Uses the RetryOrchestrator to wrap
/// the upstream HTTP call with automatic retry and provider failover logic.
#[allow(clippy::too_many_arguments)]
async fn send_upstream_with_retry(
http_client: &reqwest::Client,
upstream_url: &str,
request_headers: &mut hyper::HeaderMap,
body: bytes::Bytes,
model_from_request: &str,
alias_resolved_model: &str,
resolved_model: &str,
_model_name_only: &str,
request_path: &str,
is_streaming_request: bool,
messages_for_signals: Option<Vec<Message>>,
state_ctx: ConversationStateContext,
state_storage: Option<Arc<dyn StateStorage>>,
request_id: String,
filter_pipeline: &Arc<FilterPipeline>,
primary_provider: &Arc<LlmProvider>,
llm_providers: &Arc<RwLock<LlmProviders>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let retry_policy = primary_provider.retry_policy.as_ref().unwrap();
// Collect all providers for the retry orchestrator
let all_providers: Vec<LlmProvider> = llm_providers
.read()
.await
.iter()
.map(|(_, p)| (*p).as_ref().clone())
.collect();
// Build request signature
let request_signature = RequestSignature::new(
&body,
request_headers,
is_streaming_request,
alias_resolved_model.to_string(),
);
// Build request context
let mut request_context = RequestContext {
request_id: request_id.clone(),
attempted_providers: HashSet::new(),
retry_start_time: None,
attempt_number: 0,
request_retry_after_state: HashMap::new(),
request_latency_block_state: HashMap::new(),
request_signature: request_signature.clone(),
errors: vec![],
};
let orchestrator = RetryOrchestrator::new_default();
debug!(
model = %alias_resolved_model,
fallback_models = ?retry_policy.fallback_models,
default_strategy = ?retry_policy.default_strategy,
default_max_attempts = retry_policy.default_max_attempts,
"Retry orchestrator initialized for request"
);
// Capture references for the forward_fn closure
let base_url = upstream_url.to_string();
let original_headers = request_headers.clone();
let primary_model = alias_resolved_model.to_string();
let http_client = http_client.clone();
// The forward_fn handles the actual HTTP call to upstream for each attempt.
let forward_fn = |body: &Bytes, target_provider: &LlmProvider| {
let body = body.clone();
let target_provider = target_provider.clone();
let base_url = base_url.clone();
let original_headers = original_headers.clone();
let primary_model = primary_model.clone();
let http_client = http_client.clone();
async move {
let target_model = target_provider
.model
.as_deref()
.unwrap_or(&target_provider.name);
let (request_body, mut headers) = if target_model == primary_model {
(body.clone(), original_headers.clone())
} else {
match rebuild_request_for_provider(&body, &target_provider, &original_headers) {
Ok((new_body, new_headers)) => (new_body, new_headers),
Err(e) => {
warn!(error = %e, "Failed to rebuild request for provider");
return Err(common::retry::error_detector::TimeoutError { duration_ms: 0 });
}
}
};
// Resolve the upstream URL for the target provider.
// Always route through the Envoy proxy (base_url) and let the
// provider-hint header select the upstream cluster. Building a
// direct URL from the provider endpoint is wrong because the
// endpoint field stores a bare hostname (no scheme), and
// bypassing Envoy loses TLS, load-balancing, and observability.
let upstream_url = base_url.clone();
// Set provider hint header so the WASM gateway selects the
// correct provider (and its credentials) for this retry attempt.
// Do NOT set x-arch-llm-provider here — the WASM gateway sets it
// via add_http_request_header after provider selection. If we set
// it too, Envoy sees a duplicate multi-value header that fails
// exact-match routing and falls through to the 400 catch-all.
headers.remove(header::HeaderName::from_static(ARCH_ROUTING_HEADER));
headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(target_model)
.unwrap_or_else(|_| header::HeaderValue::from_static("unknown")),
);
headers.remove(header::CONTENT_LENGTH);
// Send the request
let result = http_client
.post(&upstream_url)
.headers(headers)
.body(request_body.to_vec())
.send()
.await;
match result {
Ok(res) => {
let status = res.status().as_u16();
let resp_headers = res.headers().clone();
let body_bytes = res.bytes().await.unwrap_or_default();
// Debug: log upstream response for retry attempts
if status >= 400 {
let body_preview = String::from_utf8_lossy(&body_bytes);
warn!(
"Retry upstream response: status={}, model={}, body={}",
status,
target_model,
&body_preview[..body_preview.len().min(500)]
);
}
let full_body = Full::new(body_bytes)
.map_err(|never| match never {})
.boxed();
let mut builder = Response::builder().status(status);
if let Some(hdrs) = builder.headers_mut() {
for (name, value) in resp_headers.iter() {
if let Ok(hyper_name) =
hyper::header::HeaderName::from_bytes(name.as_str().as_bytes())
{
if let Ok(hyper_value) =
hyper::header::HeaderValue::from_bytes(value.as_bytes())
{
hdrs.insert(hyper_name, hyper_value);
}
}
}
}
Ok(builder.body(full_body).unwrap())
}
Err(err) => {
warn!(error = %err, "Upstream request failed during retry");
Err(common::retry::error_detector::TimeoutError { duration_ms: 0 })
}
}
}
};
// Execute the retry orchestrator
let retry_result = orchestrator
.execute(
&body,
&request_signature,
primary_provider.as_ref(),
retry_policy,
&all_providers,
&mut request_context,
forward_fn,
)
.await;
match retry_result {
Ok(http_response) => {
// Success (possibly after retries) — stream the response back to client
let upstream_status = http_response.status();
let response_headers = http_response.headers().clone();
let span_name = if model_from_request == resolved_model {
format!("POST {} {}", request_path, resolved_model)
} else {
format!(
"POST {} {} -> {}",
request_path, model_from_request, resolved_model
)
};
let mut response = Response::builder().status(upstream_status);
let headers = response.headers_mut().unwrap();
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
// Collect the body from the HttpResponse
let body_bytes = http_response
.into_body()
.collect()
.await
.map(|collected| collected.to_bytes())
.unwrap_or_default();
let byte_stream = futures::stream::iter(vec![Ok::<Bytes, reqwest::Error>(body_bytes)]);
let (metric_provider_raw, metric_model_raw) =
bs_metrics::split_provider_model(resolved_model);
let base_processor = ObservableStreamProcessor::new(
operation_component::LLM,
span_name,
std::time::Instant::now(),
messages_for_signals,
)
.with_llm_metrics(LlmMetricsCtx {
provider: metric_provider_raw.to_string(),
model: metric_model_raw.to_string(),
upstream_status: upstream_status.as_u16(),
});
let output_filter_request_headers = if filter_pipeline.has_output_filters() {
Some(request_headers.clone())
} else {
None
};
let processor: Box<dyn StreamProcessor> = if let (true, false, Some(state_store)) = (
state_ctx.should_manage_state,
state_ctx.original_input_items.is_empty(),
&state_storage,
) {
let content_encoding = response_headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
Box::new(ResponsesStateProcessor::new(
base_processor,
state_store.clone(),
state_ctx.original_input_items,
alias_resolved_model.to_string(),
resolved_model.to_string(),
is_streaming_request,
false,
content_encoding,
request_id.clone(),
))
} else {
Box::new(base_processor)
};
let streaming_response = if let (Some(output_chain), Some(filter_headers)) = (
filter_pipeline.output.as_ref().filter(|c| !c.is_empty()),
output_filter_request_headers,
) {
create_streaming_response_with_output_filter(
byte_stream,
processor,
output_chain.clone(),
filter_headers,
request_path.to_string(),
)
} else {
create_streaming_response(byte_stream, processor)
};
match response.body(streaming_response.body) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(internal_error)
}
}
}
Err(retry_exhausted_error) => {
// All retries exhausted — build structured error response
info!(
request_id = %request_id,
total_attempts = retry_exhausted_error.attempts.len(),
budget_exhausted = retry_exhausted_error.retry_budget_exhausted,
"All retries exhausted"
);
let error_resp = build_error_response(&retry_exhausted_error, &request_id);
// Convert Full<Bytes> body to BoxBody<Bytes, hyper::Error>
let (parts, full_body) = error_resp.into_parts();
let boxed_body = full_body.map_err(|never| match never {}).boxed();
Ok(Response::from_parts(parts, boxed_body))
}
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------