This commit is contained in:
Troy 2026-06-05 10:39:36 -07:00 committed by GitHub
commit 9c18136a96
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 13114 additions and 107 deletions

View file

@ -214,6 +214,183 @@ properties:
required:
- name
- description
retry_policy:
type: object
description: "Retry policy configuration. When not specified, no retry logic is enabled."
properties:
fallback_models:
type: array
description: "Ordered list of model identifiers to fallback to before using Provider_List."
items:
type: string
default_strategy:
type: string
description: "Default retry strategy for unconfigured status codes. Default: different_provider."
enum:
- same_model
- same_provider
- different_provider
default_max_attempts:
type: integer
description: "Default max retry attempts for unconfigured status codes. Default: 2."
minimum: 0
on_status_codes:
type: array
description: "Per-status-code retry configuration."
items:
type: object
properties:
codes:
type: array
description: "List of status codes as integers or range strings (e.g. '502-504')."
items:
anyOf:
- type: integer
minimum: 100
maximum: 599
- type: string
description: "Range string in 'start-end' format (e.g. '502-504')."
strategy:
type: string
description: "Retry strategy for these status codes."
enum:
- same_model
- same_provider
- different_provider
max_attempts:
type: integer
description: "Max retry attempts for these status codes."
minimum: 0
additionalProperties: false
required:
- codes
- strategy
- max_attempts
on_timeout:
type: object
description: "Timeout-specific retry configuration. When omitted, timeouts use default_strategy and default_max_attempts."
properties:
strategy:
type: string
description: "Retry strategy for timeout errors."
enum:
- same_model
- same_provider
- different_provider
max_attempts:
type: integer
description: "Max retry attempts for timeout errors."
minimum: 1
additionalProperties: false
required:
- strategy
- max_attempts
on_high_latency:
type: object
description: "High latency proactive failover configuration. When omitted, no latency-based failover is performed."
properties:
threshold_ms:
type: integer
description: "Latency threshold in milliseconds. When response time exceeds this value, a High_Latency_Event is triggered."
minimum: 1
measure:
type: string
description: "What latency metric to measure. Default: ttfb."
enum:
- ttfb
- total
strategy:
type: string
description: "Retry strategy when latency threshold is exceeded."
enum:
- same_model
- same_provider
- different_provider
max_attempts:
type: integer
description: "Max retry attempts when latency threshold is exceeded."
minimum: 1
block_duration_seconds:
type: integer
description: "How long to block the model/provider after detecting high latency, in seconds. Default: 300."
minimum: 1
scope:
type: string
description: "What to block: model-level or provider-level. Default: model."
enum:
- model
- provider
apply_to:
type: string
description: "Blocking scope: global or request-scoped. Default: global."
enum:
- global
- request
min_triggers:
type: integer
description: "Number of High_Latency_Events required before creating a block. Default: 1."
minimum: 1
trigger_window_seconds:
type: integer
description: "Sliding time window in seconds for counting triggers. Required when min_triggers > 1."
minimum: 1
additionalProperties: false
required:
- threshold_ms
- strategy
- max_attempts
- block_duration_seconds
backoff:
type: object
description: "Exponential backoff configuration. When omitted, no backoff delays are applied."
properties:
apply_to:
type: string
description: "REQUIRED. Determines when backoff delays are applied."
enum:
- same_model
- same_provider
- global
base_ms:
type: integer
description: "Base delay in milliseconds for exponential backoff. Default: 100."
minimum: 1
max_ms:
type: integer
description: "Maximum delay in milliseconds for exponential backoff. Default: 5000."
minimum: 1
jitter:
type: boolean
description: "Add random jitter to prevent thundering herd. Default: true."
additionalProperties: false
required:
- apply_to
retry_after_handling:
type: object
description: "Retry-After header handling customization. When omitted, Retry-After is honored with defaults (scope: model, apply_to: global, max_retry_after_seconds: 300)."
properties:
scope:
type: string
description: "What to block: model-level or provider-level. Default: model."
enum:
- model
- provider
apply_to:
type: string
description: "Blocking scope: request-scoped or global. Default: global."
enum:
- request
- global
max_retry_after_seconds:
type: integer
description: "Maximum Retry-After value honored in seconds. Default: 300."
minimum: 1
additionalProperties: false
max_retry_duration_ms:
type: integer
description: "Maximum total time in milliseconds for all retry attempts combined. Timer starts on first retry."
minimum: 0
additionalProperties: false
additionalProperties: false
required:
- model
@ -273,6 +450,183 @@ properties:
required:
- name
- description
retry_policy:
type: object
description: "Retry policy configuration. When not specified, no retry logic is enabled."
properties:
fallback_models:
type: array
description: "Ordered list of model identifiers to fallback to before using Provider_List."
items:
type: string
default_strategy:
type: string
description: "Default retry strategy for unconfigured status codes. Default: different_provider."
enum:
- same_model
- same_provider
- different_provider
default_max_attempts:
type: integer
description: "Default max retry attempts for unconfigured status codes. Default: 2."
minimum: 0
on_status_codes:
type: array
description: "Per-status-code retry configuration."
items:
type: object
properties:
codes:
type: array
description: "List of status codes as integers or range strings (e.g. '502-504')."
items:
anyOf:
- type: integer
minimum: 100
maximum: 599
- type: string
description: "Range string in 'start-end' format (e.g. '502-504')."
strategy:
type: string
description: "Retry strategy for these status codes."
enum:
- same_model
- same_provider
- different_provider
max_attempts:
type: integer
description: "Max retry attempts for these status codes."
minimum: 0
additionalProperties: false
required:
- codes
- strategy
- max_attempts
on_timeout:
type: object
description: "Timeout-specific retry configuration. When omitted, timeouts use default_strategy and default_max_attempts."
properties:
strategy:
type: string
description: "Retry strategy for timeout errors."
enum:
- same_model
- same_provider
- different_provider
max_attempts:
type: integer
description: "Max retry attempts for timeout errors."
minimum: 1
additionalProperties: false
required:
- strategy
- max_attempts
on_high_latency:
type: object
description: "High latency proactive failover configuration. When omitted, no latency-based failover is performed."
properties:
threshold_ms:
type: integer
description: "Latency threshold in milliseconds. When response time exceeds this value, a High_Latency_Event is triggered."
minimum: 1
measure:
type: string
description: "What latency metric to measure. Default: ttfb."
enum:
- ttfb
- total
strategy:
type: string
description: "Retry strategy when latency threshold is exceeded."
enum:
- same_model
- same_provider
- different_provider
max_attempts:
type: integer
description: "Max retry attempts when latency threshold is exceeded."
minimum: 1
block_duration_seconds:
type: integer
description: "How long to block the model/provider after detecting high latency, in seconds. Default: 300."
minimum: 1
scope:
type: string
description: "What to block: model-level or provider-level. Default: model."
enum:
- model
- provider
apply_to:
type: string
description: "Blocking scope: global or request-scoped. Default: global."
enum:
- global
- request
min_triggers:
type: integer
description: "Number of High_Latency_Events required before creating a block. Default: 1."
minimum: 1
trigger_window_seconds:
type: integer
description: "Sliding time window in seconds for counting triggers. Required when min_triggers > 1."
minimum: 1
additionalProperties: false
required:
- threshold_ms
- strategy
- max_attempts
- block_duration_seconds
backoff:
type: object
description: "Exponential backoff configuration. When omitted, no backoff delays are applied."
properties:
apply_to:
type: string
description: "REQUIRED. Determines when backoff delays are applied."
enum:
- same_model
- same_provider
- global
base_ms:
type: integer
description: "Base delay in milliseconds for exponential backoff. Default: 100."
minimum: 1
max_ms:
type: integer
description: "Maximum delay in milliseconds for exponential backoff. Default: 5000."
minimum: 1
jitter:
type: boolean
description: "Add random jitter to prevent thundering herd. Default: true."
additionalProperties: false
required:
- apply_to
retry_after_handling:
type: object
description: "Retry-After header handling customization. When omitted, Retry-After is honored with defaults (scope: model, apply_to: global, max_retry_after_seconds: 300)."
properties:
scope:
type: string
description: "What to block: model-level or provider-level. Default: model."
enum:
- model
- provider
apply_to:
type: string
description: "Blocking scope: request-scoped or global. Default: global."
enum:
- request
- global
max_retry_after_seconds:
type: integer
description: "Maximum Retry-After value honored in seconds. Default: 300."
minimum: 1
additionalProperties: false
max_retry_duration_ms:
type: integer
description: "Maximum total time in milliseconds for all retry attempts combined. Timer starts on first retry."
minimum: 0
additionalProperties: false
additionalProperties: false
required:
- model

97
crates/Cargo.lock generated
View file

@ -293,7 +293,16 @@ version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
dependencies = [
"bit-vec",
"bit-vec 0.6.3",
]
[[package]]
name = "bit-set"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3"
dependencies = [
"bit-vec 0.8.0",
]
[[package]]
@ -302,6 +311,12 @@ version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]]
name = "bit-vec"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7"
[[package]]
name = "bitflags"
version = "2.11.0"
@ -519,6 +534,7 @@ version = "0.1.0"
dependencies = [
"axum",
"bytes",
"dashmap",
"derivative",
"duration-string",
"governor",
@ -528,6 +544,7 @@ dependencies = [
"hyper 1.9.0",
"log",
"pretty_assertions",
"proptest",
"proxy-wasm",
"rand 0.8.5",
"serde",
@ -535,6 +552,7 @@ dependencies = [
"serde_with",
"serde_yaml",
"serial_test",
"sha2 0.10.9",
"thiserror 1.0.69",
"tiktoken-rs",
"tokio",
@ -742,6 +760,20 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "dashmap"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "deranged"
version = "0.5.8"
@ -928,7 +960,7 @@ version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7493d4c459da9f84325ad297371a6b2b8a162800873a22e3b6b6512e61d18c05"
dependencies = [
"bit-set",
"bit-set 0.5.3",
"regex",
]
@ -2527,6 +2559,25 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "proptest"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b45fcc2344c680f5025fe57779faef368840d0bd1f42f216291f0dc4ace4744"
dependencies = [
"bit-set 0.8.0",
"bit-vec 0.8.0",
"bitflags",
"num-traits",
"rand 0.9.4",
"rand_chacha 0.9.0",
"rand_xorshift",
"regex-syntax",
"rusty-fork",
"tempfile",
"unarray",
]
[[package]]
name = "prost"
version = "0.14.3"
@ -2575,6 +2626,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "quick-error"
version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]]
name = "quinn"
version = "0.11.9"
@ -2727,6 +2784,15 @@ version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69"
[[package]]
name = "rand_xorshift"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a"
dependencies = [
"rand_core 0.9.5",
]
[[package]]
name = "raw-cpuid"
version = "11.6.0"
@ -3056,6 +3122,18 @@ version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d"
[[package]]
name = "rusty-fork"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc6bf79ff24e648f6da1f8d1f011e9cac26491b619e6b9280f2b47f1774e6ee2"
dependencies = [
"fnv",
"quick-error",
"tempfile",
"wait-timeout",
]
[[package]]
name = "ryu"
version = "1.0.23"
@ -3984,6 +4062,12 @@ version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb"
[[package]]
name = "unarray"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94"
[[package]]
name = "unicase"
version = "2.9.0"
@ -4133,6 +4217,15 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64"
[[package]]
name = "wait-timeout"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11"
dependencies = [
"libc",
]
[[package]]
name = "want"
version = "0.3.1"

View file

@ -1,19 +1,24 @@
use bytes::Bytes;
use common::configuration::{FilterPipeline, ModelAlias};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, MODEL_AFFINITY_HEADER};
use common::configuration::{FilterPipeline, LlmProvider, ModelAlias};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, MODEL_AFFINITY_HEADER,
};
use common::llm_providers::LlmProviders;
use common::retry::error_response::build_error_response;
use common::retry::orchestrator::RetryOrchestrator;
use common::retry::{rebuild_request_for_provider, RequestContext, RequestSignature};
use hermesllm::apis::openai::Message;
use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::{BodyExt, Full};
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use opentelemetry::global;
use opentelemetry::trace::get_active_span;
use opentelemetry_http::HeaderInjector;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
@ -282,17 +287,10 @@ async fn llm_chat_inner(
Err(response) => return Ok(response),
};
// Serialize request for upstream BEFORE router consumes it
let client_request_bytes_for_upstream: Bytes =
match ProviderRequestType::to_bytes(&client_request) {
Ok(bytes) => bytes.into(),
Err(err) => {
warn!(error = %err, "failed to serialize request for upstream");
let mut r = Response::new(full(format!("Failed to serialize request: {}", err)));
*r.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(r);
}
};
// Use the original bytes (from extract_routing_policy) for upstream to preserve
// JSON key order, whitespace, and unknown fields — critical for prompt cache hits.
// Only fall back to re-serialization if input filters modified the request.
let client_request_bytes_for_upstream: Bytes = chat_request_bytes.clone();
// --- Phase 3: Route the request (or use pinned model from session cache) ---
let resolved_model = if let Some(cached_model) = pinned_model {
@ -367,24 +365,57 @@ async fn llm_chat_inner(
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
// --- Phase 4: Forward to upstream and stream back ---
send_upstream(
&state.http_client,
&full_qualified_llm_provider_url,
&mut request_headers,
client_request_bytes_for_upstream,
&model_from_request,
&alias_resolved_model,
&resolved_model,
&model_name_only,
&request_path,
is_streaming_request,
messages_for_signals,
state_ctx,
state.state_storage.clone(),
request_id,
&state.filter_pipeline,
)
.await
// Check if the resolved provider has a retry_policy configured.
// If so, use the RetryOrchestrator to wrap the upstream call with retry logic.
let resolved_provider: Option<Arc<LlmProvider>> =
state.llm_providers.read().await.get(&resolved_model);
let has_retry_policy = resolved_provider
.as_ref()
.and_then(|p| p.retry_policy.as_ref())
.is_some();
if has_retry_policy {
send_upstream_with_retry(
&state.http_client,
&full_qualified_llm_provider_url,
&mut request_headers,
client_request_bytes_for_upstream,
&model_from_request,
&alias_resolved_model,
&resolved_model,
&model_name_only,
&request_path,
is_streaming_request,
messages_for_signals,
state_ctx,
state.state_storage.clone(),
request_id,
&state.filter_pipeline,
&resolved_provider.unwrap(),
&state.llm_providers,
)
.await
} else {
send_upstream(
&state.http_client,
&full_qualified_llm_provider_url,
&mut request_headers,
client_request_bytes_for_upstream,
&model_from_request,
&alias_resolved_model,
&resolved_model,
&model_name_only,
&request_path,
is_streaming_request,
messages_for_signals,
state_ctx,
state.state_storage.clone(),
request_id,
&state.filter_pipeline,
)
.await
}
}
// ---------------------------------------------------------------------------
@ -845,6 +876,312 @@ async fn send_upstream(
}
}
/// Retry-aware version of send_upstream. Uses the RetryOrchestrator to wrap
/// the upstream HTTP call with automatic retry and provider failover logic.
#[allow(clippy::too_many_arguments)]
async fn send_upstream_with_retry(
http_client: &reqwest::Client,
upstream_url: &str,
request_headers: &mut hyper::HeaderMap,
body: bytes::Bytes,
model_from_request: &str,
alias_resolved_model: &str,
resolved_model: &str,
_model_name_only: &str,
request_path: &str,
is_streaming_request: bool,
messages_for_signals: Option<Vec<Message>>,
state_ctx: ConversationStateContext,
state_storage: Option<Arc<dyn StateStorage>>,
request_id: String,
filter_pipeline: &Arc<FilterPipeline>,
primary_provider: &Arc<LlmProvider>,
llm_providers: &Arc<RwLock<LlmProviders>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let retry_policy = primary_provider.retry_policy.as_ref().unwrap();
// Collect all providers for the retry orchestrator
let all_providers: Vec<LlmProvider> = llm_providers
.read()
.await
.iter()
.map(|(_, p)| (*p).as_ref().clone())
.collect();
// Build request signature
let request_signature = RequestSignature::new(
&body,
request_headers,
is_streaming_request,
alias_resolved_model.to_string(),
);
// Build request context
let mut request_context = RequestContext {
request_id: request_id.clone(),
attempted_providers: HashSet::new(),
retry_start_time: None,
attempt_number: 0,
request_retry_after_state: HashMap::new(),
request_latency_block_state: HashMap::new(),
request_signature: request_signature.clone(),
errors: vec![],
};
let orchestrator = RetryOrchestrator::new_default();
debug!(
model = %alias_resolved_model,
fallback_models = ?retry_policy.fallback_models,
default_strategy = ?retry_policy.default_strategy,
default_max_attempts = retry_policy.default_max_attempts,
"Retry orchestrator initialized for request"
);
// Capture references for the forward_fn closure
let base_url = upstream_url.to_string();
let original_headers = request_headers.clone();
let primary_model = alias_resolved_model.to_string();
let http_client = http_client.clone();
// The forward_fn handles the actual HTTP call to upstream for each attempt.
let forward_fn = |body: &Bytes, target_provider: &LlmProvider| {
let body = body.clone();
let target_provider = target_provider.clone();
let base_url = base_url.clone();
let original_headers = original_headers.clone();
let primary_model = primary_model.clone();
let http_client = http_client.clone();
async move {
let target_model = target_provider
.model
.as_deref()
.unwrap_or(&target_provider.name);
let (request_body, mut headers) = if target_model == primary_model {
(body.clone(), original_headers.clone())
} else {
match rebuild_request_for_provider(&body, &target_provider, &original_headers) {
Ok((new_body, new_headers)) => (new_body, new_headers),
Err(e) => {
warn!(error = %e, "Failed to rebuild request for provider");
return Err(common::retry::error_detector::TimeoutError { duration_ms: 0 });
}
}
};
// Resolve the upstream URL for the target provider.
// Always route through the Envoy proxy (base_url) and let the
// provider-hint header select the upstream cluster. Building a
// direct URL from the provider endpoint is wrong because the
// endpoint field stores a bare hostname (no scheme), and
// bypassing Envoy loses TLS, load-balancing, and observability.
let upstream_url = base_url.clone();
// Set provider hint header so the WASM gateway selects the
// correct provider (and its credentials) for this retry attempt.
// Do NOT set x-arch-llm-provider here — the WASM gateway sets it
// via add_http_request_header after provider selection. If we set
// it too, Envoy sees a duplicate multi-value header that fails
// exact-match routing and falls through to the 400 catch-all.
headers.remove(header::HeaderName::from_static(ARCH_ROUTING_HEADER));
headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(target_model)
.unwrap_or_else(|_| header::HeaderValue::from_static("unknown")),
);
headers.remove(header::CONTENT_LENGTH);
// Send the request
let result = http_client
.post(&upstream_url)
.headers(headers)
.body(request_body.to_vec())
.send()
.await;
match result {
Ok(res) => {
let status = res.status().as_u16();
let resp_headers = res.headers().clone();
let body_bytes = res.bytes().await.unwrap_or_default();
// Debug: log upstream response for retry attempts
if status >= 400 {
let body_preview = String::from_utf8_lossy(&body_bytes);
warn!(
"Retry upstream response: status={}, model={}, body={}",
status,
target_model,
&body_preview[..body_preview.len().min(500)]
);
}
let full_body = Full::new(body_bytes)
.map_err(|never| match never {})
.boxed();
let mut builder = Response::builder().status(status);
if let Some(hdrs) = builder.headers_mut() {
for (name, value) in resp_headers.iter() {
if let Ok(hyper_name) =
hyper::header::HeaderName::from_bytes(name.as_str().as_bytes())
{
if let Ok(hyper_value) =
hyper::header::HeaderValue::from_bytes(value.as_bytes())
{
hdrs.insert(hyper_name, hyper_value);
}
}
}
}
Ok(builder.body(full_body).unwrap())
}
Err(err) => {
warn!(error = %err, "Upstream request failed during retry");
Err(common::retry::error_detector::TimeoutError { duration_ms: 0 })
}
}
}
};
// Execute the retry orchestrator
let retry_result = orchestrator
.execute(
&body,
&request_signature,
primary_provider.as_ref(),
retry_policy,
&all_providers,
&mut request_context,
forward_fn,
)
.await;
match retry_result {
Ok(http_response) => {
// Success (possibly after retries) — stream the response back to client
let upstream_status = http_response.status();
let response_headers = http_response.headers().clone();
let span_name = if model_from_request == resolved_model {
format!("POST {} {}", request_path, resolved_model)
} else {
format!(
"POST {} {} -> {}",
request_path, model_from_request, resolved_model
)
};
let mut response = Response::builder().status(upstream_status);
let headers = response.headers_mut().unwrap();
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
// Collect the body from the HttpResponse
let body_bytes = http_response
.into_body()
.collect()
.await
.map(|collected| collected.to_bytes())
.unwrap_or_default();
let byte_stream = futures::stream::iter(vec![Ok::<Bytes, reqwest::Error>(body_bytes)]);
let (metric_provider_raw, metric_model_raw) =
bs_metrics::split_provider_model(resolved_model);
let base_processor = ObservableStreamProcessor::new(
operation_component::LLM,
span_name,
std::time::Instant::now(),
messages_for_signals,
)
.with_llm_metrics(LlmMetricsCtx {
provider: metric_provider_raw.to_string(),
model: metric_model_raw.to_string(),
upstream_status: upstream_status.as_u16(),
});
let output_filter_request_headers = if filter_pipeline.has_output_filters() {
Some(request_headers.clone())
} else {
None
};
let processor: Box<dyn StreamProcessor> = if let (true, false, Some(state_store)) = (
state_ctx.should_manage_state,
state_ctx.original_input_items.is_empty(),
&state_storage,
) {
let content_encoding = response_headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
Box::new(ResponsesStateProcessor::new(
base_processor,
state_store.clone(),
state_ctx.original_input_items,
alias_resolved_model.to_string(),
resolved_model.to_string(),
is_streaming_request,
false,
content_encoding,
request_id.clone(),
))
} else {
Box::new(base_processor)
};
let streaming_response = if let (Some(output_chain), Some(filter_headers)) = (
filter_pipeline.output.as_ref().filter(|c| !c.is_empty()),
output_filter_request_headers,
) {
create_streaming_response_with_output_filter(
byte_stream,
processor,
output_chain.clone(),
filter_headers,
request_path.to_string(),
)
} else {
create_streaming_response(byte_stream, processor)
};
match response.body(streaming_response.body) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(internal_error)
}
}
}
Err(retry_exhausted_error) => {
// All retries exhausted — build structured error response
info!(
request_id = %request_id,
total_attempts = retry_exhausted_error.attempts.len(),
budget_exhausted = retry_exhausted_error.retry_budget_exhausted,
"All retries exhausted"
);
let error_resp = build_error_response(&retry_exhausted_error, &request_id);
// Convert Full<Bytes> body to BoxBody<Bytes, hyper::Error>
let (parts, full_body) = error_resp.into_parts();
let boxed_body = full_body.map_err(|never| match never {}).boxed();
Ok(Response::from_parts(parts, boxed_body))
}
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------

