feat: add provider arbitrage policy and fallback routing

This commit is contained in:
Musa 2026-03-18 15:54:49 -07:00
parent de2d8847f3
commit 07ad4c6ae2
No known key found for this signature in database
10 changed files with 670 additions and 57 deletions

View file

@ -29,11 +29,56 @@ use crate::state::{
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
};
use crate::tracing::{
collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name,
collect_custom_trace_attributes, llm as tracing_llm, operation_component,
routing as tracing_routing, set_service_name,
};
use common::errors::BrightStaffError;
fn strip_provider_prefix(model: &str) -> String {
if let Some((_, model_name)) = model.split_once('/') {
model_name.to_string()
} else {
model.to_string()
}
}
fn is_retryable_upstream_status(status: reqwest::StatusCode) -> bool {
matches!(
status,
reqwest::StatusCode::TOO_MANY_REQUESTS
| reqwest::StatusCode::BAD_GATEWAY
| reqwest::StatusCode::SERVICE_UNAVAILABLE
| reqwest::StatusCode::GATEWAY_TIMEOUT
)
}
async fn build_arbitrage_candidate_chain(
llm_providers: &Arc<RwLock<LlmProviders>>,
primary_model: &str,
) -> Vec<String> {
let mut chain = Vec::new();
let providers = llm_providers.read().await;
if let Some(provider) = providers.get(primary_model) {
if let Some(arbitrage_policy) = &provider.arbitrage_policy {
if arbitrage_policy.enabled.unwrap_or(false) {
for ranked_candidate in arbitrage_policy.rank.clone().unwrap_or_default() {
if !chain.contains(&ranked_candidate) {
chain.push(ranked_candidate);
}
}
}
}
}
if !chain.contains(&primary_model.to_string()) {
chain.push(primary_model.to_string());
}
chain
}
pub async fn llm_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
@ -97,7 +142,7 @@ async fn llm_chat_inner(
state_storage: Option<Arc<dyn StateStorage>>,
request_id: String,
request_path: String,
mut request_headers: hyper::HeaderMap,
request_headers: hyper::HeaderMap,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
@ -331,9 +376,6 @@ async fn llm_chat_inner(
}
}
// Serialize request for upstream BEFORE router consumes it
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
// Determine routing using the dedicated router_chat module
// This gets its own span for latency and error tracking
let routing_span = info_span!(
@ -374,73 +416,209 @@ async fn llm_chat_inner(
// Determine final model to use
// Router returns "none" as a sentinel value when it doesn't select a specific model
let router_selected_model = routing_result.model_name;
let resolved_model = if router_selected_model != "none" {
let primary_model = if router_selected_model != "none" {
// Router selected a specific model via routing preferences
router_selected_model
} else {
// Router returned "none" sentinel, use validated resolved_model from request
alias_resolved_model.clone()
};
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
let arbitrage_chain = build_arbitrage_candidate_chain(&llm_providers, &primary_model).await;
let mut selected_model = primary_model.clone();
let mut llm_response: Option<reqwest::Response> = None;
let mut last_transport_error: Option<String> = None;
let request_start_time = std::time::Instant::now();
let http_client = reqwest::Client::new();
let span_name = if model_from_request == resolved_model {
format!("POST {} {}", request_path, resolved_model)
for (attempt_idx, candidate_model) in arbitrage_chain.iter().enumerate() {
selected_model = candidate_model.clone();
let candidate_model_name = strip_provider_prefix(candidate_model);
let mut candidate_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!(
candidate = %candidate_model,
error = %err,
"failed to build candidate request"
);
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request: {}",
err
))
.into_response());
}
};
candidate_request.set_model(candidate_model_name.clone());
if candidate_request.remove_metadata_key("plano_preference_config") {
debug!("removed plano_preference_config from candidate metadata");
}
let (candidate_provider_id, _) = get_provider_info(&llm_providers, candidate_model).await;
let selection_reason = if attempt_idx == 0 {
if candidate_model == &primary_model {
"router_selected_primary"
} else {
"free_tier_available"
}
} else {
"fallback_on_retryable_error"
};
let is_fallback_attempt = attempt_idx > 0;
get_active_span(|span| {
span.set_attribute(opentelemetry::KeyValue::new(
tracing_routing::SELECTION_REASON,
selection_reason.to_string(),
));
span.set_attribute(opentelemetry::KeyValue::new(
tracing_routing::IS_FALLBACK,
is_fallback_attempt,
));
span.set_attribute(opentelemetry::KeyValue::new(
tracing_llm::PROVIDER,
candidate_provider_id.to_string(),
));
span.set_attribute(opentelemetry::KeyValue::new(
"routing.attempt_index",
(attempt_idx + 1) as i64,
));
span.set_attribute(opentelemetry::KeyValue::new(
"routing.attempt_total",
arbitrage_chain.len() as i64,
));
span.set_attribute(opentelemetry::KeyValue::new(
tracing_routing::UPSTREAM_ENDPOINT,
candidate_model.clone(),
));
});
if let Some(ref client_api_kind) = client_api {
let upstream_api = candidate_provider_id
.compatible_api_for_client(client_api_kind, is_streaming_request);
candidate_request.normalize_for_upstream(candidate_provider_id, &upstream_api);
}
let candidate_request_bytes = ProviderRequestType::to_bytes(&candidate_request).unwrap();
let mut candidate_headers = request_headers.clone();
candidate_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(candidate_model).unwrap(),
);
candidate_headers.insert(
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
);
candidate_headers.remove(header::CONTENT_LENGTH);
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut candidate_headers));
});
debug!(
url = %full_qualified_llm_provider_url,
provider_hint = %candidate_model,
upstream_model = %candidate_model_name,
selection_reason = %selection_reason,
attempt_index = attempt_idx + 1,
attempt_total = arbitrage_chain.len(),
"Routing candidate to upstream"
);
let response = match http_client
.post(&full_qualified_llm_provider_url)
.headers(candidate_headers)
.body(candidate_request_bytes)
.send()
.await
{
Ok(response) => response,
Err(err) => {
last_transport_error = Some(err.to_string());
if attempt_idx + 1 < arbitrage_chain.len() {
let next_candidate = arbitrage_chain[attempt_idx + 1].as_str();
get_active_span(|span| {
span.set_attribute(opentelemetry::KeyValue::new(
"routing.fallback_trigger",
"transport_error".to_string(),
));
span.set_attribute(opentelemetry::KeyValue::new(
"routing.next_candidate",
next_candidate.to_string(),
));
});
warn!(
candidate = %candidate_model,
error = %err,
next_candidate = %next_candidate,
attempt_index = attempt_idx + 1,
attempt_total = arbitrage_chain.len(),
"candidate transport failure, trying next fallback"
);
continue;
}
return Ok(BrightStaffError::InternalServerError(format!(
"Failed to send request: {}",
err
))
.into_response());
}
};
let status = response.status();
if is_retryable_upstream_status(status) && attempt_idx + 1 < arbitrage_chain.len() {
let next_candidate = arbitrage_chain[attempt_idx + 1].as_str();
get_active_span(|span| {
span.set_attribute(opentelemetry::KeyValue::new(
"routing.fallback_trigger",
format!("http_{}", status.as_u16()),
));
span.set_attribute(opentelemetry::KeyValue::new(
"routing.next_candidate",
next_candidate.to_string(),
));
});
warn!(
candidate = %candidate_model,
status = status.as_u16(),
next_candidate = %next_candidate,
attempt_index = attempt_idx + 1,
attempt_total = arbitrage_chain.len(),
"candidate returned retryable status, trying next fallback"
);
continue;
}
llm_response = Some(response);
break;
}
let llm_response = match llm_response {
Some(response) => response,
None => {
return Ok(BrightStaffError::InternalServerError(format!(
"Failed to send request across arbitrage chain: {}",
last_transport_error.unwrap_or_else(|| "unknown error".to_string())
))
.into_response());
}
};
tracing::Span::current().record(tracing_llm::MODEL_NAME, selected_model.as_str());
let span_name = if model_from_request == selected_model {
format!("POST {} {}", request_path, selected_model)
} else {
format!(
"POST {} {} -> {}",
request_path, model_from_request, resolved_model
request_path, model_from_request, selected_model
)
};
get_active_span(|span| {
span.update_name(span_name.clone());
});
debug!(
url = %full_qualified_llm_provider_url,
provider_hint = %resolved_model,
upstream_model = %model_name_only,
"Routing to upstream"
);
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&resolved_model).unwrap(),
);
request_headers.insert(
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
);
// remove content-length header if it exists
request_headers.remove(header::CONTENT_LENGTH);
// Inject current LLM span's trace context so upstream spans are children of plano(llm)
global::get_text_map_propagator(|propagator| {
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
});
// Capture start time right before sending request to upstream
let request_start_time = std::time::Instant::now();
let _request_start_system_time = std::time::SystemTime::now();
let llm_response = match reqwest::Client::new()
.post(&full_qualified_llm_provider_url)
.headers(request_headers)
.body(client_request_bytes_for_upstream)
.send()
.await
{
Ok(res) => res,
Err(err) => {
return Ok(BrightStaffError::InternalServerError(format!(
"Failed to send request: {}",
err
))
.into_response());
}
};
// copy over the headers and status code from the original response
let response_headers = llm_response.headers().clone();
let upstream_status = llm_response.status();
@ -480,7 +658,7 @@ async fn llm_chat_inner(
state_store,
original_input_items,
alias_resolved_model.clone(),
resolved_model.clone(),
selected_model.clone(),
is_streaming_request,
false, // Not OpenAI upstream since should_manage_state is true
content_encoding,
@ -570,3 +748,101 @@ async fn get_provider_info(
(hermesllm::ProviderId::OpenAI, None)
}
}
#[cfg(test)]
mod tests {
use super::{
build_arbitrage_candidate_chain, is_retryable_upstream_status, strip_provider_prefix,
};
use common::configuration::{
ArbitrageFailurePolicy, ArbitragePolicy, LlmProvider, LlmProviderType,
};
use common::llm_providers::LlmProviders;
use std::sync::Arc;
use tokio::sync::RwLock;
fn provider(name: &str, model: &str, default: bool) -> LlmProvider {
LlmProvider {
name: name.to_string(),
model: Some(model.to_string()),
provider_interface: LlmProviderType::OpenAI,
default: Some(default),
stream: None,
access_key: None,
endpoint: None,
port: None,
rate_limits: None,
usage: None,
routing_preferences: None,
cluster_name: None,
base_url_path_prefix: None,
internal: None,
passthrough_auth: None,
arbitrage_policy: None,
}
}
#[test]
fn strips_provider_prefix() {
assert_eq!(strip_provider_prefix("openai/gpt-4o-mini"), "gpt-4o-mini");
assert_eq!(strip_provider_prefix("gpt-4o-mini"), "gpt-4o-mini");
}
#[test]
fn retryable_status_matrix_is_deterministic() {
assert!(is_retryable_upstream_status(
reqwest::StatusCode::TOO_MANY_REQUESTS
));
assert!(is_retryable_upstream_status(
reqwest::StatusCode::BAD_GATEWAY
));
assert!(is_retryable_upstream_status(
reqwest::StatusCode::SERVICE_UNAVAILABLE
));
assert!(is_retryable_upstream_status(
reqwest::StatusCode::GATEWAY_TIMEOUT
));
assert!(!is_retryable_upstream_status(
reqwest::StatusCode::BAD_REQUEST
));
assert!(!is_retryable_upstream_status(
reqwest::StatusCode::UNAUTHORIZED
));
}
#[tokio::test]
async fn arbitrage_chain_is_ranked_then_primary() {
let mut primary = provider("openai/gpt-4o-mini", "gpt-4o-mini", true);
primary.arbitrage_policy = Some(ArbitragePolicy {
enabled: Some(true),
rank: Some(vec![
"groq/llama-3.1-8b-instant".to_string(),
"together_ai/openai/gpt-oss-20b".to_string(),
]),
on_failure: Some(ArbitrageFailurePolicy {
fallback_to_primary: Some(true),
}),
});
let providers = vec![
primary,
provider("groq/llama-3.1-8b-instant", "llama-3.1-8b-instant", false),
provider(
"together_ai/openai/gpt-oss-20b",
"openai/gpt-oss-20b",
false,
),
];
let llm_providers = Arc::new(RwLock::new(LlmProviders::try_from(providers).unwrap()));
let chain = build_arbitrage_candidate_chain(&llm_providers, "openai/gpt-4o-mini").await;
assert_eq!(
chain,
vec![
"groq/llama-3.1-8b-instant".to_string(),
"together_ai/openai/gpt-oss-20b".to_string(),
"openai/gpt-4o-mini".to_string(),
]
);
}
}

View file

@ -277,6 +277,18 @@ pub struct RoutingPreference {
pub description: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ArbitrageFailurePolicy {
pub fallback_to_primary: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ArbitragePolicy {
pub enabled: Option<bool>,
pub rank: Option<Vec<String>>,
pub on_failure: Option<ArbitrageFailurePolicy>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct AgentUsagePreference {
pub model: String,
@ -331,6 +343,7 @@ pub struct LlmProvider {
pub base_url_path_prefix: Option<String>,
pub internal: Option<bool>,
pub passthrough_auth: Option<bool>,
pub arbitrage_policy: Option<ArbitragePolicy>,
}
pub trait IntoModels {
@ -375,6 +388,7 @@ impl Default for LlmProvider {
base_url_path_prefix: None,
internal: None,
passthrough_auth: None,
arbitrage_policy: None,
}
}
}

View file

@ -278,6 +278,7 @@ mod tests {
internal: None,
stream: None,
passthrough_auth: None,
arbitrage_policy: None,
}
}