View file

@ -45,7 +45,14 @@ pub fn extract_routing_policy(
},
);
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
// Only re-serialize if we actually removed routing_preferences.
// Otherwise preserve the original bytes to maintain JSON key order,
// whitespace, and unknown fields — critical for prompt cache hits.
let bytes = if routing_preferences.is_some() {
Bytes::from(serde_json::to_vec(&json_body).unwrap())
} else {
Bytes::from(raw_bytes.to_vec())
};
Ok((bytes, routing_preferences))
}

View file

@ -20,6 +20,9 @@ urlencoding = "2.1.3"
url = "2.5.4"
hermesllm = { version = "0.1.0", path = "../hermesllm" }
serde_with = "3.13.0"
sha2 = "0.10"
dashmap = "6"
tokio = { version = "1.44", features = ["sync", "time"] }
hyper = "1.0"
bytes = "1.0"
http-body-util = "0.1"
@ -36,3 +39,4 @@ tokio = { version = "1.44", features = ["sync", "time", "macros", "rt"] }
hyper = { version = "1.0", features = ["full"] }
bytes = "1.0"
http-body-util = "0.1"
proptest = "1.4"

View file

@ -0,0 +1,7 @@
# Seeds for failure cases proptest has generated in the past. It is
# automatically read and these particular cases re-run before any
# novel cases are generated.
#
# It is recommended to check this file in to source control so that
# everyone who runs the test benefits from these saved cases.
cc e6443c9611ecf84b57514e7d12084d62e6558989f663f1106d3cedd746a20bf3 # shrinks to include_on_status_codes = false, include_backoff = true, include_retry_after = false, include_on_timeout = false, include_on_high_latency = false

File diff suppressed because it is too large Load diff

View file

@ -7,6 +7,7 @@ pub mod llm_providers;
pub mod path;
pub mod pii;
pub mod ratelimit;
pub mod retry;
pub mod routing;
pub mod stats;
pub mod tokenizer;

View file

@ -278,6 +278,7 @@ mod tests {
stream: None,
passthrough_auth: None,
headers: None,
retry_policy: None,
}
}

View file

@ -0,0 +1,510 @@
use std::time::Duration;
use rand::Rng;
use crate::configuration::{extract_provider, BackoffApplyTo, BackoffConfig, RetryStrategy};
/// Calculator for exponential backoff delays with jitter and scope filtering.
pub struct BackoffCalculator;
impl BackoffCalculator {
/// Calculate the delay before the next retry attempt.
///
/// Returns the greater of the computed backoff delay and the Retry-After delay.
/// Returns zero when the backoff `apply_to` scope doesn't match the
/// current/previous provider relationship (unless retry_after_seconds is set).
pub fn calculate_delay(
&self,
attempt_number: u32,
backoff_config: Option<&BackoffConfig>,
retry_after_seconds: Option<u64>,
current_strategy: RetryStrategy,
current_provider: &str,
previous_provider: &str,
) -> Duration {
let backoff_delay = match backoff_config {
Some(config) => {
if !Self::scope_matches(
config.apply_to,
current_strategy,
current_provider,
previous_provider,
) {
Duration::ZERO
} else {
Self::compute_backoff(attempt_number, config)
}
}
None => Duration::ZERO,
};
let retry_after_delay = retry_after_seconds
.map(|s| Duration::from_secs(s))
.unwrap_or(Duration::ZERO);
backoff_delay.max(retry_after_delay)
}
/// Check whether the backoff `apply_to` scope matches the current retry context.
fn scope_matches(
apply_to: BackoffApplyTo,
_current_strategy: RetryStrategy,
current_provider: &str,
previous_provider: &str,
) -> bool {
let current_prefix = extract_provider(current_provider);
let previous_prefix = extract_provider(previous_provider);
match apply_to {
BackoffApplyTo::SameModel => current_provider == previous_provider,
BackoffApplyTo::SameProvider => current_prefix == previous_prefix,
BackoffApplyTo::Global => true,
}
}
/// Compute exponential backoff: min(base_ms * 2^attempt, max_ms), with optional jitter.
fn compute_backoff(attempt_number: u32, config: &BackoffConfig) -> Duration {
let exp_delay = if attempt_number >= 64 {
config.max_ms
} else {
config.base_ms.saturating_mul(1u64 << attempt_number)
};
let capped = exp_delay.min(config.max_ms);
let final_ms = if config.jitter {
let mut rng = rand::thread_rng();
let jitter_factor: f64 = 0.5 + rng.gen::<f64>() * 0.5;
((capped as f64) * jitter_factor) as u64
} else {
capped
};
Duration::from_millis(final_ms)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::configuration::{BackoffApplyTo, BackoffConfig, RetryStrategy};
use proptest::prelude::*;
fn make_config(
apply_to: BackoffApplyTo,
base_ms: u64,
max_ms: u64,
jitter: bool,
) -> BackoffConfig {
BackoffConfig {
apply_to,
base_ms,
max_ms,
jitter,
}
}
#[test]
fn no_backoff_config_returns_zero() {
let calc = BackoffCalculator;
let d = calc.calculate_delay(
0,
None,
None,
RetryStrategy::SameModel,
"openai/gpt-4o",
"openai/gpt-4o",
);
assert_eq!(d, Duration::ZERO);
}
#[test]
fn no_backoff_config_with_retry_after() {
let calc = BackoffCalculator;
let d = calc.calculate_delay(
0,
None,
Some(5),
RetryStrategy::SameModel,
"openai/gpt-4o",
"openai/gpt-4o",
);
assert_eq!(d, Duration::from_secs(5));
}
#[test]
fn exponential_backoff_no_jitter() {
let calc = BackoffCalculator;
let config = make_config(BackoffApplyTo::Global, 100, 5000, false);
// attempt 0: min(100 * 2^0, 5000) = 100
assert_eq!(
calc.calculate_delay(0, Some(&config), None, RetryStrategy::SameModel, "a", "a"),
Duration::from_millis(100)
);
// attempt 1: min(100 * 2^1, 5000) = 200
assert_eq!(
calc.calculate_delay(1, Some(&config), None, RetryStrategy::SameModel, "a", "a"),
Duration::from_millis(200)
);
// attempt 2: min(100 * 2^2, 5000) = 400
assert_eq!(
calc.calculate_delay(2, Some(&config), None, RetryStrategy::SameModel, "a", "a"),
Duration::from_millis(400)
);
// attempt 6: min(100 * 64, 5000) = 5000 (capped)
assert_eq!(
calc.calculate_delay(6, Some(&config), None, RetryStrategy::SameModel, "a", "a"),
Duration::from_millis(5000)
);
}
#[test]
fn jitter_stays_within_bounds() {
let calc = BackoffCalculator;
let config = make_config(BackoffApplyTo::Global, 1000, 50000, true);
for attempt in 0..5 {
for _ in 0..20 {
let d = calc.calculate_delay(
attempt,
Some(&config),
None,
RetryStrategy::SameModel,
"a",
"a",
);
let base = (1000u64.saturating_mul(1u64 << attempt)).min(50000);
// jitter: delay * (0.5 + random(0, 0.5)) => [0.5*base, 1.0*base]
assert!(
d.as_millis() >= (base as f64 * 0.5) as u128,
"delay {} too low for base {}",
d.as_millis(),
base
);
assert!(
d.as_millis() <= base as u128,
"delay {} too high for base {}",
d.as_millis(),
base
);
}
}
}
#[test]
fn scope_same_model_filters_different_providers() {
let calc = BackoffCalculator;
let config = make_config(BackoffApplyTo::SameModel, 100, 5000, false);
// Same model -> backoff applies
let d = calc.calculate_delay(
0,
Some(&config),
None,
RetryStrategy::SameModel,
"openai/gpt-4o",
"openai/gpt-4o",
);
assert_eq!(d, Duration::from_millis(100));
// Different model, same provider -> no backoff
let d = calc.calculate_delay(
0,
Some(&config),
None,
RetryStrategy::SameProvider,
"openai/gpt-4o-mini",
"openai/gpt-4o",
);
assert_eq!(d, Duration::ZERO);
// Different provider -> no backoff
let d = calc.calculate_delay(
0,
Some(&config),
None,
RetryStrategy::DifferentProvider,
"anthropic/claude",
"openai/gpt-4o",
);
assert_eq!(d, Duration::ZERO);
}
#[test]
fn scope_same_provider_filters_different_providers() {
let calc = BackoffCalculator;
let config = make_config(BackoffApplyTo::SameProvider, 100, 5000, false);
// Same provider -> backoff applies
let d = calc.calculate_delay(
0,
Some(&config),
None,
RetryStrategy::SameProvider,
"openai/gpt-4o-mini",
"openai/gpt-4o",
);
assert_eq!(d, Duration::from_millis(100));
// Same model (same provider) -> backoff applies
let d = calc.calculate_delay(
0,
Some(&config),
None,
RetryStrategy::SameModel,
"openai/gpt-4o",
"openai/gpt-4o",
);
assert_eq!(d, Duration::from_millis(100));
// Different provider -> no backoff
let d = calc.calculate_delay(
0,
Some(&config),
None,
RetryStrategy::DifferentProvider,
"anthropic/claude",
"openai/gpt-4o",
);
assert_eq!(d, Duration::ZERO);
}
#[test]
fn scope_global_always_applies() {
let calc = BackoffCalculator;
let config = make_config(BackoffApplyTo::Global, 100, 5000, false);
let d = calc.calculate_delay(
0,
Some(&config),
None,
RetryStrategy::DifferentProvider,
"anthropic/claude",
"openai/gpt-4o",
);
assert_eq!(d, Duration::from_millis(100));
}
#[test]
fn retry_after_wins_when_greater() {
let calc = BackoffCalculator;
let config = make_config(BackoffApplyTo::Global, 100, 5000, false);
// retry_after = 10s >> backoff attempt 0 = 100ms
let d = calc.calculate_delay(
0,
Some(&config),
Some(10),
RetryStrategy::SameModel,
"a",
"a",
);
assert_eq!(d, Duration::from_secs(10));
}
#[test]
fn backoff_wins_when_greater() {
let calc = BackoffCalculator;
// base_ms=10000, attempt 0 -> 10000ms = 10s
let config = make_config(BackoffApplyTo::Global, 10000, 50000, false);
// retry_after = 5s < backoff = 10s
let d = calc.calculate_delay(
0,
Some(&config),
Some(5),
RetryStrategy::SameModel,
"a",
"a",
);
assert_eq!(d, Duration::from_millis(10000));
}
#[test]
fn scope_mismatch_still_honors_retry_after() {
let calc = BackoffCalculator;
let config = make_config(BackoffApplyTo::SameModel, 100, 5000, false);
// Scope doesn't match (different providers) but retry_after is set
let d = calc.calculate_delay(
0,
Some(&config),
Some(3),
RetryStrategy::DifferentProvider,
"anthropic/claude",
"openai/gpt-4o",
);
assert_eq!(d, Duration::from_secs(3));
}
#[test]
fn large_attempt_number_saturates() {
let calc = BackoffCalculator;
let config = make_config(BackoffApplyTo::Global, 100, 5000, false);
// Very large attempt number should saturate and cap at max_ms
let d = calc.calculate_delay(63, Some(&config), None, RetryStrategy::SameModel, "a", "a");
assert_eq!(d, Duration::from_millis(5000));
}
// --- Proptest strategies ---
fn arb_provider() -> impl Strategy<Value = String> {
prop_oneof![
Just("openai/gpt-4o".to_string()),
Just("openai/gpt-4o-mini".to_string()),
Just("anthropic/claude-3".to_string()),
Just("azure/gpt-4o".to_string()),
Just("google/gemini-pro".to_string()),
]
}
// Feature: retry-on-ratelimit, Property 12: Exponential Backoff Formula and Bounds
// **Validates: Requirements 4.6, 4.7, 4.8, 4.9, 4.10, 4.11**
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
/// Property 12 Case 1: No-jitter delay equals min(base_ms * 2^attempt, max_ms) exactly.
#[test]
fn prop_backoff_no_jitter_exact(
attempt in 0u32..20,
base_ms in 1u64..10000,
extra in 1u64..40001u64,
) {
let max_ms = base_ms + extra;
let config = make_config(BackoffApplyTo::Global, base_ms, max_ms, false);
let calc = BackoffCalculator;
let d = calc.calculate_delay(attempt, Some(&config), None, RetryStrategy::SameModel, "a", "a");
let expected = if attempt >= 64 {
max_ms
} else {
base_ms.saturating_mul(1u64 << attempt).min(max_ms)
};
prop_assert_eq!(d, Duration::from_millis(expected));
}
/// Property 12 Case 2: Jitter delay is in [0.5 * computed_base, computed_base].
#[test]
fn prop_backoff_jitter_bounds(
attempt in 0u32..20,
base_ms in 1u64..10000,
extra in 1u64..40001u64,
) {
let max_ms = base_ms + extra;
let config = make_config(BackoffApplyTo::Global, base_ms, max_ms, true);
let calc = BackoffCalculator;
let d = calc.calculate_delay(attempt, Some(&config), None, RetryStrategy::SameModel, "a", "a");
let computed_base = if attempt >= 64 {
max_ms
} else {
base_ms.saturating_mul(1u64 << attempt).min(max_ms)
};
let lower = (computed_base as f64 * 0.5) as u64;
let upper = computed_base;
prop_assert!(
d.as_millis() >= lower as u128 && d.as_millis() <= upper as u128,
"delay {}ms not in [{}, {}] for attempt={}, base_ms={}, max_ms={}",
d.as_millis(), lower, upper, attempt, base_ms, max_ms
);
}
/// Property 12 Case 3: Delay is always <= max_ms.
#[test]
fn prop_backoff_delay_capped_at_max(
attempt in 0u32..20,
base_ms in 1u64..10000,
extra in 1u64..40001u64,
jitter in proptest::bool::ANY,
) {
let max_ms = base_ms + extra;
let config = make_config(BackoffApplyTo::Global, base_ms, max_ms, jitter);
let calc = BackoffCalculator;
let d = calc.calculate_delay(attempt, Some(&config), None, RetryStrategy::SameModel, "a", "a");
prop_assert!(
d.as_millis() <= max_ms as u128,
"delay {}ms exceeds max_ms {} for attempt={}, base_ms={}, jitter={}",
d.as_millis(), max_ms, attempt, base_ms, jitter
);
}
}
// Feature: retry-on-ratelimit, Property 13: Backoff Apply-To Scope Filtering
// **Validates: Requirements 4.3, 4.4, 4.5, 4.12, 4.13**
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
/// Property 13 Case 1: SameModel apply_to with different providers → zero delay.
#[test]
fn prop_scope_same_model_different_providers_zero(
attempt in 0u32..20,
base_ms in 1u64..10000,
extra in 1u64..40001u64,
current in arb_provider(),
previous in arb_provider(),
) {
// Only test when providers are actually different models
prop_assume!(current != previous);
let max_ms = base_ms + extra;
let config = make_config(BackoffApplyTo::SameModel, base_ms, max_ms, false);
let calc = BackoffCalculator;
let d = calc.calculate_delay(
attempt, Some(&config), None,
RetryStrategy::DifferentProvider, &current, &previous,
);
prop_assert_eq!(d, Duration::ZERO,
"Expected zero delay for SameModel apply_to with different models: {} vs {}",
current, previous
);
}
/// Property 13 Case 2: SameProvider apply_to with different provider prefixes → zero delay.
#[test]
fn prop_scope_same_provider_different_prefix_zero(
attempt in 0u32..20,
base_ms in 1u64..10000,
extra in 1u64..40001u64,
current in arb_provider(),
previous in arb_provider(),
) {
let current_prefix = extract_provider(&current);
let previous_prefix = extract_provider(&previous);
prop_assume!(current_prefix != previous_prefix);
let max_ms = base_ms + extra;
let config = make_config(BackoffApplyTo::SameProvider, base_ms, max_ms, false);
let calc = BackoffCalculator;
let d = calc.calculate_delay(
attempt, Some(&config), None,
RetryStrategy::DifferentProvider, &current, &previous,
);
prop_assert_eq!(d, Duration::ZERO,
"Expected zero delay for SameProvider apply_to with different prefixes: {} vs {}",
current_prefix, previous_prefix
);
}
/// Property 13 Case 3: Global apply_to always produces non-zero delay.
#[test]
fn prop_scope_global_always_nonzero(
attempt in 0u32..20,
base_ms in 1u64..10000,
extra in 1u64..40001u64,
current in arb_provider(),
previous in arb_provider(),
) {
let max_ms = base_ms + extra;
let config = make_config(BackoffApplyTo::Global, base_ms, max_ms, false);
let calc = BackoffCalculator;
let d = calc.calculate_delay(
attempt, Some(&config), None,
RetryStrategy::DifferentProvider, &current, &previous,
);
prop_assert!(d > Duration::ZERO,
"Expected non-zero delay for Global apply_to: current={}, previous={}",
current, previous
);
}
}
}

View file

@ -0,0 +1,945 @@
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::Response;
use crate::configuration::{LatencyMeasure, RetryPolicy, RetryStrategy, StatusCodeEntry};
// ── Types ──────────────────────────────────────────────────────────────────
/// Represents a request timeout (used in P1).
#[derive(Debug)]
pub struct TimeoutError {
pub duration_ms: u64,
}
/// The HTTP response type used throughout the gateway.
pub type HttpResponse = Response<BoxBody<Bytes, hyper::Error>>;
/// Result of classifying an upstream response or error condition.
#[derive(Debug)]
pub enum ErrorClassification {
/// 2xx success — pass through to client.
Success(HttpResponse),
/// Retriable HTTP error (matched on_status_codes or default 4xx/5xx).
RetriableError {
status_code: u16,
retry_after_seconds: Option<u64>,
response_body: Vec<u8>,
},
/// Request timed out (P1 — variant defined now for forward compatibility).
TimeoutError { duration_ms: u64 },
/// Response latency exceeded threshold (P2 — variant defined for forward compat).
HighLatencyEvent {
measured_ms: u64,
threshold_ms: u64,
measure: LatencyMeasure,
response: Option<HttpResponse>,
},
/// Non-retriable error — return as-is to client.
NonRetriableError(HttpResponse),
}
// ── ErrorDetector ──────────────────────────────────────────────────────────
pub struct ErrorDetector;
impl ErrorDetector {
/// Classify an upstream response or error condition.
///
/// In P0, only handles the `Ok(response)` path for HTTP status codes.
/// The `Err(timeout)` path is added in P1.
///
/// Dual-classification for timeout + high latency:
/// When both `on_high_latency` and `on_timeout` are configured and a request
/// times out after exceeding `threshold_ms`, this returns `TimeoutError` (for
/// retry purposes) but the caller must ALSO record a `HighLatencyEvent` for
/// blocking purposes.
pub fn classify(
&self,
response: Result<HttpResponse, TimeoutError>,
retry_policy: &RetryPolicy,
elapsed_ttfb_ms: u64,
elapsed_total_ms: u64,
) -> ErrorClassification {
match response {
Ok(resp) => {
self.classify_http_response(resp, retry_policy, elapsed_ttfb_ms, elapsed_total_ms)
}
// Timeout takes priority for retry; caller handles dual-classification
// for blocking (records HighLatencyEvent separately if applicable).
Err(timeout) => ErrorClassification::TimeoutError {
duration_ms: timeout.duration_ms,
},
}
}
/// Determine retry strategy and max_attempts for a given classification.
///
/// - `RetriableError` with a matching `on_status_codes` entry → that entry's params
/// - `RetriableError` without a match (default 4xx/5xx) → (default_strategy, default_max_attempts)
/// - `TimeoutError` → `on_timeout` config or defaults
/// - `HighLatencyEvent` → `on_high_latency` config (strategy, max_attempts)
pub fn resolve_retry_params(
&self,
classification: &ErrorClassification,
retry_policy: &RetryPolicy,
) -> (RetryStrategy, u32) {
match classification {
ErrorClassification::RetriableError { status_code, .. } => {
// Try to find a matching on_status_codes entry
for entry in &retry_policy.on_status_codes {
if status_code_matches(*status_code, &entry.codes) {
return (entry.strategy, entry.max_attempts);
}
}
// No specific match — use defaults
(
retry_policy.default_strategy,
retry_policy.default_max_attempts,
)
}
ErrorClassification::TimeoutError { .. } => match &retry_policy.on_timeout {
Some(timeout_config) => (timeout_config.strategy, timeout_config.max_attempts),
None => (
retry_policy.default_strategy,
retry_policy.default_max_attempts,
),
},
ErrorClassification::HighLatencyEvent { .. } => {
match &retry_policy.on_high_latency {
Some(hl_config) => (hl_config.strategy, hl_config.max_attempts),
// Shouldn't happen (HighLatencyEvent only created when config exists),
// but fall back to defaults for safety.
None => (
retry_policy.default_strategy,
retry_policy.default_max_attempts,
),
}
}
// Success and NonRetriableError should not be passed here,
// but return defaults as a safe fallback.
_ => (
retry_policy.default_strategy,
retry_policy.default_max_attempts,
),
}
}
// ── Private helpers ────────────────────────────────────────────────────
fn classify_http_response(
&self,
response: HttpResponse,
retry_policy: &RetryPolicy,
elapsed_ttfb_ms: u64,
elapsed_total_ms: u64,
) -> ErrorClassification {
let status = response.status().as_u16();
// 2xx → check for high latency, otherwise Success
if (200..300).contains(&status) {
// If on_high_latency is configured, check if the response was slow
if let Some(hl_config) = &retry_policy.on_high_latency {
let measured_ms = match hl_config.measure {
LatencyMeasure::Ttfb => elapsed_ttfb_ms,
LatencyMeasure::Total => elapsed_total_ms,
};
if measured_ms > hl_config.threshold_ms {
return ErrorClassification::HighLatencyEvent {
measured_ms,
threshold_ms: hl_config.threshold_ms,
measure: hl_config.measure,
response: Some(response), // completed-but-slow: include the response
};
}
}
return ErrorClassification::Success(response);
}
// Check if this status code is retriable (4xx or 5xx)
let is_4xx = (400..500).contains(&status);
let is_5xx = (500..600).contains(&status);
if is_4xx || is_5xx {
// Check if it matches any on_status_codes entry, OR fall back to
// default handling for all 4xx/5xx when retry_policy exists.
let has_specific_match = retry_policy
.on_status_codes
.iter()
.any(|entry| status_code_matches(status, &entry.codes));
if has_specific_match || is_4xx || is_5xx {
// Extract Retry-After header (P1 will use this; capture it now)
let retry_after_seconds = extract_retry_after(&response);
// We need the response body for the error record.
// Since we can't easily consume the body from a BoxBody synchronously,
// store an empty body for now — the orchestrator will handle body capture.
return ErrorClassification::RetriableError {
status_code: status,
retry_after_seconds,
response_body: Vec::new(),
};
}
}
// Non-2xx, non-4xx, non-5xx (e.g. 3xx, 1xx) → NonRetriableError
ErrorClassification::NonRetriableError(response)
}
}
// ── Free functions ─────────────────────────────────────────────────────────
/// Check if a status code matches any entry in a codes list.
fn status_code_matches(status: u16, codes: &[StatusCodeEntry]) -> bool {
for entry in codes {
match entry.expand() {
Ok(expanded) => {
if expanded.contains(&status) {
return true;
}
}
Err(_) => continue, // Skip malformed ranges
}
}
false
}
/// Extract the Retry-After header value as seconds.
/// Parses integer seconds only; ignores malformed values.
fn extract_retry_after(response: &HttpResponse) -> Option<u64> {
response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.trim().parse::<u64>().ok())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::configuration::{StatusCodeConfig, TimeoutRetryConfig};
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
/// Helper to build an HttpResponse with a given status code.
fn make_response(status: u16) -> HttpResponse {
make_response_with_headers(status, vec![])
}
/// Helper to build an HttpResponse with a given status code and headers.
fn make_response_with_headers(status: u16, headers: Vec<(&str, &str)>) -> HttpResponse {
let body = Full::new(Bytes::from("test body"))
.map_err(|_| unreachable!())
.boxed();
let mut builder = Response::builder().status(status);
for (name, value) in headers {
builder = builder.header(name, value);
}
builder.body(body).unwrap()
}
fn basic_retry_policy() -> RetryPolicy {
RetryPolicy {
fallback_models: vec![],
default_strategy: RetryStrategy::DifferentProvider,
default_max_attempts: 2,
on_status_codes: vec![
StatusCodeConfig {
codes: vec![StatusCodeEntry::Single(429)],
strategy: RetryStrategy::SameProvider,
max_attempts: 3,
},
StatusCodeConfig {
codes: vec![StatusCodeEntry::Single(503)],
strategy: RetryStrategy::DifferentProvider,
max_attempts: 4,
},
],
on_timeout: Some(TimeoutRetryConfig {
strategy: RetryStrategy::DifferentProvider,
max_attempts: 2,
}),
on_high_latency: None,
backoff: None,
retry_after_handling: None,
max_retry_duration_ms: None,
}
}
// ── classify tests ─────────────────────────────────────────────────
#[test]
fn classify_2xx_returns_success() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let resp = make_response(200);
let result = detector.classify(Ok(resp), &policy, 0, 0);
assert!(matches!(result, ErrorClassification::Success(_)));
}
#[test]
fn classify_201_returns_success() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let resp = make_response(201);
let result = detector.classify(Ok(resp), &policy, 0, 0);
assert!(matches!(result, ErrorClassification::Success(_)));
}
#[test]
fn classify_429_returns_retriable_error() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let resp = make_response(429);
let result = detector.classify(Ok(resp), &policy, 0, 0);
match result {
ErrorClassification::RetriableError { status_code, .. } => {
assert_eq!(status_code, 429);
}
other => panic!("Expected RetriableError, got {:?}", other),
}
}
#[test]
fn classify_503_returns_retriable_error() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let resp = make_response(503);
let result = detector.classify(Ok(resp), &policy, 0, 0);
match result {
ErrorClassification::RetriableError { status_code, .. } => {
assert_eq!(status_code, 503);
}
other => panic!("Expected RetriableError, got {:?}", other),
}
}
#[test]
fn classify_unconfigured_4xx_returns_retriable_with_defaults() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let resp = make_response(400);
let result = detector.classify(Ok(resp), &policy, 0, 0);
match result {
ErrorClassification::RetriableError { status_code, .. } => {
assert_eq!(status_code, 400);
}
other => panic!(
"Expected RetriableError for unconfigured 4xx, got {:?}",
other
),
}
}
#[test]
fn classify_unconfigured_5xx_returns_retriable_with_defaults() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let resp = make_response(502);
let result = detector.classify(Ok(resp), &policy, 0, 0);
match result {
ErrorClassification::RetriableError { status_code, .. } => {
assert_eq!(status_code, 502);
}
other => panic!(
"Expected RetriableError for unconfigured 5xx, got {:?}",
other
),
}
}
#[test]
fn classify_3xx_returns_non_retriable() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let resp = make_response(301);
let result = detector.classify(Ok(resp), &policy, 0, 0);
assert!(matches!(result, ErrorClassification::NonRetriableError(_)));
}
#[test]
fn classify_1xx_returns_non_retriable() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let resp = make_response(100);
let result = detector.classify(Ok(resp), &policy, 0, 0);
assert!(matches!(result, ErrorClassification::NonRetriableError(_)));
}
#[test]
fn classify_timeout_returns_timeout_error() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let timeout = TimeoutError { duration_ms: 5000 };
let result = detector.classify(Err(timeout), &policy, 0, 0);
match result {
ErrorClassification::TimeoutError { duration_ms } => {
assert_eq!(duration_ms, 5000);
}
other => panic!("Expected TimeoutError, got {:?}", other),
}
}
#[test]
fn classify_extracts_retry_after_header() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let resp = make_response_with_headers(429, vec![("retry-after", "120")]);
let result = detector.classify(Ok(resp), &policy, 0, 0);
match result {
ErrorClassification::RetriableError {
retry_after_seconds,
..
} => {
assert_eq!(retry_after_seconds, Some(120));
}
other => panic!("Expected RetriableError, got {:?}", other),
}
}
#[test]
fn classify_ignores_malformed_retry_after() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let resp = make_response_with_headers(429, vec![("retry-after", "not-a-number")]);
let result = detector.classify(Ok(resp), &policy, 0, 0);
match result {
ErrorClassification::RetriableError {
retry_after_seconds,
..
} => {
assert_eq!(retry_after_seconds, None);
}
other => panic!("Expected RetriableError, got {:?}", other),
}
}
#[test]
fn classify_status_code_range() {
let detector = ErrorDetector;
let policy = RetryPolicy {
on_status_codes: vec![StatusCodeConfig {
codes: vec![StatusCodeEntry::Range("500-504".to_string())],
strategy: RetryStrategy::DifferentProvider,
max_attempts: 3,
}],
..basic_retry_policy()
};
// 502 is within the range
let resp = make_response(502);
let result = detector.classify(Ok(resp), &policy, 0, 0);
match result {
ErrorClassification::RetriableError { status_code, .. } => {
assert_eq!(status_code, 502);
}
other => panic!("Expected RetriableError, got {:?}", other),
}
}
// ── resolve_retry_params tests ─────────────────────────────────────
#[test]
fn resolve_params_for_configured_status_code() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let classification = ErrorClassification::RetriableError {
status_code: 429,
retry_after_seconds: None,
response_body: vec![],
};
let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy);
assert_eq!(strategy, RetryStrategy::SameProvider);
assert_eq!(max_attempts, 3);
}
#[test]
fn resolve_params_for_unconfigured_status_code_uses_defaults() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let classification = ErrorClassification::RetriableError {
status_code: 400,
retry_after_seconds: None,
response_body: vec![],
};
let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy);
assert_eq!(strategy, RetryStrategy::DifferentProvider);
assert_eq!(max_attempts, 2);
}
#[test]
fn resolve_params_for_timeout_with_config() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let classification = ErrorClassification::TimeoutError { duration_ms: 5000 };
let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy);
assert_eq!(strategy, RetryStrategy::DifferentProvider);
assert_eq!(max_attempts, 2);
}
#[test]
fn resolve_params_for_timeout_without_config_uses_defaults() {
let detector = ErrorDetector;
let mut policy = basic_retry_policy();
policy.on_timeout = None;
let classification = ErrorClassification::TimeoutError { duration_ms: 5000 };
let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy);
assert_eq!(strategy, RetryStrategy::DifferentProvider);
assert_eq!(max_attempts, 2);
}
#[test]
fn resolve_params_for_high_latency_with_config() {
let detector = ErrorDetector;
let mut policy = basic_retry_policy();
policy.on_high_latency = Some(crate::configuration::HighLatencyConfig {
threshold_ms: 5000,
measure: LatencyMeasure::Ttfb,
min_triggers: 1,
trigger_window_seconds: None,
strategy: RetryStrategy::SameProvider,
max_attempts: 5,
block_duration_seconds: 300,
scope: crate::configuration::BlockScope::Model,
apply_to: crate::configuration::ApplyTo::Global,
});
let classification = ErrorClassification::HighLatencyEvent {
measured_ms: 6000,
threshold_ms: 5000,
measure: LatencyMeasure::Ttfb,
response: None,
};
let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy);
assert_eq!(strategy, RetryStrategy::SameProvider);
assert_eq!(max_attempts, 5);
}
#[test]
fn resolve_params_for_success_returns_defaults() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let resp = make_response(200);
let classification = ErrorClassification::Success(resp);
let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy);
// Shouldn't normally be called for Success, but returns defaults safely
assert_eq!(strategy, RetryStrategy::DifferentProvider);
assert_eq!(max_attempts, 2);
}
#[test]
fn resolve_params_second_on_status_codes_entry() {
let detector = ErrorDetector;
let policy = basic_retry_policy();
let classification = ErrorClassification::RetriableError {
status_code: 503,
retry_after_seconds: None,
response_body: vec![],
};
let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy);
assert_eq!(strategy, RetryStrategy::DifferentProvider);
assert_eq!(max_attempts, 4);
}
// ── High latency classification tests ─────────────────────────────
fn high_latency_retry_policy(threshold_ms: u64, measure: LatencyMeasure) -> RetryPolicy {
let mut policy = basic_retry_policy();
policy.on_high_latency = Some(crate::configuration::HighLatencyConfig {
threshold_ms,
measure,
min_triggers: 1,
trigger_window_seconds: None,
strategy: RetryStrategy::DifferentProvider,
max_attempts: 2,
block_duration_seconds: 300,
scope: crate::configuration::BlockScope::Model,
apply_to: crate::configuration::ApplyTo::Global,
});
policy
}
#[test]
fn classify_2xx_high_latency_ttfb_returns_high_latency_event() {
let detector = ErrorDetector;
let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb);
let resp = make_response(200);
// TTFB = 6000ms exceeds threshold of 5000ms
let result = detector.classify(Ok(resp), &policy, 6000, 7000);
match result {
ErrorClassification::HighLatencyEvent {
measured_ms,
threshold_ms,
measure,
response,
} => {
assert_eq!(measured_ms, 6000);
assert_eq!(threshold_ms, 5000);
assert_eq!(measure, LatencyMeasure::Ttfb);
assert!(response.is_some(), "Completed response should be present");
}
other => panic!("Expected HighLatencyEvent, got {:?}", other),
}
}
#[test]
fn classify_2xx_high_latency_total_returns_high_latency_event() {
let detector = ErrorDetector;
let policy = high_latency_retry_policy(5000, LatencyMeasure::Total);
let resp = make_response(200);
// Total = 8000ms exceeds threshold, TTFB = 3000ms does not
let result = detector.classify(Ok(resp), &policy, 3000, 8000);
match result {
ErrorClassification::HighLatencyEvent {
measured_ms,
threshold_ms,
measure,
..
} => {
assert_eq!(measured_ms, 8000);
assert_eq!(threshold_ms, 5000);
assert_eq!(measure, LatencyMeasure::Total);
}
other => panic!("Expected HighLatencyEvent, got {:?}", other),
}
}
#[test]
fn classify_2xx_below_threshold_returns_success() {
let detector = ErrorDetector;
let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb);
let resp = make_response(200);
// TTFB = 3000ms is below threshold of 5000ms
let result = detector.classify(Ok(resp), &policy, 3000, 4000);
assert!(matches!(result, ErrorClassification::Success(_)));
}
#[test]
fn classify_2xx_at_threshold_returns_success() {
let detector = ErrorDetector;
let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb);
let resp = make_response(200);
// TTFB = 5000ms equals threshold — not exceeded
let result = detector.classify(Ok(resp), &policy, 5000, 6000);
assert!(matches!(result, ErrorClassification::Success(_)));
}
#[test]
fn classify_2xx_no_high_latency_config_returns_success() {
let detector = ErrorDetector;
let policy = basic_retry_policy(); // no on_high_latency
let resp = make_response(200);
// High latency values but no config → Success
let result = detector.classify(Ok(resp), &policy, 99999, 99999);
assert!(matches!(result, ErrorClassification::Success(_)));
}
#[test]
fn classify_timeout_takes_priority_over_high_latency() {
let detector = ErrorDetector;
let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb);
let timeout = TimeoutError { duration_ms: 10000 };
// Even with high latency config, timeout returns TimeoutError
let result = detector.classify(Err(timeout), &policy, 10000, 10000);
match result {
ErrorClassification::TimeoutError { duration_ms } => {
assert_eq!(duration_ms, 10000);
}
other => panic!("Expected TimeoutError, got {:?}", other),
}
}
#[test]
fn classify_4xx_not_affected_by_high_latency() {
let detector = ErrorDetector;
let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb);
let resp = make_response(429);
// Even with high latency, 4xx is still RetriableError
let result = detector.classify(Ok(resp), &policy, 6000, 7000);
assert!(matches!(
result,
ErrorClassification::RetriableError {
status_code: 429,
..
}
));
}
// ── P2 Edge Case: measure-specific classification tests ────────────
#[test]
fn classify_ttfb_measure_triggers_on_slow_ttfb_even_if_total_is_fast() {
let detector = ErrorDetector;
// measure: ttfb, threshold: 5000ms
let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb);
let resp = make_response(200);
// TTFB = 6000ms exceeds threshold, but total = 4000ms is below threshold
let result = detector.classify(Ok(resp), &policy, 6000, 4000);
match result {
ErrorClassification::HighLatencyEvent {
measured_ms,
threshold_ms,
measure,
response,
} => {
assert_eq!(measured_ms, 6000, "Should measure TTFB, not total");
assert_eq!(threshold_ms, 5000);
assert_eq!(measure, LatencyMeasure::Ttfb);
assert!(response.is_some(), "Completed response should be present");
}
other => panic!("Expected HighLatencyEvent for slow TTFB, got {:?}", other),
}
}
#[test]
fn classify_total_measure_does_not_trigger_when_only_ttfb_is_slow() {
let detector = ErrorDetector;
// measure: total, threshold: 5000ms
let policy = high_latency_retry_policy(5000, LatencyMeasure::Total);
let resp = make_response(200);
// TTFB = 8000ms is slow, but total = 4000ms is below threshold
// With measure: "total", only total time matters
let result = detector.classify(Ok(resp), &policy, 8000, 4000);
assert!(
matches!(result, ErrorClassification::Success(_)),
"measure: total should NOT trigger when only TTFB is slow but total is below threshold, got {:?}",
result
);
}
#[test]
fn classify_total_measure_triggers_on_slow_total_even_if_ttfb_is_fast() {
let detector = ErrorDetector;
// measure: total, threshold: 5000ms
let policy = high_latency_retry_policy(5000, LatencyMeasure::Total);
let resp = make_response(200);
// TTFB = 1000ms is fast, total = 7000ms exceeds threshold
let result = detector.classify(Ok(resp), &policy, 1000, 7000);
match result {
ErrorClassification::HighLatencyEvent {
measured_ms,
threshold_ms,
measure,
response,
} => {
assert_eq!(measured_ms, 7000, "Should measure total, not TTFB");
assert_eq!(threshold_ms, 5000);
assert_eq!(measure, LatencyMeasure::Total);
assert!(response.is_some(), "Completed response should be present");
}
other => panic!("Expected HighLatencyEvent for slow total, got {:?}", other),
}
}
// ── Property-based tests ───────────────────────────────────────────
use proptest::prelude::*;
/// Generate an arbitrary RetryStrategy.
fn arb_retry_strategy() -> impl Strategy<Value = RetryStrategy> {
prop_oneof![
Just(RetryStrategy::SameModel),
Just(RetryStrategy::SameProvider),
Just(RetryStrategy::DifferentProvider),
]
}
/// Generate an arbitrary StatusCodeEntry (single code in 100-599).
fn arb_status_code_entry() -> impl Strategy<Value = StatusCodeEntry> {
(100u16..=599u16).prop_map(StatusCodeEntry::Single)
}
/// Generate an arbitrary StatusCodeConfig with 1-5 single status code entries.
fn arb_status_code_config() -> impl Strategy<Value = StatusCodeConfig> {
(
proptest::collection::vec(arb_status_code_entry(), 1..=5),
arb_retry_strategy(),
1u32..=10u32,
)
.prop_map(|(codes, strategy, max_attempts)| StatusCodeConfig {
codes,
strategy,
max_attempts,
})
}
/// Generate an arbitrary RetryPolicy with 0-3 on_status_codes entries.
fn arb_retry_policy() -> impl Strategy<Value = RetryPolicy> {
(
arb_retry_strategy(),
1u32..=10u32,
proptest::collection::vec(arb_status_code_config(), 0..=3),
)
.prop_map(
|(default_strategy, default_max_attempts, on_status_codes)| RetryPolicy {
fallback_models: vec![],
default_strategy,
default_max_attempts,
on_status_codes,
on_timeout: None,
on_high_latency: None,
backoff: None,
retry_after_handling: None,
max_retry_duration_ms: None,
},
)
}
// Feature: retry-on-ratelimit, Property 5: Error Classification Correctness
// **Validates: Requirements 1.2**
proptest! {
#![proptest_config(proptest::prelude::ProptestConfig::with_cases(100))]
/// Property 5: For any status code in 100-599 and any RetryPolicy,
/// classify() returns the correct variant:
/// 2xx → Success
/// 4xx/5xx → RetriableError with matching status_code
/// 1xx/3xx → NonRetriableError
#[test]
fn prop_error_classification_correctness(
status_code in 100u16..=599u16,
policy in arb_retry_policy(),
) {
let detector = ErrorDetector;
let resp = make_response(status_code);
let result = detector.classify(Ok(resp), &policy, 0, 0);
match status_code {
200..=299 => {
prop_assert!(
matches!(result, ErrorClassification::Success(_)),
"Expected Success for status {}, got {:?}", status_code, result
);
}
400..=499 | 500..=599 => {
match &result {
ErrorClassification::RetriableError { status_code: sc, .. } => {
prop_assert_eq!(
*sc, status_code,
"RetriableError status_code mismatch: expected {}, got {}", status_code, sc
);
}
other => {
prop_assert!(false, "Expected RetriableError for status {}, got {:?}", status_code, other);
}
}
}
100..=199 | 300..=399 => {
prop_assert!(
matches!(result, ErrorClassification::NonRetriableError(_)),
"Expected NonRetriableError for status {}, got {:?}", status_code, result
);
}
_ => {
// Should not happen given our range 100-599
prop_assert!(false, "Unexpected status code: {}", status_code);
}
}
}
}
// Feature: retry-on-ratelimit, Property 17: Timeout vs High Latency Precedence
// **Validates: Requirements 2.13, 2.14, 2.15, 2a.19, 2a.20**
proptest! {
#![proptest_config(proptest::prelude::ProptestConfig::with_cases(100))]
/// Property 17: When both on_high_latency and on_timeout are configured:
/// - Timeout (Err) → always TimeoutError regardless of latency config
/// - Completed 2xx exceeding threshold → HighLatencyEvent with response present
/// - Completed 2xx below/at threshold → Success
#[test]
fn prop_timeout_vs_high_latency_precedence(
threshold_ms in 1u64..=30_000u64,
elapsed_ttfb_ms in 0u64..=60_000u64,
elapsed_total_ms in 0u64..=60_000u64,
timeout_duration_ms in 1u64..=60_000u64,
measure_is_ttfb in proptest::bool::ANY,
// 0 = timeout scenario, 1 = completed-above-threshold, 2 = completed-below-threshold
scenario in 0u8..=2u8,
) {
let measure = if measure_is_ttfb { LatencyMeasure::Ttfb } else { LatencyMeasure::Total };
let mut policy = basic_retry_policy();
policy.on_high_latency = Some(crate::configuration::HighLatencyConfig {
threshold_ms,
measure,
min_triggers: 1,
trigger_window_seconds: None,
strategy: RetryStrategy::DifferentProvider,
max_attempts: 2,
block_duration_seconds: 300,
scope: crate::configuration::BlockScope::Model,
apply_to: crate::configuration::ApplyTo::Global,
});
// Ensure on_timeout is configured
policy.on_timeout = Some(TimeoutRetryConfig {
strategy: RetryStrategy::DifferentProvider,
max_attempts: 2,
});
let detector = ErrorDetector;
match scenario {
0 => {
// Timeout scenario: Err(TimeoutError) → always TimeoutError
let timeout = TimeoutError { duration_ms: timeout_duration_ms };
let result = detector.classify(Err(timeout), &policy, elapsed_ttfb_ms, elapsed_total_ms);
match result {
ErrorClassification::TimeoutError { duration_ms } => {
prop_assert_eq!(duration_ms, timeout_duration_ms,
"TimeoutError duration should match input");
}
other => {
prop_assert!(false,
"Timeout should always produce TimeoutError, got {:?}", other);
}
}
}
1 => {
// Completed 2xx with latency ABOVE threshold → HighLatencyEvent
// Force the measured value to exceed threshold
let forced_ttfb = if measure_is_ttfb { threshold_ms + 1 + (elapsed_ttfb_ms % 30_000) } else { elapsed_ttfb_ms };
let forced_total = if !measure_is_ttfb { threshold_ms + 1 + (elapsed_total_ms % 30_000) } else { elapsed_total_ms };
let resp = make_response(200);
let result = detector.classify(Ok(resp), &policy, forced_ttfb, forced_total);
match result {
ErrorClassification::HighLatencyEvent {
measured_ms: actual_ms,
threshold_ms: actual_threshold,
measure: actual_measure,
response,
} => {
let expected_measured = if measure_is_ttfb { forced_ttfb } else { forced_total };
prop_assert_eq!(actual_ms, expected_measured,
"HighLatencyEvent measured_ms should match the selected measure");
prop_assert_eq!(actual_threshold, threshold_ms,
"HighLatencyEvent threshold_ms should match config");
prop_assert_eq!(actual_measure, measure,
"HighLatencyEvent measure should match config");
prop_assert!(response.is_some(),
"Completed response should be present in HighLatencyEvent");
}
other => {
prop_assert!(false,
"Completed 2xx above threshold should produce HighLatencyEvent, got {:?}", other);
}
}
}
2 => {
// Completed 2xx with latency AT or BELOW threshold → Success
// Force the measured value to be at or below threshold
let forced_ttfb = if measure_is_ttfb { threshold_ms.min(elapsed_ttfb_ms) } else { elapsed_ttfb_ms };
let forced_total = if !measure_is_ttfb { threshold_ms.min(elapsed_total_ms) } else { elapsed_total_ms };
let resp = make_response(200);
let result = detector.classify(Ok(resp), &policy, forced_ttfb, forced_total);
prop_assert!(
matches!(result, ErrorClassification::Success(_)),
"Completed 2xx at/below threshold should be Success, got {:?}", result
);
}
_ => {} // unreachable given range 0..=2
}
}
}
}

View file

@ -0,0 +1,611 @@
use bytes::Bytes;
use http_body_util::Full;
use hyper::header::HeaderValue;
use hyper::Response;
use serde_json::json;
use super::{AttemptErrorType, RetryExhaustedError};
/// Build an HTTP response from a `RetryExhaustedError`.
///
/// The response body is a JSON object matching the design's error response format.
/// The HTTP status code is derived from the most recent attempt's error:
/// - For `HttpError`: the upstream status code
/// - For `Timeout` or `HighLatency`: 504 Gateway Timeout
///
/// The `request_id` is preserved in the `x-request-id` response header.
///
/// Optional fields `observed_max_retry_after_seconds` and
/// `shortest_remaining_block_seconds` are included only when their
/// corresponding values are `Some`.
pub fn build_error_response(
error: &RetryExhaustedError,
request_id: &str,
) -> Response<Full<Bytes>> {
let status_code = determine_status_code(error);
let attempts_json: Vec<serde_json::Value> = error
.attempts
.iter()
.map(|a| {
let error_type_str = match &a.error_type {
AttemptErrorType::HttpError { status_code, .. } => {
format!("http_{}", status_code)
}
AttemptErrorType::Timeout { duration_ms } => {
format!("timeout_{}ms", duration_ms)
}
AttemptErrorType::HighLatency {
measured_ms,
threshold_ms,
} => {
format!(
"high_latency_{}ms_threshold_{}ms",
measured_ms, threshold_ms
)
}
};
json!({
"model": a.model_id,
"error_type": error_type_str,
"attempt": a.attempt_number,
})
})
.collect();
let message = build_message(error);
let mut error_obj = serde_json::Map::new();
error_obj.insert("message".to_string(), json!(message));
error_obj.insert("type".to_string(), json!("retry_exhausted"));
error_obj.insert("attempts".to_string(), json!(attempts_json));
error_obj.insert("total_attempts".to_string(), json!(error.attempts.len()));
if let Some(max_ra) = error.max_retry_after_seconds {
error_obj.insert(
"observed_max_retry_after_seconds".to_string(),
json!(max_ra),
);
}
if let Some(shortest) = error.shortest_remaining_block_seconds {
error_obj.insert(
"shortest_remaining_block_seconds".to_string(),
json!(shortest),
);
}
error_obj.insert(
"retry_budget_exhausted".to_string(),
json!(error.retry_budget_exhausted),
);
let body_json = json!({ "error": error_obj });
let body_bytes = serde_json::to_vec(&body_json).unwrap_or_default();
let mut response = Response::builder()
.status(status_code)
.header("content-type", "application/json")
.body(Full::new(Bytes::from(body_bytes)))
.unwrap();
if let Ok(val) = HeaderValue::from_str(request_id) {
response.headers_mut().insert("x-request-id", val);
}
response
}
/// Determine the HTTP status code from the most recent attempt error.
/// Returns 504 for timeouts and high latency exhaustion, otherwise the
/// upstream HTTP status code. Falls back to 502 if no attempts exist.
fn determine_status_code(error: &RetryExhaustedError) -> u16 {
match error.attempts.last() {
Some(last) => match &last.error_type {
AttemptErrorType::HttpError { status_code, .. } => *status_code,
AttemptErrorType::Timeout { .. } => 504,
AttemptErrorType::HighLatency { .. } => 504,
},
None => 502,
}
}
/// Build a human-readable message describing the exhaustion cause.
fn build_message(error: &RetryExhaustedError) -> String {
if error.retry_budget_exhausted {
return "All retry attempts exhausted: retry budget exceeded".to_string();
}
match error.attempts.last() {
Some(last) => match &last.error_type {
AttemptErrorType::Timeout { .. } => {
"All retry attempts exhausted: upstream request timed out".to_string()
}
AttemptErrorType::HighLatency { .. } => {
"All retry attempts exhausted: upstream high latency detected".to_string()
}
_ => "All retry attempts exhausted".to_string(),
},
None => "All retry attempts exhausted".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::retry::{AttemptError, AttemptErrorType, RetryExhaustedError};
use http_body_util::BodyExt;
use proptest::prelude::*;
/// Helper to extract the JSON body from a response.
async fn response_json(resp: Response<Full<Bytes>>) -> serde_json::Value {
let body = resp.into_body().collect().await.unwrap().to_bytes();
serde_json::from_slice(&body).unwrap()
}
#[tokio::test]
async fn test_basic_http_error_response() {
let error = RetryExhaustedError {
attempts: vec![
AttemptError {
model_id: "openai/gpt-4o".to_string(),
error_type: AttemptErrorType::HttpError {
status_code: 429,
body: b"rate limited".to_vec(),
},
attempt_number: 1,
},
AttemptError {
model_id: "anthropic/claude-3-5-sonnet".to_string(),
error_type: AttemptErrorType::HttpError {
status_code: 503,
body: b"unavailable".to_vec(),
},
attempt_number: 2,
},
],
max_retry_after_seconds: Some(30),
shortest_remaining_block_seconds: Some(12),
retry_budget_exhausted: false,
};
let resp = build_error_response(&error, "req-123");
assert_eq!(resp.status().as_u16(), 503); // most recent error
assert_eq!(
resp.headers()
.get("x-request-id")
.unwrap()
.to_str()
.unwrap(),
"req-123"
);
assert_eq!(
resp.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap(),
"application/json"
);
let json = response_json(resp).await;
let err = &json["error"];
assert_eq!(err["type"], "retry_exhausted");
assert_eq!(err["total_attempts"], 2);
assert_eq!(err["observed_max_retry_after_seconds"], 30);
assert_eq!(err["shortest_remaining_block_seconds"], 12);
assert_eq!(err["retry_budget_exhausted"], false);
let attempts = err["attempts"].as_array().unwrap();
assert_eq!(attempts.len(), 2);
assert_eq!(attempts[0]["model"], "openai/gpt-4o");
assert_eq!(attempts[0]["error_type"], "http_429");
assert_eq!(attempts[0]["attempt"], 1);
assert_eq!(attempts[1]["model"], "anthropic/claude-3-5-sonnet");
assert_eq!(attempts[1]["error_type"], "http_503");
assert_eq!(attempts[1]["attempt"], 2);
}
#[tokio::test]
async fn test_timeout_returns_504() {
let error = RetryExhaustedError {
attempts: vec![AttemptError {
model_id: "openai/gpt-4o".to_string(),
error_type: AttemptErrorType::Timeout { duration_ms: 30000 },
attempt_number: 1,
}],
max_retry_after_seconds: None,
shortest_remaining_block_seconds: None,
retry_budget_exhausted: false,
};
let resp = build_error_response(&error, "req-timeout");
assert_eq!(resp.status().as_u16(), 504);
let json = response_json(resp).await;
let err = &json["error"];
assert_eq!(err["attempts"][0]["error_type"], "timeout_30000ms");
assert!(err["message"].as_str().unwrap().contains("timed out"));
}
#[tokio::test]
async fn test_high_latency_returns_504() {
let error = RetryExhaustedError {
attempts: vec![AttemptError {
model_id: "openai/gpt-4o".to_string(),
error_type: AttemptErrorType::HighLatency {
measured_ms: 8000,
threshold_ms: 5000,
},
attempt_number: 1,
}],
max_retry_after_seconds: None,
shortest_remaining_block_seconds: None,
retry_budget_exhausted: false,
};
let resp = build_error_response(&error, "req-latency");
assert_eq!(resp.status().as_u16(), 504);
let json = response_json(resp).await;
let err = &json["error"];
assert_eq!(
err["attempts"][0]["error_type"],
"high_latency_8000ms_threshold_5000ms"
);
assert!(err["message"].as_str().unwrap().contains("high latency"));
}
#[tokio::test]
async fn test_optional_fields_omitted_when_none() {
let error = RetryExhaustedError {
attempts: vec![AttemptError {
model_id: "openai/gpt-4o".to_string(),
error_type: AttemptErrorType::HttpError {
status_code: 429,
body: vec![],
},
attempt_number: 1,
}],
max_retry_after_seconds: None,
shortest_remaining_block_seconds: None,
retry_budget_exhausted: false,
};
let resp = build_error_response(&error, "req-456");
let json = response_json(resp).await;
let err = &json["error"];
// These fields should not be present
assert!(err.get("observed_max_retry_after_seconds").is_none());
assert!(err.get("shortest_remaining_block_seconds").is_none());
// These should always be present
assert!(err.get("retry_budget_exhausted").is_some());
assert!(err.get("total_attempts").is_some());
assert!(err.get("type").is_some());
assert!(err.get("message").is_some());
assert!(err.get("attempts").is_some());
}
#[tokio::test]
async fn test_retry_budget_exhausted_message() {
let error = RetryExhaustedError {
attempts: vec![AttemptError {
model_id: "openai/gpt-4o".to_string(),
error_type: AttemptErrorType::HttpError {
status_code: 429,
body: vec![],
},
attempt_number: 1,
}],
max_retry_after_seconds: None,
shortest_remaining_block_seconds: None,
retry_budget_exhausted: true,
};
let resp = build_error_response(&error, "req-budget");
let json = response_json(resp).await;
let err = &json["error"];
assert_eq!(err["retry_budget_exhausted"], true);
assert!(err["message"].as_str().unwrap().contains("budget exceeded"));
}
#[tokio::test]
async fn test_empty_attempts_returns_502() {
let error = RetryExhaustedError {
attempts: vec![],
max_retry_after_seconds: None,
shortest_remaining_block_seconds: None,
retry_budget_exhausted: false,
};
let resp = build_error_response(&error, "req-empty");
assert_eq!(resp.status().as_u16(), 502);
let json = response_json(resp).await;
assert_eq!(json["error"]["total_attempts"], 0);
assert_eq!(json["error"]["attempts"].as_array().unwrap().len(), 0);
}
#[tokio::test]
async fn test_request_id_preserved_in_header() {
let error = RetryExhaustedError {
attempts: vec![AttemptError {
model_id: "m".to_string(),
error_type: AttemptErrorType::HttpError {
status_code: 500,
body: vec![],
},
attempt_number: 1,
}],
max_retry_after_seconds: None,
shortest_remaining_block_seconds: None,
retry_budget_exhausted: false,
};
let resp = build_error_response(&error, "unique-request-id-abc-123");
assert_eq!(
resp.headers()
.get("x-request-id")
.unwrap()
.to_str()
.unwrap(),
"unique-request-id-abc-123"
);
}
#[tokio::test]
async fn test_mixed_error_types_in_attempts() {
let error = RetryExhaustedError {
attempts: vec![
AttemptError {
model_id: "openai/gpt-4o".to_string(),
error_type: AttemptErrorType::HttpError {
status_code: 429,
body: vec![],
},
attempt_number: 1,
},
AttemptError {
model_id: "anthropic/claude".to_string(),
error_type: AttemptErrorType::Timeout { duration_ms: 5000 },
attempt_number: 2,
},
AttemptError {
model_id: "gemini/pro".to_string(),
error_type: AttemptErrorType::HighLatency {
measured_ms: 10000,
threshold_ms: 3000,
},
attempt_number: 3,
},
],
max_retry_after_seconds: Some(60),
shortest_remaining_block_seconds: Some(5),
retry_budget_exhausted: false,
};
// Last attempt is HighLatency → 504
let resp = build_error_response(&error, "req-mixed");
assert_eq!(resp.status().as_u16(), 504);
let json = response_json(resp).await;
let err = &json["error"];
assert_eq!(err["total_attempts"], 3);
assert_eq!(err["observed_max_retry_after_seconds"], 60);
assert_eq!(err["shortest_remaining_block_seconds"], 5);
let attempts = err["attempts"].as_array().unwrap();
assert_eq!(attempts[0]["error_type"], "http_429");
assert_eq!(attempts[1]["error_type"], "timeout_5000ms");
assert_eq!(
attempts[2]["error_type"],
"high_latency_10000ms_threshold_3000ms"
);
}
// ── Proptest strategies ────────────────────────────────────────────────
/// Generate an arbitrary AttemptErrorType.
fn arb_attempt_error_type() -> impl Strategy<Value = AttemptErrorType> {
prop_oneof![
(
100u16..=599u16,
proptest::collection::vec(any::<u8>(), 0..32)
)
.prop_map(|(status_code, body)| AttemptErrorType::HttpError { status_code, body }),
(1u64..=120_000u64).prop_map(|duration_ms| AttemptErrorType::Timeout { duration_ms }),
(1u64..=120_000u64, 1u64..=120_000u64).prop_map(|(measured_ms, threshold_ms)| {
AttemptErrorType::HighLatency {
measured_ms,
threshold_ms,
}
}),
]
}
/// Generate an arbitrary AttemptError with a model_id from a small set of
/// realistic provider/model identifiers.
fn arb_attempt_error() -> impl Strategy<Value = AttemptError> {
let model_ids = prop_oneof![
Just("openai/gpt-4o".to_string()),
Just("openai/gpt-4o-mini".to_string()),
Just("anthropic/claude-3-5-sonnet".to_string()),
Just("gemini/pro".to_string()),
Just("azure/gpt-4o".to_string()),
];
(model_ids, arb_attempt_error_type(), 1u32..=10u32).prop_map(
|(model_id, error_type, attempt_number)| AttemptError {
model_id,
error_type,
attempt_number,
},
)
}
/// Generate an arbitrary RetryExhaustedError with 1..=8 attempts.
fn arb_retry_exhausted_error() -> impl Strategy<Value = RetryExhaustedError> {
(
proptest::collection::vec(arb_attempt_error(), 1..=8),
proptest::option::of(1u64..=600u64),
proptest::option::of(1u64..=600u64),
any::<bool>(),
)
.prop_map(
|(
attempts,
max_retry_after_seconds,
shortest_remaining_block_seconds,
retry_budget_exhausted,
)| {
RetryExhaustedError {
attempts,
max_retry_after_seconds,
shortest_remaining_block_seconds,
retry_budget_exhausted,
}
},
)
}
/// Generate an arbitrary request_id (non-empty ASCII string valid for HTTP headers).
fn arb_request_id() -> impl Strategy<Value = String> {
"[a-zA-Z0-9_-]{1,64}"
}
// Feature: retry-on-ratelimit, Property 21: Error Response Contains Attempt Details
// **Validates: Requirements 10.4, 10.5, 10.7**
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
/// Property 21: For any exhausted retry sequence, the error response
/// must include all attempted model identifiers and their error types,
/// and must preserve the original request_id.
#[test]
fn prop_error_response_contains_attempt_details(
error in arb_retry_exhausted_error(),
request_id in arb_request_id(),
) {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let resp = build_error_response(&error, &request_id);
// request_id preserved in x-request-id header
let header_val = resp.headers().get("x-request-id")
.expect("x-request-id header must be present");
prop_assert_eq!(header_val.to_str().unwrap(), request_id.as_str());
// Content-Type is application/json
let ct = resp.headers().get("content-type")
.expect("content-type header must be present");
prop_assert_eq!(ct.to_str().unwrap(), "application/json");
// Parse JSON body
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body)
.expect("response body must be valid JSON");
let err_obj = &json["error"];
// type is always "retry_exhausted"
prop_assert_eq!(err_obj["type"].as_str().unwrap(), "retry_exhausted");
// total_attempts matches input
prop_assert_eq!(
err_obj["total_attempts"].as_u64().unwrap(),
error.attempts.len() as u64
);
// retry_budget_exhausted matches input
prop_assert_eq!(
err_obj["retry_budget_exhausted"].as_bool().unwrap(),
error.retry_budget_exhausted
);
// attempts array has correct length
let attempts_arr = err_obj["attempts"].as_array()
.expect("attempts must be an array");
prop_assert_eq!(attempts_arr.len(), error.attempts.len());
// Every attempt's model_id and error_type are present and correct
for (i, attempt) in error.attempts.iter().enumerate() {
let json_attempt = &attempts_arr[i];
// model_id preserved
prop_assert_eq!(
json_attempt["model"].as_str().unwrap(),
attempt.model_id.as_str()
);
// attempt_number preserved
prop_assert_eq!(
json_attempt["attempt"].as_u64().unwrap(),
attempt.attempt_number as u64
);
// error_type string matches the variant
let error_type_str = json_attempt["error_type"].as_str().unwrap();
match &attempt.error_type {
AttemptErrorType::HttpError { status_code, .. } => {
prop_assert_eq!(
error_type_str,
&format!("http_{}", status_code)
);
}
AttemptErrorType::Timeout { duration_ms } => {
prop_assert_eq!(
error_type_str,
&format!("timeout_{}ms", duration_ms)
);
}
AttemptErrorType::HighLatency { measured_ms, threshold_ms } => {
prop_assert_eq!(
error_type_str,
&format!("high_latency_{}ms_threshold_{}ms", measured_ms, threshold_ms)
);
}
}
}
// Optional fields: observed_max_retry_after_seconds
match error.max_retry_after_seconds {
Some(v) => {
prop_assert_eq!(
err_obj["observed_max_retry_after_seconds"].as_u64().unwrap(),
v
);
}
None => {
prop_assert!(err_obj.get("observed_max_retry_after_seconds").is_none()
|| err_obj["observed_max_retry_after_seconds"].is_null());
}
}
// Optional fields: shortest_remaining_block_seconds
match error.shortest_remaining_block_seconds {
Some(v) => {
prop_assert_eq!(
err_obj["shortest_remaining_block_seconds"].as_u64().unwrap(),
v
);
}
None => {
prop_assert!(err_obj.get("shortest_remaining_block_seconds").is_none()
|| err_obj["shortest_remaining_block_seconds"].is_null());
}
}
// message is a non-empty string
let message = err_obj["message"].as_str()
.expect("message must be a string");
prop_assert!(!message.is_empty());
Ok(())
})?;
}
}
}

View file

@ -0,0 +1,375 @@
use std::time::{Duration, Instant};
use dashmap::DashMap;
use log::info;
use crate::configuration::{extract_provider, BlockScope};
/// Thread-safe global state manager for latency-based blocking.
///
/// Blocks expire only via `block_duration_seconds` — successful requests
/// do NOT remove existing blocks. There is no `remove_block()` method.
///
/// This manager handles ONLY global state (`apply_to: "global"`).
/// Request-scoped state (`apply_to: "request"`) is stored in
/// `RequestContext.request_latency_block_state` and managed by the orchestrator.
///
/// Entries use max-expiration semantics: if a new block is recorded for an
/// identifier that already has an entry, the expiration is updated only if
/// the new expiration is later than the existing one.
pub struct LatencyBlockStateManager {
/// Global state: identifier (model ID or provider prefix) -> (expiration timestamp, measured_latency_ms)
global_state: DashMap<String, (Instant, u64)>,
}
impl LatencyBlockStateManager {
pub fn new() -> Self {
Self {
global_state: DashMap::new(),
}
}
/// Record a latency block after min_triggers threshold is met.
///
/// If an entry already exists for the identifier, updates only if the new
/// expiration is later than the existing one (max-expiration semantics).
/// The `measured_latency_ms` is always updated to the latest value when
/// the expiration is extended.
pub fn record_block(
&self,
identifier: &str,
block_duration_seconds: u64,
measured_latency_ms: u64,
) {
let new_expiration = Instant::now() + Duration::from_secs(block_duration_seconds);
self.global_state
.entry(identifier.to_string())
.and_modify(|existing| {
if new_expiration > existing.0 {
existing.0 = new_expiration;
existing.1 = measured_latency_ms;
}
})
.or_insert((new_expiration, measured_latency_ms));
}
/// Check if an identifier is currently blocked.
///
/// Lazily cleans up expired entries.
pub fn is_blocked(&self, identifier: &str) -> bool {
if let Some(entry) = self.global_state.get(identifier) {
if Instant::now() < entry.0 {
return true;
}
// Entry expired — drop the read guard before removing
drop(entry);
self.global_state.remove(identifier);
info!("Latency_Block_State expired: identifier={}", identifier);
info!("metric.latency_block_expired: model={}", identifier);
}
false
}
/// Get remaining block duration for an identifier, if blocked.
///
/// Returns `None` if the identifier is not blocked or the entry has expired.
/// Lazily cleans up expired entries.
pub fn remaining_block_duration(&self, identifier: &str) -> Option<Duration> {
if let Some(entry) = self.global_state.get(identifier) {
let now = Instant::now();
if now < entry.0 {
return Some(entry.0 - now);
}
// Entry expired — drop the read guard before removing
drop(entry);
self.global_state.remove(identifier);
info!("Latency_Block_State expired: identifier={}", identifier);
info!("metric.latency_block_expired: model={}", identifier);
}
None
}
/// Check if a model is blocked, considering scope (model or provider).
///
/// - `BlockScope::Model`: checks if the exact `model_id` is blocked.
/// - `BlockScope::Provider`: extracts the provider prefix from `model_id`
/// and checks if that prefix is blocked.
pub fn is_model_blocked(&self, model_id: &str, scope: BlockScope) -> bool {
match scope {
BlockScope::Model => self.is_blocked(model_id),
BlockScope::Provider => {
let provider = extract_provider(model_id);
self.is_blocked(provider)
}
}
}
}
impl Default for LatencyBlockStateManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_new_manager_has_no_blocks() {
let mgr = LatencyBlockStateManager::new();
assert!(!mgr.is_blocked("openai/gpt-4o"));
assert!(mgr.remaining_block_duration("openai/gpt-4o").is_none());
}
#[test]
fn test_record_block_and_is_blocked() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 60, 5500);
assert!(mgr.is_blocked("openai/gpt-4o"));
assert!(!mgr.is_blocked("anthropic/claude"));
}
#[test]
fn test_remaining_block_duration() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 10, 5000);
let remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap();
assert!(remaining <= Duration::from_secs(11));
assert!(remaining > Duration::from_secs(8));
}
#[test]
fn test_expired_entry_cleaned_up_on_is_blocked() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 0, 5000);
thread::sleep(Duration::from_millis(10));
assert!(!mgr.is_blocked("openai/gpt-4o"));
}
#[test]
fn test_expired_entry_cleaned_up_on_remaining() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 0, 5000);
thread::sleep(Duration::from_millis(10));
assert!(mgr.remaining_block_duration("openai/gpt-4o").is_none());
}
#[test]
fn test_max_expiration_semantics_longer_wins() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 10, 5000);
let first_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap();
mgr.record_block("openai/gpt-4o", 60, 6000);
let second_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap();
assert!(second_remaining > first_remaining);
}
#[test]
fn test_max_expiration_semantics_shorter_does_not_overwrite() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 60, 5000);
let first_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap();
mgr.record_block("openai/gpt-4o", 5, 6000);
let second_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap();
// Should still be close to the original 60s
assert!(second_remaining > Duration::from_secs(50));
let diff = if first_remaining > second_remaining {
first_remaining - second_remaining
} else {
second_remaining - first_remaining
};
assert!(diff < Duration::from_secs(2));
}
#[test]
fn test_is_model_blocked_model_scope() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 60, 5000);
assert!(mgr.is_model_blocked("openai/gpt-4o", BlockScope::Model));
assert!(!mgr.is_model_blocked("openai/gpt-4o-mini", BlockScope::Model));
}
#[test]
fn test_is_model_blocked_provider_scope() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai", 60, 5000);
assert!(mgr.is_model_blocked("openai/gpt-4o", BlockScope::Provider));
assert!(mgr.is_model_blocked("openai/gpt-4o-mini", BlockScope::Provider));
assert!(!mgr.is_model_blocked("anthropic/claude", BlockScope::Provider));
}
#[test]
fn test_multiple_identifiers_independent() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 60, 5000);
mgr.record_block("anthropic/claude", 30, 4000);
assert!(mgr.is_blocked("openai/gpt-4o"));
assert!(mgr.is_blocked("anthropic/claude"));
assert!(!mgr.is_blocked("azure/gpt-4o"));
}
#[test]
fn test_record_block_stores_measured_latency() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 60, 5500);
// Verify the entry exists and has the correct latency
let entry = mgr.global_state.get("openai/gpt-4o").unwrap();
assert_eq!(entry.1, 5500);
}
#[test]
fn test_latency_updated_when_expiration_extended() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 10, 5000);
// Extend with longer duration and different latency
mgr.record_block("openai/gpt-4o", 60, 7000);
let entry = mgr.global_state.get("openai/gpt-4o").unwrap();
assert_eq!(entry.1, 7000);
}
#[test]
fn test_latency_not_updated_when_expiration_not_extended() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 60, 5000);
// Shorter duration — should NOT update
mgr.record_block("openai/gpt-4o", 5, 9000);
let entry = mgr.global_state.get("openai/gpt-4o").unwrap();
// Latency should remain 5000 since expiration wasn't extended
assert_eq!(entry.1, 5000);
}
#[test]
fn test_zero_duration_block_expires_immediately() {
let mgr = LatencyBlockStateManager::new();
mgr.record_block("openai/gpt-4o", 0, 5000);
thread::sleep(Duration::from_millis(5));
assert!(!mgr.is_blocked("openai/gpt-4o"));
}
#[test]
fn test_default_trait() {
let mgr = LatencyBlockStateManager::default();
assert!(!mgr.is_blocked("anything"));
}
// --- Property-based tests ---
use proptest::prelude::*;
fn arb_identifier() -> impl Strategy<Value = String> {
prop_oneof![
"[a-z]{3,8}/[a-z0-9\\-]{3,12}".prop_map(|s| s),
"[a-z]{3,8}".prop_map(|s| s),
]
}
/// A single block recording: (block_duration_seconds, measured_latency_ms)
fn arb_block_recording() -> impl Strategy<Value = (u64, u64)> {
(1u64..=600, 100u64..=30_000)
}
// Feature: retry-on-ratelimit, Property 22: Latency Block State Max Expiration Update
// **Validates: Requirements 14.15**
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
/// Property 22 Case 1: After recording multiple blocks for the same identifier
/// with different durations, the remaining block duration reflects the maximum
/// duration recorded (max-expiration semantics).
#[test]
fn prop_latency_block_max_expiration_update(
identifier in arb_identifier(),
recordings in prop::collection::vec(arb_block_recording(), 2..=10),
) {
let mgr = LatencyBlockStateManager::new();
for &(duration, latency) in &recordings {
mgr.record_block(&identifier, duration, latency);
}
let max_duration = recordings.iter().map(|&(d, _)| d).max().unwrap();
// The identifier should still be blocked
let remaining = mgr.remaining_block_duration(&identifier);
prop_assert!(
remaining.is_some(),
"Identifier {} should be blocked after {} recordings (max_duration={}s)",
identifier, recordings.len(), max_duration
);
let remaining_secs = remaining.unwrap().as_secs();
// Remaining should be close to max_duration (allow 2s tolerance for execution time)
prop_assert!(
remaining_secs >= max_duration.saturating_sub(2),
"Remaining {}s should reflect the max duration ({}s), not a smaller value. Recordings: {:?}",
remaining_secs, max_duration, recordings
);
prop_assert!(
remaining_secs <= max_duration + 1,
"Remaining {}s should not exceed max duration {}s + tolerance. Recordings: {:?}",
remaining_secs, max_duration, recordings
);
}
/// Property 22 Case 2: measured_latency_ms is updated when expiration is extended
/// but NOT when a shorter duration is recorded.
#[test]
fn prop_latency_block_measured_latency_update_semantics(
identifier in arb_identifier(),
first_duration in 10u64..=300,
first_latency in 100u64..=30_000,
extra_duration in 1u64..=300,
longer_latency in 100u64..=30_000,
shorter_duration in 1u64..=9,
shorter_latency in 100u64..=30_000,
) {
let mgr = LatencyBlockStateManager::new();
// Record initial block
mgr.record_block(&identifier, first_duration, first_latency);
{
let entry = mgr.global_state.get(&identifier).unwrap();
prop_assert_eq!(entry.1, first_latency);
}
// Record a longer duration — latency SHOULD be updated
let longer_duration = first_duration + extra_duration;
mgr.record_block(&identifier, longer_duration, longer_latency);
{
let entry = mgr.global_state.get(&identifier).unwrap();
prop_assert_eq!(
entry.1, longer_latency,
"Latency should be updated to {} when expiration is extended (duration {} > {})",
longer_latency, longer_duration, first_duration
);
}
// Record a shorter duration — latency should NOT be updated
mgr.record_block(&identifier, shorter_duration, shorter_latency);
{
let entry = mgr.global_state.get(&identifier).unwrap();
prop_assert_eq!(
entry.1, longer_latency,
"Latency should remain {} (not {}) when shorter duration {} < {} doesn't extend expiration",
longer_latency, shorter_latency, shorter_duration, longer_duration
);
}
}
}
}

View file

@ -0,0 +1,230 @@
use std::time::Instant;
use dashmap::DashMap;
/// Thread-safe sliding window counter for tracking High_Latency_Events.
///
/// Maintains per-identifier timestamps of latency events within a configurable
/// sliding window. When the count of recent events meets or exceeds `min_triggers`,
/// the caller should create a `Latency_Block_State` entry and then call `reset()`.
pub struct LatencyTriggerCounter {
/// model/provider identifier -> list of event timestamps within the window
counters: DashMap<String, Vec<Instant>>,
}
impl LatencyTriggerCounter {
pub fn new() -> Self {
Self {
counters: DashMap::new(),
}
}
/// Record a High_Latency_Event. Returns true if `min_triggers` threshold
/// is now met (caller should create a Latency_Block_State).
///
/// Lazily discards events older than `trigger_window_seconds` before checking
/// the count.
pub fn record_event(
&self,
identifier: &str,
min_triggers: u32,
trigger_window_seconds: u64,
) -> bool {
let now = Instant::now();
let window = std::time::Duration::from_secs(trigger_window_seconds);
let mut entry = self.counters.entry(identifier.to_string()).or_default();
// Add current event
entry.push(now);
// Discard events older than the window
entry.retain(|ts| now.duration_since(*ts) <= window);
// Check threshold
entry.len() >= min_triggers as usize
}
/// Reset the counter for an identifier (called after a block is created
/// to prevent re-triggering on the same events).
pub fn reset(&self, identifier: &str) {
if let Some(mut entry) = self.counters.get_mut(identifier) {
entry.clear();
}
}
}
impl Default for LatencyTriggerCounter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
use std::time::Duration;
#[test]
fn test_record_event_returns_true_when_threshold_met() {
let counter = LatencyTriggerCounter::new();
assert!(!counter.record_event("model-a", 3, 60));
assert!(!counter.record_event("model-a", 3, 60));
assert!(counter.record_event("model-a", 3, 60));
}
#[test]
fn test_record_event_single_trigger_always_fires() {
let counter = LatencyTriggerCounter::new();
assert!(counter.record_event("model-a", 1, 60));
}
#[test]
fn test_events_expire_outside_window() {
let counter = LatencyTriggerCounter::new();
// Record 2 events
counter.record_event("model-a", 3, 1);
counter.record_event("model-a", 3, 1);
// Wait for them to expire
sleep(Duration::from_millis(1100));
// Third event should not meet threshold since previous two expired
assert!(!counter.record_event("model-a", 3, 1));
}
#[test]
fn test_reset_clears_counter() {
let counter = LatencyTriggerCounter::new();
counter.record_event("model-a", 3, 60);
counter.record_event("model-a", 3, 60);
counter.reset("model-a");
// After reset, need 3 fresh events again
assert!(!counter.record_event("model-a", 3, 60));
assert!(!counter.record_event("model-a", 3, 60));
assert!(counter.record_event("model-a", 3, 60));
}
#[test]
fn test_reset_nonexistent_identifier_is_noop() {
let counter = LatencyTriggerCounter::new();
// Should not panic
counter.reset("nonexistent");
}
#[test]
fn test_separate_identifiers_are_independent() {
let counter = LatencyTriggerCounter::new();
counter.record_event("model-a", 2, 60);
counter.record_event("model-b", 2, 60);
// model-a has 1 event, model-b has 1 event — neither at threshold of 2
assert!(!counter.record_event("model-b", 3, 60));
// model-a reaches threshold
assert!(counter.record_event("model-a", 2, 60));
}
#[test]
fn test_threshold_exceeded_still_returns_true() {
let counter = LatencyTriggerCounter::new();
assert!(counter.record_event("model-a", 1, 60));
// Already past threshold, still returns true
assert!(counter.record_event("model-a", 1, 60));
assert!(counter.record_event("model-a", 1, 60));
}
// --- Property-based tests ---
use proptest::prelude::*;
// Feature: retry-on-ratelimit, Property 18: Latency Trigger Counter Sliding Window
// **Validates: Requirements 2a.6, 2a.7, 2a.8, 2a.21, 14.1, 14.2, 14.3, 14.12**
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
/// Property 18 Case 1: Recording N events in quick succession (all within window)
/// returns true iff N >= min_triggers.
#[test]
fn prop_sliding_window_threshold(
min_triggers in 1u32..=10,
trigger_window_seconds in 1u64..=60,
num_events in 1u32..=20,
) {
let counter = LatencyTriggerCounter::new();
let identifier = "test-model";
let mut last_result = false;
for i in 1..=num_events {
last_result = counter.record_event(identifier, min_triggers, trigger_window_seconds);
// Before reaching threshold, should be false
if i < min_triggers {
prop_assert!(!last_result, "Expected false at event {} with min_triggers {}", i, min_triggers);
} else {
// At or past threshold, should be true
prop_assert!(last_result, "Expected true at event {} with min_triggers {}", i, min_triggers);
}
}
// Final result should match whether we recorded enough events
prop_assert_eq!(last_result, num_events >= min_triggers);
}
/// Property 18 Case 2: After reset, counter starts fresh and previous events
/// do not count toward the threshold.
#[test]
fn prop_reset_clears_counter(
min_triggers in 2u32..=10,
trigger_window_seconds in 1u64..=60,
events_before_reset in 1u32..=10,
) {
let counter = LatencyTriggerCounter::new();
let identifier = "test-model";
// Record some events before reset
for _ in 0..events_before_reset {
counter.record_event(identifier, min_triggers, trigger_window_seconds);
}
// Reset the counter
counter.reset(identifier);
// After reset, a single event should not meet threshold (min_triggers >= 2)
let result = counter.record_event(identifier, min_triggers, trigger_window_seconds);
prop_assert!(!result, "After reset, first event should not meet threshold of {}", min_triggers);
// Need min_triggers - 1 more events to reach threshold again
let mut final_result = result;
for _ in 1..min_triggers {
final_result = counter.record_event(identifier, min_triggers, trigger_window_seconds);
}
prop_assert!(final_result, "After reset + {} events, should meet threshold", min_triggers);
}
/// Property 18 Case 3: Different identifiers are independent — events for one
/// identifier do not affect the count for another.
#[test]
fn prop_identifiers_independent(
min_triggers in 1u32..=10,
trigger_window_seconds in 1u64..=60,
events_a in 1u32..=20,
events_b in 1u32..=20,
) {
let counter = LatencyTriggerCounter::new();
let id_a = "model-a";
let id_b = "model-b";
// Record events for identifier A
let mut result_a = false;
for _ in 0..events_a {
result_a = counter.record_event(id_a, min_triggers, trigger_window_seconds);
}
// Record events for identifier B
let mut result_b = false;
for _ in 0..events_b {
result_b = counter.record_event(id_b, min_triggers, trigger_window_seconds);
}
// Each identifier's result depends only on its own event count
prop_assert_eq!(result_a, events_a >= min_triggers,
"id_a: events={}, min_triggers={}", events_a, min_triggers);
prop_assert_eq!(result_b, events_b >= min_triggers,
"id_b: events={}, min_triggers={}", events_b, min_triggers);
}
}
} // mod tests

View file

@ -0,0 +1,804 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Instant;
use bytes::Bytes;
use hyper::HeaderMap;
use sha2::{Digest, Sha256};
use crate::configuration::{ApplyTo, LlmProvider, LlmProviderType};
// Sub-modules
pub mod backoff;
pub mod error_detector;
pub mod error_response;
pub mod latency_block_state;
pub mod latency_trigger;
pub mod orchestrator;
pub mod provider_selector;
pub mod retry_after_state;
pub mod validation;
// ── State Structs ──────────────────────────────────────────────────────────
/// In-memory Retry-After state entry.
#[derive(Debug, Clone)]
pub struct RetryAfterEntry {
pub identifier: String,
pub expires_at: Instant,
pub apply_to: ApplyTo,
}
/// In-memory Latency Block state entry.
#[derive(Debug, Clone)]
pub struct LatencyBlockEntry {
pub identifier: String,
pub expires_at: Instant,
pub measured_latency_ms: u64,
pub apply_to: ApplyTo,
}
/// Error accumulated from a single attempt.
#[derive(Debug, Clone)]
pub struct AttemptError {
pub model_id: String,
pub error_type: AttemptErrorType,
pub attempt_number: u32,
}
#[derive(Debug, Clone)]
pub enum AttemptErrorType {
HttpError { status_code: u16, body: Vec<u8> },
Timeout { duration_ms: u64 },
HighLatency { measured_ms: u64, threshold_ms: u64 },
}
/// Lightweight request signature for retry tracking.
/// The actual request body bytes are passed by reference from the handler scope
/// (as `&Bytes`) rather than cloned into this struct.
#[derive(Debug, Clone)]
pub struct RequestSignature {
/// SHA-256 hash of the original request body
pub body_hash: [u8; 32],
pub headers: HeaderMap,
pub streaming: bool,
pub original_model: String,
}
impl RequestSignature {
pub fn new(body: &[u8], headers: &HeaderMap, streaming: bool, original_model: String) -> Self {
let mut hasher = Sha256::new();
hasher.update(body);
let hash: [u8; 32] = hasher.finalize().into();
Self {
body_hash: hash,
headers: headers.clone(),
streaming,
original_model,
}
}
}
// ── Auth Header Constants ───────────────────────────────────────────────────
/// Headers that carry authentication credentials and must be sanitized
/// when forwarding requests to a different provider.
const AUTH_HEADERS: &[&str] = &["authorization", "x-api-key"];
/// Additional provider-specific headers that should be sanitized.
const PROVIDER_SPECIFIC_HEADERS: &[&str] = &["anthropic-version"];
/// Rebuild a request for a different target provider.
///
/// Updates the `model` field in the JSON body to match the target provider's
/// model name (without provider prefix), and applies the correct auth
/// credentials for the target provider. Sanitizes auth headers from the
/// original request to prevent credential leakage across providers.
///
/// Returns the updated body bytes and headers, or an error if the body
/// cannot be parsed as JSON.
pub fn rebuild_request_for_provider(
body: &Bytes,
target_provider: &LlmProvider,
original_headers: &HeaderMap,
) -> Result<(Bytes, HeaderMap), RebuildError> {
// Update the model field in the JSON body
let mut json_body: serde_json::Value =
serde_json::from_slice(body).map_err(|e| RebuildError::InvalidJson(e.to_string()))?;
// Extract model name without provider prefix (e.g., "openai/gpt-4o" -> "gpt-4o")
let target_model = target_provider
.model
.as_deref()
.or(Some(&target_provider.name))
.unwrap_or(&target_provider.name);
let model_name_only = if let Some((_, model)) = target_model.split_once('/') {
model
} else {
target_model
};
if let Some(obj) = json_body.as_object_mut() {
obj.insert(
"model".to_string(),
serde_json::Value::String(model_name_only.to_string()),
);
}
let updated_body = Bytes::from(
serde_json::to_vec(&json_body).map_err(|e| RebuildError::InvalidJson(e.to_string()))?,
);
// Sanitize and rebuild headers
let mut headers = sanitize_headers(original_headers);
apply_auth_headers(&mut headers, target_provider)?;
Ok((updated_body, headers))
}
/// Remove auth-related headers from the original request to prevent
/// credential leakage when forwarding to a different provider.
fn sanitize_headers(original: &HeaderMap) -> HeaderMap {
let mut headers = original.clone();
for header_name in AUTH_HEADERS.iter().chain(PROVIDER_SPECIFIC_HEADERS.iter()) {
headers.remove(*header_name);
}
headers
}
/// Apply the correct auth headers for the target provider.
fn apply_auth_headers(headers: &mut HeaderMap, provider: &LlmProvider) -> Result<(), RebuildError> {
// If passthrough_auth is enabled, don't set provider credentials
if provider.passthrough_auth == Some(true) {
return Ok(());
}
let access_key = provider
.access_key
.as_ref()
.ok_or_else(|| RebuildError::MissingAccessKey(provider.name.clone()))?;
match provider.provider_interface {
LlmProviderType::Anthropic => {
headers.insert(
hyper::header::HeaderName::from_static("x-api-key"),
hyper::header::HeaderValue::from_str(access_key)
.map_err(|_| RebuildError::InvalidHeaderValue("x-api-key".to_string()))?,
);
headers.insert(
hyper::header::HeaderName::from_static("anthropic-version"),
hyper::header::HeaderValue::from_static("2023-06-01"),
);
}
_ => {
// OpenAI-compatible providers use Authorization: Bearer <key>
let bearer = format!("Bearer {}", access_key);
headers.insert(
hyper::header::AUTHORIZATION,
hyper::header::HeaderValue::from_str(&bearer)
.map_err(|_| RebuildError::InvalidHeaderValue("authorization".to_string()))?,
);
}
}
Ok(())
}
/// Errors that can occur when rebuilding a request for a different provider.
#[derive(Debug, Clone, PartialEq)]
pub enum RebuildError {
/// The request body is not valid JSON.
InvalidJson(String),
/// The target provider has no access_key configured.
MissingAccessKey(String),
/// A header value could not be constructed.
InvalidHeaderValue(String),
}
impl std::fmt::Display for RebuildError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RebuildError::InvalidJson(e) => write!(f, "invalid JSON body: {}", e),
RebuildError::MissingAccessKey(name) => {
write!(f, "no access key configured for provider '{}'", name)
}
RebuildError::InvalidHeaderValue(header) => {
write!(f, "invalid header value for '{}'", header)
}
}
}
}
impl std::error::Error for RebuildError {}
/// Extended request context for retry tracking.
#[derive(Debug)]
pub struct RequestContext {
pub request_id: String,
pub attempted_providers: HashSet<String>,
pub retry_start_time: Option<Instant>,
pub attempt_number: u32,
/// Request-scoped Retry_After_State (when apply_to: "request")
pub request_retry_after_state: HashMap<String, Instant>,
/// Request-scoped Latency_Block_State (when apply_to: "request")
pub request_latency_block_state: HashMap<String, Instant>,
/// Request signature for tracking
pub request_signature: RequestSignature,
/// Accumulated errors from all attempts
pub errors: Vec<AttemptError>,
}
/// Bounded semaphore controlling the maximum number of concurrent in-flight
/// retry operations. Prevents OOM under high load by rejecting new retry
/// attempts when the limit is reached (fail-open: original request proceeds
/// without retry).
pub struct RetryGate {
pub semaphore: Arc<tokio::sync::Semaphore>,
}
impl RetryGate {
const DEFAULT_MAX_IN_FLIGHT: usize = 1000;
pub fn new(max_in_flight_retries: usize) -> Self {
Self {
semaphore: Arc::new(tokio::sync::Semaphore::new(max_in_flight_retries)),
}
}
pub fn try_acquire(&self) -> Option<tokio::sync::OwnedSemaphorePermit> {
self.semaphore.clone().try_acquire_owned().ok()
}
}
impl Default for RetryGate {
fn default() -> Self {
Self::new(Self::DEFAULT_MAX_IN_FLIGHT)
}
}
// ── Error Types ────────────────────────────────────────────────────────────
/// All retry attempts exhausted for a single provider's retry sequence.
#[derive(Debug)]
pub struct RetryExhaustedError {
/// All attempt errors accumulated during the retry sequence.
pub attempts: Vec<AttemptError>,
/// Maximum Retry-After value observed across all attempts (if any).
pub max_retry_after_seconds: Option<u64>,
/// Shortest remaining block duration among blocked candidates at exhaustion time.
pub shortest_remaining_block_seconds: Option<u64>,
/// Whether the retry budget (max_retry_duration_ms) was exceeded.
pub retry_budget_exhausted: bool,
}
/// All providers (including fallbacks) exhausted.
#[derive(Debug)]
pub struct AllProvidersExhaustedError {
/// Shortest remaining block duration among blocked candidates.
pub shortest_remaining_block_seconds: Option<u64>,
}
// ── Validation Types ───────────────────────────────────────────────────────
/// Configuration validation errors that prevent gateway startup.
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationError {
/// Backoff section present without required `apply_to` field.
BackoffMissingApplyTo { model: String },
/// `min_triggers > 1` without `trigger_window_seconds`.
LatencyMissingTriggerWindow { model: String },
/// Invalid strategy value.
InvalidStrategy { model: String, value: String },
/// Invalid `apply_to` value.
InvalidApplyTo { model: String, value: String },
/// Invalid `scope` value.
InvalidScope { model: String, value: String },
/// Status code outside 100599.
StatusCodeOutOfRange { model: String, code: u16 },
/// Range with start > end.
StatusCodeRangeInverted { model: String, range: String },
/// Invalid status code range format.
StatusCodeRangeInvalid { model: String, range: String },
/// `threshold_ms`, `block_duration_seconds`, `max_retry_after_seconds`,
/// `max_retry_duration_ms`, or `base_ms` not positive.
NonPositiveValue { model: String, field: String },
/// `trigger_window_seconds` not positive when specified.
NonPositiveTriggerWindow { model: String },
/// `max_ms` ≤ `base_ms` in backoff config.
MaxMsNotGreaterThanBaseMs {
model: String,
base_ms: u64,
max_ms: u64,
},
/// `max_attempts` is negative (represented as u32, so this catches zero if needed).
InvalidMaxAttempts { model: String, value: String },
/// Fallback model string is empty or doesn't contain a "/" separator.
InvalidFallbackModel { model: String, fallback: String },
}
/// Configuration validation warnings (gateway starts, warning logged).
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationWarning {
/// Single provider with failover strategy.
SingleProviderWithFailover { model: String, strategy: String },
/// Provider-scope Retry-After with same_model strategy.
ProviderScopeWithSameModel { model: String },
/// Backoff apply_to mismatch with default strategy.
BackoffApplyToMismatch {
model: String,
apply_to: String,
strategy: String,
},
/// Latency scope/strategy mismatch.
LatencyScopeStrategyMismatch { model: String },
/// Aggressive latency threshold (< 1000ms).
AggressiveLatencyThreshold { model: String, threshold_ms: u64 },
/// Fallback model not in Provider_List.
FallbackModelNotInProviderList { model: String, fallback: String },
/// Overlapping status codes across on_status_codes entries.
OverlappingStatusCodes { model: String, code: u16 },
}
#[cfg(test)]
mod tests {
use super::*;
use crate::configuration::{LlmProvider, LlmProviderType};
use bytes::Bytes;
use hyper::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use proptest::prelude::*;
fn make_provider(name: &str, interface: LlmProviderType, key: Option<&str>) -> LlmProvider {
LlmProvider {
name: name.to_string(),
provider_interface: interface,
access_key: key.map(|k| k.to_string()),
model: Some(name.to_string()),
default: None,
stream: None,
endpoint: None,
port: None,
rate_limits: None,
usage: None,
cluster_name: None,
base_url_path_prefix: None,
internal: None,
passthrough_auth: None,
retry_policy: None,
headers: None,
}
}
// ── RequestSignature tests ─────────────────────────────────────────
#[test]
fn test_request_signature_computes_hash() {
let body = b"hello world";
let headers = HeaderMap::new();
let sig = RequestSignature::new(body, &headers, false, "openai/gpt-4o".to_string());
// SHA-256 of "hello world" is deterministic
let mut hasher = Sha256::new();
hasher.update(b"hello world");
let expected: [u8; 32] = hasher.finalize().into();
assert_eq!(sig.body_hash, expected);
assert!(!sig.streaming);
assert_eq!(sig.original_model, "openai/gpt-4o");
}
#[test]
fn test_request_signature_preserves_headers() {
let mut headers = HeaderMap::new();
headers.insert("x-custom", HeaderValue::from_static("value"));
let sig = RequestSignature::new(b"body", &headers, true, "model".to_string());
assert_eq!(sig.headers.get("x-custom").unwrap(), "value");
assert!(sig.streaming);
}
#[test]
fn test_request_signature_different_bodies_different_hashes() {
let headers = HeaderMap::new();
let sig1 = RequestSignature::new(b"body1", &headers, false, "m".to_string());
let sig2 = RequestSignature::new(b"body2", &headers, false, "m".to_string());
assert_ne!(sig1.body_hash, sig2.body_hash);
}
// ── RetryGate tests ────────────────────────────────────────────────
#[test]
fn test_retry_gate_default_permits() {
let gate = RetryGate::default();
// Should be able to acquire at least one permit
assert!(gate.try_acquire().is_some());
}
#[test]
fn test_retry_gate_exhaustion() {
let gate = RetryGate::new(1);
let permit = gate.try_acquire();
assert!(permit.is_some());
// Second acquire should fail (only 1 permit)
assert!(gate.try_acquire().is_none());
// Drop permit, should be able to acquire again
drop(permit);
assert!(gate.try_acquire().is_some());
}
#[test]
fn test_retry_gate_custom_capacity() {
let gate = RetryGate::new(3);
let _p1 = gate.try_acquire().unwrap();
let _p2 = gate.try_acquire().unwrap();
let _p3 = gate.try_acquire().unwrap();
assert!(gate.try_acquire().is_none());
}
// ── rebuild_request_for_provider tests ─────────────────────────────
#[test]
fn test_rebuild_updates_model_field() {
let body = Bytes::from(r#"{"model":"gpt-4o","messages":[]}"#);
let headers = HeaderMap::new();
let provider = make_provider(
"openai/gpt-4o-mini",
LlmProviderType::OpenAI,
Some("sk-test"),
);
let (new_body, _) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
let json: serde_json::Value = serde_json::from_slice(&new_body).unwrap();
assert_eq!(json["model"], "gpt-4o-mini");
}
#[test]
fn test_rebuild_preserves_other_fields() {
let body = Bytes::from(
r#"{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}],"temperature":0.7}"#,
);
let headers = HeaderMap::new();
let provider = make_provider(
"openai/gpt-4o-mini",
LlmProviderType::OpenAI,
Some("sk-test"),
);
let (new_body, _) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
let json: serde_json::Value = serde_json::from_slice(&new_body).unwrap();
assert_eq!(json["messages"][0]["role"], "user");
assert_eq!(json["messages"][0]["content"], "hi");
assert_eq!(json["temperature"], 0.7);
}
#[test]
fn test_rebuild_sets_openai_auth() {
let body = Bytes::from(r#"{"model":"old"}"#);
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-key"));
let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("sk-new"));
let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
assert_eq!(
new_headers.get(AUTHORIZATION).unwrap().to_str().unwrap(),
"Bearer sk-new"
);
assert!(new_headers.get("x-api-key").is_none());
}
#[test]
fn test_rebuild_sets_anthropic_auth() {
let body = Bytes::from(r#"{"model":"old"}"#);
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-key"));
let provider = make_provider(
"anthropic/claude-3-5-sonnet",
LlmProviderType::Anthropic,
Some("ant-key"),
);
let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
// Anthropic uses x-api-key, not Authorization
assert!(new_headers.get(AUTHORIZATION).is_none());
assert_eq!(
new_headers.get("x-api-key").unwrap().to_str().unwrap(),
"ant-key"
);
assert_eq!(
new_headers
.get("anthropic-version")
.unwrap()
.to_str()
.unwrap(),
"2023-06-01"
);
}
#[test]
fn test_rebuild_sanitizes_old_auth_headers() {
let body = Bytes::from(r#"{"model":"old"}"#);
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-key"));
headers.insert("x-api-key", HeaderValue::from_static("old-api-key"));
headers.insert("anthropic-version", HeaderValue::from_static("old-version"));
headers.insert("x-custom", HeaderValue::from_static("keep-me"));
let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("sk-new"));
let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
// Old x-api-key and anthropic-version should be removed
assert!(new_headers.get("anthropic-version").is_none());
// New auth should be set
assert_eq!(
new_headers.get(AUTHORIZATION).unwrap().to_str().unwrap(),
"Bearer sk-new"
);
// Custom headers preserved
assert_eq!(
new_headers.get("x-custom").unwrap().to_str().unwrap(),
"keep-me"
);
}
#[test]
fn test_rebuild_passthrough_auth_skips_credentials() {
let body = Bytes::from(r#"{"model":"old"}"#);
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer client-key"));
let mut provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("sk-new"));
provider.passthrough_auth = Some(true);
let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
// Auth headers are sanitized, and passthrough_auth means no new ones are set
assert!(new_headers.get(AUTHORIZATION).is_none());
}
#[test]
fn test_rebuild_missing_access_key_errors() {
let body = Bytes::from(r#"{"model":"old"}"#);
let headers = HeaderMap::new();
let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, None);
let result = rebuild_request_for_provider(&body, &provider, &headers);
assert!(matches!(result, Err(RebuildError::MissingAccessKey(_))));
}
#[test]
fn test_rebuild_invalid_json_errors() {
let body = Bytes::from("not json");
let headers = HeaderMap::new();
let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("key"));
let result = rebuild_request_for_provider(&body, &provider, &headers);
assert!(matches!(result, Err(RebuildError::InvalidJson(_))));
}
#[test]
fn test_rebuild_model_without_provider_prefix() {
let body = Bytes::from(r#"{"model":"old"}"#);
let headers = HeaderMap::new();
let mut provider = make_provider("gpt-4o", LlmProviderType::OpenAI, Some("key"));
provider.model = Some("gpt-4o".to_string());
let (new_body, _) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
let json: serde_json::Value = serde_json::from_slice(&new_body).unwrap();
// No prefix to strip, model name used as-is
assert_eq!(json["model"], "gpt-4o");
}
// --- Proptest strategies ---
fn arb_provider_type() -> impl Strategy<Value = LlmProviderType> {
prop_oneof![
Just(LlmProviderType::OpenAI),
Just(LlmProviderType::Anthropic),
Just(LlmProviderType::Gemini),
Just(LlmProviderType::Deepseek),
]
}
fn arb_model_name() -> impl Strategy<Value = String> {
prop_oneof![
Just("openai/gpt-4o".to_string()),
Just("openai/gpt-4o-mini".to_string()),
Just("anthropic/claude-3-5-sonnet".to_string()),
Just("gemini/gemini-pro".to_string()),
Just("deepseek/deepseek-chat".to_string()),
]
}
fn arb_target_provider() -> impl Strategy<Value = LlmProvider> {
(arb_model_name(), arb_provider_type())
.prop_map(|(model, iface)| make_provider(&model, iface, Some("test-key-123")))
}
fn arb_message_content() -> impl Strategy<Value = String> {
"[a-zA-Z0-9 ]{1,50}"
}
fn arb_messages() -> impl Strategy<Value = Vec<serde_json::Value>> {
prop::collection::vec(
(
prop_oneof![Just("user"), Just("assistant"), Just("system")],
arb_message_content(),
)
.prop_map(|(role, content)| serde_json::json!({"role": role, "content": content})),
1..5,
)
}
fn arb_json_body() -> impl Strategy<Value = serde_json::Value> {
(
arb_model_name(),
arb_messages(),
prop::option::of(0.0f64..2.0),
prop::option::of(1u32..4096),
proptest::bool::ANY,
)
.prop_map(|(model, messages, temperature, max_tokens, stream)| {
let model_only = model.split('/').nth(1).unwrap_or(&model);
let mut obj = serde_json::json!({
"model": model_only,
"messages": messages,
});
if let Some(t) = temperature {
obj["temperature"] = serde_json::json!(t);
}
if let Some(mt) = max_tokens {
obj["max_tokens"] = serde_json::json!(mt);
}
if stream {
obj["stream"] = serde_json::json!(true);
}
obj
})
}
fn arb_custom_headers() -> impl Strategy<Value = Vec<(String, String)>> {
prop::collection::vec(
(
prop_oneof![
Just("x-request-id".to_string()),
Just("x-custom-header".to_string()),
Just("x-trace-id".to_string()),
Just("content-type".to_string()),
],
"[a-zA-Z0-9-]{1,30}",
),
0..4,
)
}
// Feature: retry-on-ratelimit, Property 14: Request Preservation Across Retries
// **Validates: Requirements 5.1, 5.2, 5.3, 5.4, 5.5, 3.15**
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
/// Property 14 The original body bytes are unchanged after rebuild (body is passed by reference).
/// The rebuilt body has the model field updated to the target provider's model.
/// All other JSON fields are preserved. The RequestSignature hash matches the original body hash.
/// Custom headers are preserved while auth headers are sanitized.
#[test]
fn prop_request_preservation_across_retries(
json_body in arb_json_body(),
custom_headers in arb_custom_headers(),
streaming in proptest::bool::ANY,
target_provider in arb_target_provider(),
) {
let body_bytes = serde_json::to_vec(&json_body).unwrap();
let body = Bytes::from(body_bytes.clone());
// Build original headers with custom + auth headers
let mut original_headers = HeaderMap::new();
for (name, value) in &custom_headers {
if let (Ok(hn), Ok(hv)) = (
hyper::header::HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(value),
) {
original_headers.insert(hn, hv);
}
}
// Add auth headers that should be sanitized
original_headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-secret"));
original_headers.insert("x-api-key", HeaderValue::from_static("old-api-key"));
let original_model = json_body["model"].as_str().unwrap_or("unknown").to_string();
// Create RequestSignature from original body
let sig = RequestSignature::new(&body, &original_headers, streaming, original_model.clone());
// Assert: body bytes are unchanged (passed by reference, not modified)
prop_assert_eq!(&body[..], &body_bytes[..], "Original body bytes must be unchanged");
// Assert: RequestSignature hash matches a fresh hash of the same body
let mut hasher = Sha256::new();
hasher.update(&body);
let expected_hash: [u8; 32] = hasher.finalize().into();
prop_assert_eq!(sig.body_hash, expected_hash, "RequestSignature hash must match original body hash");
// Assert: streaming flag preserved
prop_assert_eq!(sig.streaming, streaming, "Streaming flag must be preserved in signature");
// Rebuild for target provider
let result = rebuild_request_for_provider(&body, &target_provider, &original_headers);
prop_assert!(result.is_ok(), "rebuild_request_for_provider should succeed for valid JSON body");
let (rebuilt_body, rebuilt_headers) = result.unwrap();
// Parse rebuilt body
let rebuilt_json: serde_json::Value = serde_json::from_slice(&rebuilt_body).unwrap();
// Assert: model field updated to target provider's model (without prefix)
let target_model = target_provider.model.as_deref().unwrap_or(&target_provider.name);
let expected_model = target_model.split_once('/').map(|(_, m)| m).unwrap_or(target_model);
prop_assert_eq!(
rebuilt_json["model"].as_str().unwrap(),
expected_model,
"Model field must be updated to target provider's model"
);
// Assert: messages array preserved
prop_assert_eq!(
&rebuilt_json["messages"],
&json_body["messages"],
"Messages array must be preserved across rebuild"
);
// Assert: other JSON fields preserved (temperature, max_tokens, stream)
// The rebuild function does a JSON round-trip (deserialize → modify model → serialize),
// so we compare against a round-tripped version of the original to account for
// any f64 precision changes inherent to JSON serialization.
let original_round_tripped: serde_json::Value = serde_json::from_slice(
&serde_json::to_vec(&json_body).unwrap()
).unwrap();
for key in ["temperature", "max_tokens", "stream"] {
if let Some(original_val) = original_round_tripped.get(key) {
prop_assert_eq!(
&rebuilt_json[key],
original_val,
"Field '{}' must be preserved across rebuild",
key
);
}
}
// Assert: custom headers preserved (non-auth headers)
// Note: HeaderMap::insert overwrites, so only the last value for each name survives
let mut last_custom: std::collections::HashMap<String, String> = std::collections::HashMap::new();
for (name, value) in &custom_headers {
let lower = name.to_lowercase();
if lower == "authorization" || lower == "x-api-key" || lower == "anthropic-version" {
continue;
}
last_custom.insert(lower, value.clone());
}
for (name, value) in &last_custom {
if let Some(hv) = rebuilt_headers.get(name.as_str()) {
prop_assert_eq!(
hv.to_str().unwrap(),
value.as_str(),
"Custom header '{}' must be preserved",
name
);
}
}
// Assert: old auth headers are sanitized (not leaked to target provider)
// The old "Bearer old-secret" and "old-api-key" should NOT appear
if let Some(auth) = rebuilt_headers.get(AUTHORIZATION) {
prop_assert_ne!(
auth.to_str().unwrap(),
"Bearer old-secret",
"Old authorization header must be sanitized"
);
}
if let Some(api_key) = rebuilt_headers.get("x-api-key") {
prop_assert_ne!(
api_key.to_str().unwrap(),
"old-api-key",
"Old x-api-key header must be sanitized"
);
}
// Assert: original body is still unchanged after rebuild
prop_assert_eq!(&body[..], &body_bytes[..], "Original body bytes must remain unchanged after rebuild");
}
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,510 @@
use std::time::{Duration, Instant};
use dashmap::DashMap;
use log::info;
use crate::configuration::{extract_provider, BlockScope};
/// Thread-safe global state manager for Retry-After header blocking.
///
/// This manager handles ONLY global state (`apply_to: "global"`).
/// Request-scoped state (`apply_to: "request"`) is stored in
/// `RequestContext.request_retry_after_state` and managed by the orchestrator.
///
/// Entries use max-expiration semantics: if a new Retry-After value is recorded
/// for an identifier that already has an entry, the expiration is updated only
/// if the new expiration is later than the existing one.
pub struct RetryAfterStateManager {
/// Global state: identifier (model ID or provider prefix) -> expiration timestamp
global_state: DashMap<String, Instant>,
}
impl RetryAfterStateManager {
pub fn new() -> Self {
Self {
global_state: DashMap::new(),
}
}
/// Record a Retry-After header, creating or updating the block entry.
///
/// The `retry_after_seconds` value is capped at `max_retry_after_seconds`.
/// Uses max-expiration semantics: if an entry already exists, the expiration
/// is updated only if the new expiration is later.
pub fn record(&self, identifier: &str, retry_after_seconds: u64, max_retry_after_seconds: u64) {
let capped = retry_after_seconds.min(max_retry_after_seconds);
let new_expiration = Instant::now() + Duration::from_secs(capped);
self.global_state
.entry(identifier.to_string())
.and_modify(|existing| {
if new_expiration > *existing {
*existing = new_expiration;
}
})
.or_insert(new_expiration);
}
/// Check if an identifier is currently blocked.
///
/// Lazily cleans up expired entries.
pub fn is_blocked(&self, identifier: &str) -> bool {
if let Some(entry) = self.global_state.get(identifier) {
if Instant::now() < *entry {
return true;
}
// Entry expired — drop the read guard before removing
drop(entry);
self.global_state.remove(identifier);
info!("Retry_After_State expired: identifier={}", identifier);
}
false
}
/// Get remaining block duration for an identifier, if blocked.
///
/// Returns `None` if the identifier is not blocked or the entry has expired.
/// Lazily cleans up expired entries.
pub fn remaining_block_duration(&self, identifier: &str) -> Option<Duration> {
if let Some(entry) = self.global_state.get(identifier) {
let now = Instant::now();
if now < *entry {
return Some(*entry - now);
}
// Entry expired — drop the read guard before removing
drop(entry);
self.global_state.remove(identifier);
info!("Retry_After_State expired: identifier={}", identifier);
}
None
}
/// Check if a model is blocked, considering scope (model or provider).
///
/// - `BlockScope::Model`: checks if the exact `model_id` is blocked.
/// - `BlockScope::Provider`: extracts the provider prefix from `model_id`
/// and checks if that prefix is blocked.
pub fn is_model_blocked(&self, model_id: &str, scope: BlockScope) -> bool {
match scope {
BlockScope::Model => self.is_blocked(model_id),
BlockScope::Provider => {
let provider = extract_provider(model_id);
self.is_blocked(provider)
}
}
}
}
impl Default for RetryAfterStateManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_new_manager_has_no_blocks() {
let mgr = RetryAfterStateManager::new();
assert!(!mgr.is_blocked("openai/gpt-4o"));
assert!(mgr.remaining_block_duration("openai/gpt-4o").is_none());
}
#[test]
fn test_record_and_is_blocked() {
let mgr = RetryAfterStateManager::new();
mgr.record("openai/gpt-4o", 60, 300);
assert!(mgr.is_blocked("openai/gpt-4o"));
assert!(!mgr.is_blocked("anthropic/claude"));
}
#[test]
fn test_record_caps_at_max() {
let mgr = RetryAfterStateManager::new();
// Retry-After of 600 seconds, but max is 300
mgr.record("openai/gpt-4o", 600, 300);
let remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap();
// Should be capped at ~300 seconds (allow some tolerance)
assert!(remaining <= Duration::from_secs(301));
assert!(remaining > Duration::from_secs(298));
}
#[test]
fn test_remaining_block_duration() {
let mgr = RetryAfterStateManager::new();
mgr.record("openai/gpt-4o", 10, 300);
let remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap();
assert!(remaining <= Duration::from_secs(11));
assert!(remaining > Duration::from_secs(8));
}
#[test]
fn test_expired_entry_cleaned_up_on_is_blocked() {
let mgr = RetryAfterStateManager::new();
// Record with 0 seconds — effectively expires immediately
mgr.record("openai/gpt-4o", 0, 300);
// Sleep briefly to ensure expiration
thread::sleep(Duration::from_millis(10));
assert!(!mgr.is_blocked("openai/gpt-4o"));
}
#[test]
fn test_expired_entry_cleaned_up_on_remaining() {
let mgr = RetryAfterStateManager::new();
mgr.record("openai/gpt-4o", 0, 300);
thread::sleep(Duration::from_millis(10));
assert!(mgr.remaining_block_duration("openai/gpt-4o").is_none());
}
#[test]
fn test_max_expiration_semantics_longer_wins() {
let mgr = RetryAfterStateManager::new();
mgr.record("openai/gpt-4o", 10, 300);
let first_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap();
// Record a longer duration — should update
mgr.record("openai/gpt-4o", 60, 300);
let second_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap();
assert!(second_remaining > first_remaining);
}
#[test]
fn test_max_expiration_semantics_shorter_does_not_overwrite() {
let mgr = RetryAfterStateManager::new();
mgr.record("openai/gpt-4o", 60, 300);
let first_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap();
// Record a shorter duration — should NOT overwrite
mgr.record("openai/gpt-4o", 5, 300);
let second_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap();
// The remaining should still be close to the original 60s
assert!(second_remaining > Duration::from_secs(50));
// Allow small timing variance
let diff = if first_remaining > second_remaining {
first_remaining - second_remaining
} else {
second_remaining - first_remaining
};
assert!(diff < Duration::from_secs(2));
}
#[test]
fn test_is_model_blocked_model_scope() {
let mgr = RetryAfterStateManager::new();
mgr.record("openai/gpt-4o", 60, 300);
assert!(mgr.is_model_blocked("openai/gpt-4o", BlockScope::Model));
assert!(!mgr.is_model_blocked("openai/gpt-4o-mini", BlockScope::Model));
}
#[test]
fn test_is_model_blocked_provider_scope() {
let mgr = RetryAfterStateManager::new();
// Block at provider level by recording with provider prefix
mgr.record("openai", 60, 300);
// Both openai models should be blocked
assert!(mgr.is_model_blocked("openai/gpt-4o", BlockScope::Provider));
assert!(mgr.is_model_blocked("openai/gpt-4o-mini", BlockScope::Provider));
// Anthropic should not be blocked
assert!(!mgr.is_model_blocked("anthropic/claude", BlockScope::Provider));
}
#[test]
fn test_model_scope_does_not_block_other_models() {
let mgr = RetryAfterStateManager::new();
mgr.record("openai/gpt-4o", 60, 300);
// Model scope: only exact match is blocked
assert!(mgr.is_model_blocked("openai/gpt-4o", BlockScope::Model));
assert!(!mgr.is_model_blocked("openai/gpt-4o-mini", BlockScope::Model));
}
#[test]
fn test_multiple_identifiers_independent() {
let mgr = RetryAfterStateManager::new();
mgr.record("openai/gpt-4o", 60, 300);
mgr.record("anthropic/claude", 30, 300);
assert!(mgr.is_blocked("openai/gpt-4o"));
assert!(mgr.is_blocked("anthropic/claude"));
assert!(!mgr.is_blocked("azure/gpt-4o"));
}
#[test]
fn test_record_with_zero_seconds() {
let mgr = RetryAfterStateManager::new();
mgr.record("openai/gpt-4o", 0, 300);
// With 0 seconds, the entry expires at Instant::now() + 0,
// which is effectively immediately
thread::sleep(Duration::from_millis(5));
assert!(!mgr.is_blocked("openai/gpt-4o"));
}
#[test]
fn test_max_retry_after_seconds_zero_caps_to_zero() {
let mgr = RetryAfterStateManager::new();
// Even with retry_after_seconds=60, max=0 caps to 0
mgr.record("openai/gpt-4o", 60, 0);
thread::sleep(Duration::from_millis(5));
assert!(!mgr.is_blocked("openai/gpt-4o"));
}
#[test]
fn test_default_trait() {
let mgr = RetryAfterStateManager::default();
assert!(!mgr.is_blocked("anything"));
}
// --- Proptest strategies ---
use proptest::prelude::*;
fn arb_provider_prefix() -> impl Strategy<Value = String> {
prop_oneof![
Just("openai".to_string()),
Just("anthropic".to_string()),
Just("azure".to_string()),
Just("google".to_string()),
Just("cohere".to_string()),
]
}
fn arb_model_suffix() -> impl Strategy<Value = String> {
prop_oneof![
Just("gpt-4o".to_string()),
Just("gpt-4o-mini".to_string()),
Just("claude-3".to_string()),
Just("gemini-pro".to_string()),
]
}
fn arb_model_id() -> impl Strategy<Value = String> {
(arb_provider_prefix(), arb_model_suffix())
.prop_map(|(prefix, suffix)| format!("{}/{}", prefix, suffix))
}
fn arb_scope() -> impl Strategy<Value = BlockScope> {
prop_oneof![Just(BlockScope::Model), Just(BlockScope::Provider),]
}
// Feature: retry-on-ratelimit, Property 15: Retry_After_State Scope Behavior
// **Validates: Requirements 11.5, 11.6, 11.7, 11.8, 12.9, 12.10, 13.10, 13.11**
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
/// Property 15 Case 1: Model scope blocks only the exact model_id.
#[test]
fn prop_model_scope_blocks_exact_model_only(
model_id in arb_model_id(),
other_model_id in arb_model_id(),
retry_after in 1u64..300,
) {
prop_assume!(model_id != other_model_id);
let mgr = RetryAfterStateManager::new();
// Record with the exact model_id (model scope records the full model ID)
mgr.record(&model_id, retry_after, 300);
// The exact model should be blocked
prop_assert!(
mgr.is_model_blocked(&model_id, BlockScope::Model),
"Model {} should be blocked with Model scope after recording",
model_id
);
// A different model should NOT be blocked (even if same provider)
prop_assert!(
!mgr.is_model_blocked(&other_model_id, BlockScope::Model),
"Model {} should NOT be blocked when {} was recorded with Model scope",
other_model_id, model_id
);
}
/// Property 15 Case 2: Provider scope blocks all models from the same provider.
#[test]
fn prop_provider_scope_blocks_all_same_provider_models(
provider in arb_provider_prefix(),
suffix1 in arb_model_suffix(),
suffix2 in arb_model_suffix(),
other_provider in arb_provider_prefix(),
other_suffix in arb_model_suffix(),
retry_after in 1u64..300,
) {
let model1 = format!("{}/{}", provider, suffix1);
let model2 = format!("{}/{}", provider, suffix2);
let other_model = format!("{}/{}", other_provider, other_suffix);
prop_assume!(provider != other_provider);
let mgr = RetryAfterStateManager::new();
// Record at provider level (provider scope records the provider prefix)
mgr.record(&provider, retry_after, 300);
// Both models from the same provider should be blocked
prop_assert!(
mgr.is_model_blocked(&model1, BlockScope::Provider),
"Model {} should be blocked with Provider scope after recording provider {}",
model1, provider
);
prop_assert!(
mgr.is_model_blocked(&model2, BlockScope::Provider),
"Model {} should be blocked with Provider scope after recording provider {}",
model2, provider
);
// Model from a different provider should NOT be blocked
prop_assert!(
!mgr.is_model_blocked(&other_model, BlockScope::Provider),
"Model {} should NOT be blocked when provider {} was recorded",
other_model, provider
);
}
/// Property 15 Case 3: Global state is visible across different "requests"
/// (same manager instance is shared).
#[test]
fn prop_global_state_shared_across_requests(
model_id in arb_model_id(),
scope in arb_scope(),
retry_after in 1u64..300,
) {
let mgr = RetryAfterStateManager::new();
// Determine the identifier to record based on scope
let identifier = match scope {
BlockScope::Model => model_id.clone(),
BlockScope::Provider => extract_provider(&model_id).to_string(),
};
mgr.record(&identifier, retry_after, 300);
// Simulate "different requests" by checking from the same manager instance.
// Global state means any check against the same manager sees the block.
// Check 1 (simulating request A)
let blocked_a = mgr.is_model_blocked(&model_id, scope);
// Check 2 (simulating request B)
let blocked_b = mgr.is_model_blocked(&model_id, scope);
prop_assert!(
blocked_a && blocked_b,
"Global state should be visible to all requests: request_a={}, request_b={}",
blocked_a, blocked_b
);
}
/// Property 15 Case 4: Request-scoped state (HashMap) is isolated per request.
/// Two separate HashMaps don't share state.
#[test]
fn prop_request_scoped_state_isolated(
model_id in arb_model_id(),
retry_after in 1u64..300,
) {
use std::collections::HashMap;
use std::time::Instant;
// Simulate request-scoped state using separate HashMaps
// (as RequestContext.request_retry_after_state would be)
let mut request_a_state: HashMap<String, Instant> = HashMap::new();
let mut request_b_state: HashMap<String, Instant> = HashMap::new();
// Request A records a Retry-After entry
let expiration = Instant::now() + Duration::from_secs(retry_after);
request_a_state.insert(model_id.clone(), expiration);
// Request A should see the block
let a_blocked = request_a_state
.get(&model_id)
.map_or(false, |exp| Instant::now() < *exp);
// Request B should NOT see the block (separate HashMap)
let b_blocked = request_b_state
.get(&model_id)
.map_or(false, |exp| Instant::now() < *exp);
prop_assert!(
a_blocked,
"Request A should see its own block for {}",
model_id
);
prop_assert!(
!b_blocked,
"Request B should NOT see Request A's block for {}",
model_id
);
// Recording in request B should not affect request A
let expiration_b = Instant::now() + Duration::from_secs(retry_after);
request_b_state.insert(model_id.clone(), expiration_b);
// Both should now be blocked independently
let a_still_blocked = request_a_state
.get(&model_id)
.map_or(false, |exp| Instant::now() < *exp);
let b_now_blocked = request_b_state
.get(&model_id)
.map_or(false, |exp| Instant::now() < *exp);
prop_assert!(a_still_blocked, "Request A should still be blocked");
prop_assert!(b_now_blocked, "Request B should now be blocked independently");
}
}
// Feature: retry-on-ratelimit, Property 16: Retry_After_State Max Expiration Update
// **Validates: Requirements 12.11**
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
/// Property 16: Recording multiple Retry-After values for the same identifier
/// should result in the expiration reflecting the maximum value, not the most recent.
#[test]
fn prop_max_expiration_update(
identifier in arb_model_id(),
// Generate 2..=10 Retry-After values, each between 1 and 600 seconds
retry_after_values in prop::collection::vec(1u64..=600, 2..=10),
max_cap in 300u64..=600,
) {
let mgr = RetryAfterStateManager::new();
// Record all values for the same identifier
for &val in &retry_after_values {
mgr.record(&identifier, val, max_cap);
}
// The effective maximum is the max of all capped values
let effective_max = retry_after_values
.iter()
.map(|&v| v.min(max_cap))
.max()
.unwrap();
// The remaining block duration should be close to the effective maximum
let remaining = mgr.remaining_block_duration(&identifier);
prop_assert!(
remaining.is_some(),
"Identifier {} should still be blocked after recording {} values (effective_max={}s)",
identifier, retry_after_values.len(), effective_max
);
let remaining_secs = remaining.unwrap().as_secs();
// The remaining duration should be within a reasonable tolerance of the
// effective maximum (allow up to 2 seconds for test execution time).
// It must be at least (effective_max - 2) to prove the max won.
prop_assert!(
remaining_secs >= effective_max.saturating_sub(2),
"Remaining {}s should reflect the max ({}s), not a smaller value. Values: {:?}",
remaining_secs, effective_max, retry_after_values
);
// It should not exceed the effective max (plus small tolerance for timing)
prop_assert!(
remaining_secs <= effective_max + 1,
"Remaining {}s should not exceed effective max {}s + tolerance. Values: {:?}",
remaining_secs, effective_max, retry_after_values
);
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -1023,8 +1023,15 @@ impl HttpContext for StreamContext {
}
};
// Set the resolved model using the trait method
deserialized_client_request.set_model(resolved_model.clone());
// Set the resolved model using the trait method.
// Strip provider prefix (e.g., "custom-aws/claude-opus-4-6" -> "claude-opus-4-6")
// so the upstream API receives only the model name it recognizes.
let upstream_model = if let Some((_prefix, model_only)) = resolved_model.split_once('/') {
model_only.to_string()
} else {
resolved_model.clone()
};
deserialized_client_request.set_model(upstream_model.clone());
// Extract user message for tracing
self.user_message = deserialized_client_request.get_recent_user_message();
@ -1056,82 +1063,93 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
// Convert chat completion request to llm provider specific request using provider interface
let serialized_body_bytes_upstream = match self.resolved_api.as_ref() {
Some(upstream) => {
info!(
"request_id={}: upstream transform, client_api={:?} -> upstream_api={:?}",
self.request_identifier(),
self.client_api,
upstream
);
match ProviderRequestType::try_from((deserialized_client_request, upstream)) {
Ok(mut request) => {
if let Err(e) =
request.normalize_for_upstream(self.get_provider_id(), upstream)
{
warn!(
"request_id={}: normalize_for_upstream failed: {}",
self.request_identifier(),
e
);
// Preserve original body bytes for prompt cache compatibility.
// Only replace the "model" field value at the byte level instead of
// deserializing + re-serializing, which destroys key order, whitespace,
// and unknown fields — breaking prompt cache prefix matching.
// Use upstream_model (prefix-stripped) so the upstream API receives
// only the model name it recognizes.
let original_model = model_requested.as_str();
let serialized_body_bytes_upstream = if original_model != upstream_model.as_str() {
match replace_json_model_value(&body_bytes, original_model, &upstream_model) {
Some(patched) => {
debug!(
"request_id={}: byte-level model replacement '{}' -> '{}'",
self.request_identifier(),
original_model,
upstream_model
);
patched
}
None => {
// Fallback: full re-serialization if byte-level replacement fails
warn!(
"request_id={}: byte-level model replacement failed, falling back to re-serialization",
self.request_identifier()
);
match self.resolved_api.as_ref() {
Some(upstream) => {
match ProviderRequestType::try_from((
deserialized_client_request,
upstream,
)) {
Ok(mut request) => {
if let Err(e) = request
.normalize_for_upstream(self.get_provider_id(), upstream)
{
warn!(
"request_id={}: normalize_for_upstream failed: {}",
self.request_identifier(),
e
);
self.send_server_error(
ServerError::LogicError(e.message),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
match request.to_bytes() {
Ok(bytes) => bytes,
Err(e) => {
self.send_server_error(
ServerError::LogicError(format!(
"Request serialization error: {}",
e
)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
}
}
Err(e) => {
self.send_server_error(
ServerError::LogicError(format!(
"Provider request error: {}",
e
)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
}
}
None => {
self.send_server_error(
ServerError::LogicError(e.message),
ServerError::LogicError("No upstream API resolved".into()),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
debug!(
"request_id={}: upstream request payload: {}",
self.request_identifier(),
String::from_utf8_lossy(&request.to_bytes().unwrap_or_default())
);
match request.to_bytes() {
Ok(bytes) => bytes,
Err(e) => {
warn!(
"request_id={}: failed to serialize request body: {}",
self.request_identifier(),
e
);
self.send_server_error(
ServerError::LogicError(format!(
"Request serialization error: {}",
e
)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
}
}
Err(e) => {
warn!(
"request_id={}: failed to create provider request: {}",
self.request_identifier(),
e
);
self.send_server_error(
ServerError::LogicError(format!("Provider request error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
}
}
None => {
warn!(
"request_id={}: no upstream api resolved",
self.request_identifier()
);
self.send_server_error(
ServerError::LogicError("No upstream API resolved".into()),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
} else {
debug!(
"request_id={}: model unchanged, passing original body through",
self.request_identifier()
);
body_bytes.clone()
};
self.set_http_request_body(0, body_size, &serialized_body_bytes_upstream);
@ -1260,6 +1278,80 @@ impl HttpContext for StreamContext {
}
}
/// Replace the value of the top-level `"model"` key in a JSON byte slice
/// without re-serializing. Returns `Some(new_bytes)` on success, `None` if the
/// pattern wasn't found (caller should fall back to full re-serialization).
///
/// This is intentionally simple and does NOT use regex (unavailable in WASM).
/// It scans for `"model"` followed by `:` and a quoted string value, then
/// splices in the new model name. Works for the common case where model values
/// are simple strings like `"gpt-4o"` without JSON escapes.
fn replace_json_model_value(body: &[u8], old_model: &str, new_model: &str) -> Option<Vec<u8>> {
// Build the needle: `"model"` (we'll then skip whitespace + colon + whitespace + opening quote)
let model_key = b"\"model\"";
// Find the position of `"model"` key
let key_pos = find_bytes(body, model_key)?;
// After the key, skip whitespace, expect ':', skip whitespace, expect '"'
let mut pos = key_pos + model_key.len();
pos = skip_json_whitespace(body, pos);
if body.get(pos)? != &b':' {
return None;
}
pos += 1;
pos = skip_json_whitespace(body, pos);
if body.get(pos)? != &b'"' {
return None;
}
let _value_start_quote = pos; // position of the opening '"'
pos += 1;
// Find the closing quote (handle escaped quotes)
let value_content_start = pos;
loop {
let ch = *body.get(pos)?;
if ch == b'\\' {
pos += 2; // skip escaped char
continue;
}
if ch == b'"' {
break;
}
pos += 1;
}
let value_content_end = pos; // position of closing '"'
// Verify the current value matches old_model
let current_value = &body[value_content_start..value_content_end];
if current_value != old_model.as_bytes() {
return None;
}
// Build new body: everything before value content + new model + everything after
let mut result = Vec::with_capacity(body.len() + new_model.len() - old_model.len());
result.extend_from_slice(&body[..value_content_start]);
result.extend_from_slice(new_model.as_bytes());
result.extend_from_slice(&body[value_content_end..]);
Some(result)
}
/// Find first occurrence of `needle` in `haystack`.
fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() || needle.len() > haystack.len() {
return None;
}
(0..=haystack.len() - needle.len()).find(|&i| &haystack[i..i + needle.len()] == needle)
}
/// Skip JSON whitespace (space, tab, newline, carriage return).
fn skip_json_whitespace(data: &[u8], mut pos: usize) -> usize {
while pos < data.len() && matches!(data[pos], b' ' | b'\t' | b'\n' | b'\r') {
pos += 1;
}
pos
}
fn current_time_ns() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)