diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index d08597dc..0bf7c2b6 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -710,3 +710,904 @@ impl From<&PromptTarget> for ChatCompletionTool { } } +#[cfg(test)] +mod test { + use pretty_assertions::assert_eq; + use std::fs; + + use super::{IntoModels, LlmProvider, LlmProviderType}; + use crate::api::open_ai::ToolType; + + use proptest::prelude::*; + + // ── Proptest Strategies for Retry Config Types ───────────────────────── + + fn arb_retry_strategy() -> impl Strategy { + prop_oneof![ + Just(super::RetryStrategy::SameModel), + Just(super::RetryStrategy::SameProvider), + Just(super::RetryStrategy::DifferentProvider), + ] + } + + fn arb_block_scope() -> impl Strategy { + prop_oneof![ + Just(super::BlockScope::Model), + Just(super::BlockScope::Provider), + ] + } + + fn arb_apply_to() -> impl Strategy { + prop_oneof![ + Just(super::ApplyTo::Global), + Just(super::ApplyTo::Request), + ] + } + + fn arb_backoff_apply_to() -> impl Strategy { + prop_oneof![ + Just(super::BackoffApplyTo::SameModel), + Just(super::BackoffApplyTo::SameProvider), + Just(super::BackoffApplyTo::Global), + ] + } + + fn arb_latency_measure() -> impl Strategy { + prop_oneof![ + Just(super::LatencyMeasure::Ttfb), + Just(super::LatencyMeasure::Total), + ] + } + + fn arb_status_code_entry() -> impl Strategy { + prop_oneof![ + (100u16..=599u16).prop_map(super::StatusCodeEntry::Single), + (100u16..=599u16) + .prop_flat_map(|start| (Just(start), start..=599u16)) + .prop_map(|(start, end)| super::StatusCodeEntry::Range(format!("{}-{}", start, end))), + ] + } + + fn arb_status_code_config() -> impl Strategy { + ( + prop::collection::vec(arb_status_code_entry(), 1..=3), + arb_retry_strategy(), + 1u32..=10u32, + ) + .prop_map(|(codes, strategy, max_attempts)| super::StatusCodeConfig { + codes, + strategy, + max_attempts, + }) + } + + fn arb_timeout_retry_config() -> impl Strategy { + (arb_retry_strategy(), 1u32..=10u32).prop_map(|(strategy, max_attempts)| { + super::TimeoutRetryConfig { + strategy, + max_attempts, + } + }) + } + + fn arb_backoff_config() -> impl Strategy { + ( + arb_backoff_apply_to(), + 1u64..=1000u64, + prop::bool::ANY, + ) + .prop_flat_map(|(apply_to, base_ms, jitter)| { + let max_ms_min = base_ms + 1; + ( + Just(apply_to), + Just(base_ms), + max_ms_min..=(base_ms + 50000), + Just(jitter), + ) + }) + .prop_map(|(apply_to, base_ms, max_ms, jitter)| super::BackoffConfig { + apply_to, + base_ms, + max_ms, + jitter, + }) + } + + fn arb_retry_after_handling_config() -> impl Strategy { + (arb_block_scope(), arb_apply_to(), 1u64..=3600u64).prop_map( + |(scope, apply_to, max_retry_after_seconds)| super::RetryAfterHandlingConfig { + scope, + apply_to, + max_retry_after_seconds, + }, + ) + } + + fn arb_high_latency_config() -> impl Strategy { + ( + 1u64..=60000u64, + arb_latency_measure(), + 1u32..=10u32, + arb_retry_strategy(), + 1u32..=10u32, + 1u64..=3600u64, + arb_block_scope(), + arb_apply_to(), + ) + .prop_map( + |( + threshold_ms, + measure, + min_triggers, + strategy, + max_attempts, + block_duration_seconds, + scope, + apply_to, + )| { + let trigger_window_seconds = if min_triggers > 1 { + Some(60u64) + } else { + None + }; + super::HighLatencyConfig { + threshold_ms, + measure, + min_triggers, + trigger_window_seconds, + strategy, + max_attempts, + block_duration_seconds, + scope, + apply_to, + } + }, + ) + } + + fn arb_retry_policy() -> impl Strategy { + ( + prop::collection::vec("[a-z]{2,6}/[a-z0-9-]{3,10}", 0..=3), + arb_retry_strategy(), + 1u32..=10u32, + prop::collection::vec(arb_status_code_config(), 0..=3), + prop::option::of(arb_timeout_retry_config()), + prop::option::of(arb_high_latency_config()), + prop::option::of(arb_backoff_config()), + prop::option::of(arb_retry_after_handling_config()), + prop::option::of(1u64..=120000u64), + ) + .prop_map( + |( + fallback_models, + default_strategy, + default_max_attempts, + on_status_codes, + on_timeout, + on_high_latency, + backoff, + retry_after_handling, + max_retry_duration_ms, + )| { + super::RetryPolicy { + fallback_models, + default_strategy, + default_max_attempts, + on_status_codes, + on_timeout, + on_high_latency, + backoff, + retry_after_handling, + max_retry_duration_ms, + } + }, + ) + } + + // ── Property Tests ───────────────────────────────────────────────────── + + // Feature: retry-on-ratelimit, Property 1: Configuration Round-Trip Parsing + // **Validates: Requirements 1.2** + proptest! { + #![proptest_config(proptest::prelude::ProptestConfig::with_cases(100))] + + /// Property 1: Configuration Round-Trip Parsing + /// Generate arbitrary valid RetryPolicy structs, serialize to YAML, + /// re-parse, and assert equivalence. + #[test] + fn prop_retry_policy_round_trip(policy in arb_retry_policy()) { + let yaml = serde_yaml::to_string(&policy) + .expect("serialization should succeed"); + let parsed: super::RetryPolicy = serde_yaml::from_str(&yaml) + .expect("deserialization should succeed"); + + // Direct structural equality — all types derive PartialEq + prop_assert_eq!(&policy, &parsed); + } + + } + + // Feature: retry-on-ratelimit, Property 2: Configuration Defaults Applied Correctly + // **Validates: Requirements 1.2** + proptest! { + #![proptest_config(proptest::prelude::ProptestConfig::with_cases(100))] + + /// Property 2: Configuration Defaults Applied Correctly + /// Generate RetryPolicy YAML with optional fields omitted, parse, + /// and assert correct defaults are applied. + #[test] + fn prop_retry_policy_defaults( + include_on_status_codes in prop::bool::ANY, + include_backoff in prop::bool::ANY, + include_retry_after in prop::bool::ANY, + include_on_timeout in prop::bool::ANY, + include_on_high_latency in prop::bool::ANY, + ) { + // Build a minimal YAML — RetryPolicy has serde defaults for all fields, + // so even an empty mapping is valid. + let mut parts: Vec = Vec::new(); + + // When we include sections, only provide required sub-fields so + // we can verify the optional sub-fields get their defaults. + if include_on_status_codes { + parts.push("on_status_codes:\n - codes: [429]\n strategy: same_model\n max_attempts: 2".to_string()); + } + if include_backoff { + parts.push("backoff:\n apply_to: global".to_string()); + } + if include_retry_after { + parts.push("retry_after_handling:\n scope: provider".to_string()); + } + if include_on_timeout { + parts.push("on_timeout:\n strategy: same_model\n max_attempts: 1".to_string()); + } + if include_on_high_latency { + parts.push("on_high_latency:\n threshold_ms: 5000\n strategy: different_provider\n max_attempts: 2".to_string()); + } + + let yaml = if parts.is_empty() { + "{}".to_string() + } else { + parts.join("\n") + }; + + let parsed: super::RetryPolicy = serde_yaml::from_str(&yaml) + .expect("deserialization should succeed"); + + // Assert top-level defaults + prop_assert_eq!(parsed.default_strategy, super::RetryStrategy::DifferentProvider); + prop_assert_eq!(parsed.default_max_attempts, 2); + prop_assert!(parsed.fallback_models.is_empty()); + prop_assert_eq!(parsed.max_retry_duration_ms, None); + + // Assert on_status_codes defaults to empty vec + if !include_on_status_codes { + prop_assert!(parsed.on_status_codes.is_empty()); + } + + // Assert backoff defaults when present + if include_backoff { + let backoff = parsed.backoff.as_ref().unwrap(); + prop_assert_eq!(backoff.base_ms, 100); + prop_assert_eq!(backoff.max_ms, 5000); + prop_assert_eq!(backoff.jitter, true); + } else { + prop_assert!(parsed.backoff.is_none()); + } + + // Assert retry_after_handling defaults when present + if include_retry_after { + let rah = parsed.retry_after_handling.as_ref().unwrap(); + prop_assert_eq!(rah.scope, super::BlockScope::Provider); // explicitly set + prop_assert_eq!(rah.apply_to, super::ApplyTo::Global); // default + prop_assert_eq!(rah.max_retry_after_seconds, 300); // default + } else { + prop_assert!(parsed.retry_after_handling.is_none()); + } + + // Assert effective_retry_after_config always returns valid defaults + let effective = parsed.effective_retry_after_config(); + if include_retry_after { + prop_assert_eq!(effective.scope, super::BlockScope::Provider); + } else { + prop_assert_eq!(effective.scope, super::BlockScope::Model); + } + prop_assert_eq!(effective.apply_to, super::ApplyTo::Global); + prop_assert_eq!(effective.max_retry_after_seconds, 300); + + // Assert high latency defaults when present + if include_on_high_latency { + let hl = parsed.on_high_latency.as_ref().unwrap(); + prop_assert_eq!(hl.measure, super::LatencyMeasure::Ttfb); // default + prop_assert_eq!(hl.min_triggers, 1); // default + prop_assert_eq!(hl.block_duration_seconds, 300); // default + prop_assert_eq!(hl.scope, super::BlockScope::Model); // default + prop_assert_eq!(hl.apply_to, super::ApplyTo::Global); // default + } + } + } + + #[test] + fn test_deserialize_configuration() { + let ref_config = fs::read_to_string( + "../../docs/source/resources/includes/plano_config_full_reference_rendered.yaml", + ) + .expect("reference config file not found"); + + let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); + assert_eq!(config.version, "v0.3.0"); + + if let Some(prompt_targets) = &config.prompt_targets { + assert!( + !prompt_targets.is_empty(), + "prompt_targets should not be empty if present" + ); + } + + if let Some(tracing) = config.tracing.as_ref() { + if let Some(sampling_rate) = tracing.sampling_rate { + assert_eq!(sampling_rate, 0.1); + } + } + + let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt); + assert_eq!(*mode, super::GatewayMode::Prompt); + } + + #[test] + fn test_tool_conversion() { + let ref_config = fs::read_to_string( + "../../docs/source/resources/includes/plano_config_full_reference_rendered.yaml", + ) + .expect("reference config file not found"); + let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); + if let Some(prompt_targets) = &config.prompt_targets { + if let Some(prompt_target) = prompt_targets + .iter() + .find(|p| p.name == "reboot_network_device") + { + let chat_completion_tool: super::ChatCompletionTool = prompt_target.into(); + assert_eq!(chat_completion_tool.tool_type, ToolType::Function); + assert_eq!(chat_completion_tool.function.name, "reboot_network_device"); + assert_eq!( + chat_completion_tool.function.description, + "Reboot a specific network device" + ); + assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2); + assert!(chat_completion_tool + .function + .parameters + .properties + .contains_key("device_id")); + let device_id_param = chat_completion_tool + .function + .parameters + .properties + .get("device_id") + .unwrap(); + assert_eq!( + device_id_param.parameter_type, + crate::api::open_ai::ParameterType::String + ); + assert_eq!( + device_id_param.description, + "Identifier of the network device to reboot.".to_string() + ); + assert_eq!(device_id_param.required, Some(true)); + let confirmation_param = chat_completion_tool + .function + .parameters + .properties + .get("confirmation") + .unwrap(); + assert_eq!( + confirmation_param.parameter_type, + crate::api::open_ai::ParameterType::Bool + ); + } + } + } + + // Feature: retry-on-ratelimit, Property 4: Status Code Range Expansion + // **Validates: Requirements 1.8** + proptest! { + #![proptest_config(proptest::prelude::ProptestConfig::with_cases(100))] + + /// Property 4: Status Code Range Expansion — degenerate range (start == end) + /// A range "N-N" should expand to a single-element vec containing N. + #[test] + fn prop_status_code_range_expansion( + code in 100u16..=599u16, + ) { + let range_str = format!("{}-{}", code, code); + let entry = super::StatusCodeEntry::Range(range_str); + let expanded = entry.expand().expect("expand should succeed for valid range"); + prop_assert_eq!(expanded.len(), 1); + prop_assert_eq!(expanded[0], code); + } + + /// Property 4: Status Code Range Expansion — Single variant + /// Generate arbitrary code (100..=599), expand, assert vec of length 1 containing that code. + #[test] + fn prop_status_code_single_expansion(code in 100u16..=599u16) { + let entry = super::StatusCodeEntry::Single(code); + let expanded = entry.expand().expect("expand should succeed for Single"); + prop_assert_eq!(expanded.len(), 1); + prop_assert_eq!(expanded[0], code); + } + } + + proptest! { + #![proptest_config(proptest::prelude::ProptestConfig::with_cases(100))] + + /// Property 4: Status Code Range Expansion — arbitrary start..=end range + /// Generate arbitrary valid range strings "start-end" (100 ≤ start ≤ end ≤ 599), + /// expand, and assert correct count and bounds. + #[test] + fn prop_status_code_range_expansion_full( + (start, end) in (100u16..=599u16).prop_flat_map(|s| (Just(s), s..=599u16)) + ) { + let range_str = format!("{}-{}", start, end); + let entry = super::StatusCodeEntry::Range(range_str); + let expanded = entry.expand().expect("expand should succeed for valid range"); + + let expected_len = (end - start + 1) as usize; + prop_assert_eq!(expanded.len(), expected_len, "length should be end - start + 1"); + prop_assert_eq!(*expanded.first().unwrap(), start, "first element should be start"); + prop_assert_eq!(*expanded.last().unwrap(), end, "last element should be end"); + + for &code in &expanded { + prop_assert!(code >= start && code <= end, "all codes should be in [start, end]"); + } + } + } + + #[test] + fn test_into_models_filters_internal_providers() { + let providers = vec![ + LlmProvider { + name: "openai-gpt4".to_string(), + provider_interface: LlmProviderType::OpenAI, + model: Some("gpt-4".to_string()), + internal: None, + ..Default::default() + }, + LlmProvider { + name: "arch-router".to_string(), + provider_interface: LlmProviderType::Arch, + model: Some("Arch-Router".to_string()), + internal: Some(true), + ..Default::default() + }, + LlmProvider { + name: "plano-orchestrator".to_string(), + provider_interface: LlmProviderType::Arch, + model: Some("Plano-Orchestrator".to_string()), + internal: Some(true), + ..Default::default() + }, + ]; + + let models = providers.into_models(); + + // Should only have 1 model: openai-gpt4 + assert_eq!(models.data.len(), 1); + + // Verify internal models are excluded from /v1/models + let model_ids: Vec = models.data.iter().map(|m| m.id.clone()).collect(); + assert!(model_ids.contains(&"openai-gpt4".to_string())); + assert!(!model_ids.contains(&"arch-router".to_string())); + assert!(!model_ids.contains(&"plano-orchestrator".to_string())); + } + + // ── P0 Edge Case Tests: YAML Config Pattern Parsing ──────────────────── + + /// Helper to parse a RetryPolicy from a YAML string. + fn parse_retry_policy(yaml: &str) -> super::RetryPolicy { + serde_yaml::from_str(yaml).expect("YAML should parse into RetryPolicy") + } + + #[test] + fn test_pattern1_multi_provider_failover_for_rate_limits() { + let yaml = r#" + fallback_models: [anthropic/claude-3-5-sonnet] + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + "#; + let policy = parse_retry_policy(yaml); + assert_eq!(policy.fallback_models, vec!["anthropic/claude-3-5-sonnet"]); + assert_eq!(policy.on_status_codes.len(), 1); + assert_eq!(policy.on_status_codes[0].strategy, super::RetryStrategy::DifferentProvider); + assert_eq!(policy.on_status_codes[0].max_attempts, 2); + } + + #[test] + fn test_pattern2_same_provider_failover_with_model_downgrade() { + let yaml = r#" + fallback_models: [openai/gpt-4o-mini, anthropic/claude-3-5-sonnet] + on_status_codes: + - codes: [429] + strategy: "same_provider" + max_attempts: 2 + "#; + let policy = parse_retry_policy(yaml); + assert_eq!(policy.fallback_models.len(), 2); + assert_eq!(policy.on_status_codes[0].strategy, super::RetryStrategy::SameProvider); + } + + #[test] + fn test_pattern3_single_model_with_backoff_on_multiple_error_types() { + let yaml = r#" + fallback_models: [] + on_status_codes: + - codes: [429] + strategy: "same_model" + max_attempts: 3 + - codes: [503] + strategy: "same_model" + max_attempts: 3 + backoff: + apply_to: "same_model" + base_ms: 500 + "#; + let policy = parse_retry_policy(yaml); + assert!(policy.fallback_models.is_empty()); + assert_eq!(policy.on_status_codes.len(), 2); + let backoff = policy.backoff.unwrap(); + assert_eq!(backoff.apply_to, super::BackoffApplyTo::SameModel); + assert_eq!(backoff.base_ms, 500); + // max_ms defaults to 5000 + assert_eq!(backoff.max_ms, 5000); + } + + #[test] + fn test_pattern4_per_status_code_strategy_customization() { + let yaml = r#" + fallback_models: [openai/gpt-4o-mini, anthropic/claude-3-5-sonnet] + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "same_provider" + max_attempts: 2 + - codes: [502] + strategy: "different_provider" + max_attempts: 3 + - codes: [503] + strategy: "same_model" + max_attempts: 2 + - codes: [504] + strategy: "different_provider" + max_attempts: 2 + on_timeout: + strategy: "different_provider" + max_attempts: 2 + "#; + let policy = parse_retry_policy(yaml); + assert_eq!(policy.default_strategy, super::RetryStrategy::DifferentProvider); + assert_eq!(policy.default_max_attempts, 2); + assert_eq!(policy.on_status_codes.len(), 4); + assert_eq!(policy.on_status_codes[2].strategy, super::RetryStrategy::SameModel); + let timeout = policy.on_timeout.unwrap(); + assert_eq!(timeout.strategy, super::RetryStrategy::DifferentProvider); + assert_eq!(timeout.max_attempts, 2); + } + + #[test] + fn test_pattern5_timeout_specific_configuration() { + let yaml = r#" + fallback_models: [anthropic/claude-3-5-sonnet] + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "same_provider" + max_attempts: 2 + on_timeout: + strategy: "different_provider" + max_attempts: 3 + "#; + let policy = parse_retry_policy(yaml); + let timeout = policy.on_timeout.unwrap(); + assert_eq!(timeout.max_attempts, 3); + } + + #[test] + fn test_pattern6_no_retry_parses_as_empty() { + // Pattern 6: No retry_policy section. We test that an empty YAML + // object parses with all defaults. + let yaml = "{}"; + let policy = parse_retry_policy(yaml); + assert!(policy.fallback_models.is_empty()); + assert_eq!(policy.default_strategy, super::RetryStrategy::DifferentProvider); + assert_eq!(policy.default_max_attempts, 2); + assert!(policy.on_status_codes.is_empty()); + assert!(policy.on_timeout.is_none()); + assert!(policy.backoff.is_none()); + assert!(policy.max_retry_duration_ms.is_none()); + } + + #[test] + fn test_pattern7_backoff_only_for_same_model() { + let yaml = r#" + fallback_models: [anthropic/claude-3-5-sonnet] + on_status_codes: + - codes: [429] + strategy: "same_model" + max_attempts: 2 + backoff: + apply_to: "same_model" + base_ms: 100 + max_ms: 5000 + jitter: true + "#; + let policy = parse_retry_policy(yaml); + let backoff = policy.backoff.unwrap(); + assert_eq!(backoff.apply_to, super::BackoffApplyTo::SameModel); + assert!(backoff.jitter); + } + + #[test] + fn test_pattern8_backoff_for_same_provider() { + let yaml = r#" + fallback_models: [openai/gpt-4o-mini, anthropic/claude-3-5-sonnet] + on_status_codes: + - codes: [429] + strategy: "same_provider" + max_attempts: 2 + backoff: + apply_to: "same_provider" + base_ms: 200 + max_ms: 10000 + jitter: true + "#; + let policy = parse_retry_policy(yaml); + let backoff = policy.backoff.unwrap(); + assert_eq!(backoff.apply_to, super::BackoffApplyTo::SameProvider); + assert_eq!(backoff.base_ms, 200); + assert_eq!(backoff.max_ms, 10000); + } + + #[test] + fn test_pattern9_global_backoff() { + let yaml = r#" + fallback_models: [anthropic/claude-3-5-sonnet] + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + backoff: + apply_to: "global" + base_ms: 50 + max_ms: 2000 + jitter: true + "#; + let policy = parse_retry_policy(yaml); + let backoff = policy.backoff.unwrap(); + assert_eq!(backoff.apply_to, super::BackoffApplyTo::Global); + assert_eq!(backoff.base_ms, 50); + assert_eq!(backoff.max_ms, 2000); + } + + #[test] + fn test_pattern10_deterministic_backoff_without_jitter() { + let yaml = r#" + fallback_models: [] + on_status_codes: + - codes: [429] + strategy: "same_model" + max_attempts: 3 + backoff: + apply_to: "same_model" + base_ms: 1000 + max_ms: 30000 + jitter: false + "#; + let policy = parse_retry_policy(yaml); + let backoff = policy.backoff.unwrap(); + assert!(!backoff.jitter); + assert_eq!(backoff.base_ms, 1000); + assert_eq!(backoff.max_ms, 30000); + } + + #[test] + fn test_pattern11_no_backoff_fast_failover() { + let yaml = r#" + fallback_models: [anthropic/claude-3-5-sonnet] + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + "#; + let policy = parse_retry_policy(yaml); + assert!(policy.backoff.is_none()); + } + + #[test] + fn test_pattern17_mixed_integer_and_range_codes() { + let yaml = r#" + fallback_models: [anthropic/claude-3-5-sonnet] + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429, "430-450", 526] + strategy: "same_provider" + max_attempts: 2 + - codes: ["502-504"] + strategy: "different_provider" + max_attempts: 3 + "#; + let policy = parse_retry_policy(yaml); + assert_eq!(policy.on_status_codes.len(), 2); + + // Verify first entry: 429 + range 430-450 + 526 + let first = &policy.on_status_codes[0]; + assert_eq!(first.codes.len(), 3); + let expanded: Vec = first.codes.iter() + .flat_map(|c| c.expand().unwrap()) + .collect(); + // 429 + (430..=450 = 21 codes) + 526 = 23 codes + assert_eq!(expanded.len(), 23); + assert!(expanded.contains(&429)); + assert!(expanded.contains(&430)); + assert!(expanded.contains(&450)); + assert!(expanded.contains(&526)); + assert!(!expanded.contains(&451)); + + // Verify second entry: range 502-504 + let second = &policy.on_status_codes[1]; + let expanded2: Vec = second.codes.iter() + .flat_map(|c| c.expand().unwrap()) + .collect(); + assert_eq!(expanded2, vec![502, 503, 504]); + } + + #[test] + fn test_pattern12_model_level_retry_after_blocking() { + let yaml = r#" + fallback_models: [openai/gpt-4o-mini, anthropic/claude-3-5-sonnet] + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + - codes: [503] + strategy: "different_provider" + max_attempts: 2 + retry_after_handling: + scope: "model" + apply_to: "global" + "#; + let policy = parse_retry_policy(yaml); + assert_eq!(policy.fallback_models.len(), 2); + assert_eq!(policy.on_status_codes.len(), 2); + let rah = policy.retry_after_handling.unwrap(); + assert_eq!(rah.scope, super::BlockScope::Model); + assert_eq!(rah.apply_to, super::ApplyTo::Global); + // max_retry_after_seconds defaults to 300 + assert_eq!(rah.max_retry_after_seconds, 300); + } + + #[test] + fn test_pattern13_provider_level_retry_after_blocking() { + let yaml = r#" + fallback_models: [anthropic/claude-3-5-sonnet] + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + - codes: [503] + strategy: "different_provider" + max_attempts: 2 + - codes: [502] + strategy: "different_provider" + max_attempts: 2 + retry_after_handling: + scope: "provider" + apply_to: "global" + "#; + let policy = parse_retry_policy(yaml); + assert_eq!(policy.on_status_codes.len(), 3); + let rah = policy.retry_after_handling.unwrap(); + assert_eq!(rah.scope, super::BlockScope::Provider); + assert_eq!(rah.apply_to, super::ApplyTo::Global); + assert_eq!(rah.max_retry_after_seconds, 300); + } + + #[test] + fn test_pattern14_request_level_retry_after() { + let yaml = r#" + fallback_models: [anthropic/claude-3-5-sonnet] + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + - codes: [503] + strategy: "different_provider" + max_attempts: 2 + retry_after_handling: + scope: "model" + apply_to: "request" + "#; + let policy = parse_retry_policy(yaml); + let rah = policy.retry_after_handling.unwrap(); + assert_eq!(rah.scope, super::BlockScope::Model); + assert_eq!(rah.apply_to, super::ApplyTo::Request); + assert_eq!(rah.max_retry_after_seconds, 300); + } + + #[test] + fn test_pattern15_no_custom_retry_after_config_defaults_plus_backoff() { + let yaml = r#" + fallback_models: [] + on_status_codes: + - codes: [429] + strategy: "same_model" + max_attempts: 3 + - codes: [503] + strategy: "same_model" + max_attempts: 3 + backoff: + apply_to: "same_model" + base_ms: 1000 + max_ms: 30000 + jitter: true + "#; + let policy = parse_retry_policy(yaml); + // No retry_after_handling section → None + assert!(policy.retry_after_handling.is_none()); + // But effective config should return defaults + let effective = policy.effective_retry_after_config(); + assert_eq!(effective.scope, super::BlockScope::Model); + assert_eq!(effective.apply_to, super::ApplyTo::Global); + assert_eq!(effective.max_retry_after_seconds, 300); + // Backoff is present + let backoff = policy.backoff.unwrap(); + assert_eq!(backoff.apply_to, super::BackoffApplyTo::SameModel); + assert_eq!(backoff.base_ms, 1000); + assert_eq!(backoff.max_ms, 30000); + assert!(backoff.jitter); + } + + #[test] + fn test_pattern16_fallback_models_list_for_targeted_failover() { + let yaml = r#" + fallback_models: [openai/gpt-4o-mini, anthropic/claude-3-5-sonnet, anthropic/claude-3-opus] + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "same_provider" + max_attempts: 2 + "#; + let policy = parse_retry_policy(yaml); + assert_eq!(policy.fallback_models, vec![ + "openai/gpt-4o-mini", + "anthropic/claude-3-5-sonnet", + "anthropic/claude-3-opus", + ]); + assert_eq!(policy.default_strategy, super::RetryStrategy::DifferentProvider); + assert_eq!(policy.default_max_attempts, 2); + assert_eq!(policy.on_status_codes.len(), 1); + assert_eq!(policy.on_status_codes[0].strategy, super::RetryStrategy::SameProvider); + } + + #[test] + fn test_backoff_without_apply_to_fails_deserialization() { + // backoff.apply_to is a required field (no serde default), so YAML + // without it should fail to deserialize. + let yaml = r#" + on_status_codes: + - codes: [429] + strategy: "same_model" + max_attempts: 2 + backoff: + base_ms: 100 + max_ms: 5000 + "#; + let result: Result = serde_yaml::from_str(yaml); + assert!(result.is_err(), "backoff without apply_to should fail deserialization"); + } + +} diff --git a/crates/common/src/retry/backoff.rs b/crates/common/src/retry/backoff.rs index 6756ed56..545c9276 100644 --- a/crates/common/src/retry/backoff.rs +++ b/crates/common/src/retry/backoff.rs @@ -78,3 +78,305 @@ impl BackoffCalculator { } } +#[cfg(test)] +mod tests { + use super::*; + use crate::configuration::{BackoffApplyTo, BackoffConfig, RetryStrategy}; + use proptest::prelude::*; + + fn make_config(apply_to: BackoffApplyTo, base_ms: u64, max_ms: u64, jitter: bool) -> BackoffConfig { + BackoffConfig { apply_to, base_ms, max_ms, jitter } + } + + #[test] + fn no_backoff_config_returns_zero() { + let calc = BackoffCalculator; + let d = calc.calculate_delay(0, None, None, RetryStrategy::SameModel, "openai/gpt-4o", "openai/gpt-4o"); + assert_eq!(d, Duration::ZERO); + } + + #[test] + fn no_backoff_config_with_retry_after() { + let calc = BackoffCalculator; + let d = calc.calculate_delay(0, None, Some(5), RetryStrategy::SameModel, "openai/gpt-4o", "openai/gpt-4o"); + assert_eq!(d, Duration::from_secs(5)); + } + + #[test] + fn exponential_backoff_no_jitter() { + let calc = BackoffCalculator; + let config = make_config(BackoffApplyTo::Global, 100, 5000, false); + + // attempt 0: min(100 * 2^0, 5000) = 100 + assert_eq!(calc.calculate_delay(0, Some(&config), None, RetryStrategy::SameModel, "a", "a"), Duration::from_millis(100)); + // attempt 1: min(100 * 2^1, 5000) = 200 + assert_eq!(calc.calculate_delay(1, Some(&config), None, RetryStrategy::SameModel, "a", "a"), Duration::from_millis(200)); + // attempt 2: min(100 * 2^2, 5000) = 400 + assert_eq!(calc.calculate_delay(2, Some(&config), None, RetryStrategy::SameModel, "a", "a"), Duration::from_millis(400)); + // attempt 6: min(100 * 64, 5000) = 5000 (capped) + assert_eq!(calc.calculate_delay(6, Some(&config), None, RetryStrategy::SameModel, "a", "a"), Duration::from_millis(5000)); + } + + #[test] + fn jitter_stays_within_bounds() { + let calc = BackoffCalculator; + let config = make_config(BackoffApplyTo::Global, 1000, 50000, true); + + for attempt in 0..5 { + for _ in 0..20 { + let d = calc.calculate_delay(attempt, Some(&config), None, RetryStrategy::SameModel, "a", "a"); + let base = (1000u64.saturating_mul(1u64 << attempt)).min(50000); + // jitter: delay * (0.5 + random(0, 0.5)) => [0.5*base, 1.0*base] + assert!(d.as_millis() >= (base as f64 * 0.5) as u128, "delay {} too low for base {}", d.as_millis(), base); + assert!(d.as_millis() <= base as u128, "delay {} too high for base {}", d.as_millis(), base); + } + } + } + + #[test] + fn scope_same_model_filters_different_providers() { + let calc = BackoffCalculator; + let config = make_config(BackoffApplyTo::SameModel, 100, 5000, false); + + // Same model -> backoff applies + let d = calc.calculate_delay(0, Some(&config), None, RetryStrategy::SameModel, "openai/gpt-4o", "openai/gpt-4o"); + assert_eq!(d, Duration::from_millis(100)); + + // Different model, same provider -> no backoff + let d = calc.calculate_delay(0, Some(&config), None, RetryStrategy::SameProvider, "openai/gpt-4o-mini", "openai/gpt-4o"); + assert_eq!(d, Duration::ZERO); + + // Different provider -> no backoff + let d = calc.calculate_delay(0, Some(&config), None, RetryStrategy::DifferentProvider, "anthropic/claude", "openai/gpt-4o"); + assert_eq!(d, Duration::ZERO); + } + + #[test] + fn scope_same_provider_filters_different_providers() { + let calc = BackoffCalculator; + let config = make_config(BackoffApplyTo::SameProvider, 100, 5000, false); + + // Same provider -> backoff applies + let d = calc.calculate_delay(0, Some(&config), None, RetryStrategy::SameProvider, "openai/gpt-4o-mini", "openai/gpt-4o"); + assert_eq!(d, Duration::from_millis(100)); + + // Same model (same provider) -> backoff applies + let d = calc.calculate_delay(0, Some(&config), None, RetryStrategy::SameModel, "openai/gpt-4o", "openai/gpt-4o"); + assert_eq!(d, Duration::from_millis(100)); + + // Different provider -> no backoff + let d = calc.calculate_delay(0, Some(&config), None, RetryStrategy::DifferentProvider, "anthropic/claude", "openai/gpt-4o"); + assert_eq!(d, Duration::ZERO); + } + + #[test] + fn scope_global_always_applies() { + let calc = BackoffCalculator; + let config = make_config(BackoffApplyTo::Global, 100, 5000, false); + + let d = calc.calculate_delay(0, Some(&config), None, RetryStrategy::DifferentProvider, "anthropic/claude", "openai/gpt-4o"); + assert_eq!(d, Duration::from_millis(100)); + } + + #[test] + fn retry_after_wins_when_greater() { + let calc = BackoffCalculator; + let config = make_config(BackoffApplyTo::Global, 100, 5000, false); + + // retry_after = 10s >> backoff attempt 0 = 100ms + let d = calc.calculate_delay(0, Some(&config), Some(10), RetryStrategy::SameModel, "a", "a"); + assert_eq!(d, Duration::from_secs(10)); + } + + #[test] + fn backoff_wins_when_greater() { + let calc = BackoffCalculator; + // base_ms=10000, attempt 0 -> 10000ms = 10s + let config = make_config(BackoffApplyTo::Global, 10000, 50000, false); + + // retry_after = 5s < backoff = 10s + let d = calc.calculate_delay(0, Some(&config), Some(5), RetryStrategy::SameModel, "a", "a"); + assert_eq!(d, Duration::from_millis(10000)); + } + + #[test] + fn scope_mismatch_still_honors_retry_after() { + let calc = BackoffCalculator; + let config = make_config(BackoffApplyTo::SameModel, 100, 5000, false); + + // Scope doesn't match (different providers) but retry_after is set + let d = calc.calculate_delay(0, Some(&config), Some(3), RetryStrategy::DifferentProvider, "anthropic/claude", "openai/gpt-4o"); + assert_eq!(d, Duration::from_secs(3)); + } + + #[test] + fn large_attempt_number_saturates() { + let calc = BackoffCalculator; + let config = make_config(BackoffApplyTo::Global, 100, 5000, false); + + // Very large attempt number should saturate and cap at max_ms + let d = calc.calculate_delay(63, Some(&config), None, RetryStrategy::SameModel, "a", "a"); + assert_eq!(d, Duration::from_millis(5000)); + } + + // --- Proptest strategies --- + + fn arb_provider() -> impl Strategy { + prop_oneof![ + Just("openai/gpt-4o".to_string()), + Just("openai/gpt-4o-mini".to_string()), + Just("anthropic/claude-3".to_string()), + Just("azure/gpt-4o".to_string()), + Just("google/gemini-pro".to_string()), + ] + } + + // Feature: retry-on-ratelimit, Property 12: Exponential Backoff Formula and Bounds + // **Validates: Requirements 4.6, 4.7, 4.8, 4.9, 4.10, 4.11** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 12 – Case 1: No-jitter delay equals min(base_ms * 2^attempt, max_ms) exactly. + #[test] + fn prop_backoff_no_jitter_exact( + attempt in 0u32..20, + base_ms in 1u64..10000, + extra in 1u64..40001u64, + ) { + let max_ms = base_ms + extra; + let config = make_config(BackoffApplyTo::Global, base_ms, max_ms, false); + let calc = BackoffCalculator; + let d = calc.calculate_delay(attempt, Some(&config), None, RetryStrategy::SameModel, "a", "a"); + + let expected = if attempt >= 64 { + max_ms + } else { + base_ms.saturating_mul(1u64 << attempt).min(max_ms) + }; + prop_assert_eq!(d, Duration::from_millis(expected)); + } + + /// Property 12 – Case 2: Jitter delay is in [0.5 * computed_base, computed_base]. + #[test] + fn prop_backoff_jitter_bounds( + attempt in 0u32..20, + base_ms in 1u64..10000, + extra in 1u64..40001u64, + ) { + let max_ms = base_ms + extra; + let config = make_config(BackoffApplyTo::Global, base_ms, max_ms, true); + let calc = BackoffCalculator; + let d = calc.calculate_delay(attempt, Some(&config), None, RetryStrategy::SameModel, "a", "a"); + + let computed_base = if attempt >= 64 { + max_ms + } else { + base_ms.saturating_mul(1u64 << attempt).min(max_ms) + }; + let lower = (computed_base as f64 * 0.5) as u64; + let upper = computed_base; + prop_assert!( + d.as_millis() >= lower as u128 && d.as_millis() <= upper as u128, + "delay {}ms not in [{}, {}] for attempt={}, base_ms={}, max_ms={}", + d.as_millis(), lower, upper, attempt, base_ms, max_ms + ); + } + + /// Property 12 – Case 3: Delay is always <= max_ms. + #[test] + fn prop_backoff_delay_capped_at_max( + attempt in 0u32..20, + base_ms in 1u64..10000, + extra in 1u64..40001u64, + jitter in proptest::bool::ANY, + ) { + let max_ms = base_ms + extra; + let config = make_config(BackoffApplyTo::Global, base_ms, max_ms, jitter); + let calc = BackoffCalculator; + let d = calc.calculate_delay(attempt, Some(&config), None, RetryStrategy::SameModel, "a", "a"); + + prop_assert!( + d.as_millis() <= max_ms as u128, + "delay {}ms exceeds max_ms {} for attempt={}, base_ms={}, jitter={}", + d.as_millis(), max_ms, attempt, base_ms, jitter + ); + } + } + + // Feature: retry-on-ratelimit, Property 13: Backoff Apply-To Scope Filtering + // **Validates: Requirements 4.3, 4.4, 4.5, 4.12, 4.13** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 13 – Case 1: SameModel apply_to with different providers → zero delay. + #[test] + fn prop_scope_same_model_different_providers_zero( + attempt in 0u32..20, + base_ms in 1u64..10000, + extra in 1u64..40001u64, + current in arb_provider(), + previous in arb_provider(), + ) { + // Only test when providers are actually different models + prop_assume!(current != previous); + let max_ms = base_ms + extra; + let config = make_config(BackoffApplyTo::SameModel, base_ms, max_ms, false); + let calc = BackoffCalculator; + let d = calc.calculate_delay( + attempt, Some(&config), None, + RetryStrategy::DifferentProvider, ¤t, &previous, + ); + prop_assert_eq!(d, Duration::ZERO, + "Expected zero delay for SameModel apply_to with different models: {} vs {}", + current, previous + ); + } + + /// Property 13 – Case 2: SameProvider apply_to with different provider prefixes → zero delay. + #[test] + fn prop_scope_same_provider_different_prefix_zero( + attempt in 0u32..20, + base_ms in 1u64..10000, + extra in 1u64..40001u64, + current in arb_provider(), + previous in arb_provider(), + ) { + let current_prefix = extract_provider(¤t); + let previous_prefix = extract_provider(&previous); + prop_assume!(current_prefix != previous_prefix); + let max_ms = base_ms + extra; + let config = make_config(BackoffApplyTo::SameProvider, base_ms, max_ms, false); + let calc = BackoffCalculator; + let d = calc.calculate_delay( + attempt, Some(&config), None, + RetryStrategy::DifferentProvider, ¤t, &previous, + ); + prop_assert_eq!(d, Duration::ZERO, + "Expected zero delay for SameProvider apply_to with different prefixes: {} vs {}", + current_prefix, previous_prefix + ); + } + + /// Property 13 – Case 3: Global apply_to always produces non-zero delay. + #[test] + fn prop_scope_global_always_nonzero( + attempt in 0u32..20, + base_ms in 1u64..10000, + extra in 1u64..40001u64, + current in arb_provider(), + previous in arb_provider(), + ) { + let max_ms = base_ms + extra; + let config = make_config(BackoffApplyTo::Global, base_ms, max_ms, false); + let calc = BackoffCalculator; + let d = calc.calculate_delay( + attempt, Some(&config), None, + RetryStrategy::DifferentProvider, ¤t, &previous, + ); + prop_assert!(d > Duration::ZERO, + "Expected non-zero delay for Global apply_to: current={}, previous={}", + current, previous + ); + } + } +} diff --git a/crates/common/src/retry/error_detector.rs b/crates/common/src/retry/error_detector.rs index 1fd36a16..edcf47ab 100644 --- a/crates/common/src/retry/error_detector.rs +++ b/crates/common/src/retry/error_detector.rs @@ -207,3 +207,724 @@ fn extract_retry_after(response: &HttpResponse) -> Option { .and_then(|s| s.trim().parse::().ok()) } +#[cfg(test)] +mod tests { + use super::*; + use crate::configuration::{ + StatusCodeConfig, TimeoutRetryConfig, + }; + use bytes::Bytes; + use http_body_util::{BodyExt, Full}; + + /// Helper to build an HttpResponse with a given status code. + fn make_response(status: u16) -> HttpResponse { + make_response_with_headers(status, vec![]) + } + + /// Helper to build an HttpResponse with a given status code and headers. + fn make_response_with_headers(status: u16, headers: Vec<(&str, &str)>) -> HttpResponse { + let body = Full::new(Bytes::from("test body")) + .map_err(|_| unreachable!()) + .boxed(); + let mut builder = Response::builder().status(status); + for (name, value) in headers { + builder = builder.header(name, value); + } + builder.body(body).unwrap() + } + + fn basic_retry_policy() -> RetryPolicy { + RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![ + StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::SameProvider, + max_attempts: 3, + }, + StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(503)], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 4, + }, + ], + on_timeout: Some(TimeoutRetryConfig { + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }), + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + } + } + + // ── classify tests ───────────────────────────────────────────────── + + #[test] + fn classify_2xx_returns_success() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let resp = make_response(200); + let result = detector.classify(Ok(resp), &policy, 0, 0); + assert!(matches!(result, ErrorClassification::Success(_))); + } + + #[test] + fn classify_201_returns_success() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let resp = make_response(201); + let result = detector.classify(Ok(resp), &policy, 0, 0); + assert!(matches!(result, ErrorClassification::Success(_))); + } + + #[test] + fn classify_429_returns_retriable_error() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let resp = make_response(429); + let result = detector.classify(Ok(resp), &policy, 0, 0); + match result { + ErrorClassification::RetriableError { status_code, .. } => { + assert_eq!(status_code, 429); + } + other => panic!("Expected RetriableError, got {:?}", other), + } + } + + #[test] + fn classify_503_returns_retriable_error() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let resp = make_response(503); + let result = detector.classify(Ok(resp), &policy, 0, 0); + match result { + ErrorClassification::RetriableError { status_code, .. } => { + assert_eq!(status_code, 503); + } + other => panic!("Expected RetriableError, got {:?}", other), + } + } + + #[test] + fn classify_unconfigured_4xx_returns_retriable_with_defaults() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let resp = make_response(400); + let result = detector.classify(Ok(resp), &policy, 0, 0); + match result { + ErrorClassification::RetriableError { status_code, .. } => { + assert_eq!(status_code, 400); + } + other => panic!("Expected RetriableError for unconfigured 4xx, got {:?}", other), + } + } + + #[test] + fn classify_unconfigured_5xx_returns_retriable_with_defaults() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let resp = make_response(502); + let result = detector.classify(Ok(resp), &policy, 0, 0); + match result { + ErrorClassification::RetriableError { status_code, .. } => { + assert_eq!(status_code, 502); + } + other => panic!("Expected RetriableError for unconfigured 5xx, got {:?}", other), + } + } + + #[test] + fn classify_3xx_returns_non_retriable() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let resp = make_response(301); + let result = detector.classify(Ok(resp), &policy, 0, 0); + assert!(matches!(result, ErrorClassification::NonRetriableError(_))); + } + + #[test] + fn classify_1xx_returns_non_retriable() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let resp = make_response(100); + let result = detector.classify(Ok(resp), &policy, 0, 0); + assert!(matches!(result, ErrorClassification::NonRetriableError(_))); + } + + #[test] + fn classify_timeout_returns_timeout_error() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let timeout = TimeoutError { duration_ms: 5000 }; + let result = detector.classify(Err(timeout), &policy, 0, 0); + match result { + ErrorClassification::TimeoutError { duration_ms } => { + assert_eq!(duration_ms, 5000); + } + other => panic!("Expected TimeoutError, got {:?}", other), + } + } + + #[test] + fn classify_extracts_retry_after_header() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let resp = make_response_with_headers(429, vec![("retry-after", "120")]); + let result = detector.classify(Ok(resp), &policy, 0, 0); + match result { + ErrorClassification::RetriableError { + retry_after_seconds, .. + } => { + assert_eq!(retry_after_seconds, Some(120)); + } + other => panic!("Expected RetriableError, got {:?}", other), + } + } + + #[test] + fn classify_ignores_malformed_retry_after() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let resp = make_response_with_headers(429, vec![("retry-after", "not-a-number")]); + let result = detector.classify(Ok(resp), &policy, 0, 0); + match result { + ErrorClassification::RetriableError { + retry_after_seconds, .. + } => { + assert_eq!(retry_after_seconds, None); + } + other => panic!("Expected RetriableError, got {:?}", other), + } + } + + #[test] + fn classify_status_code_range() { + let detector = ErrorDetector; + let policy = RetryPolicy { + on_status_codes: vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Range("500-504".to_string())], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 3, + }], + ..basic_retry_policy() + }; + // 502 is within the range + let resp = make_response(502); + let result = detector.classify(Ok(resp), &policy, 0, 0); + match result { + ErrorClassification::RetriableError { status_code, .. } => { + assert_eq!(status_code, 502); + } + other => panic!("Expected RetriableError, got {:?}", other), + } + } + + // ── resolve_retry_params tests ───────────────────────────────────── + + #[test] + fn resolve_params_for_configured_status_code() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let classification = ErrorClassification::RetriableError { + status_code: 429, + retry_after_seconds: None, + response_body: vec![], + }; + let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy); + assert_eq!(strategy, RetryStrategy::SameProvider); + assert_eq!(max_attempts, 3); + } + + #[test] + fn resolve_params_for_unconfigured_status_code_uses_defaults() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let classification = ErrorClassification::RetriableError { + status_code: 400, + retry_after_seconds: None, + response_body: vec![], + }; + let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy); + assert_eq!(strategy, RetryStrategy::DifferentProvider); + assert_eq!(max_attempts, 2); + } + + #[test] + fn resolve_params_for_timeout_with_config() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let classification = ErrorClassification::TimeoutError { duration_ms: 5000 }; + let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy); + assert_eq!(strategy, RetryStrategy::DifferentProvider); + assert_eq!(max_attempts, 2); + } + + #[test] + fn resolve_params_for_timeout_without_config_uses_defaults() { + let detector = ErrorDetector; + let mut policy = basic_retry_policy(); + policy.on_timeout = None; + let classification = ErrorClassification::TimeoutError { duration_ms: 5000 }; + let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy); + assert_eq!(strategy, RetryStrategy::DifferentProvider); + assert_eq!(max_attempts, 2); + } + + #[test] + fn resolve_params_for_high_latency_with_config() { + let detector = ErrorDetector; + let mut policy = basic_retry_policy(); + policy.on_high_latency = Some(crate::configuration::HighLatencyConfig { + threshold_ms: 5000, + measure: LatencyMeasure::Ttfb, + min_triggers: 1, + trigger_window_seconds: None, + strategy: RetryStrategy::SameProvider, + max_attempts: 5, + block_duration_seconds: 300, + scope: crate::configuration::BlockScope::Model, + apply_to: crate::configuration::ApplyTo::Global, + }); + let classification = ErrorClassification::HighLatencyEvent { + measured_ms: 6000, + threshold_ms: 5000, + measure: LatencyMeasure::Ttfb, + response: None, + }; + let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy); + assert_eq!(strategy, RetryStrategy::SameProvider); + assert_eq!(max_attempts, 5); + } + + #[test] + fn resolve_params_for_success_returns_defaults() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let resp = make_response(200); + let classification = ErrorClassification::Success(resp); + let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy); + // Shouldn't normally be called for Success, but returns defaults safely + assert_eq!(strategy, RetryStrategy::DifferentProvider); + assert_eq!(max_attempts, 2); + } + + #[test] + fn resolve_params_second_on_status_codes_entry() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); + let classification = ErrorClassification::RetriableError { + status_code: 503, + retry_after_seconds: None, + response_body: vec![], + }; + let (strategy, max_attempts) = detector.resolve_retry_params(&classification, &policy); + assert_eq!(strategy, RetryStrategy::DifferentProvider); + assert_eq!(max_attempts, 4); + } + + // ── High latency classification tests ───────────────────────────── + + fn high_latency_retry_policy(threshold_ms: u64, measure: LatencyMeasure) -> RetryPolicy { + let mut policy = basic_retry_policy(); + policy.on_high_latency = Some(crate::configuration::HighLatencyConfig { + threshold_ms, + measure, + min_triggers: 1, + trigger_window_seconds: None, + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + block_duration_seconds: 300, + scope: crate::configuration::BlockScope::Model, + apply_to: crate::configuration::ApplyTo::Global, + }); + policy + } + + #[test] + fn classify_2xx_high_latency_ttfb_returns_high_latency_event() { + let detector = ErrorDetector; + let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb); + let resp = make_response(200); + // TTFB = 6000ms exceeds threshold of 5000ms + let result = detector.classify(Ok(resp), &policy, 6000, 7000); + match result { + ErrorClassification::HighLatencyEvent { + measured_ms, + threshold_ms, + measure, + response, + } => { + assert_eq!(measured_ms, 6000); + assert_eq!(threshold_ms, 5000); + assert_eq!(measure, LatencyMeasure::Ttfb); + assert!(response.is_some(), "Completed response should be present"); + } + other => panic!("Expected HighLatencyEvent, got {:?}", other), + } + } + + #[test] + fn classify_2xx_high_latency_total_returns_high_latency_event() { + let detector = ErrorDetector; + let policy = high_latency_retry_policy(5000, LatencyMeasure::Total); + let resp = make_response(200); + // Total = 8000ms exceeds threshold, TTFB = 3000ms does not + let result = detector.classify(Ok(resp), &policy, 3000, 8000); + match result { + ErrorClassification::HighLatencyEvent { + measured_ms, + threshold_ms, + measure, + .. + } => { + assert_eq!(measured_ms, 8000); + assert_eq!(threshold_ms, 5000); + assert_eq!(measure, LatencyMeasure::Total); + } + other => panic!("Expected HighLatencyEvent, got {:?}", other), + } + } + + #[test] + fn classify_2xx_below_threshold_returns_success() { + let detector = ErrorDetector; + let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb); + let resp = make_response(200); + // TTFB = 3000ms is below threshold of 5000ms + let result = detector.classify(Ok(resp), &policy, 3000, 4000); + assert!(matches!(result, ErrorClassification::Success(_))); + } + + #[test] + fn classify_2xx_at_threshold_returns_success() { + let detector = ErrorDetector; + let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb); + let resp = make_response(200); + // TTFB = 5000ms equals threshold — not exceeded + let result = detector.classify(Ok(resp), &policy, 5000, 6000); + assert!(matches!(result, ErrorClassification::Success(_))); + } + + #[test] + fn classify_2xx_no_high_latency_config_returns_success() { + let detector = ErrorDetector; + let policy = basic_retry_policy(); // no on_high_latency + let resp = make_response(200); + // High latency values but no config → Success + let result = detector.classify(Ok(resp), &policy, 99999, 99999); + assert!(matches!(result, ErrorClassification::Success(_))); + } + + #[test] + fn classify_timeout_takes_priority_over_high_latency() { + let detector = ErrorDetector; + let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb); + let timeout = TimeoutError { duration_ms: 10000 }; + // Even with high latency config, timeout returns TimeoutError + let result = detector.classify(Err(timeout), &policy, 10000, 10000); + match result { + ErrorClassification::TimeoutError { duration_ms } => { + assert_eq!(duration_ms, 10000); + } + other => panic!("Expected TimeoutError, got {:?}", other), + } + } + + #[test] + fn classify_4xx_not_affected_by_high_latency() { + let detector = ErrorDetector; + let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb); + let resp = make_response(429); + // Even with high latency, 4xx is still RetriableError + let result = detector.classify(Ok(resp), &policy, 6000, 7000); + assert!(matches!( + result, + ErrorClassification::RetriableError { status_code: 429, .. } + )); + } + + // ── P2 Edge Case: measure-specific classification tests ──────────── + + #[test] + fn classify_ttfb_measure_triggers_on_slow_ttfb_even_if_total_is_fast() { + let detector = ErrorDetector; + // measure: ttfb, threshold: 5000ms + let policy = high_latency_retry_policy(5000, LatencyMeasure::Ttfb); + let resp = make_response(200); + // TTFB = 6000ms exceeds threshold, but total = 4000ms is below threshold + let result = detector.classify(Ok(resp), &policy, 6000, 4000); + match result { + ErrorClassification::HighLatencyEvent { + measured_ms, + threshold_ms, + measure, + response, + } => { + assert_eq!(measured_ms, 6000, "Should measure TTFB, not total"); + assert_eq!(threshold_ms, 5000); + assert_eq!(measure, LatencyMeasure::Ttfb); + assert!(response.is_some(), "Completed response should be present"); + } + other => panic!("Expected HighLatencyEvent for slow TTFB, got {:?}", other), + } + } + + #[test] + fn classify_total_measure_does_not_trigger_when_only_ttfb_is_slow() { + let detector = ErrorDetector; + // measure: total, threshold: 5000ms + let policy = high_latency_retry_policy(5000, LatencyMeasure::Total); + let resp = make_response(200); + // TTFB = 8000ms is slow, but total = 4000ms is below threshold + // With measure: "total", only total time matters + let result = detector.classify(Ok(resp), &policy, 8000, 4000); + assert!( + matches!(result, ErrorClassification::Success(_)), + "measure: total should NOT trigger when only TTFB is slow but total is below threshold, got {:?}", + result + ); + } + + #[test] + fn classify_total_measure_triggers_on_slow_total_even_if_ttfb_is_fast() { + let detector = ErrorDetector; + // measure: total, threshold: 5000ms + let policy = high_latency_retry_policy(5000, LatencyMeasure::Total); + let resp = make_response(200); + // TTFB = 1000ms is fast, total = 7000ms exceeds threshold + let result = detector.classify(Ok(resp), &policy, 1000, 7000); + match result { + ErrorClassification::HighLatencyEvent { + measured_ms, + threshold_ms, + measure, + response, + } => { + assert_eq!(measured_ms, 7000, "Should measure total, not TTFB"); + assert_eq!(threshold_ms, 5000); + assert_eq!(measure, LatencyMeasure::Total); + assert!(response.is_some(), "Completed response should be present"); + } + other => panic!("Expected HighLatencyEvent for slow total, got {:?}", other), + } + } + + + // ── Property-based tests ─────────────────────────────────────────── + + use proptest::prelude::*; + + /// Generate an arbitrary RetryStrategy. + fn arb_retry_strategy() -> impl Strategy { + prop_oneof![ + Just(RetryStrategy::SameModel), + Just(RetryStrategy::SameProvider), + Just(RetryStrategy::DifferentProvider), + ] + } + + /// Generate an arbitrary StatusCodeEntry (single code in 100-599). + fn arb_status_code_entry() -> impl Strategy { + (100u16..=599u16).prop_map(StatusCodeEntry::Single) + } + + /// Generate an arbitrary StatusCodeConfig with 1-5 single status code entries. + fn arb_status_code_config() -> impl Strategy { + ( + proptest::collection::vec(arb_status_code_entry(), 1..=5), + arb_retry_strategy(), + 1u32..=10u32, + ) + .prop_map(|(codes, strategy, max_attempts)| StatusCodeConfig { + codes, + strategy, + max_attempts, + }) + } + + /// Generate an arbitrary RetryPolicy with 0-3 on_status_codes entries. + fn arb_retry_policy() -> impl Strategy { + ( + arb_retry_strategy(), + 1u32..=10u32, + proptest::collection::vec(arb_status_code_config(), 0..=3), + ) + .prop_map(|(default_strategy, default_max_attempts, on_status_codes)| { + RetryPolicy { + fallback_models: vec![], + default_strategy, + default_max_attempts, + on_status_codes, + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + } + }) + } + + // Feature: retry-on-ratelimit, Property 5: Error Classification Correctness + // **Validates: Requirements 1.2** + proptest! { + #![proptest_config(proptest::prelude::ProptestConfig::with_cases(100))] + + /// Property 5: For any status code in 100-599 and any RetryPolicy, + /// classify() returns the correct variant: + /// 2xx → Success + /// 4xx/5xx → RetriableError with matching status_code + /// 1xx/3xx → NonRetriableError + #[test] + fn prop_error_classification_correctness( + status_code in 100u16..=599u16, + policy in arb_retry_policy(), + ) { + let detector = ErrorDetector; + let resp = make_response(status_code); + let result = detector.classify(Ok(resp), &policy, 0, 0); + + match status_code { + 200..=299 => { + prop_assert!( + matches!(result, ErrorClassification::Success(_)), + "Expected Success for status {}, got {:?}", status_code, result + ); + } + 400..=499 | 500..=599 => { + match &result { + ErrorClassification::RetriableError { status_code: sc, .. } => { + prop_assert_eq!( + *sc, status_code, + "RetriableError status_code mismatch: expected {}, got {}", status_code, sc + ); + } + other => { + prop_assert!(false, "Expected RetriableError for status {}, got {:?}", status_code, other); + } + } + } + 100..=199 | 300..=399 => { + prop_assert!( + matches!(result, ErrorClassification::NonRetriableError(_)), + "Expected NonRetriableError for status {}, got {:?}", status_code, result + ); + } + _ => { + // Should not happen given our range 100-599 + prop_assert!(false, "Unexpected status code: {}", status_code); + } + } + } + } + + // Feature: retry-on-ratelimit, Property 17: Timeout vs High Latency Precedence + // **Validates: Requirements 2.13, 2.14, 2.15, 2a.19, 2a.20** + proptest! { + #![proptest_config(proptest::prelude::ProptestConfig::with_cases(100))] + + /// Property 17: When both on_high_latency and on_timeout are configured: + /// - Timeout (Err) → always TimeoutError regardless of latency config + /// - Completed 2xx exceeding threshold → HighLatencyEvent with response present + /// - Completed 2xx below/at threshold → Success + #[test] + fn prop_timeout_vs_high_latency_precedence( + threshold_ms in 1u64..=30_000u64, + elapsed_ttfb_ms in 0u64..=60_000u64, + elapsed_total_ms in 0u64..=60_000u64, + timeout_duration_ms in 1u64..=60_000u64, + measure_is_ttfb in proptest::bool::ANY, + // 0 = timeout scenario, 1 = completed-above-threshold, 2 = completed-below-threshold + scenario in 0u8..=2u8, + ) { + let measure = if measure_is_ttfb { LatencyMeasure::Ttfb } else { LatencyMeasure::Total }; + + let mut policy = basic_retry_policy(); + policy.on_high_latency = Some(crate::configuration::HighLatencyConfig { + threshold_ms, + measure, + min_triggers: 1, + trigger_window_seconds: None, + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + block_duration_seconds: 300, + scope: crate::configuration::BlockScope::Model, + apply_to: crate::configuration::ApplyTo::Global, + }); + // Ensure on_timeout is configured + policy.on_timeout = Some(TimeoutRetryConfig { + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }); + + let detector = ErrorDetector; + + match scenario { + 0 => { + // Timeout scenario: Err(TimeoutError) → always TimeoutError + let timeout = TimeoutError { duration_ms: timeout_duration_ms }; + let result = detector.classify(Err(timeout), &policy, elapsed_ttfb_ms, elapsed_total_ms); + match result { + ErrorClassification::TimeoutError { duration_ms } => { + prop_assert_eq!(duration_ms, timeout_duration_ms, + "TimeoutError duration should match input"); + } + other => { + prop_assert!(false, + "Timeout should always produce TimeoutError, got {:?}", other); + } + } + } + 1 => { + // Completed 2xx with latency ABOVE threshold → HighLatencyEvent + // Force the measured value to exceed threshold + let forced_ttfb = if measure_is_ttfb { threshold_ms + 1 + (elapsed_ttfb_ms % 30_000) } else { elapsed_ttfb_ms }; + let forced_total = if !measure_is_ttfb { threshold_ms + 1 + (elapsed_total_ms % 30_000) } else { elapsed_total_ms }; + + let resp = make_response(200); + let result = detector.classify(Ok(resp), &policy, forced_ttfb, forced_total); + match result { + ErrorClassification::HighLatencyEvent { + measured_ms: actual_ms, + threshold_ms: actual_threshold, + measure: actual_measure, + response, + } => { + let expected_measured = if measure_is_ttfb { forced_ttfb } else { forced_total }; + prop_assert_eq!(actual_ms, expected_measured, + "HighLatencyEvent measured_ms should match the selected measure"); + prop_assert_eq!(actual_threshold, threshold_ms, + "HighLatencyEvent threshold_ms should match config"); + prop_assert_eq!(actual_measure, measure, + "HighLatencyEvent measure should match config"); + prop_assert!(response.is_some(), + "Completed response should be present in HighLatencyEvent"); + } + other => { + prop_assert!(false, + "Completed 2xx above threshold should produce HighLatencyEvent, got {:?}", other); + } + } + } + 2 => { + // Completed 2xx with latency AT or BELOW threshold → Success + // Force the measured value to be at or below threshold + let forced_ttfb = if measure_is_ttfb { threshold_ms.min(elapsed_ttfb_ms) } else { elapsed_ttfb_ms }; + let forced_total = if !measure_is_ttfb { threshold_ms.min(elapsed_total_ms) } else { elapsed_total_ms }; + + let resp = make_response(200); + let result = detector.classify(Ok(resp), &policy, forced_ttfb, forced_total); + prop_assert!( + matches!(result, ErrorClassification::Success(_)), + "Completed 2xx at/below threshold should be Success, got {:?}", result + ); + } + _ => {} // unreachable given range 0..=2 + } + } + } + +} diff --git a/crates/common/src/retry/error_response.rs b/crates/common/src/retry/error_response.rs index 7b764d11..6a5a6449 100644 --- a/crates/common/src/retry/error_response.rs +++ b/crates/common/src/retry/error_response.rs @@ -130,3 +130,472 @@ fn build_message(error: &RetryExhaustedError) -> String { } } +#[cfg(test)] +mod tests { + use super::*; + use crate::retry::{AttemptError, AttemptErrorType, RetryExhaustedError}; + use http_body_util::BodyExt; + use proptest::prelude::*; + + /// Helper to extract the JSON body from a response. + async fn response_json(resp: Response>) -> serde_json::Value { + let body = resp.into_body().collect().await.unwrap().to_bytes(); + serde_json::from_slice(&body).unwrap() + } + + #[tokio::test] + async fn test_basic_http_error_response() { + let error = RetryExhaustedError { + attempts: vec![ + AttemptError { + model_id: "openai/gpt-4o".to_string(), + error_type: AttemptErrorType::HttpError { + status_code: 429, + body: b"rate limited".to_vec(), + }, + attempt_number: 1, + }, + AttemptError { + model_id: "anthropic/claude-3-5-sonnet".to_string(), + error_type: AttemptErrorType::HttpError { + status_code: 503, + body: b"unavailable".to_vec(), + }, + attempt_number: 2, + }, + ], + max_retry_after_seconds: Some(30), + shortest_remaining_block_seconds: Some(12), + retry_budget_exhausted: false, + }; + + let resp = build_error_response(&error, "req-123"); + assert_eq!(resp.status().as_u16(), 503); // most recent error + assert_eq!( + resp.headers().get("x-request-id").unwrap().to_str().unwrap(), + "req-123" + ); + assert_eq!( + resp.headers().get("content-type").unwrap().to_str().unwrap(), + "application/json" + ); + + let json = response_json(resp).await; + let err = &json["error"]; + assert_eq!(err["type"], "retry_exhausted"); + assert_eq!(err["total_attempts"], 2); + assert_eq!(err["observed_max_retry_after_seconds"], 30); + assert_eq!(err["shortest_remaining_block_seconds"], 12); + assert_eq!(err["retry_budget_exhausted"], false); + + let attempts = err["attempts"].as_array().unwrap(); + assert_eq!(attempts.len(), 2); + assert_eq!(attempts[0]["model"], "openai/gpt-4o"); + assert_eq!(attempts[0]["error_type"], "http_429"); + assert_eq!(attempts[0]["attempt"], 1); + assert_eq!(attempts[1]["model"], "anthropic/claude-3-5-sonnet"); + assert_eq!(attempts[1]["error_type"], "http_503"); + assert_eq!(attempts[1]["attempt"], 2); + } + + #[tokio::test] + async fn test_timeout_returns_504() { + let error = RetryExhaustedError { + attempts: vec![AttemptError { + model_id: "openai/gpt-4o".to_string(), + error_type: AttemptErrorType::Timeout { duration_ms: 30000 }, + attempt_number: 1, + }], + max_retry_after_seconds: None, + shortest_remaining_block_seconds: None, + retry_budget_exhausted: false, + }; + + let resp = build_error_response(&error, "req-timeout"); + assert_eq!(resp.status().as_u16(), 504); + + let json = response_json(resp).await; + let err = &json["error"]; + assert_eq!(err["attempts"][0]["error_type"], "timeout_30000ms"); + assert!(err["message"] + .as_str() + .unwrap() + .contains("timed out")); + } + + #[tokio::test] + async fn test_high_latency_returns_504() { + let error = RetryExhaustedError { + attempts: vec![AttemptError { + model_id: "openai/gpt-4o".to_string(), + error_type: AttemptErrorType::HighLatency { + measured_ms: 8000, + threshold_ms: 5000, + }, + attempt_number: 1, + }], + max_retry_after_seconds: None, + shortest_remaining_block_seconds: None, + retry_budget_exhausted: false, + }; + + let resp = build_error_response(&error, "req-latency"); + assert_eq!(resp.status().as_u16(), 504); + + let json = response_json(resp).await; + let err = &json["error"]; + assert_eq!( + err["attempts"][0]["error_type"], + "high_latency_8000ms_threshold_5000ms" + ); + assert!(err["message"] + .as_str() + .unwrap() + .contains("high latency")); + } + + #[tokio::test] + async fn test_optional_fields_omitted_when_none() { + let error = RetryExhaustedError { + attempts: vec![AttemptError { + model_id: "openai/gpt-4o".to_string(), + error_type: AttemptErrorType::HttpError { + status_code: 429, + body: vec![], + }, + attempt_number: 1, + }], + max_retry_after_seconds: None, + shortest_remaining_block_seconds: None, + retry_budget_exhausted: false, + }; + + let resp = build_error_response(&error, "req-456"); + let json = response_json(resp).await; + let err = &json["error"]; + + // These fields should not be present + assert!(err.get("observed_max_retry_after_seconds").is_none()); + assert!(err.get("shortest_remaining_block_seconds").is_none()); + + // These should always be present + assert!(err.get("retry_budget_exhausted").is_some()); + assert!(err.get("total_attempts").is_some()); + assert!(err.get("type").is_some()); + assert!(err.get("message").is_some()); + assert!(err.get("attempts").is_some()); + } + + #[tokio::test] + async fn test_retry_budget_exhausted_message() { + let error = RetryExhaustedError { + attempts: vec![AttemptError { + model_id: "openai/gpt-4o".to_string(), + error_type: AttemptErrorType::HttpError { + status_code: 429, + body: vec![], + }, + attempt_number: 1, + }], + max_retry_after_seconds: None, + shortest_remaining_block_seconds: None, + retry_budget_exhausted: true, + }; + + let resp = build_error_response(&error, "req-budget"); + let json = response_json(resp).await; + let err = &json["error"]; + assert_eq!(err["retry_budget_exhausted"], true); + assert!(err["message"] + .as_str() + .unwrap() + .contains("budget exceeded")); + } + + #[tokio::test] + async fn test_empty_attempts_returns_502() { + let error = RetryExhaustedError { + attempts: vec![], + max_retry_after_seconds: None, + shortest_remaining_block_seconds: None, + retry_budget_exhausted: false, + }; + + let resp = build_error_response(&error, "req-empty"); + assert_eq!(resp.status().as_u16(), 502); + + let json = response_json(resp).await; + assert_eq!(json["error"]["total_attempts"], 0); + assert_eq!(json["error"]["attempts"].as_array().unwrap().len(), 0); + } + + #[tokio::test] + async fn test_request_id_preserved_in_header() { + let error = RetryExhaustedError { + attempts: vec![AttemptError { + model_id: "m".to_string(), + error_type: AttemptErrorType::HttpError { + status_code: 500, + body: vec![], + }, + attempt_number: 1, + }], + max_retry_after_seconds: None, + shortest_remaining_block_seconds: None, + retry_budget_exhausted: false, + }; + + let resp = build_error_response(&error, "unique-request-id-abc-123"); + assert_eq!( + resp.headers() + .get("x-request-id") + .unwrap() + .to_str() + .unwrap(), + "unique-request-id-abc-123" + ); + } + + #[tokio::test] + async fn test_mixed_error_types_in_attempts() { + let error = RetryExhaustedError { + attempts: vec![ + AttemptError { + model_id: "openai/gpt-4o".to_string(), + error_type: AttemptErrorType::HttpError { + status_code: 429, + body: vec![], + }, + attempt_number: 1, + }, + AttemptError { + model_id: "anthropic/claude".to_string(), + error_type: AttemptErrorType::Timeout { duration_ms: 5000 }, + attempt_number: 2, + }, + AttemptError { + model_id: "gemini/pro".to_string(), + error_type: AttemptErrorType::HighLatency { + measured_ms: 10000, + threshold_ms: 3000, + }, + attempt_number: 3, + }, + ], + max_retry_after_seconds: Some(60), + shortest_remaining_block_seconds: Some(5), + retry_budget_exhausted: false, + }; + + // Last attempt is HighLatency → 504 + let resp = build_error_response(&error, "req-mixed"); + assert_eq!(resp.status().as_u16(), 504); + + let json = response_json(resp).await; + let err = &json["error"]; + assert_eq!(err["total_attempts"], 3); + assert_eq!(err["observed_max_retry_after_seconds"], 60); + assert_eq!(err["shortest_remaining_block_seconds"], 5); + + let attempts = err["attempts"].as_array().unwrap(); + assert_eq!(attempts[0]["error_type"], "http_429"); + assert_eq!(attempts[1]["error_type"], "timeout_5000ms"); + assert_eq!(attempts[2]["error_type"], "high_latency_10000ms_threshold_3000ms"); + } + + // ── Proptest strategies ──────────────────────────────────────────────── + + /// Generate an arbitrary AttemptErrorType. + fn arb_attempt_error_type() -> impl Strategy { + prop_oneof![ + (100u16..=599u16, proptest::collection::vec(any::(), 0..32)) + .prop_map(|(status_code, body)| AttemptErrorType::HttpError { status_code, body }), + (1u64..=120_000u64) + .prop_map(|duration_ms| AttemptErrorType::Timeout { duration_ms }), + (1u64..=120_000u64, 1u64..=120_000u64) + .prop_map(|(measured_ms, threshold_ms)| AttemptErrorType::HighLatency { + measured_ms, + threshold_ms, + }), + ] + } + + /// Generate an arbitrary AttemptError with a model_id from a small set of + /// realistic provider/model identifiers. + fn arb_attempt_error() -> impl Strategy { + let model_ids = prop_oneof![ + Just("openai/gpt-4o".to_string()), + Just("openai/gpt-4o-mini".to_string()), + Just("anthropic/claude-3-5-sonnet".to_string()), + Just("gemini/pro".to_string()), + Just("azure/gpt-4o".to_string()), + ]; + (model_ids, arb_attempt_error_type(), 1u32..=10u32).prop_map( + |(model_id, error_type, attempt_number)| AttemptError { + model_id, + error_type, + attempt_number, + }, + ) + } + + /// Generate an arbitrary RetryExhaustedError with 1..=8 attempts. + fn arb_retry_exhausted_error() -> impl Strategy { + ( + proptest::collection::vec(arb_attempt_error(), 1..=8), + proptest::option::of(1u64..=600u64), + proptest::option::of(1u64..=600u64), + any::(), + ) + .prop_map( + |(attempts, max_retry_after_seconds, shortest_remaining_block_seconds, retry_budget_exhausted)| { + RetryExhaustedError { + attempts, + max_retry_after_seconds, + shortest_remaining_block_seconds, + retry_budget_exhausted, + } + }, + ) + } + + /// Generate an arbitrary request_id (non-empty ASCII string valid for HTTP headers). + fn arb_request_id() -> impl Strategy { + "[a-zA-Z0-9_-]{1,64}" + } + + // Feature: retry-on-ratelimit, Property 21: Error Response Contains Attempt Details + // **Validates: Requirements 10.4, 10.5, 10.7** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 21: For any exhausted retry sequence, the error response + /// must include all attempted model identifiers and their error types, + /// and must preserve the original request_id. + #[test] + fn prop_error_response_contains_attempt_details( + error in arb_retry_exhausted_error(), + request_id in arb_request_id(), + ) { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + rt.block_on(async { + let resp = build_error_response(&error, &request_id); + + // request_id preserved in x-request-id header + let header_val = resp.headers().get("x-request-id") + .expect("x-request-id header must be present"); + prop_assert_eq!(header_val.to_str().unwrap(), request_id.as_str()); + + // Content-Type is application/json + let ct = resp.headers().get("content-type") + .expect("content-type header must be present"); + prop_assert_eq!(ct.to_str().unwrap(), "application/json"); + + // Parse JSON body + let body = resp.into_body().collect().await.unwrap().to_bytes(); + let json: serde_json::Value = serde_json::from_slice(&body) + .expect("response body must be valid JSON"); + + let err_obj = &json["error"]; + + // type is always "retry_exhausted" + prop_assert_eq!(err_obj["type"].as_str().unwrap(), "retry_exhausted"); + + // total_attempts matches input + prop_assert_eq!( + err_obj["total_attempts"].as_u64().unwrap(), + error.attempts.len() as u64 + ); + + // retry_budget_exhausted matches input + prop_assert_eq!( + err_obj["retry_budget_exhausted"].as_bool().unwrap(), + error.retry_budget_exhausted + ); + + // attempts array has correct length + let attempts_arr = err_obj["attempts"].as_array() + .expect("attempts must be an array"); + prop_assert_eq!(attempts_arr.len(), error.attempts.len()); + + // Every attempt's model_id and error_type are present and correct + for (i, attempt) in error.attempts.iter().enumerate() { + let json_attempt = &attempts_arr[i]; + + // model_id preserved + prop_assert_eq!( + json_attempt["model"].as_str().unwrap(), + attempt.model_id.as_str() + ); + + // attempt_number preserved + prop_assert_eq!( + json_attempt["attempt"].as_u64().unwrap(), + attempt.attempt_number as u64 + ); + + // error_type string matches the variant + let error_type_str = json_attempt["error_type"].as_str().unwrap(); + match &attempt.error_type { + AttemptErrorType::HttpError { status_code, .. } => { + prop_assert_eq!( + error_type_str, + &format!("http_{}", status_code) + ); + } + AttemptErrorType::Timeout { duration_ms } => { + prop_assert_eq!( + error_type_str, + &format!("timeout_{}ms", duration_ms) + ); + } + AttemptErrorType::HighLatency { measured_ms, threshold_ms } => { + prop_assert_eq!( + error_type_str, + &format!("high_latency_{}ms_threshold_{}ms", measured_ms, threshold_ms) + ); + } + } + } + + // Optional fields: observed_max_retry_after_seconds + match error.max_retry_after_seconds { + Some(v) => { + prop_assert_eq!( + err_obj["observed_max_retry_after_seconds"].as_u64().unwrap(), + v + ); + } + None => { + prop_assert!(err_obj.get("observed_max_retry_after_seconds").is_none() + || err_obj["observed_max_retry_after_seconds"].is_null()); + } + } + + // Optional fields: shortest_remaining_block_seconds + match error.shortest_remaining_block_seconds { + Some(v) => { + prop_assert_eq!( + err_obj["shortest_remaining_block_seconds"].as_u64().unwrap(), + v + ); + } + None => { + prop_assert!(err_obj.get("shortest_remaining_block_seconds").is_none() + || err_obj["shortest_remaining_block_seconds"].is_null()); + } + } + + // message is a non-empty string + let message = err_obj["message"].as_str() + .expect("message must be a string"); + prop_assert!(!message.is_empty()); + + Ok(()) + })?; + } + } +} diff --git a/crates/common/src/retry/latency_block_state.rs b/crates/common/src/retry/latency_block_state.rs index 60dec185..d2add5d9 100644 --- a/crates/common/src/retry/latency_block_state.rs +++ b/crates/common/src/retry/latency_block_state.rs @@ -118,3 +118,265 @@ impl Default for LatencyBlockStateManager { } } +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + use std::time::Duration; + + #[test] + fn test_new_manager_has_no_blocks() { + let mgr = LatencyBlockStateManager::new(); + assert!(!mgr.is_blocked("openai/gpt-4o")); + assert!(mgr.remaining_block_duration("openai/gpt-4o").is_none()); + } + + #[test] + fn test_record_block_and_is_blocked() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 60, 5500); + assert!(mgr.is_blocked("openai/gpt-4o")); + assert!(!mgr.is_blocked("anthropic/claude")); + } + + #[test] + fn test_remaining_block_duration() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 10, 5000); + let remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap(); + assert!(remaining <= Duration::from_secs(11)); + assert!(remaining > Duration::from_secs(8)); + } + + #[test] + fn test_expired_entry_cleaned_up_on_is_blocked() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 0, 5000); + thread::sleep(Duration::from_millis(10)); + assert!(!mgr.is_blocked("openai/gpt-4o")); + } + + #[test] + fn test_expired_entry_cleaned_up_on_remaining() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 0, 5000); + thread::sleep(Duration::from_millis(10)); + assert!(mgr.remaining_block_duration("openai/gpt-4o").is_none()); + } + + #[test] + fn test_max_expiration_semantics_longer_wins() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 10, 5000); + let first_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap(); + + mgr.record_block("openai/gpt-4o", 60, 6000); + let second_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap(); + assert!(second_remaining > first_remaining); + } + + #[test] + fn test_max_expiration_semantics_shorter_does_not_overwrite() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 60, 5000); + let first_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap(); + + mgr.record_block("openai/gpt-4o", 5, 6000); + let second_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap(); + // Should still be close to the original 60s + assert!(second_remaining > Duration::from_secs(50)); + let diff = if first_remaining > second_remaining { + first_remaining - second_remaining + } else { + second_remaining - first_remaining + }; + assert!(diff < Duration::from_secs(2)); + } + + #[test] + fn test_is_model_blocked_model_scope() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 60, 5000); + + assert!(mgr.is_model_blocked("openai/gpt-4o", BlockScope::Model)); + assert!(!mgr.is_model_blocked("openai/gpt-4o-mini", BlockScope::Model)); + } + + #[test] + fn test_is_model_blocked_provider_scope() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai", 60, 5000); + + assert!(mgr.is_model_blocked("openai/gpt-4o", BlockScope::Provider)); + assert!(mgr.is_model_blocked("openai/gpt-4o-mini", BlockScope::Provider)); + assert!(!mgr.is_model_blocked("anthropic/claude", BlockScope::Provider)); + } + + #[test] + fn test_multiple_identifiers_independent() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 60, 5000); + mgr.record_block("anthropic/claude", 30, 4000); + + assert!(mgr.is_blocked("openai/gpt-4o")); + assert!(mgr.is_blocked("anthropic/claude")); + assert!(!mgr.is_blocked("azure/gpt-4o")); + } + + #[test] + fn test_record_block_stores_measured_latency() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 60, 5500); + + // Verify the entry exists and has the correct latency + let entry = mgr.global_state.get("openai/gpt-4o").unwrap(); + assert_eq!(entry.1, 5500); + } + + #[test] + fn test_latency_updated_when_expiration_extended() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 10, 5000); + + // Extend with longer duration and different latency + mgr.record_block("openai/gpt-4o", 60, 7000); + + let entry = mgr.global_state.get("openai/gpt-4o").unwrap(); + assert_eq!(entry.1, 7000); + } + + #[test] + fn test_latency_not_updated_when_expiration_not_extended() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 60, 5000); + + // Shorter duration — should NOT update + mgr.record_block("openai/gpt-4o", 5, 9000); + + let entry = mgr.global_state.get("openai/gpt-4o").unwrap(); + // Latency should remain 5000 since expiration wasn't extended + assert_eq!(entry.1, 5000); + } + + #[test] + fn test_zero_duration_block_expires_immediately() { + let mgr = LatencyBlockStateManager::new(); + mgr.record_block("openai/gpt-4o", 0, 5000); + thread::sleep(Duration::from_millis(5)); + assert!(!mgr.is_blocked("openai/gpt-4o")); + } + + #[test] + fn test_default_trait() { + let mgr = LatencyBlockStateManager::default(); + assert!(!mgr.is_blocked("anything")); + } + + // --- Property-based tests --- + + use proptest::prelude::*; + + fn arb_identifier() -> impl Strategy { + prop_oneof![ + "[a-z]{3,8}/[a-z0-9\\-]{3,12}".prop_map(|s| s), + "[a-z]{3,8}".prop_map(|s| s), + ] + } + + /// A single block recording: (block_duration_seconds, measured_latency_ms) + fn arb_block_recording() -> impl Strategy { + (1u64..=600, 100u64..=30_000) + } + + // Feature: retry-on-ratelimit, Property 22: Latency Block State Max Expiration Update + // **Validates: Requirements 14.15** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 22 – Case 1: After recording multiple blocks for the same identifier + /// with different durations, the remaining block duration reflects the maximum + /// duration recorded (max-expiration semantics). + #[test] + fn prop_latency_block_max_expiration_update( + identifier in arb_identifier(), + recordings in prop::collection::vec(arb_block_recording(), 2..=10), + ) { + let mgr = LatencyBlockStateManager::new(); + + for &(duration, latency) in &recordings { + mgr.record_block(&identifier, duration, latency); + } + + let max_duration = recordings.iter().map(|&(d, _)| d).max().unwrap(); + + // The identifier should still be blocked + let remaining = mgr.remaining_block_duration(&identifier); + prop_assert!( + remaining.is_some(), + "Identifier {} should be blocked after {} recordings (max_duration={}s)", + identifier, recordings.len(), max_duration + ); + + let remaining_secs = remaining.unwrap().as_secs(); + + // Remaining should be close to max_duration (allow 2s tolerance for execution time) + prop_assert!( + remaining_secs >= max_duration.saturating_sub(2), + "Remaining {}s should reflect the max duration ({}s), not a smaller value. Recordings: {:?}", + remaining_secs, max_duration, recordings + ); + + prop_assert!( + remaining_secs <= max_duration + 1, + "Remaining {}s should not exceed max duration {}s + tolerance. Recordings: {:?}", + remaining_secs, max_duration, recordings + ); + } + + /// Property 22 – Case 2: measured_latency_ms is updated when expiration is extended + /// but NOT when a shorter duration is recorded. + #[test] + fn prop_latency_block_measured_latency_update_semantics( + identifier in arb_identifier(), + first_duration in 10u64..=300, + first_latency in 100u64..=30_000, + extra_duration in 1u64..=300, + longer_latency in 100u64..=30_000, + shorter_duration in 1u64..=9, + shorter_latency in 100u64..=30_000, + ) { + let mgr = LatencyBlockStateManager::new(); + + // Record initial block + mgr.record_block(&identifier, first_duration, first_latency); + { + let entry = mgr.global_state.get(&identifier).unwrap(); + prop_assert_eq!(entry.1, first_latency); + } + + // Record a longer duration — latency SHOULD be updated + let longer_duration = first_duration + extra_duration; + mgr.record_block(&identifier, longer_duration, longer_latency); + { + let entry = mgr.global_state.get(&identifier).unwrap(); + prop_assert_eq!( + entry.1, longer_latency, + "Latency should be updated to {} when expiration is extended (duration {} > {})", + longer_latency, longer_duration, first_duration + ); + } + + // Record a shorter duration — latency should NOT be updated + mgr.record_block(&identifier, shorter_duration, shorter_latency); + { + let entry = mgr.global_state.get(&identifier).unwrap(); + prop_assert_eq!( + entry.1, longer_latency, + "Latency should remain {} (not {}) when shorter duration {} < {} doesn't extend expiration", + longer_latency, shorter_latency, shorter_duration, longer_duration + ); + } + } + } +} + diff --git a/crates/common/src/retry/latency_trigger.rs b/crates/common/src/retry/latency_trigger.rs index dab5ffc7..059dbad8 100644 --- a/crates/common/src/retry/latency_trigger.rs +++ b/crates/common/src/retry/latency_trigger.rs @@ -57,3 +57,175 @@ impl Default for LatencyTriggerCounter { } } +#[cfg(test)] +mod tests { + use super::*; + use std::thread::sleep; + use std::time::Duration; + + #[test] + fn test_record_event_returns_true_when_threshold_met() { + let counter = LatencyTriggerCounter::new(); + assert!(!counter.record_event("model-a", 3, 60)); + assert!(!counter.record_event("model-a", 3, 60)); + assert!(counter.record_event("model-a", 3, 60)); + } + + #[test] + fn test_record_event_single_trigger_always_fires() { + let counter = LatencyTriggerCounter::new(); + assert!(counter.record_event("model-a", 1, 60)); + } + + #[test] + fn test_events_expire_outside_window() { + let counter = LatencyTriggerCounter::new(); + // Record 2 events + counter.record_event("model-a", 3, 1); + counter.record_event("model-a", 3, 1); + // Wait for them to expire + sleep(Duration::from_millis(1100)); + // Third event should not meet threshold since previous two expired + assert!(!counter.record_event("model-a", 3, 1)); + } + + #[test] + fn test_reset_clears_counter() { + let counter = LatencyTriggerCounter::new(); + counter.record_event("model-a", 3, 60); + counter.record_event("model-a", 3, 60); + counter.reset("model-a"); + // After reset, need 3 fresh events again + assert!(!counter.record_event("model-a", 3, 60)); + assert!(!counter.record_event("model-a", 3, 60)); + assert!(counter.record_event("model-a", 3, 60)); + } + + #[test] + fn test_reset_nonexistent_identifier_is_noop() { + let counter = LatencyTriggerCounter::new(); + // Should not panic + counter.reset("nonexistent"); + } + + #[test] + fn test_separate_identifiers_are_independent() { + let counter = LatencyTriggerCounter::new(); + counter.record_event("model-a", 2, 60); + counter.record_event("model-b", 2, 60); + // model-a has 1 event, model-b has 1 event — neither at threshold of 2 + assert!(!counter.record_event("model-b", 3, 60)); + // model-a reaches threshold + assert!(counter.record_event("model-a", 2, 60)); + } + + #[test] + fn test_threshold_exceeded_still_returns_true() { + let counter = LatencyTriggerCounter::new(); + assert!(counter.record_event("model-a", 1, 60)); + // Already past threshold, still returns true + assert!(counter.record_event("model-a", 1, 60)); + assert!(counter.record_event("model-a", 1, 60)); + } + + // --- Property-based tests --- + + use proptest::prelude::*; + + // Feature: retry-on-ratelimit, Property 18: Latency Trigger Counter Sliding Window + // **Validates: Requirements 2a.6, 2a.7, 2a.8, 2a.21, 14.1, 14.2, 14.3, 14.12** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 18 – Case 1: Recording N events in quick succession (all within window) + /// returns true iff N >= min_triggers. + #[test] + fn prop_sliding_window_threshold( + min_triggers in 1u32..=10, + trigger_window_seconds in 1u64..=60, + num_events in 1u32..=20, + ) { + let counter = LatencyTriggerCounter::new(); + let identifier = "test-model"; + + let mut last_result = false; + for i in 1..=num_events { + last_result = counter.record_event(identifier, min_triggers, trigger_window_seconds); + // Before reaching threshold, should be false + if i < min_triggers { + prop_assert!(!last_result, "Expected false at event {} with min_triggers {}", i, min_triggers); + } else { + // At or past threshold, should be true + prop_assert!(last_result, "Expected true at event {} with min_triggers {}", i, min_triggers); + } + } + + // Final result should match whether we recorded enough events + prop_assert_eq!(last_result, num_events >= min_triggers); + } + + /// Property 18 – Case 2: After reset, counter starts fresh and previous events + /// do not count toward the threshold. + #[test] + fn prop_reset_clears_counter( + min_triggers in 2u32..=10, + trigger_window_seconds in 1u64..=60, + events_before_reset in 1u32..=10, + ) { + let counter = LatencyTriggerCounter::new(); + let identifier = "test-model"; + + // Record some events before reset + for _ in 0..events_before_reset { + counter.record_event(identifier, min_triggers, trigger_window_seconds); + } + + // Reset the counter + counter.reset(identifier); + + // After reset, a single event should not meet threshold (min_triggers >= 2) + let result = counter.record_event(identifier, min_triggers, trigger_window_seconds); + prop_assert!(!result, "After reset, first event should not meet threshold of {}", min_triggers); + + // Need min_triggers - 1 more events to reach threshold again + let mut final_result = result; + for _ in 1..min_triggers { + final_result = counter.record_event(identifier, min_triggers, trigger_window_seconds); + } + prop_assert!(final_result, "After reset + {} events, should meet threshold", min_triggers); + } + + /// Property 18 – Case 3: Different identifiers are independent — events for one + /// identifier do not affect the count for another. + #[test] + fn prop_identifiers_independent( + min_triggers in 1u32..=10, + trigger_window_seconds in 1u64..=60, + events_a in 1u32..=20, + events_b in 1u32..=20, + ) { + let counter = LatencyTriggerCounter::new(); + let id_a = "model-a"; + let id_b = "model-b"; + + // Record events for identifier A + let mut result_a = false; + for _ in 0..events_a { + result_a = counter.record_event(id_a, min_triggers, trigger_window_seconds); + } + + // Record events for identifier B + let mut result_b = false; + for _ in 0..events_b { + result_b = counter.record_event(id_b, min_triggers, trigger_window_seconds); + } + + // Each identifier's result depends only on its own event count + prop_assert_eq!(result_a, events_a >= min_triggers, + "id_a: events={}, min_triggers={}", events_a, min_triggers); + prop_assert_eq!(result_b, events_b >= min_triggers, + "id_b: events={}, min_triggers={}", events_b, min_triggers); + } + } + +} // mod tests \ No newline at end of file diff --git a/crates/common/src/retry/mod.rs b/crates/common/src/retry/mod.rs index e04b0ef3..3108fc68 100644 --- a/crates/common/src/retry/mod.rs +++ b/crates/common/src/retry/mod.rs @@ -331,3 +331,455 @@ pub enum ValidationWarning { } +#[cfg(test)] +mod tests { + use super::*; + use crate::configuration::{LlmProviderType, LlmProvider}; + use bytes::Bytes; + use hyper::header::{HeaderMap, HeaderValue, AUTHORIZATION}; + use proptest::prelude::*; + + fn make_provider(name: &str, interface: LlmProviderType, key: Option<&str>) -> LlmProvider { + LlmProvider { + name: name.to_string(), + provider_interface: interface, + access_key: key.map(|k| k.to_string()), + model: Some(name.to_string()), + default: None, + stream: None, + endpoint: None, + port: None, + rate_limits: None, + usage: None, + routing_preferences: None, + cluster_name: None, + base_url_path_prefix: None, + internal: None, + passthrough_auth: None, + retry_policy: None, + } + } + + // ── RequestSignature tests ───────────────────────────────────────── + + #[test] + fn test_request_signature_computes_hash() { + let body = b"hello world"; + let headers = HeaderMap::new(); + let sig = RequestSignature::new(body, &headers, false, "openai/gpt-4o".to_string()); + + // SHA-256 of "hello world" is deterministic + let mut hasher = Sha256::new(); + hasher.update(b"hello world"); + let expected: [u8; 32] = hasher.finalize().into(); + assert_eq!(sig.body_hash, expected); + assert!(!sig.streaming); + assert_eq!(sig.original_model, "openai/gpt-4o"); + } + + #[test] + fn test_request_signature_preserves_headers() { + let mut headers = HeaderMap::new(); + headers.insert("x-custom", HeaderValue::from_static("value")); + let sig = RequestSignature::new(b"body", &headers, true, "model".to_string()); + assert_eq!(sig.headers.get("x-custom").unwrap(), "value"); + assert!(sig.streaming); + } + + #[test] + fn test_request_signature_different_bodies_different_hashes() { + let headers = HeaderMap::new(); + let sig1 = RequestSignature::new(b"body1", &headers, false, "m".to_string()); + let sig2 = RequestSignature::new(b"body2", &headers, false, "m".to_string()); + assert_ne!(sig1.body_hash, sig2.body_hash); + } + + // ── RetryGate tests ──────────────────────────────────────────────── + + #[test] + fn test_retry_gate_default_permits() { + let gate = RetryGate::default(); + // Should be able to acquire at least one permit + assert!(gate.try_acquire().is_some()); + } + + #[test] + fn test_retry_gate_exhaustion() { + let gate = RetryGate::new(1); + let permit = gate.try_acquire(); + assert!(permit.is_some()); + // Second acquire should fail (only 1 permit) + assert!(gate.try_acquire().is_none()); + // Drop permit, should be able to acquire again + drop(permit); + assert!(gate.try_acquire().is_some()); + } + + #[test] + fn test_retry_gate_custom_capacity() { + let gate = RetryGate::new(3); + let _p1 = gate.try_acquire().unwrap(); + let _p2 = gate.try_acquire().unwrap(); + let _p3 = gate.try_acquire().unwrap(); + assert!(gate.try_acquire().is_none()); + } + + // ── rebuild_request_for_provider tests ───────────────────────────── + + #[test] + fn test_rebuild_updates_model_field() { + let body = Bytes::from(r#"{"model":"gpt-4o","messages":[]}"#); + let headers = HeaderMap::new(); + let provider = make_provider("openai/gpt-4o-mini", LlmProviderType::OpenAI, Some("sk-test")); + + let (new_body, _) = rebuild_request_for_provider(&body, &provider, &headers).unwrap(); + let json: serde_json::Value = serde_json::from_slice(&new_body).unwrap(); + assert_eq!(json["model"], "gpt-4o-mini"); + } + + #[test] + fn test_rebuild_preserves_other_fields() { + let body = Bytes::from(r#"{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}],"temperature":0.7}"#); + let headers = HeaderMap::new(); + let provider = make_provider("openai/gpt-4o-mini", LlmProviderType::OpenAI, Some("sk-test")); + + let (new_body, _) = rebuild_request_for_provider(&body, &provider, &headers).unwrap(); + let json: serde_json::Value = serde_json::from_slice(&new_body).unwrap(); + assert_eq!(json["messages"][0]["role"], "user"); + assert_eq!(json["messages"][0]["content"], "hi"); + assert_eq!(json["temperature"], 0.7); + } + + #[test] + fn test_rebuild_sets_openai_auth() { + let body = Bytes::from(r#"{"model":"old"}"#); + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-key")); + let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("sk-new")); + + let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap(); + assert_eq!( + new_headers.get(AUTHORIZATION).unwrap().to_str().unwrap(), + "Bearer sk-new" + ); + assert!(new_headers.get("x-api-key").is_none()); + } + + #[test] + fn test_rebuild_sets_anthropic_auth() { + let body = Bytes::from(r#"{"model":"old"}"#); + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-key")); + let provider = make_provider( + "anthropic/claude-3-5-sonnet", + LlmProviderType::Anthropic, + Some("ant-key"), + ); + + let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap(); + // Anthropic uses x-api-key, not Authorization + assert!(new_headers.get(AUTHORIZATION).is_none()); + assert_eq!( + new_headers.get("x-api-key").unwrap().to_str().unwrap(), + "ant-key" + ); + assert_eq!( + new_headers.get("anthropic-version").unwrap().to_str().unwrap(), + "2023-06-01" + ); + } + + #[test] + fn test_rebuild_sanitizes_old_auth_headers() { + let body = Bytes::from(r#"{"model":"old"}"#); + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-key")); + headers.insert("x-api-key", HeaderValue::from_static("old-api-key")); + headers.insert("anthropic-version", HeaderValue::from_static("old-version")); + headers.insert("x-custom", HeaderValue::from_static("keep-me")); + + let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("sk-new")); + let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap(); + + // Old x-api-key and anthropic-version should be removed + assert!(new_headers.get("anthropic-version").is_none()); + // New auth should be set + assert_eq!( + new_headers.get(AUTHORIZATION).unwrap().to_str().unwrap(), + "Bearer sk-new" + ); + // Custom headers preserved + assert_eq!( + new_headers.get("x-custom").unwrap().to_str().unwrap(), + "keep-me" + ); + } + + #[test] + fn test_rebuild_passthrough_auth_skips_credentials() { + let body = Bytes::from(r#"{"model":"old"}"#); + let mut headers = HeaderMap::new(); + headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer client-key")); + + let mut provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("sk-new")); + provider.passthrough_auth = Some(true); + + let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap(); + // Auth headers are sanitized, and passthrough_auth means no new ones are set + assert!(new_headers.get(AUTHORIZATION).is_none()); + } + + #[test] + fn test_rebuild_missing_access_key_errors() { + let body = Bytes::from(r#"{"model":"old"}"#); + let headers = HeaderMap::new(); + let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, None); + + let result = rebuild_request_for_provider(&body, &provider, &headers); + assert!(matches!(result, Err(RebuildError::MissingAccessKey(_)))); + } + + #[test] + fn test_rebuild_invalid_json_errors() { + let body = Bytes::from("not json"); + let headers = HeaderMap::new(); + let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("key")); + + let result = rebuild_request_for_provider(&body, &provider, &headers); + assert!(matches!(result, Err(RebuildError::InvalidJson(_)))); + } + + #[test] + fn test_rebuild_model_without_provider_prefix() { + let body = Bytes::from(r#"{"model":"old"}"#); + let headers = HeaderMap::new(); + let mut provider = make_provider("gpt-4o", LlmProviderType::OpenAI, Some("key")); + provider.model = Some("gpt-4o".to_string()); + + let (new_body, _) = rebuild_request_for_provider(&body, &provider, &headers).unwrap(); + let json: serde_json::Value = serde_json::from_slice(&new_body).unwrap(); + // No prefix to strip, model name used as-is + assert_eq!(json["model"], "gpt-4o"); + } + + // --- Proptest strategies --- + + fn arb_provider_type() -> impl Strategy { + prop_oneof![ + Just(LlmProviderType::OpenAI), + Just(LlmProviderType::Anthropic), + Just(LlmProviderType::Gemini), + Just(LlmProviderType::Deepseek), + ] + } + + fn arb_model_name() -> impl Strategy { + prop_oneof![ + Just("openai/gpt-4o".to_string()), + Just("openai/gpt-4o-mini".to_string()), + Just("anthropic/claude-3-5-sonnet".to_string()), + Just("gemini/gemini-pro".to_string()), + Just("deepseek/deepseek-chat".to_string()), + ] + } + + fn arb_target_provider() -> impl Strategy { + (arb_model_name(), arb_provider_type()).prop_map(|(model, iface)| { + make_provider(&model, iface, Some("test-key-123")) + }) + } + + fn arb_message_content() -> impl Strategy { + "[a-zA-Z0-9 ]{1,50}" + } + + fn arb_messages() -> impl Strategy> { + prop::collection::vec( + ( + prop_oneof![Just("user"), Just("assistant"), Just("system")], + arb_message_content(), + ) + .prop_map(|(role, content)| { + serde_json::json!({"role": role, "content": content}) + }), + 1..5, + ) + } + + fn arb_json_body() -> impl Strategy { + ( + arb_model_name(), + arb_messages(), + prop::option::of(0.0f64..2.0), + prop::option::of(1u32..4096), + proptest::bool::ANY, + ) + .prop_map(|(model, messages, temperature, max_tokens, stream)| { + let model_only = model.split('/').nth(1).unwrap_or(&model); + let mut obj = serde_json::json!({ + "model": model_only, + "messages": messages, + }); + if let Some(t) = temperature { + obj["temperature"] = serde_json::json!(t); + } + if let Some(mt) = max_tokens { + obj["max_tokens"] = serde_json::json!(mt); + } + if stream { + obj["stream"] = serde_json::json!(true); + } + obj + }) + } + + fn arb_custom_headers() -> impl Strategy> { + prop::collection::vec( + ( + prop_oneof![ + Just("x-request-id".to_string()), + Just("x-custom-header".to_string()), + Just("x-trace-id".to_string()), + Just("content-type".to_string()), + ], + "[a-zA-Z0-9-]{1,30}", + ), + 0..4, + ) + } + + // Feature: retry-on-ratelimit, Property 14: Request Preservation Across Retries + // **Validates: Requirements 5.1, 5.2, 5.3, 5.4, 5.5, 3.15** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 14 – The original body bytes are unchanged after rebuild (body is passed by reference). + /// The rebuilt body has the model field updated to the target provider's model. + /// All other JSON fields are preserved. The RequestSignature hash matches the original body hash. + /// Custom headers are preserved while auth headers are sanitized. + #[test] + fn prop_request_preservation_across_retries( + json_body in arb_json_body(), + custom_headers in arb_custom_headers(), + streaming in proptest::bool::ANY, + target_provider in arb_target_provider(), + ) { + let body_bytes = serde_json::to_vec(&json_body).unwrap(); + let body = Bytes::from(body_bytes.clone()); + + // Build original headers with custom + auth headers + let mut original_headers = HeaderMap::new(); + for (name, value) in &custom_headers { + if let (Ok(hn), Ok(hv)) = ( + hyper::header::HeaderName::from_bytes(name.as_bytes()), + HeaderValue::from_str(value), + ) { + original_headers.insert(hn, hv); + } + } + // Add auth headers that should be sanitized + original_headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-secret")); + original_headers.insert("x-api-key", HeaderValue::from_static("old-api-key")); + + let original_model = json_body["model"].as_str().unwrap_or("unknown").to_string(); + + // Create RequestSignature from original body + let sig = RequestSignature::new(&body, &original_headers, streaming, original_model.clone()); + + // Assert: body bytes are unchanged (passed by reference, not modified) + prop_assert_eq!(&body[..], &body_bytes[..], "Original body bytes must be unchanged"); + + // Assert: RequestSignature hash matches a fresh hash of the same body + let mut hasher = Sha256::new(); + hasher.update(&body); + let expected_hash: [u8; 32] = hasher.finalize().into(); + prop_assert_eq!(sig.body_hash, expected_hash, "RequestSignature hash must match original body hash"); + + // Assert: streaming flag preserved + prop_assert_eq!(sig.streaming, streaming, "Streaming flag must be preserved in signature"); + + // Rebuild for target provider + let result = rebuild_request_for_provider(&body, &target_provider, &original_headers); + prop_assert!(result.is_ok(), "rebuild_request_for_provider should succeed for valid JSON body"); + let (rebuilt_body, rebuilt_headers) = result.unwrap(); + + // Parse rebuilt body + let rebuilt_json: serde_json::Value = serde_json::from_slice(&rebuilt_body).unwrap(); + + // Assert: model field updated to target provider's model (without prefix) + let target_model = target_provider.model.as_deref().unwrap_or(&target_provider.name); + let expected_model = target_model.split_once('/').map(|(_, m)| m).unwrap_or(target_model); + prop_assert_eq!( + rebuilt_json["model"].as_str().unwrap(), + expected_model, + "Model field must be updated to target provider's model" + ); + + // Assert: messages array preserved + prop_assert_eq!( + &rebuilt_json["messages"], + &json_body["messages"], + "Messages array must be preserved across rebuild" + ); + + // Assert: other JSON fields preserved (temperature, max_tokens, stream) + // The rebuild function does a JSON round-trip (deserialize → modify model → serialize), + // so we compare against a round-tripped version of the original to account for + // any f64 precision changes inherent to JSON serialization. + let original_round_tripped: serde_json::Value = serde_json::from_slice( + &serde_json::to_vec(&json_body).unwrap() + ).unwrap(); + for key in ["temperature", "max_tokens", "stream"] { + if let Some(original_val) = original_round_tripped.get(key) { + prop_assert_eq!( + &rebuilt_json[key], + original_val, + "Field '{}' must be preserved across rebuild", + key + ); + } + } + + // Assert: custom headers preserved (non-auth headers) + // Note: HeaderMap::insert overwrites, so only the last value for each name survives + let mut last_custom: std::collections::HashMap = std::collections::HashMap::new(); + for (name, value) in &custom_headers { + let lower = name.to_lowercase(); + if lower == "authorization" || lower == "x-api-key" || lower == "anthropic-version" { + continue; + } + last_custom.insert(lower, value.clone()); + } + for (name, value) in &last_custom { + if let Some(hv) = rebuilt_headers.get(name.as_str()) { + prop_assert_eq!( + hv.to_str().unwrap(), + value.as_str(), + "Custom header '{}' must be preserved", + name + ); + } + } + + // Assert: old auth headers are sanitized (not leaked to target provider) + // The old "Bearer old-secret" and "old-api-key" should NOT appear + if let Some(auth) = rebuilt_headers.get(AUTHORIZATION) { + prop_assert_ne!( + auth.to_str().unwrap(), + "Bearer old-secret", + "Old authorization header must be sanitized" + ); + } + if let Some(api_key) = rebuilt_headers.get("x-api-key") { + prop_assert_ne!( + api_key.to_str().unwrap(), + "old-api-key", + "Old x-api-key header must be sanitized" + ); + } + + // Assert: original body is still unchanged after rebuild + prop_assert_eq!(&body[..], &body_bytes[..], "Original body bytes must remain unchanged after rebuild"); + } + } +} diff --git a/crates/common/src/retry/orchestrator.rs b/crates/common/src/retry/orchestrator.rs index 1deee6fe..eddc1ac0 100644 --- a/crates/common/src/retry/orchestrator.rs +++ b/crates/common/src/retry/orchestrator.rs @@ -809,3 +809,1936 @@ fn log_retriable_error( } } +#[cfg(test)] +mod tests { + use super::*; + use crate::configuration::{ + LlmProviderType, RetryPolicy, RetryStrategy, StatusCodeConfig, StatusCodeEntry, + TimeoutRetryConfig, RetryAfterHandlingConfig, BlockScope, ApplyTo, HighLatencyConfig, LatencyMeasure, + }; + use bytes::Bytes; + use http_body_util::{BodyExt, Full}; + use hyper::Response; + use proptest::prelude::*; + use std::collections::{HashMap, HashSet}; + + use super::super::error_detector::HttpResponse; + + /// Helper to build an HttpResponse with a given status code. + fn make_response(status: u16) -> HttpResponse { + let body = Full::new(Bytes::from("test body")) + .map_err(|_| unreachable!()) + .boxed(); + Response::builder().status(status).body(body).unwrap() + } + + /// Helper to build an HttpResponse with a given status code and headers. + fn make_response_with_headers(status: u16, headers: Vec<(&str, &str)>) -> HttpResponse { + let body = Full::new(Bytes::from("test body")) + .map_err(|_| unreachable!()) + .boxed(); + let mut builder = Response::builder().status(status); + for (name, value) in headers { + builder = builder.header(name, value); + } + builder.body(body).unwrap() + } + + /// Helper to create a test LlmProvider with a given model name. + fn make_provider(model: &str) -> LlmProvider { + LlmProvider { + name: model.to_string(), + provider_interface: LlmProviderType::OpenAI, + model: Some(model.to_string()), + access_key: Some("test-key".to_string()), + ..LlmProvider::default() + } + } + + // Feature: retry-on-ratelimit, Property 8: Bounded Retry (CP-2) + // **Validates: Requirements 1.2** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 8: For arbitrary max_attempts and max_retry_duration_ms, + /// when all providers return 429 (all-failing), the orchestrator: + /// - Returns Err(RetryExhaustedError) + /// - The number of attempts ≤ max_attempts + /// - If max_retry_duration_ms was set, retry_budget_exhausted is true when budget exceeded + #[test] + fn prop_bounded_retry( + max_attempts in 1u32..=5u32, + has_budget in proptest::bool::ANY, + budget_ms in 100u64..=5000u64, + ) { + let max_retry_duration_ms = if has_budget { Some(budget_ms) } else { None }; + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + // Run the async orchestrator and collect results for assertion. + let (attempt_count, retry_budget_exhausted) = rt.block_on(async { + let orchestrator = RetryOrchestrator::new_default(); + + // Use same_model strategy with a single provider so max_attempts + // is the precise bound on retry count. + let provider = make_provider("openai/gpt-4o"); + let all_providers = vec![provider.clone()]; + + let retry_policy = RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::SameModel, + default_max_attempts: max_attempts, + on_status_codes: vec![ + StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::SameModel, + max_attempts, + }, + ], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms, + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = RequestContext { + request_id: "test-req".to_string(), + attempted_providers: HashSet::new(), + retry_start_time: None, + attempt_number: 0, + request_retry_after_state: HashMap::new(), + request_latency_block_state: HashMap::new(), + request_signature: sig.clone(), + errors: vec![], + }; + + let body = Bytes::from("test body"); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &retry_policy, + &all_providers, + &mut ctx, + |_body, _provider| async { Ok(make_response(429)) }, + ) + .await; + + // Must be an error (all providers fail) + let err = result.expect_err( + "Expected RetryExhaustedError when all providers return 429", + ); + + (err.attempts.len() as u32, err.retry_budget_exhausted) + }); + + // Attempt count must be bounded by max_attempts. + // The orchestrator makes 1 initial attempt, then the per-classification + // counter increments. When count >= max_attempts, it stops. So total + // attempts recorded in errors = max_attempts (initial + retries that + // hit the counter limit). We allow max_attempts + 1 as an upper bound + // to account for the initial attempt before the counter check. + prop_assert!( + attempt_count <= max_attempts + 1, + "Attempt count {} exceeded max_attempts + 1 ({})", + attempt_count, + max_attempts + 1 + ); + + // If max_retry_duration_ms was set, either budget was exhausted + // (retry_budget_exhausted = true) or attempts were exhausted first + // (retry_budget_exhausted = false). Both are valid outcomes. + // With no backoff and instant responses, attempts exhaust before budget. + // When no budget is set, retry_budget_exhausted must be false. + if max_retry_duration_ms.is_none() { + prop_assert!( + !retry_budget_exhausted, + "retry_budget_exhausted should be false when no budget is set" + ); + } + } + } + + // ── P0 Edge Case Unit Tests ──────────────────────────────────────────── + + /// Helper to create a RequestContext for tests. + fn make_context(request_id: &str) -> RequestContext { + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + RequestContext { + request_id: request_id.to_string(), + attempted_providers: HashSet::new(), + retry_start_time: None, + attempt_number: 0, + request_retry_after_state: HashMap::new(), + request_latency_block_state: HashMap::new(), + request_signature: sig, + errors: vec![], + } + } + + /// Helper to create a basic retry policy for tests. + fn basic_retry_policy(max_attempts: u32) -> RetryPolicy { + RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::SameModel, + default_max_attempts: max_attempts, + on_status_codes: vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::SameModel, + max_attempts, + }], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + } + } + + #[tokio::test] + async fn test_max_retry_duration_ms_exceeded_mid_retry_stops_with_most_recent_error() { + // Use different_provider strategy with multiple providers so the retry + // loop actually continues past the first attempt. The budget is small + // enough that it will be exceeded during the retry sequence. + let orchestrator = RetryOrchestrator::new_default(); + let all_providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("anthropic/claude-3"), + make_provider("azure/gpt-4o"), + ]; + + let policy = RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 10, + on_status_codes: vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 10, + }], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: Some(1), // 1ms budget — will be exhausted quickly + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-budget-exceeded"); + let body = Bytes::from("test body"); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + // Small sleep to ensure budget is exceeded + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + Ok(make_response(429)) + }, + ) + .await; + + let err = result.expect_err("Should return RetryExhaustedError when budget exceeded"); + // Either budget was exhausted or providers were exhausted — both are valid + // since we have 3 providers and a tiny budget. The key assertion is that + // the error contains attempt details. + assert!(!err.attempts.is_empty(), "Should have at least one attempt recorded"); + // The most recent error should be a 429 + let last = err.attempts.last().unwrap(); + match &last.error_type { + AttemptErrorType::HttpError { status_code, .. } => { + assert_eq!(*status_code, 429); + } + _ => panic!("Expected HttpError for last attempt"), + } + } + + #[tokio::test] + async fn test_max_retry_duration_timer_starts_on_first_retry_not_original_request() { + // Req 3.16: Timer starts when the first retry attempt begins, not the original request. + // We verify this by checking that retry_start_time is None before the first failure + // and set after it. + let orchestrator = RetryOrchestrator::new_default(); + let provider = make_provider("openai/gpt-4o"); + let all_providers = vec![provider.clone()]; + + // Use a generous budget so we can observe the timer behavior + let mut policy = basic_retry_policy(2); + policy.max_retry_duration_ms = Some(60000); // 60s budget + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-timer-start"); + + // Verify retry_start_time is None before execute + assert!(ctx.retry_start_time.is_none(), "retry_start_time should be None before execute"); + + let body = Bytes::from("test body"); + + let _result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { Ok(make_response(429)) }, + ) + .await; + + // After execution with retries, retry_start_time should have been set + assert!( + ctx.retry_start_time.is_some(), + "retry_start_time should be set after first retry attempt" + ); + } + + #[tokio::test] + async fn test_max_retry_duration_zero_effectively_disables_retries() { + // max_retry_duration_ms = 0 is rejected by validation (NonPositiveValue). + // With a very small budget (1ms) and multiple providers, the budget should + // be exhausted very quickly, effectively limiting retries. + let orchestrator = RetryOrchestrator::new_default(); + let all_providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("anthropic/claude-3"), + make_provider("azure/gpt-4o"), + make_provider("google/gemini-pro"), + ]; + + let policy = RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 10, + on_status_codes: vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 10, + }], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: Some(1), // Near-zero budget + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-zero-budget"); + let body = Bytes::from("test body"); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + tokio::time::sleep(std::time::Duration::from_millis(5)).await; + Ok(make_response(429)) + }, + ) + .await; + + let err = result.expect_err("Should exhaust budget quickly"); + // With a 1ms budget and 5ms per attempt, we should get very few attempts + // before either budget or providers are exhausted. + assert!( + err.attempts.len() <= 4, + "With near-zero budget, should have few attempts, got {}", + err.attempts.len() + ); + } + + #[tokio::test] + async fn test_no_retry_policy_returns_error_directly() { + // When no retry_policy is configured, the orchestrator should still work + // but with default behavior. The key test is that without on_status_codes + // matching, a 429 is still treated as retriable (default strategy applies). + // However, when retry_policy has no on_status_codes and default_max_attempts = 0, + // no retries should occur. + let orchestrator = RetryOrchestrator::new_default(); + let provider = make_provider("openai/gpt-4o"); + let all_providers = vec![provider.clone()]; + + // Simulate "no retry" by setting max_attempts to 1 (only initial attempt) + let policy = RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 1, + on_status_codes: vec![], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-no-retry"); + let body = Bytes::from("test body"); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { Ok(make_response(429)) }, + ) + .await; + + let err = result.expect_err("Should return error when max_attempts exhausted"); + // With default_max_attempts = 1, should have at most 2 attempts + // (initial + 1 retry that hits the limit) + assert!( + err.attempts.len() <= 2, + "With max_attempts=1, should have at most 2 attempts, got {}", + err.attempts.len() + ); + } + + #[tokio::test] + async fn test_empty_fallback_models_different_provider_uses_provider_list() { + // When fallback_models is empty and strategy is different_provider, + // the orchestrator should select from the Provider_List. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let fallback = make_provider("anthropic/claude-3-5-sonnet"); + let all_providers = vec![primary.clone(), fallback.clone()]; + + let policy = RetryPolicy { + fallback_models: vec![], // empty — should fall back to Provider_List + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-empty-fallback"); + let body = Bytes::from("test body"); + + // Track which providers were called + let call_log = std::sync::Arc::new(std::sync::Mutex::new(Vec::::new())); + let call_log_clone = call_log.clone(); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + move |_body, provider| { + let log = call_log_clone.clone(); + let model = provider.model.clone().unwrap_or_default(); + async move { + log.lock().unwrap().push(model.clone()); + if model == "anthropic/claude-3-5-sonnet" { + Ok(make_response(200)) + } else { + Ok(make_response(429)) + } + } + }, + ) + .await; + + assert!(result.is_ok(), "Should succeed after falling back to Provider_List"); + let calls = call_log.lock().unwrap(); + assert!(calls.len() >= 2, "Should have at least 2 calls"); + assert_eq!(calls[0], "openai/gpt-4o", "First call should be primary"); + assert_eq!( + calls[1], "anthropic/claude-3-5-sonnet", + "Second call should be from Provider_List (different provider)" + ); + } + + // ── P1 Timeout Classification Tests ──────────────────────────────────── + + #[tokio::test] + async fn test_timeout_triggers_retry_to_different_provider() { + // When the primary provider times out and on_timeout is configured with + // different_provider strategy, the orchestrator should retry on a different provider. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let fallback = make_provider("anthropic/claude-3-5-sonnet"); + let all_providers = vec![primary.clone(), fallback.clone()]; + + let policy = RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![], + on_timeout: Some(TimeoutRetryConfig { + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }), + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-timeout-retry"); + let body = Bytes::from("test body"); + + let call_log = std::sync::Arc::new(std::sync::Mutex::new(Vec::::new())); + let call_log_clone = call_log.clone(); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + move |_body, provider| { + let log = call_log_clone.clone(); + let model = provider.model.clone().unwrap_or_default(); + async move { + log.lock().unwrap().push(model.clone()); + if model == "openai/gpt-4o" { + Err(TimeoutError { duration_ms: 5000 }) + } else { + Ok(make_response(200)) + } + } + }, + ) + .await; + + assert!(result.is_ok(), "Should succeed after timeout retry to different provider"); + let calls = call_log.lock().unwrap(); + assert_eq!(calls.len(), 2, "Should have 2 calls (primary + fallback)"); + assert_eq!(calls[0], "openai/gpt-4o"); + assert_eq!(calls[1], "anthropic/claude-3-5-sonnet"); + } + + #[tokio::test] + async fn test_timeout_uses_on_timeout_strategy_not_default() { + // Verify that on_timeout config overrides default_strategy for timeout errors. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let all_providers = vec![primary.clone()]; + + let policy = RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 5, + on_status_codes: vec![], + on_timeout: Some(TimeoutRetryConfig { + strategy: RetryStrategy::SameModel, + max_attempts: 2, + }), + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-timeout-strategy"); + let body = Bytes::from("test body"); + + let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); + let call_count_clone = call_count.clone(); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + move |_body, _provider| { + let count = call_count_clone.clone(); + async move { + count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + Err(TimeoutError { duration_ms: 3000 }) + } + }, + ) + .await; + + let err = result.expect_err("Should exhaust timeout retries"); + // on_timeout max_attempts = 2, so we should see at most 3 total attempts + // (1 initial + 2 retries) + assert!( + err.attempts.len() <= 3, + "With on_timeout max_attempts=2, should have at most 3 attempts, got {}", + err.attempts.len() + ); + // All attempts should be timeout errors + for attempt in &err.attempts { + assert!( + matches!(attempt.error_type, AttemptErrorType::Timeout { .. }), + "All attempts should be timeout errors" + ); + } + } + + #[tokio::test] + async fn test_timeout_without_on_timeout_uses_defaults() { + // When on_timeout is None, timeout errors should use default_strategy and + // default_max_attempts. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let fallback = make_provider("anthropic/claude-3-5-sonnet"); + let all_providers = vec![primary.clone(), fallback.clone()]; + + let policy = RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![], + on_timeout: None, // No timeout-specific config + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-timeout-defaults"); + let body = Bytes::from("test body"); + + let call_log = std::sync::Arc::new(std::sync::Mutex::new(Vec::::new())); + let call_log_clone = call_log.clone(); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + move |_body, provider| { + let log = call_log_clone.clone(); + let model = provider.model.clone().unwrap_or_default(); + async move { + log.lock().unwrap().push(model.clone()); + if model == "openai/gpt-4o" { + Err(TimeoutError { duration_ms: 5000 }) + } else { + Ok(make_response(200)) + } + } + }, + ) + .await; + + // With default_strategy=DifferentProvider and default_max_attempts=1, + // should retry to the different provider and succeed. + assert!(result.is_ok(), "Should succeed after timeout retry using defaults"); + let calls = call_log.lock().unwrap(); + assert_eq!(calls[0], "openai/gpt-4o"); + assert_eq!(calls[1], "anthropic/claude-3-5-sonnet"); + } + + #[tokio::test] + async fn test_timeout_max_attempts_exhausted_returns_error() { + // When all timeout retries are exhausted, should return RetryExhaustedError + // with timeout attempt details. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let fallback = make_provider("anthropic/claude-3-5-sonnet"); + let all_providers = vec![primary.clone(), fallback.clone()]; + + let policy = RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![], + on_timeout: Some(TimeoutRetryConfig { + strategy: RetryStrategy::DifferentProvider, + max_attempts: 1, + }), + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-timeout-exhausted"); + let body = Bytes::from("test body"); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + Err(TimeoutError { duration_ms: 5000 }) + }, + ) + .await; + + let err = result.expect_err("Should exhaust timeout retries"); + assert!(!err.attempts.is_empty(), "Should have recorded attempts"); + // Verify all attempts are timeout errors with correct duration + for attempt in &err.attempts { + match &attempt.error_type { + AttemptErrorType::Timeout { duration_ms } => { + assert_eq!(*duration_ms, 5000); + } + other => panic!("Expected Timeout error type, got {:?}", other), + } + } + } + + #[tokio::test] + async fn test_timeout_error_records_duration_in_attempt() { + // Verify that the timeout duration is correctly recorded in the attempt error. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let all_providers = vec![primary.clone()]; + + let policy = RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::SameModel, + default_max_attempts: 1, + on_status_codes: vec![], + on_timeout: Some(TimeoutRetryConfig { + strategy: RetryStrategy::SameModel, + max_attempts: 1, + }), + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-timeout-duration"); + let body = Bytes::from("test body"); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + Err(TimeoutError { duration_ms: 12345 }) + }, + ) + .await; + + let err = result.expect_err("Should exhaust retries"); + let first_attempt = &err.attempts[0]; + assert_eq!(first_attempt.model_id, "openai/gpt-4o"); + match &first_attempt.error_type { + AttemptErrorType::Timeout { duration_ms } => { + assert_eq!(*duration_ms, 12345, "Duration should be preserved"); + } + other => panic!("Expected Timeout, got {:?}", other), + } + } + + #[tokio::test] + async fn test_timeout_then_success_on_retry() { + // Primary times out, retry to same model succeeds. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let all_providers = vec![primary.clone()]; + + let policy = RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::SameModel, + default_max_attempts: 2, + on_status_codes: vec![], + on_timeout: Some(TimeoutRetryConfig { + strategy: RetryStrategy::SameModel, + max_attempts: 2, + }), + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-timeout-then-success"); + let body = Bytes::from("test body"); + + let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); + let call_count_clone = call_count.clone(); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + move |_body, _provider| { + let count = call_count_clone.clone(); + async move { + let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + if n == 0 { + Err(TimeoutError { duration_ms: 5000 }) + } else { + Ok(make_response(200)) + } + } + }, + ) + .await; + + assert!(result.is_ok(), "Should succeed on retry after timeout"); + assert_eq!( + call_count.load(std::sync::atomic::Ordering::SeqCst), + 2, + "Should have made 2 calls (initial timeout + successful retry)" + ); + } + + // ── Retry-After State Recording Tests (Task 16.1) ────────────────── + + #[tokio::test] + async fn test_retry_after_global_records_state_in_manager() { + // When a 429 response includes Retry-After header and apply_to is Global, + // the orchestrator should record the entry in the global RetryAfterStateManager. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + let policy = RetryPolicy { + fallback_models: vec!["anthropic/claude-3".to_string()], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: Some(RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: ApplyTo::Global, + max_retry_after_seconds: 300, + }), + // Use a tight budget so the orchestrator records state but bails + // before sleeping the full Retry-After delay. + max_retry_duration_ms: Some(1), + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-ra-global"); + let body = Bytes::from("test body"); + + let _result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + Ok(make_response_with_headers(429, vec![("retry-after", "10")])) + }, + ) + .await; + + // The global RetryAfterStateManager should have recorded the entry + assert!( + orchestrator.retry_after_state.is_blocked("openai/gpt-4o"), + "Model should be blocked in global RetryAfterStateManager after 429 with Retry-After" + ); + } + + #[tokio::test] + async fn test_retry_after_global_provider_scope_blocks_provider() { + // When scope is Provider, the entry should be recorded with the provider prefix. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + let policy = RetryPolicy { + fallback_models: vec!["anthropic/claude-3".to_string()], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: Some(RetryAfterHandlingConfig { + scope: BlockScope::Provider, + apply_to: ApplyTo::Global, + max_retry_after_seconds: 300, + }), + max_retry_duration_ms: Some(1), + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-ra-provider-scope"); + let body = Bytes::from("test body"); + + let _result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + Ok(make_response_with_headers(429, vec![("retry-after", "10")])) + }, + ) + .await; + + // Provider prefix "openai" should be blocked + assert!( + orchestrator.retry_after_state.is_blocked("openai"), + "Provider prefix should be blocked in global RetryAfterStateManager" + ); + // The full model ID should NOT be directly blocked (it's provider-scoped) + assert!( + !orchestrator.retry_after_state.is_blocked("openai/gpt-4o"), + "Full model ID should not be directly blocked when scope is Provider" + ); + } + + #[tokio::test] + async fn test_retry_after_request_scope_records_in_request_context() { + // When apply_to is Request, the entry should be recorded in request_context. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + let policy = RetryPolicy { + fallback_models: vec!["anthropic/claude-3".to_string()], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: Some(RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: ApplyTo::Request, + max_retry_after_seconds: 300, + }), + max_retry_duration_ms: Some(1), + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-ra-request-scope"); + let body = Bytes::from("test body"); + + let _result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + Ok(make_response_with_headers(429, vec![("retry-after", "10")])) + }, + ) + .await; + + // Request-scoped state should have the entry + assert!( + ctx.request_retry_after_state.contains_key("openai/gpt-4o"), + "Model should be recorded in request-scoped retry_after_state" + ); + // Global state should NOT have the entry + assert!( + !orchestrator.retry_after_state.is_blocked("openai/gpt-4o"), + "Global RetryAfterStateManager should not have entry when apply_to is Request" + ); + } + + #[tokio::test] + async fn test_retry_after_no_header_does_not_record_state() { + // When a 429 response does NOT include Retry-After header, + // no state entry should be created. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + let policy = RetryPolicy { + fallback_models: vec!["anthropic/claude-3".to_string()], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: Some(RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: ApplyTo::Global, + max_retry_after_seconds: 300, + }), + max_retry_duration_ms: Some(1), + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-ra-no-header"); + let body = Bytes::from("test body"); + + let _result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + // 429 without Retry-After header + Ok(make_response(429)) + }, + ) + .await; + + // No state should be recorded + assert!( + !orchestrator.retry_after_state.is_blocked("openai/gpt-4o"), + "No global state should be recorded when Retry-After header is absent" + ); + assert!( + ctx.request_retry_after_state.is_empty(), + "No request-scoped state should be recorded when Retry-After header is absent" + ); + } + + #[tokio::test] + async fn test_retry_after_malformed_header_does_not_record_state() { + // When Retry-After header has a malformed value, it should be ignored. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + let policy = RetryPolicy { + fallback_models: vec!["anthropic/claude-3".to_string()], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: Some(RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: ApplyTo::Global, + max_retry_after_seconds: 300, + }), + max_retry_duration_ms: Some(1), + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-ra-malformed"); + let body = Bytes::from("test body"); + + let _result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + // 429 with malformed Retry-After + Ok(make_response_with_headers(429, vec![("retry-after", "not-a-number")])) + }, + ) + .await; + + // No state should be recorded for malformed values + assert!( + !orchestrator.retry_after_state.is_blocked("openai/gpt-4o"), + "No state should be recorded when Retry-After header is malformed" + ); + } + + #[tokio::test] + async fn test_retry_after_default_config_when_retry_after_handling_omitted() { + // When retry_after_handling is None, effective_retry_after_config() returns + // defaults (scope: Model, apply_to: Global, max: 300). The orchestrator + // should still record state using these defaults. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + let policy = RetryPolicy { + fallback_models: vec!["anthropic/claude-3".to_string()], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: None, // Omitted — defaults apply + max_retry_duration_ms: Some(1), + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-ra-defaults"); + let body = Bytes::from("test body"); + + let _result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + Ok(make_response_with_headers(429, vec![("retry-after", "10")])) + }, + ) + .await; + + // Default config: scope=Model, apply_to=Global + // So the model ID should be blocked globally + assert!( + orchestrator.retry_after_state.is_blocked("openai/gpt-4o"), + "Model should be blocked with default retry_after config (scope: Model, apply_to: Global)" + ); + } + + // ── Task 23.2: High latency handling tests ───────────────────────────── + + fn high_latency_retry_policy(threshold_ms: u64) -> RetryPolicy { + RetryPolicy { + fallback_models: vec!["anthropic/claude-3".to_string()], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![], + on_timeout: Some(TimeoutRetryConfig { + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }), + on_high_latency: Some(HighLatencyConfig { + threshold_ms, + measure: LatencyMeasure::Ttfb, + min_triggers: 1, + trigger_window_seconds: Some(60), + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + block_duration_seconds: 300, + scope: BlockScope::Model, + apply_to: ApplyTo::Global, + }), + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + } + } + + #[tokio::test] + async fn test_high_latency_completed_response_delivered_and_block_state_created() { + // When a response completes but exceeds the latency threshold, + // the response should be delivered to the client AND a block state + // should be created for future requests. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + // threshold_ms=100, so any response taking >100ms is "slow" + // But since our mock returns instantly, we need the ErrorDetector to + // classify based on elapsed time. The mock returns 200 OK, and the + // ErrorDetector will see elapsed_ttfb_ms > threshold_ms. + // However, in the test the elapsed time is near-zero. + // We need to use a threshold of 0 so that any response triggers it. + let policy = high_latency_retry_policy(0); + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-hl-completed"); + let body = Bytes::from("test body"); + + // The mock returns 200 OK. With threshold_ms=0, any elapsed time > 0 + // will trigger HighLatencyEvent with response: Some(resp). + // But elapsed_ttfb_ms is measured as 0 in fast tests, so we need + // threshold_ms=0 and the classify logic checks measured_ms > threshold_ms. + // 0 > 0 is false, so we need to add a small delay. + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + // Small delay to ensure elapsed > 0 + tokio::time::sleep(std::time::Duration::from_millis(2)).await; + Ok(make_response(200)) + }, + ) + .await; + + // Response should be delivered successfully + assert!(result.is_ok(), "Completed-but-slow response should be delivered to client"); + let resp = result.unwrap(); + assert_eq!(resp.status().as_u16(), 200); + + // Block state should be created (min_triggers=1, so first event triggers block) + assert!( + orchestrator.latency_block_state.is_blocked("openai/gpt-4o"), + "Latency block state should be created for the slow model" + ); + } + + #[tokio::test] + async fn test_high_latency_completed_response_block_state_provider_scope() { + // When scope is "provider", the block should use the provider prefix. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + let mut policy = high_latency_retry_policy(0); + policy.on_high_latency.as_mut().unwrap().scope = BlockScope::Provider; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-hl-provider-scope"); + let body = Bytes::from("test body"); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + tokio::time::sleep(std::time::Duration::from_millis(2)).await; + Ok(make_response(200)) + }, + ) + .await; + + assert!(result.is_ok()); + + // Provider prefix "openai" should be blocked, not the full model ID + assert!( + orchestrator.latency_block_state.is_blocked("openai"), + "Provider prefix should be blocked when scope is Provider" + ); + assert!( + !orchestrator.latency_block_state.is_blocked("openai/gpt-4o"), + "Full model ID should not be directly blocked when scope is Provider" + ); + } + + #[tokio::test] + async fn test_high_latency_completed_response_request_scoped_block() { + // When apply_to is "request", block state should be in RequestContext. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + let mut policy = high_latency_retry_policy(0); + policy.on_high_latency.as_mut().unwrap().apply_to = ApplyTo::Request; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-hl-request-scope"); + let body = Bytes::from("test body"); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + tokio::time::sleep(std::time::Duration::from_millis(2)).await; + Ok(make_response(200)) + }, + ) + .await; + + assert!(result.is_ok()); + + // Block should be in request context, not global + assert!( + !orchestrator.latency_block_state.is_blocked("openai/gpt-4o"), + "Global state should NOT be blocked when apply_to is Request" + ); + assert!( + ctx.request_latency_block_state.contains_key("openai/gpt-4o"), + "Request-scoped latency block state should be recorded" + ); + } + + #[tokio::test] + async fn test_high_latency_without_response_triggers_retry() { + // When HighLatencyEvent has no completed response (response: None), + // the orchestrator should trigger retry and record the latency event. + // This scenario happens when TTFB exceeds threshold but response hasn't completed. + // In practice, this is simulated by the ErrorDetector returning HighLatencyEvent + // with response: None. Since our ErrorDetector always returns response: Some for + // 2xx, we test this indirectly through the retry loop behavior. + // + // For a direct test, we'd need a custom ErrorDetector. Instead, we verify + // that the retry loop handles HighLatencyEvent without response by checking + // that it falls through to retry logic (the attempt is recorded as an error). + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + let policy = high_latency_retry_policy(0); + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-hl-no-response"); + let body = Bytes::from("test body"); + + let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); + let call_count_clone = call_count.clone(); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + move |_body, provider| { + let count = call_count_clone.clone(); + let _model = provider.model.clone().unwrap_or_default(); + async move { + let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + if n == 0 { + // First call: slow response (200 OK but exceeds threshold) + tokio::time::sleep(std::time::Duration::from_millis(2)).await; + Ok(make_response(200)) + } else { + // Second call: fast success + Ok(make_response(200)) + } + } + }, + ) + .await; + + // The first response is completed-but-slow, so it's delivered directly. + // The block state should still be recorded. + assert!(result.is_ok()); + assert!( + orchestrator.latency_block_state.is_blocked("openai/gpt-4o"), + "Block state should be recorded even when completed response is delivered" + ); + } + + #[tokio::test] + async fn test_timeout_dual_classification_records_high_latency_event() { + // When a request times out AND on_high_latency is configured AND + // elapsed time exceeds threshold_ms, the orchestrator should: + // 1. Use TimeoutError for retry purposes + // 2. Also record a HighLatencyEvent for blocking purposes + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + // threshold_ms=50, timeout will report duration_ms > 50 + let policy = high_latency_retry_policy(50); + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-timeout-dual"); + let body = Bytes::from("test body"); + + let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); + let call_count_clone = call_count.clone(); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + move |_body, _provider| { + let count = call_count_clone.clone(); + async move { + let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + if n == 0 { + // First call: timeout after 100ms (exceeds threshold of 50ms) + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + Err(TimeoutError { duration_ms: 100 }) + } else { + // Second call: success + Ok(make_response(200)) + } + } + }, + ) + .await; + + // Should succeed on retry + assert!(result.is_ok(), "Should succeed on retry after timeout"); + + // The timeout should have also recorded a latency block + // because duration_ms (100) > threshold_ms (50) + assert!( + orchestrator.latency_block_state.is_blocked("openai/gpt-4o"), + "Latency block state should be created via dual-classification (timeout + high latency)" + ); + + // The attempt error should be recorded as a Timeout (not HighLatency) + assert!( + ctx.errors.iter().any(|e| matches!(e.error_type, AttemptErrorType::Timeout { .. })), + "The attempt should be recorded as a Timeout error" + ); + } + + #[tokio::test] + async fn test_timeout_no_dual_classification_when_below_threshold() { + // When a request times out but elapsed time is below threshold_ms, + // no HighLatencyEvent should be recorded for blocking. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + // threshold_ms=5000, timeout will report duration_ms=10 (below threshold) + let policy = high_latency_retry_policy(5000); + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-timeout-no-dual"); + let body = Bytes::from("test body"); + + let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); + let call_count_clone = call_count.clone(); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + move |_body, _provider| { + let count = call_count_clone.clone(); + async move { + let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + if n == 0 { + // Timeout with short duration (below threshold) + Err(TimeoutError { duration_ms: 10 }) + } else { + Ok(make_response(200)) + } + } + }, + ) + .await; + + assert!(result.is_ok()); + + // No latency block should be created since timeout duration < threshold + assert!( + !orchestrator.latency_block_state.is_blocked("openai/gpt-4o"), + "No latency block should be created when timeout duration is below threshold" + ); + } + + #[tokio::test] + async fn test_high_latency_min_triggers_not_met_no_block() { + // When min_triggers > 1 and only 1 event occurs, no block should be created. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + let mut policy = high_latency_retry_policy(0); + policy.on_high_latency.as_mut().unwrap().min_triggers = 3; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-hl-min-triggers"); + let body = Bytes::from("test body"); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + tokio::time::sleep(std::time::Duration::from_millis(2)).await; + Ok(make_response(200)) + }, + ) + .await; + + assert!(result.is_ok()); + + // Only 1 event recorded, but min_triggers=3, so no block + assert!( + !orchestrator.latency_block_state.is_blocked("openai/gpt-4o"), + "No block should be created when min_triggers threshold is not met" + ); + } + + #[tokio::test] + async fn test_timeout_dual_classification_provider_scope() { + // Dual-classification with provider scope should block the provider prefix. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + let mut policy = high_latency_retry_policy(50); + policy.on_high_latency.as_mut().unwrap().scope = BlockScope::Provider; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-timeout-dual-provider"); + let body = Bytes::from("test body"); + + let call_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); + let call_count_clone = call_count.clone(); + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + move |_body, _provider| { + let count = call_count_clone.clone(); + async move { + let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + if n == 0 { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + Err(TimeoutError { duration_ms: 100 }) + } else { + Ok(make_response(200)) + } + } + }, + ) + .await; + + assert!(result.is_ok()); + + // Provider prefix should be blocked + assert!( + orchestrator.latency_block_state.is_blocked("openai"), + "Provider prefix should be blocked via dual-classification" + ); + } + + // ── P2 Edge Case: successful request below threshold does NOT remove block ── + + #[tokio::test] + async fn test_successful_request_below_threshold_does_not_remove_latency_block() { + // Design Decision 9: A successful request with latency below the threshold + // does NOT remove an existing Latency_Block_State entry. Blocks expire only + // via their configured block_duration_seconds. + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + // Pre-create a latency block for the primary model (simulating a previous + // high latency event that triggered a block). + orchestrator + .latency_block_state + .record_block("openai/gpt-4o", 300, 6000); + assert!( + orchestrator.latency_block_state.is_blocked("openai/gpt-4o"), + "Pre-condition: model should be blocked" + ); + + // Now send a request with a high latency config that has a high threshold. + // The response will be fast (below threshold), so no new HighLatencyEvent + // should be triggered. The existing block must remain. + let policy = high_latency_retry_policy(99999); // very high threshold — response will be fast + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let mut ctx = make_context("test-block-not-removed"); + let body = Bytes::from("test body"); + + // The primary is blocked, so the orchestrator should route to the secondary. + // The secondary returns 200 quickly (below threshold). + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { Ok(make_response(200)) }, + ) + .await; + + assert!(result.is_ok(), "Request should succeed via secondary provider"); + + // The existing block on the primary model must still be present. + // A successful fast request must NOT remove the block (Design Decision 9). + assert!( + orchestrator.latency_block_state.is_blocked("openai/gpt-4o"), + "Latency block must NOT be removed by a successful request below threshold" + ); + } + + // Feature: retry-on-ratelimit, Property 20: Completed High-Latency Response Delivered + // **Validates: Requirements 2a.17, 2a.18, 3.4** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 20: For any request that completes successfully but exceeds + /// the latency threshold, the completed response must be delivered to the + /// client (no retry for the current request). However, a Latency_Block_State + /// entry must still be created (if min_triggers threshold is met) so future + /// requests skip the slow model/provider. + #[test] + fn prop_completed_high_latency_response_delivered( + min_triggers in 1u32..=3u32, + block_duration_seconds in 1u64..=600u64, + scope in prop_oneof![Just(BlockScope::Model), Just(BlockScope::Provider)], + apply_to in prop_oneof![Just(ApplyTo::Global), Just(ApplyTo::Request)], + measure in prop_oneof![Just(LatencyMeasure::Ttfb), Just(LatencyMeasure::Total)], + ) { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async { + let orchestrator = RetryOrchestrator::new_default(); + let primary = make_provider("openai/gpt-4o"); + let secondary = make_provider("anthropic/claude-3"); + let all_providers = vec![primary.clone(), secondary.clone()]; + + // Use threshold_ms=0 so any elapsed time > 0 triggers HighLatencyEvent. + let policy = RetryPolicy { + fallback_models: vec!["anthropic/claude-3".to_string()], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![], + on_timeout: None, + on_high_latency: Some(HighLatencyConfig { + threshold_ms: 0, + measure, + min_triggers, + trigger_window_seconds: Some(60), + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + block_duration_seconds, + scope, + apply_to, + }), + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + }; + + let sig = RequestSignature::new( + b"test body", + &hyper::HeaderMap::new(), + false, + "openai/gpt-4o".to_string(), + ); + let body = Bytes::from("test body"); + + // Send min_triggers requests so the trigger counter is met. + // Each request should return Ok(200) since the response completed. + for i in 0..min_triggers { + let mut ctx = RequestContext { + request_id: format!("test-prop20-{}", i), + attempted_providers: HashSet::new(), + retry_start_time: None, + attempt_number: 0, + request_retry_after_state: HashMap::new(), + request_latency_block_state: HashMap::new(), + request_signature: sig.clone(), + errors: vec![], + }; + + let result = orchestrator + .execute( + &body, + &sig, + &all_providers[0], + &policy, + &all_providers, + &mut ctx, + |_body, _provider| async { + // Small delay to ensure elapsed > 0 (threshold_ms=0) + tokio::time::sleep(std::time::Duration::from_millis(2)).await; + Ok(make_response(200)) + }, + ) + .await; + + // Response must always be delivered to the client + prop_assert!( + result.is_ok(), + "Completed-but-slow response must be delivered to client (attempt {})", + i + 1 + ); + let resp = result.unwrap(); + prop_assert_eq!( + resp.status().as_u16(), + 200u16, + "Response status must be 200 (attempt {})", + i + 1 + ); + + // After the last request that meets min_triggers, check block state + if i + 1 == min_triggers { + let expected_identifier = match scope { + BlockScope::Model => "openai/gpt-4o".to_string(), + BlockScope::Provider => "openai".to_string(), + }; + + match apply_to { + ApplyTo::Global => { + prop_assert!( + orchestrator.latency_block_state.is_blocked(&expected_identifier), + "Global block state should be created for '{}' after {} triggers", + expected_identifier, + min_triggers + ); + } + ApplyTo::Request => { + // Request-scoped block is stored in the RequestContext, + // which is local to this request. Verify it was set. + prop_assert!( + ctx.request_latency_block_state.contains_key(&expected_identifier), + "Request-scoped block state should be created for '{}' after {} triggers", + expected_identifier, + min_triggers + ); + // Global state should NOT be set for request-scoped blocks + prop_assert!( + !orchestrator.latency_block_state.is_blocked(&expected_identifier), + "Global block state should NOT be created when apply_to is Request" + ); + } + } + } + } + + Ok(()) + })?; + } + } +} \ No newline at end of file diff --git a/crates/common/src/retry/provider_selector.rs b/crates/common/src/retry/provider_selector.rs index 62ddf26d..54756439 100644 --- a/crates/common/src/retry/provider_selector.rs +++ b/crates/common/src/retry/provider_selector.rs @@ -469,3 +469,2715 @@ impl ProviderSelector { } } +#[cfg(test)] +mod tests { + use super::*; + use crate::configuration::{extract_provider, LlmProviderType}; + use proptest::prelude::*; + + fn make_provider(model: &str) -> LlmProvider { + LlmProvider { + name: model.to_string(), + provider_interface: LlmProviderType::OpenAI, + access_key: None, + model: Some(model.to_string()), + default: None, + stream: None, + endpoint: None, + port: None, + rate_limits: None, + usage: None, + routing_preferences: None, + cluster_name: None, + base_url_path_prefix: None, + internal: None, + passthrough_auth: None, + retry_policy: None, + } + } + + fn stub_context() -> RequestContext { + use std::collections::HashMap; + use hyper::HeaderMap; + use super::super::RequestSignature; + + let sig = RequestSignature::new(b"test", &HeaderMap::new(), false, "test".to_string()); + RequestContext { + request_id: "test-req".to_string(), + attempted_providers: HashSet::new(), + retry_start_time: None, + attempt_number: 0, + request_retry_after_state: HashMap::new(), + request_latency_block_state: HashMap::new(), + request_signature: sig, + errors: Vec::new(), + } + } + + #[test] + fn same_model_returns_matching_provider() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("openai/gpt-4o-mini"), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("openai/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn same_model_exhausted_when_already_attempted() { + let providers = vec![make_provider("openai/gpt-4o")]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + assert!(result.is_err()); + } + + #[test] + fn same_provider_filters_by_prefix() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("openai/gpt-4o-mini"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::SameProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("openai/gpt-4o-mini")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn different_provider_filters_by_different_prefix() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("openai/gpt-4o-mini"), + make_provider("anthropic/claude-3-5-sonnet"), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("anthropic/claude-3-5-sonnet")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn different_provider_exhausted_when_all_same_prefix() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("openai/gpt-4o-mini"), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + assert!(result.is_err()); + } + + #[test] + fn respects_provider_list_ordering() { + let providers = vec![ + make_provider("anthropic/claude-3-opus"), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("openai/gpt-4o"), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + + // different_provider from openai should pick the first anthropic in list order + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("anthropic/claude-3-opus")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn skips_attempted_and_picks_next() { + let providers = vec![ + make_provider("anthropic/claude-3-opus"), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("openai/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("anthropic/claude-3-opus".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("anthropic/claude-3-5-sonnet")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn all_providers_exhausted_returns_error() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("anthropic/claude-3-5-sonnet"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + attempted.insert("anthropic/claude-3-5-sonnet".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + assert!(result.is_err()); + } + + // ── Fallback models tests (Task 13.1) ───────────────────────────────── + + #[test] + fn fallback_models_tried_in_order_before_provider_list() { + // Provider_List has anthropic first, but fallback_models says try azure first. + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + let fallback_models = vec![ + "azure/gpt-4o".to_string(), + "anthropic/claude-3-5-sonnet".to_string(), + ]; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &fallback_models, + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + // Should pick azure/gpt-4o (first in fallback_models) not anthropic (first in Provider_List) + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("azure/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn fallback_models_skips_attempted_picks_next() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + attempted.insert("anthropic/claude-3-5-sonnet".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + let fallback_models = vec![ + "anthropic/claude-3-5-sonnet".to_string(), + "azure/gpt-4o".to_string(), + ]; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &fallback_models, + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("azure/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn fallback_models_exhausted_falls_back_to_provider_list() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + attempted.insert("anthropic/claude-3-5-sonnet".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + // Fallback list only has anthropic (already attempted) + let fallback_models = vec!["anthropic/claude-3-5-sonnet".to_string()]; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &fallback_models, + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + // Should fall back to Provider_List and find azure/gpt-4o + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("azure/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn fallback_models_not_in_provider_list_skipped() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("anthropic/claude-3-5-sonnet"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + // "azure/gpt-4o" is in fallback_models but NOT in Provider_List + let fallback_models = vec![ + "azure/gpt-4o".to_string(), + "anthropic/claude-3-5-sonnet".to_string(), + ]; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &fallback_models, + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + // azure/gpt-4o skipped (not in Provider_List), picks anthropic from fallback list + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("anthropic/claude-3-5-sonnet")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn fallback_models_strategy_filtering_same_provider() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("openai/gpt-4o-mini"), + make_provider("anthropic/claude-3-5-sonnet"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + // Fallback list has anthropic first, but strategy is same_provider + let fallback_models = vec![ + "anthropic/claude-3-5-sonnet".to_string(), + "openai/gpt-4o-mini".to_string(), + ]; + + let result = selector.select( + RetryStrategy::SameProvider, + "openai/gpt-4o", + &fallback_models, + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + // anthropic filtered out by same_provider strategy, picks openai/gpt-4o-mini + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("openai/gpt-4o-mini")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn fallback_models_strategy_filtering_different_provider() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("openai/gpt-4o-mini"), + make_provider("anthropic/claude-3-5-sonnet"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + // Fallback list has openai/gpt-4o-mini first, but strategy is different_provider + let fallback_models = vec![ + "openai/gpt-4o-mini".to_string(), + "anthropic/claude-3-5-sonnet".to_string(), + ]; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &fallback_models, + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + // openai/gpt-4o-mini filtered out by different_provider strategy, picks anthropic + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("anthropic/claude-3-5-sonnet")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn same_model_ignores_fallback_models() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("anthropic/claude-3-5-sonnet"), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + + let fallback_models = vec!["anthropic/claude-3-5-sonnet".to_string()]; + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &fallback_models, + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + // SameModel always returns the primary model, ignoring fallback_models + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("openai/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn fallback_all_exhausted_and_provider_list_exhausted() { + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("anthropic/claude-3-5-sonnet"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + attempted.insert("anthropic/claude-3-5-sonnet".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + let fallback_models = vec!["anthropic/claude-3-5-sonnet".to_string()]; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &fallback_models, + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + assert!(result.is_err()); + } + + #[test] + fn empty_fallback_models_uses_provider_list() { + // Verify backward compatibility: empty fallback_models behaves like P0 + let providers = vec![ + make_provider("openai/gpt-4o"), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + // Should pick anthropic (first different-provider in Provider_List order) + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("anthropic/claude-3-5-sonnet")); + } + _ => panic!("expected Selected"), + } + } + + // ── Retry-After state integration tests (Task 17.1) ────────────────── + + use crate::configuration::{HighLatencyConfig, LatencyMeasure, RetryPolicy, RetryAfterHandlingConfig}; + + fn make_provider_with_retry_policy(model: &str, ra_config: Option) -> LlmProvider { + let mut p = make_provider(model); + p.retry_policy = Some(RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: ra_config, + max_retry_duration_ms: None, + }); + p + } + + #[test] + fn same_model_global_ra_block_returns_wait_and_retry() { + // When same_model strategy and model is globally RA-blocked, + // select() should return WaitAndRetrySameModel with remaining duration. + let providers = vec![ + make_provider_with_retry_policy("openai/gpt-4o", None), // defaults: scope=Model, apply_to=Global + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block the model globally for 60 seconds + ra_state.record("openai/gpt-4o", 60, 300); + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + match result.unwrap() { + ProviderSelectionResult::WaitAndRetrySameModel { wait_duration } => { + // Should have a positive remaining duration + assert!(wait_duration.as_secs() > 0, "wait_duration should be positive"); + assert!(wait_duration.as_secs() <= 60, "wait_duration should be <= 60s"); + } + _ => panic!("expected WaitAndRetrySameModel"), + } + } + + #[test] + fn same_model_no_ra_block_returns_selected() { + // When same_model strategy and model is NOT RA-blocked, + // select() should return Selected. + let providers = vec![ + make_provider_with_retry_policy("openai/gpt-4o", None), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("openai/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn same_model_ra_block_ignored_when_has_retry_policy_false() { + // When has_retry_policy is false, RA state should not be checked. + let providers = vec![ + make_provider_with_retry_policy("openai/gpt-4o", None), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block the model globally + ra_state.record("openai/gpt-4o", 60, 300); + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + false, // has_retry_policy = false + false, + ); + + // Should return Selected despite the block, because has_retry_policy is false + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("openai/gpt-4o")); + } + _ => panic!("expected Selected when has_retry_policy is false"), + } + } + + #[test] + fn same_model_request_scoped_ra_block_returns_wait_and_retry() { + // When same_model strategy and model is request-scoped RA-blocked, + // select() should return WaitAndRetrySameModel. + let ra_config = RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: ApplyTo::Request, + max_retry_after_seconds: 300, + }; + let providers = vec![ + make_provider_with_retry_policy("openai/gpt-4o", Some(ra_config)), + ]; + let attempted = HashSet::new(); + let mut ctx = stub_context(); + // Add request-scoped block + ctx.request_retry_after_state.insert( + "openai/gpt-4o".to_string(), + Instant::now() + Duration::from_secs(30), + ); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + match result.unwrap() { + ProviderSelectionResult::WaitAndRetrySameModel { wait_duration } => { + assert!(wait_duration.as_secs() > 0); + assert!(wait_duration.as_secs() <= 30); + } + _ => panic!("expected WaitAndRetrySameModel for request-scoped block"), + } + } + + #[test] + fn different_provider_skips_ra_blocked_candidate() { + // When different_provider strategy and a candidate is RA-blocked, + // it should be skipped and the next eligible candidate selected. + let providers = vec![ + make_provider_with_retry_policy("openai/gpt-4o", None), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block anthropic model globally (scope: Model by default) + ra_state.record("anthropic/claude-3-5-sonnet", 60, 300); + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + // Should skip anthropic (blocked) and pick azure + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("azure/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn provider_scope_blocks_all_models_from_provider() { + // When scope is Provider, blocking "openai" should block all openai/* models. + let ra_config = RetryAfterHandlingConfig { + scope: BlockScope::Provider, + apply_to: ApplyTo::Global, + max_retry_after_seconds: 300, + }; + let providers = vec![ + make_provider_with_retry_policy("openai/gpt-4o", Some(ra_config)), + make_provider("openai/gpt-4o-mini"), + make_provider("anthropic/claude-3-5-sonnet"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block at provider level: "openai" + ra_state.record("openai", 60, 300); + + let result = selector.select( + RetryStrategy::SameProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + // openai/gpt-4o-mini should be blocked because provider "openai" is blocked + // No same-provider candidates available → error + assert!(result.is_err()); + } + + #[test] + fn fallback_model_ra_blocked_skipped() { + // When a fallback model is RA-blocked, it should be skipped. + let providers = vec![ + make_provider_with_retry_policy("openai/gpt-4o", None), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block anthropic model + ra_state.record("anthropic/claude-3-5-sonnet", 60, 300); + + let fallback_models = vec![ + "anthropic/claude-3-5-sonnet".to_string(), + "azure/gpt-4o".to_string(), + ]; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &fallback_models, + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + // anthropic blocked → skip to azure + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("azure/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn all_candidates_ra_blocked_returns_error_with_shortest_remaining() { + // When all candidates are RA-blocked, return error with shortest_remaining_block_seconds. + let providers = vec![ + make_provider_with_retry_policy("openai/gpt-4o", None), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block both alternative providers + ra_state.record("anthropic/claude-3-5-sonnet", 60, 300); + ra_state.record("azure/gpt-4o", 30, 300); + + let fallback_models = vec![ + "anthropic/claude-3-5-sonnet".to_string(), + "azure/gpt-4o".to_string(), + ]; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &fallback_models, + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + match result { + Err(e) => { + // shortest_remaining should be set (azure has 30s, anthropic has 60s) + assert!(e.shortest_remaining_block_seconds.is_some()); + let shortest = e.shortest_remaining_block_seconds.unwrap(); + assert!(shortest <= 30, "shortest remaining should be <= 30s, got {}", shortest); + } + Ok(_) => panic!("expected AllProvidersExhaustedError"), + } + } + + #[test] + fn same_model_provider_scope_global_ra_block_returns_wait() { + // When same_model strategy with provider-scope RA block, + // blocking the provider should trigger WaitAndRetrySameModel. + let ra_config = RetryAfterHandlingConfig { + scope: BlockScope::Provider, + apply_to: ApplyTo::Global, + max_retry_after_seconds: 300, + }; + let providers = vec![ + make_provider_with_retry_policy("openai/gpt-4o", Some(ra_config)), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block at provider level + ra_state.record("openai", 45, 300); + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + match result.unwrap() { + ProviderSelectionResult::WaitAndRetrySameModel { wait_duration } => { + assert!(wait_duration.as_secs() > 0); + assert!(wait_duration.as_secs() <= 45); + } + _ => panic!("expected WaitAndRetrySameModel for provider-scope block"), + } + } + + // ── Latency Block state integration tests (Task 23.1) ──────────────── + + fn make_hl_config(scope: BlockScope, apply_to: ApplyTo) -> HighLatencyConfig { + HighLatencyConfig { + threshold_ms: 5000, + measure: LatencyMeasure::Ttfb, + min_triggers: 1, + trigger_window_seconds: None, + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + block_duration_seconds: 300, + scope, + apply_to, + } + } + + fn make_provider_with_hl_config( + model: &str, + ra_config: Option, + hl_config: Option, + ) -> LlmProvider { + let mut p = make_provider(model); + p.retry_policy = Some(RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![], + on_timeout: None, + on_high_latency: hl_config, + backoff: None, + retry_after_handling: ra_config, + max_retry_duration_ms: None, + }); + p + } + + #[test] + fn same_model_lb_block_returns_error_not_wait() { + // For same_model strategy with LB block: return AllProvidersExhaustedError + // (skip to alternative), NOT WaitAndRetrySameModel (unlike RA). + let hl_config = make_hl_config(BlockScope::Model, ApplyTo::Global); + let providers = vec![make_provider_with_hl_config( + "openai/gpt-4o", + None, + Some(hl_config), + )]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let lb_state = LatencyBlockStateManager::new(); + + // Block the model globally via LB + lb_state.record_block("openai/gpt-4o", 60, 6000); + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &lb_state, + &ctx, + true, + true, + ); + + // Should return AllProvidersExhaustedError, NOT WaitAndRetrySameModel + match result { + Err(e) => { + assert!( + e.shortest_remaining_block_seconds.is_some(), + "should include remaining block seconds" + ); + let secs = e.shortest_remaining_block_seconds.unwrap(); + assert!(secs > 0 && secs <= 60); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + panic!("LB block on same_model should NOT return WaitAndRetrySameModel"); + } + Ok(ProviderSelectionResult::Selected(_)) => { + panic!("LB-blocked model should not be Selected"); + } + } + } + + #[test] + fn same_model_no_lb_block_returns_selected() { + // When same_model strategy and model is NOT LB-blocked, returns Selected. + let hl_config = make_hl_config(BlockScope::Model, ApplyTo::Global); + let providers = vec![make_provider_with_hl_config( + "openai/gpt-4o", + None, + Some(hl_config), + )]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + true, + true, + ); + + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("openai/gpt-4o")); + } + _ => panic!("expected Selected when not LB-blocked"), + } + } + + #[test] + fn same_model_lb_block_ignored_when_has_high_latency_config_false() { + // When has_high_latency_config is false, LB state should not be checked. + let hl_config = make_hl_config(BlockScope::Model, ApplyTo::Global); + let providers = vec![make_provider_with_hl_config( + "openai/gpt-4o", + None, + Some(hl_config), + )]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let lb_state = LatencyBlockStateManager::new(); + + lb_state.record_block("openai/gpt-4o", 60, 6000); + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &lb_state, + &ctx, + false, + false, // has_high_latency_config = false + ); + + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("openai/gpt-4o")); + } + _ => panic!("expected Selected when has_high_latency_config is false"), + } + } + + #[test] + fn same_model_request_scoped_lb_block_returns_error() { + // When same_model strategy and model is request-scoped LB-blocked, + // returns AllProvidersExhaustedError. + let hl_config = make_hl_config(BlockScope::Model, ApplyTo::Request); + let providers = vec![make_provider_with_hl_config( + "openai/gpt-4o", + None, + Some(hl_config), + )]; + let attempted = HashSet::new(); + let mut ctx = stub_context(); + // Add request-scoped LB block + ctx.request_latency_block_state.insert( + "openai/gpt-4o".to_string(), + Instant::now() + Duration::from_secs(30), + ); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + true, + true, + ); + + assert!(result.is_err(), "request-scoped LB block should return error for same_model"); + } + + #[test] + fn different_provider_skips_lb_blocked_candidate() { + // When different_provider strategy and a candidate is LB-blocked, + // it should be skipped and the next eligible candidate selected. + let hl_config = make_hl_config(BlockScope::Model, ApplyTo::Global); + let providers = vec![ + make_provider_with_hl_config("openai/gpt-4o", None, Some(hl_config)), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + let lb_state = LatencyBlockStateManager::new(); + + // Block anthropic model globally via LB + lb_state.record_block("anthropic/claude-3-5-sonnet", 60, 6000); + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &lb_state, + &ctx, + true, + true, + ); + + // Should skip anthropic (LB-blocked) and pick azure + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("azure/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn provider_scope_lb_blocks_all_models_from_provider() { + // When LB scope is Provider, blocking "openai" should block all openai/* models. + let hl_config = make_hl_config(BlockScope::Provider, ApplyTo::Global); + let providers = vec![ + make_provider_with_hl_config("openai/gpt-4o", None, Some(hl_config)), + make_provider("openai/gpt-4o-mini"), + make_provider("anthropic/claude-3-5-sonnet"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + let lb_state = LatencyBlockStateManager::new(); + + // Block at provider level: "openai" + lb_state.record_block("openai", 60, 6000); + + let result = selector.select( + RetryStrategy::SameProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &lb_state, + &ctx, + true, + true, + ); + + // openai/gpt-4o-mini should be blocked because provider "openai" is LB-blocked + assert!(result.is_err()); + } + + #[test] + fn fallback_model_lb_blocked_skipped() { + // When a fallback model is LB-blocked, it should be skipped. + let hl_config = make_hl_config(BlockScope::Model, ApplyTo::Global); + let providers = vec![ + make_provider_with_hl_config("openai/gpt-4o", None, Some(hl_config)), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + let lb_state = LatencyBlockStateManager::new(); + + // Block anthropic model via LB + lb_state.record_block("anthropic/claude-3-5-sonnet", 60, 6000); + + let fallback_models = vec![ + "anthropic/claude-3-5-sonnet".to_string(), + "azure/gpt-4o".to_string(), + ]; + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &fallback_models, + &providers, + &attempted, + &RetryAfterStateManager::new(), + &lb_state, + &ctx, + true, + true, + ); + + // anthropic LB-blocked → skip to azure + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("azure/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn both_ra_and_lb_block_skips_candidate() { + // When both RA and LB block a candidate, skip it (either block is sufficient). + let hl_config = make_hl_config(BlockScope::Model, ApplyTo::Global); + let providers = vec![ + make_provider_with_hl_config("openai/gpt-4o", None, Some(hl_config)), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + let lb_state = LatencyBlockStateManager::new(); + + // Block anthropic via BOTH RA and LB + ra_state.record("anthropic/claude-3-5-sonnet", 60, 300); + lb_state.record_block("anthropic/claude-3-5-sonnet", 60, 6000); + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &ra_state, + &lb_state, + &ctx, + true, + true, + ); + + // Should skip anthropic (both blocked) and pick azure + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("azure/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn ra_only_block_still_skips_when_lb_not_blocked() { + // When only RA blocks a candidate (LB does not), still skip it. + let hl_config = make_hl_config(BlockScope::Model, ApplyTo::Global); + let providers = vec![ + make_provider_with_hl_config("openai/gpt-4o", None, Some(hl_config)), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block anthropic via RA only + ra_state.record("anthropic/claude-3-5-sonnet", 60, 300); + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + true, + ); + + // Should skip anthropic (RA-blocked) and pick azure + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("azure/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn lb_only_block_still_skips_when_ra_not_blocked() { + // When only LB blocks a candidate (RA does not), still skip it. + let hl_config = make_hl_config(BlockScope::Model, ApplyTo::Global); + let providers = vec![ + make_provider_with_hl_config("openai/gpt-4o", None, Some(hl_config)), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + let lb_state = LatencyBlockStateManager::new(); + + // Block anthropic via LB only + lb_state.record_block("anthropic/claude-3-5-sonnet", 60, 6000); + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &lb_state, + &ctx, + true, + true, + ); + + // Should skip anthropic (LB-blocked) and pick azure + match result.unwrap() { + ProviderSelectionResult::Selected(p) => { + assert_eq!(p.model.as_deref(), Some("azure/gpt-4o")); + } + _ => panic!("expected Selected"), + } + } + + #[test] + fn all_candidates_lb_blocked_returns_error_with_shortest_remaining() { + // When all candidates are LB-blocked, return error with shortest_remaining_block_seconds. + let hl_config = make_hl_config(BlockScope::Model, ApplyTo::Global); + let providers = vec![ + make_provider_with_hl_config("openai/gpt-4o", None, Some(hl_config)), + make_provider("anthropic/claude-3-5-sonnet"), + make_provider("azure/gpt-4o"), + ]; + let mut attempted = HashSet::new(); + attempted.insert("openai/gpt-4o".to_string()); + let ctx = stub_context(); + let selector = ProviderSelector; + let lb_state = LatencyBlockStateManager::new(); + + // Block both alternative providers via LB + lb_state.record_block("anthropic/claude-3-5-sonnet", 60, 6000); + lb_state.record_block("azure/gpt-4o", 30, 6000); + + let result = selector.select( + RetryStrategy::DifferentProvider, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &lb_state, + &ctx, + true, + true, + ); + + match result { + Err(e) => { + assert!(e.shortest_remaining_block_seconds.is_some()); + let shortest = e.shortest_remaining_block_seconds.unwrap(); + assert!(shortest <= 30, "shortest remaining should be <= 30s, got {}", shortest); + } + Ok(_) => panic!("expected AllProvidersExhaustedError"), + } + } + + #[test] + fn same_model_both_ra_and_lb_blocked_ra_takes_precedence() { + // When same_model and both RA and LB block the model, + // RA check happens first → returns WaitAndRetrySameModel. + let hl_config = make_hl_config(BlockScope::Model, ApplyTo::Global); + let ra_config = RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: ApplyTo::Global, + max_retry_after_seconds: 300, + }; + let providers = vec![make_provider_with_hl_config( + "openai/gpt-4o", + Some(ra_config), + Some(hl_config), + )]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + let lb_state = LatencyBlockStateManager::new(); + + // Block via both RA and LB + ra_state.record("openai/gpt-4o", 60, 300); + lb_state.record_block("openai/gpt-4o", 60, 6000); + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &ra_state, + &lb_state, + &ctx, + true, + true, + ); + + // RA check happens first → WaitAndRetrySameModel + match result.unwrap() { + ProviderSelectionResult::WaitAndRetrySameModel { wait_duration } => { + assert!(wait_duration.as_secs() > 0); + } + _ => panic!("expected WaitAndRetrySameModel when both RA and LB block same_model"), + } + } + + #[test] + fn same_model_provider_scope_lb_block_returns_error() { + // When same_model strategy with provider-scope LB block, + // blocking the provider should return AllProvidersExhaustedError. + let hl_config = make_hl_config(BlockScope::Provider, ApplyTo::Global); + let providers = vec![make_provider_with_hl_config( + "openai/gpt-4o", + None, + Some(hl_config), + )]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let lb_state = LatencyBlockStateManager::new(); + + // Block at provider level + lb_state.record_block("openai", 45, 6000); + + let result = selector.select( + RetryStrategy::SameModel, + "openai/gpt-4o", + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &lb_state, + &ctx, + true, + true, + ); + + match result { + Err(e) => { + assert!(e.shortest_remaining_block_seconds.is_some()); + let secs = e.shortest_remaining_block_seconds.unwrap(); + assert!(secs > 0 && secs <= 45); + } + Ok(_) => panic!("expected AllProvidersExhaustedError for provider-scope LB block"), + } + } + + // --- Proptest strategies --- + + /// Generates a provider prefix from a fixed set. + fn arb_prefix() -> impl Strategy { + prop_oneof![ + Just("openai".to_string()), + Just("anthropic".to_string()), + Just("azure".to_string()), + ] + } + + /// Generates a model identifier like "openai/gpt-4o". + fn arb_model_id() -> impl Strategy { + (arb_prefix(), prop_oneof![ + Just("model-a".to_string()), + Just("model-b".to_string()), + Just("model-c".to_string()), + ]) + .prop_map(|(prefix, model)| format!("{}/{}", prefix, model)) + } + + /// Generates a non-empty list of providers (1..=6). + fn arb_provider_list() -> impl Strategy> { + proptest::collection::vec(arb_model_id(), 1..=6) + .prop_map(|ids| ids.into_iter().map(|id| make_provider(&id)).collect()) + } + + + + // Feature: retry-on-ratelimit, Property 11: Strategy-Correct Provider Selection + // **Validates: Requirements 3.10, 3.11, 3.12, 3.13, 6.2, 6.3, 6.4** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 11 – Case 1: SameModel returns the provider whose model matches primary_model. + #[test] + fn prop_same_model_returns_matching_or_exhausted( + providers in arb_provider_list(), + attempted_indices in proptest::collection::hash_set(0usize..6, 0..=3), + ) { + let primary_model = providers[0].model.as_deref().unwrap(); + let primary_model_owned = primary_model.to_string(); + let attempted: HashSet = attempted_indices + .into_iter() + .filter_map(|i| providers.get(i).and_then(|p| p.model.clone())) + .collect(); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::SameModel, + &primary_model_owned, + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + // SameModel: selected provider's model must equal primary_model + prop_assert_eq!( + p.model.as_deref(), + Some(primary_model_owned.as_str()), + "SameModel selected a different model: {:?} vs {}", + p.model, primary_model_owned + ); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + // Acceptable in P1/P2 when RA-blocked; not expected in P0 but valid. + } + Err(_) => { + // All matching candidates must have been attempted + let has_unattempted = providers.iter().any(|p| { + p.model.as_deref() == Some(primary_model_owned.as_str()) + && !attempted.contains(&primary_model_owned) + }); + prop_assert!( + !has_unattempted, + "SameModel returned Err but unattempted candidate exists" + ); + } + } + } + + /// Property 11 – Case 2: SameProvider returns a provider with the same prefix as primary_model. + #[test] + fn prop_same_provider_selects_matching_prefix( + providers in arb_provider_list(), + attempted_indices in proptest::collection::hash_set(0usize..6, 0..=3), + ) { + let primary_model = providers[0].model.as_deref().unwrap(); + let primary_model_owned = primary_model.to_string(); + let primary_prefix = extract_provider(&primary_model_owned).to_string(); + let attempted: HashSet = attempted_indices + .into_iter() + .filter_map(|i| providers.get(i).and_then(|p| p.model.clone())) + .collect(); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::SameProvider, + &primary_model_owned, + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + let selected_model = p.model.as_deref().unwrap(); + let selected_prefix = extract_provider(selected_model); + prop_assert_eq!( + selected_prefix, primary_prefix.as_str(), + "SameProvider selected different prefix: {} vs {}", + selected_prefix, primary_prefix + ); + prop_assert!( + !attempted.contains(selected_model), + "SameProvider selected an already-attempted provider: {}", + selected_model + ); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + // Not expected for SameProvider in P0, but valid variant. + } + Err(_) => { + // All same-prefix candidates must have been attempted + let has_unattempted = providers.iter().any(|p| { + if let Some(ref m) = p.model { + extract_provider(m) == primary_prefix + && !attempted.contains(m.as_str()) + } else { + false + } + }); + prop_assert!( + !has_unattempted, + "SameProvider returned Err but unattempted same-prefix candidate exists" + ); + } + } + } + + /// Property 11 – Case 3: DifferentProvider returns a provider with a different prefix than primary_model. + #[test] + fn prop_different_provider_selects_different_prefix( + providers in arb_provider_list(), + attempted_indices in proptest::collection::hash_set(0usize..6, 0..=3), + ) { + let primary_model = providers[0].model.as_deref().unwrap(); + let primary_model_owned = primary_model.to_string(); + let primary_prefix = extract_provider(&primary_model_owned).to_string(); + let attempted: HashSet = attempted_indices + .into_iter() + .filter_map(|i| providers.get(i).and_then(|p| p.model.clone())) + .collect(); + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + RetryStrategy::DifferentProvider, + &primary_model_owned, + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + let selected_model = p.model.as_deref().unwrap(); + let selected_prefix = extract_provider(selected_model); + prop_assert_ne!( + selected_prefix, primary_prefix.as_str(), + "DifferentProvider selected same prefix: {} vs {}", + selected_prefix, primary_prefix + ); + prop_assert!( + !attempted.contains(selected_model), + "DifferentProvider selected an already-attempted provider: {}", + selected_model + ); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + // Not expected for DifferentProvider, but valid variant. + } + Err(_) => { + // All different-prefix candidates must have been attempted + let has_unattempted = providers.iter().any(|p| { + if let Some(ref m) = p.model { + extract_provider(m) != primary_prefix + && !attempted.contains(m.as_str()) + } else { + false + } + }); + prop_assert!( + !has_unattempted, + "DifferentProvider returned Err but unattempted different-prefix candidate exists" + ); + } + } + } + } + + // Feature: retry-on-ratelimit, Property 10: Fallback Models Priority Ordering + // **Validates: Requirements 3.10, 3.11, 3.12, 3.13, 6.2, 6.3, 6.4** + // + // For any provider selection where fallback_models is non-empty, the selector + // must try models from fallback_models in their defined order before considering + // models from the general Provider_List. A model should only be skipped if it + // has already been attempted or is blocked. + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + #[test] + fn prop_fallback_models_priority_ordering( + all_providers in arb_provider_list(), + fallback_indices in proptest::collection::vec(0usize..6, 0..=4), + attempted_indices in proptest::collection::hash_set(0usize..6, 0..=3), + strategy in prop_oneof![ + Just(RetryStrategy::SameProvider), + Just(RetryStrategy::DifferentProvider), + ], + ) { + // Use first provider as primary model. + let primary_model = all_providers[0].model.as_deref().unwrap().to_string(); + let primary_prefix = extract_provider(&primary_model).to_string(); + + // Build fallback_models from indices into all_providers (may reference + // models not in all_providers if index is out of range — that's fine, + // those get skipped). + let fallback_models: Vec = fallback_indices + .iter() + .filter_map(|&i| all_providers.get(i).and_then(|p| p.model.clone())) + .collect(); + + // Build attempted set from indices. + let attempted: HashSet = attempted_indices + .iter() + .filter_map(|&i| all_providers.get(i).and_then(|p| p.model.clone())) + .collect(); + + let ctx = stub_context(); + let selector = ProviderSelector; + + let result = selector.select( + strategy, + &primary_model, + &fallback_models, + &all_providers, + &attempted, + &RetryAfterStateManager::new(), + &LatencyBlockStateManager::new(), + &ctx, + false, + false, + ); + + // Determine which fallback models are eligible: present in + // all_providers, match strategy, and not attempted. + let matches_strategy = |model_id: &str| -> bool { + let prefix = extract_provider(model_id); + match strategy { + RetryStrategy::SameProvider => prefix == primary_prefix, + RetryStrategy::DifferentProvider => prefix != primary_prefix, + _ => unreachable!(), + } + }; + + let first_eligible_fallback: Option<&str> = fallback_models.iter().find_map(|fm| { + if attempted.contains(fm.as_str()) { + return None; + } + if !matches_strategy(fm) { + return None; + } + // Must exist in all_providers. + if all_providers.iter().any(|p| p.model.as_deref() == Some(fm.as_str())) { + Some(fm.as_str()) + } else { + None + } + }); + + // First eligible Provider_List candidate (not in fallback, or any + // eligible candidate from Provider_List order). + let first_eligible_provider_list: Option<&str> = all_providers.iter().find_map(|p| { + if let Some(ref m) = p.model { + if matches_strategy(m) && !attempted.contains(m.as_str()) { + Some(m.as_str()) + } else { + None + } + } else { + None + } + }); + + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + let selected = p.model.as_deref().unwrap(); + + if let Some(expected_fallback) = first_eligible_fallback { + // If there's an eligible fallback, it MUST be selected + // (priority over Provider_List). + prop_assert_eq!( + selected, expected_fallback, + "Expected first eligible fallback '{}' but got '{}'. \ + fallback_models={:?}, attempted={:?}, strategy={:?}", + expected_fallback, selected, fallback_models, attempted, strategy + ); + } else { + // No eligible fallback → must come from Provider_List. + // The selected model must match strategy and not be attempted. + prop_assert!( + matches_strategy(selected), + "Selected '{}' doesn't match strategy {:?}", + selected, strategy + ); + prop_assert!( + !attempted.contains(selected), + "Selected '{}' was already attempted", + selected + ); + // Should be the first eligible from Provider_List order. + if let Some(expected_pl) = first_eligible_provider_list { + prop_assert_eq!( + selected, expected_pl, + "Expected first Provider_List candidate '{}' but got '{}'", + expected_pl, selected + ); + } + } + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + // Not expected for SameProvider/DifferentProvider, but valid variant. + } + Err(_) => { + // No eligible candidate at all — verify that's correct. + prop_assert!( + first_eligible_fallback.is_none(), + "Returned Err but eligible fallback exists: {:?}", + first_eligible_fallback + ); + prop_assert!( + first_eligible_provider_list.is_none(), + "Returned Err but eligible Provider_List candidate exists: {:?}", + first_eligible_provider_list + ); + } + } + } + } + + // Feature: retry-on-ratelimit, Property 7: Cooldown Exclusion Invariant (CP-1) + // **Validates: Requirements 6.5, 11.5, 11.6, 12.6, 12.7, 13.1, 13.3, 13.4, 13.9, CP-1** + // + // For any model/provider with an active Retry_After_State entry (expires_at > now), + // that model/provider must NOT be selected by ProviderSelector. For same_model strategy, + // WaitAndRetrySameModel is returned instead. Once expired, the model must be eligible again. + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 7 – Case 1: Blocked models are never returned as Selected + /// for SameProvider / DifferentProvider strategies. + #[test] + fn prop_cooldown_exclusion_blocked_never_selected( + all_providers in proptest::collection::vec(arb_model_id(), 2..=6) + .prop_map(|ids| { + ids.into_iter() + .map(|id| make_provider_with_retry_policy(&id, None)) + .collect::>() + }), + // Indices of providers to block via RA state + block_indices in proptest::collection::hash_set(0usize..6, 1..=3), + strategy in prop_oneof![ + Just(RetryStrategy::SameProvider), + Just(RetryStrategy::DifferentProvider), + ], + ) { + let primary_model = all_providers[0].model.as_deref().unwrap().to_string(); + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block selected providers with a long RA duration + let blocked_models: HashSet = block_indices + .iter() + .filter_map(|&i| all_providers.get(i).and_then(|p| p.model.clone())) + .collect(); + + for model_id in &blocked_models { + ra_state.record(model_id, 600, 600); + } + + let result = selector.select( + strategy, + &primary_model, + &[], + &all_providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, // has_retry_policy = true to enable RA checks + false, + ); + + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + let selected = p.model.as_deref().unwrap(); + prop_assert!( + !blocked_models.contains(selected), + "Blocked model '{}' was returned as Selected! blocked={:?}", + selected, blocked_models + ); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + // Not expected for SameProvider/DifferentProvider, but acceptable. + } + Err(_) => { + // All eligible candidates were blocked or exhausted — valid. + } + } + } + + /// Property 7 – Case 2: For same_model strategy with RA block, + /// WaitAndRetrySameModel is returned (not Selected). + #[test] + fn prop_cooldown_exclusion_same_model_returns_wait( + model_id in arb_model_id(), + block_seconds in 1u64..=300, + ) { + let providers = vec![ + make_provider_with_retry_policy(&model_id, None), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block the model + ra_state.record(&model_id, block_seconds, 300); + + let result = selector.select( + RetryStrategy::SameModel, + &model_id, + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + match result { + Ok(ProviderSelectionResult::WaitAndRetrySameModel { wait_duration }) => { + // Duration must be positive and bounded by block_seconds + let capped = block_seconds.min(300); + prop_assert!( + wait_duration.as_secs() <= capped, + "wait_duration {}s exceeds capped block {}s", + wait_duration.as_secs(), capped + ); + prop_assert!( + !wait_duration.is_zero(), + "wait_duration should be positive for an active block" + ); + } + Ok(ProviderSelectionResult::Selected(_)) => { + prop_assert!(false, "Blocked model should not be Selected for same_model strategy"); + } + Err(_) => { + prop_assert!(false, "same_model with blocked model should return WaitAndRetrySameModel, not Err"); + } + } + } + + /// Property 7 – Case 3: Blocked models in fallback_models are skipped. + #[test] + fn prop_cooldown_exclusion_fallback_blocked_skipped( + all_providers in proptest::collection::vec(arb_model_id(), 3..=6) + .prop_map(|ids| { + ids.into_iter() + .map(|id| make_provider_with_retry_policy(&id, None)) + .collect::>() + }), + // Block the first 1-2 fallback candidates + num_blocked in 1usize..=2, + strategy in prop_oneof![ + Just(RetryStrategy::SameProvider), + Just(RetryStrategy::DifferentProvider), + ], + ) { + let primary_model = all_providers[0].model.as_deref().unwrap().to_string(); + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Build fallback_models from providers (skip primary) + let fallback_models: Vec = all_providers[1..] + .iter() + .filter_map(|p| p.model.clone()) + .collect(); + + // Block the first num_blocked fallback models + let blocked_models: HashSet = fallback_models + .iter() + .take(num_blocked.min(fallback_models.len())) + .cloned() + .collect(); + + for model_id in &blocked_models { + ra_state.record(model_id, 600, 600); + } + + let result = selector.select( + strategy, + &primary_model, + &fallback_models, + &all_providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + let selected = p.model.as_deref().unwrap(); + prop_assert!( + !blocked_models.contains(selected), + "Blocked fallback model '{}' was selected! blocked={:?}", + selected, blocked_models + ); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + // Not expected for these strategies, but acceptable. + } + Err(_) => { + // All eligible candidates blocked or exhausted — valid. + } + } + } + + /// Property 7 – Case 4: After RA expiration, model becomes selectable again. + /// We use a 0-second block which expires immediately. + #[test] + fn prop_cooldown_exclusion_unblocked_after_expiration( + model_id in arb_model_id(), + strategy in prop_oneof![ + Just(RetryStrategy::SameModel), + Just(RetryStrategy::SameProvider), + Just(RetryStrategy::DifferentProvider), + ], + ) { + let providers = vec![ + make_provider_with_retry_policy(&model_id, None), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Record with 0 seconds — expires immediately + ra_state.record(&model_id, 0, 300); + + // The model should NOT be blocked (expired immediately) + prop_assert!( + !ra_state.is_blocked(&model_id), + "Model should not be blocked after 0-second RA record" + ); + + let result = selector.select( + strategy, + &model_id, + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + // For any strategy, the model should be selectable (not blocked) + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + prop_assert_eq!( + p.model.as_deref(), + Some(model_id.as_str()), + "Expected the unblocked model to be selected" + ); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + prop_assert!(false, "Expired RA should not trigger WaitAndRetrySameModel"); + } + Err(_) => { + // For DifferentProvider strategy, the single provider may not match + // (same prefix as primary). This is a strategy mismatch, not a block issue. + // Only fail if strategy should have matched. + match strategy { + RetryStrategy::SameModel | RetryStrategy::SameProvider => { + prop_assert!(false, "Unblocked model should be selectable for {:?}", strategy); + } + RetryStrategy::DifferentProvider => { + // Expected: single provider can't match "different provider" strategy. + } + } + } + } + } + } + + // Feature: retry-on-ratelimit, Property 19: Latency Block Exclusion During Provider Selection + // **Validates: Requirements 6.7, 6.8, 15.1, 15.3, 15.4, 15.12, 15.13** + // + // For any model/provider with an active Latency_Block_State entry (expires_at > now), + // that model/provider must be skipped during provider selection (both initial and retry). + // When both Retry_After_State and Latency_Block_State exist for the same identifier, + // the candidate must be skipped if either state indicates blocking. + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 19 – Case 1: LB-blocked models are never returned as Selected + /// for SameProvider / DifferentProvider strategies. + #[test] + fn prop_lb_blocked_never_selected( + all_providers in proptest::collection::vec(arb_model_id(), 2..=6) + .prop_map(|ids| { + ids.into_iter() + .map(|id| make_provider_with_hl_config( + &id, + None, + Some(make_hl_config(BlockScope::Model, ApplyTo::Global)), + )) + .collect::>() + }), + block_indices in proptest::collection::hash_set(0usize..6, 1..=3), + strategy in prop_oneof![ + Just(RetryStrategy::SameProvider), + Just(RetryStrategy::DifferentProvider), + ], + ) { + let primary_model = all_providers[0].model.as_deref().unwrap().to_string(); + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let lb_state = LatencyBlockStateManager::new(); + + // Block selected providers with a long LB duration + let blocked_models: HashSet = block_indices + .iter() + .filter_map(|&i| all_providers.get(i).and_then(|p| p.model.clone())) + .collect(); + + for model_id in &blocked_models { + lb_state.record_block(model_id, 600, 8000); + } + + let result = selector.select( + strategy, + &primary_model, + &[], + &all_providers, + &attempted, + &RetryAfterStateManager::new(), + &lb_state, + &ctx, + false, + true, // has_high_latency_config = true to enable LB checks + ); + + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + let selected = p.model.as_deref().unwrap(); + prop_assert!( + !blocked_models.contains(selected), + "LB-blocked model '{}' was returned as Selected! blocked={:?}", + selected, blocked_models + ); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + // Not expected for SameProvider/DifferentProvider, but acceptable. + } + Err(_) => { + // All eligible candidates were blocked or exhausted — valid. + } + } + } + + /// Property 19 – Case 2: For same_model strategy with LB block, + /// AllProvidersExhaustedError is returned (skip to alternative, not wait). + #[test] + fn prop_lb_blocked_same_model_returns_error( + model_id in arb_model_id(), + block_seconds in 1u64..=300, + ) { + let providers = vec![ + make_provider_with_hl_config( + &model_id, + None, + Some(make_hl_config(BlockScope::Model, ApplyTo::Global)), + ), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let lb_state = LatencyBlockStateManager::new(); + + // Block the model + lb_state.record_block(&model_id, block_seconds, 8000); + + let result = selector.select( + RetryStrategy::SameModel, + &model_id, + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &lb_state, + &ctx, + false, + true, + ); + + match result { + Err(_) => { + // Expected: same_model with LB block returns error (skip to alternative) + } + Ok(ProviderSelectionResult::Selected(_)) => { + prop_assert!(false, "LB-blocked model should not be Selected for same_model strategy"); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + prop_assert!(false, "LB block should return error, not WaitAndRetrySameModel (unlike RA)"); + } + } + } + + /// Property 19 – Case 3: When both RA and LB exist for the same identifier, + /// the candidate is skipped if either blocks. + #[test] + fn prop_both_ra_and_lb_either_blocks_skips( + all_providers in proptest::collection::vec(arb_model_id(), 2..=6) + .prop_map(|ids| { + ids.into_iter() + .map(|id| make_provider_with_hl_config( + &id, + None, + Some(make_hl_config(BlockScope::Model, ApplyTo::Global)), + )) + .collect::>() + }), + block_index in 0usize..6, + // Which state(s) to block: 0 = RA only, 1 = LB only, 2 = both + block_type in 0u8..3, + strategy in prop_oneof![ + Just(RetryStrategy::SameProvider), + Just(RetryStrategy::DifferentProvider), + ], + ) { + let primary_model = all_providers[0].model.as_deref().unwrap().to_string(); + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + let lb_state = LatencyBlockStateManager::new(); + + // Pick a model to block (clamped to valid index) + let target_index = block_index % all_providers.len(); + let target_model = all_providers[target_index].model.as_deref().unwrap().to_string(); + + match block_type { + 0 => { + // RA only + ra_state.record(&target_model, 600, 600); + } + 1 => { + // LB only + lb_state.record_block(&target_model, 600, 8000); + } + _ => { + // Both RA and LB + ra_state.record(&target_model, 600, 600); + lb_state.record_block(&target_model, 600, 8000); + } + } + + let result = selector.select( + strategy, + &primary_model, + &[], + &all_providers, + &attempted, + &ra_state, + &lb_state, + &ctx, + true, // has_retry_policy = true to enable RA checks + true, // has_high_latency_config = true to enable LB checks + ); + + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + let selected = p.model.as_deref().unwrap(); + prop_assert!( + selected != target_model, + "Blocked model '{}' was selected despite block_type={}! \ + (0=RA, 1=LB, 2=both)", + target_model, block_type + ); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + // Not expected for SameProvider/DifferentProvider. + } + Err(_) => { + // All eligible candidates blocked or exhausted — valid. + } + } + } + + /// Property 19 – Case 4: After LB expiration, model becomes selectable again. + /// We use a 0-second block which expires immediately. + #[test] + fn prop_lb_unblocked_after_expiration( + model_id in arb_model_id(), + strategy in prop_oneof![ + Just(RetryStrategy::SameModel), + Just(RetryStrategy::SameProvider), + Just(RetryStrategy::DifferentProvider), + ], + ) { + let providers = vec![ + make_provider_with_hl_config( + &model_id, + None, + Some(make_hl_config(BlockScope::Model, ApplyTo::Global)), + ), + ]; + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let lb_state = LatencyBlockStateManager::new(); + + // Record with 0 seconds — expires immediately + lb_state.record_block(&model_id, 0, 8000); + + // The model should NOT be blocked (expired immediately) + prop_assert!( + !lb_state.is_blocked(&model_id), + "Model should not be blocked after 0-second LB record" + ); + + let result = selector.select( + strategy, + &model_id, + &[], + &providers, + &attempted, + &RetryAfterStateManager::new(), + &lb_state, + &ctx, + false, + true, + ); + + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + prop_assert_eq!( + p.model.as_deref(), + Some(model_id.as_str()), + "Expected the unblocked model to be selected" + ); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + prop_assert!(false, "Expired LB should not trigger WaitAndRetrySameModel"); + } + Err(_) => { + match strategy { + RetryStrategy::SameModel | RetryStrategy::SameProvider => { + prop_assert!(false, "Unblocked model should be selectable for {:?}", strategy); + } + RetryStrategy::DifferentProvider => { + // Expected: single provider can't match "different provider" strategy. + } + } + } + } + } + } + + // Feature: retry-on-ratelimit, Property 9: Cooldown Applies to Initial Provider Selection (CP-3) + // **Validates: Requirements 13.1, 13.12, CP-3** + // + // For any new request (not a retry) targeting a model that has an active + // Retry_After_State entry with apply_to: "global", the ProviderSelector must + // skip that model during initial provider selection and route to an alternative + // model, without first attempting the blocked model. + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 9 – Case 1: Default model is globally RA-blocked → + /// new request with same_model strategy gets WaitAndRetrySameModel. + #[test] + fn prop_initial_selection_cooldown_same_model( + model_id in arb_model_id(), + block_seconds in 1u64..=300, + ) { + let providers = vec![ + make_provider_with_retry_policy(&model_id, Some(RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: ApplyTo::Global, + max_retry_after_seconds: 300, + })), + ]; + // Empty attempted set = brand new request (initial selection) + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block the default model globally + ra_state.record(&model_id, block_seconds, 300); + + let result = selector.select( + RetryStrategy::SameModel, + &model_id, + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + // For same_model with global RA block, must return WaitAndRetrySameModel + match result { + Ok(ProviderSelectionResult::WaitAndRetrySameModel { wait_duration }) => { + let capped = block_seconds.min(300); + prop_assert!( + !wait_duration.is_zero(), + "wait_duration should be positive for an active block" + ); + prop_assert!( + wait_duration.as_secs() <= capped, + "wait_duration {}s exceeds capped block {}s", + wait_duration.as_secs(), capped + ); + } + Ok(ProviderSelectionResult::Selected(_)) => { + prop_assert!(false, + "Globally RA-blocked model should NOT be Selected on initial request \ + with same_model strategy; expected WaitAndRetrySameModel" + ); + } + Err(_) => { + prop_assert!(false, + "same_model with globally blocked model should return \ + WaitAndRetrySameModel, not AllProvidersExhausted" + ); + } + } + } + + /// Property 9 – Case 2: Default model is globally RA-blocked → + /// new request with different_provider strategy skips it and picks alternative. + #[test] + fn prop_initial_selection_cooldown_different_provider( + _primary_prefix in arb_prefix(), + alt_prefix in arb_prefix().prop_filter("must differ from primary", + |p| p != "openai"), // we'll force primary to "openai" + block_seconds in 1u64..=300, + ) { + let primary_model = format!("openai/model-a"); + let alt_model = format!("{}/model-b", alt_prefix); + + // Ensure alt is actually a different provider + if extract_provider(&alt_model) == extract_provider(&primary_model) { + // Skip this case — proptest will generate others + return Ok(()); + } + + let providers = vec![ + make_provider_with_retry_policy(&primary_model, Some(RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: ApplyTo::Global, + max_retry_after_seconds: 300, + })), + make_provider_with_retry_policy(&alt_model, None), + ]; + // Empty attempted set = brand new request + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block the primary/default model globally + ra_state.record(&primary_model, block_seconds, 300); + + let result = selector.select( + RetryStrategy::DifferentProvider, + &primary_model, + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + let selected = p.model.as_deref().unwrap(); + // Must NOT be the blocked primary model + prop_assert_ne!( + selected, primary_model.as_str(), + "Blocked primary model was selected on initial request!" + ); + // Must be from a different provider (strategy constraint) + prop_assert_ne!( + extract_provider(selected), + extract_provider(&primary_model), + "DifferentProvider selected same provider prefix" + ); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + prop_assert!(false, + "DifferentProvider strategy should not return WaitAndRetrySameModel" + ); + } + Err(_) => { + // Only valid if the alt model also happens to be same provider + // (filtered out above) — should not happen. + prop_assert!(false, + "Should have selected alternative provider, not exhausted" + ); + } + } + } + + /// Property 9 – Case 3: Default model is globally RA-blocked → + /// new request with same_provider strategy skips it and picks same-provider alternative. + #[test] + fn prop_initial_selection_cooldown_same_provider( + prefix in arb_prefix(), + block_seconds in 1u64..=300, + ) { + let primary_model = format!("{}/model-a", prefix); + let alt_model = format!("{}/model-b", prefix); + + let providers = vec![ + make_provider_with_retry_policy(&primary_model, Some(RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: ApplyTo::Global, + max_retry_after_seconds: 300, + })), + make_provider_with_retry_policy(&alt_model, None), + ]; + // Empty attempted set = brand new request + let attempted = HashSet::new(); + let ctx = stub_context(); + let selector = ProviderSelector; + let ra_state = RetryAfterStateManager::new(); + + // Block the primary/default model globally (model-scope, not provider-scope) + ra_state.record(&primary_model, block_seconds, 300); + + let result = selector.select( + RetryStrategy::SameProvider, + &primary_model, + &[], + &providers, + &attempted, + &ra_state, + &LatencyBlockStateManager::new(), + &ctx, + true, + false, + ); + + match result { + Ok(ProviderSelectionResult::Selected(p)) => { + let selected = p.model.as_deref().unwrap(); + // Must NOT be the blocked primary model + prop_assert_ne!( + selected, primary_model.as_str(), + "Blocked primary model was selected on initial request!" + ); + // Must be from the same provider (strategy constraint) + prop_assert_eq!( + extract_provider(selected), + extract_provider(&primary_model), + "SameProvider selected different provider prefix" + ); + // Should be the alternative model + prop_assert_eq!( + selected, alt_model.as_str(), + "Expected the alternative same-provider model" + ); + } + Ok(ProviderSelectionResult::WaitAndRetrySameModel { .. }) => { + prop_assert!(false, + "SameProvider strategy should not return WaitAndRetrySameModel" + ); + } + Err(_) => { + prop_assert!(false, + "Should have selected same-provider alternative, not exhausted" + ); + } + } + } + } +} + diff --git a/crates/common/src/retry/retry_after_state.rs b/crates/common/src/retry/retry_after_state.rs index 5a2c43c1..c9e1a2c9 100644 --- a/crates/common/src/retry/retry_after_state.rs +++ b/crates/common/src/retry/retry_after_state.rs @@ -106,3 +106,410 @@ impl Default for RetryAfterStateManager { } } +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + use std::time::Duration; + + #[test] + fn test_new_manager_has_no_blocks() { + let mgr = RetryAfterStateManager::new(); + assert!(!mgr.is_blocked("openai/gpt-4o")); + assert!(mgr.remaining_block_duration("openai/gpt-4o").is_none()); + } + + #[test] + fn test_record_and_is_blocked() { + let mgr = RetryAfterStateManager::new(); + mgr.record("openai/gpt-4o", 60, 300); + assert!(mgr.is_blocked("openai/gpt-4o")); + assert!(!mgr.is_blocked("anthropic/claude")); + } + + #[test] + fn test_record_caps_at_max() { + let mgr = RetryAfterStateManager::new(); + // Retry-After of 600 seconds, but max is 300 + mgr.record("openai/gpt-4o", 600, 300); + let remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap(); + // Should be capped at ~300 seconds (allow some tolerance) + assert!(remaining <= Duration::from_secs(301)); + assert!(remaining > Duration::from_secs(298)); + } + + #[test] + fn test_remaining_block_duration() { + let mgr = RetryAfterStateManager::new(); + mgr.record("openai/gpt-4o", 10, 300); + let remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap(); + assert!(remaining <= Duration::from_secs(11)); + assert!(remaining > Duration::from_secs(8)); + } + + #[test] + fn test_expired_entry_cleaned_up_on_is_blocked() { + let mgr = RetryAfterStateManager::new(); + // Record with 0 seconds — effectively expires immediately + mgr.record("openai/gpt-4o", 0, 300); + // Sleep briefly to ensure expiration + thread::sleep(Duration::from_millis(10)); + assert!(!mgr.is_blocked("openai/gpt-4o")); + } + + #[test] + fn test_expired_entry_cleaned_up_on_remaining() { + let mgr = RetryAfterStateManager::new(); + mgr.record("openai/gpt-4o", 0, 300); + thread::sleep(Duration::from_millis(10)); + assert!(mgr.remaining_block_duration("openai/gpt-4o").is_none()); + } + + #[test] + fn test_max_expiration_semantics_longer_wins() { + let mgr = RetryAfterStateManager::new(); + mgr.record("openai/gpt-4o", 10, 300); + let first_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap(); + + // Record a longer duration — should update + mgr.record("openai/gpt-4o", 60, 300); + let second_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap(); + assert!(second_remaining > first_remaining); + } + + #[test] + fn test_max_expiration_semantics_shorter_does_not_overwrite() { + let mgr = RetryAfterStateManager::new(); + mgr.record("openai/gpt-4o", 60, 300); + let first_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap(); + + // Record a shorter duration — should NOT overwrite + mgr.record("openai/gpt-4o", 5, 300); + let second_remaining = mgr.remaining_block_duration("openai/gpt-4o").unwrap(); + // The remaining should still be close to the original 60s + assert!(second_remaining > Duration::from_secs(50)); + // Allow small timing variance + let diff = if first_remaining > second_remaining { + first_remaining - second_remaining + } else { + second_remaining - first_remaining + }; + assert!(diff < Duration::from_secs(2)); + } + + #[test] + fn test_is_model_blocked_model_scope() { + let mgr = RetryAfterStateManager::new(); + mgr.record("openai/gpt-4o", 60, 300); + + assert!(mgr.is_model_blocked("openai/gpt-4o", BlockScope::Model)); + assert!(!mgr.is_model_blocked("openai/gpt-4o-mini", BlockScope::Model)); + } + + #[test] + fn test_is_model_blocked_provider_scope() { + let mgr = RetryAfterStateManager::new(); + // Block at provider level by recording with provider prefix + mgr.record("openai", 60, 300); + + // Both openai models should be blocked + assert!(mgr.is_model_blocked("openai/gpt-4o", BlockScope::Provider)); + assert!(mgr.is_model_blocked("openai/gpt-4o-mini", BlockScope::Provider)); + // Anthropic should not be blocked + assert!(!mgr.is_model_blocked("anthropic/claude", BlockScope::Provider)); + } + + #[test] + fn test_model_scope_does_not_block_other_models() { + let mgr = RetryAfterStateManager::new(); + mgr.record("openai/gpt-4o", 60, 300); + + // Model scope: only exact match is blocked + assert!(mgr.is_model_blocked("openai/gpt-4o", BlockScope::Model)); + assert!(!mgr.is_model_blocked("openai/gpt-4o-mini", BlockScope::Model)); + } + + #[test] + fn test_multiple_identifiers_independent() { + let mgr = RetryAfterStateManager::new(); + mgr.record("openai/gpt-4o", 60, 300); + mgr.record("anthropic/claude", 30, 300); + + assert!(mgr.is_blocked("openai/gpt-4o")); + assert!(mgr.is_blocked("anthropic/claude")); + assert!(!mgr.is_blocked("azure/gpt-4o")); + } + + #[test] + fn test_record_with_zero_seconds() { + let mgr = RetryAfterStateManager::new(); + mgr.record("openai/gpt-4o", 0, 300); + // With 0 seconds, the entry expires at Instant::now() + 0, + // which is effectively immediately + thread::sleep(Duration::from_millis(5)); + assert!(!mgr.is_blocked("openai/gpt-4o")); + } + + #[test] + fn test_max_retry_after_seconds_zero_caps_to_zero() { + let mgr = RetryAfterStateManager::new(); + // Even with retry_after_seconds=60, max=0 caps to 0 + mgr.record("openai/gpt-4o", 60, 0); + thread::sleep(Duration::from_millis(5)); + assert!(!mgr.is_blocked("openai/gpt-4o")); + } + + #[test] + fn test_default_trait() { + let mgr = RetryAfterStateManager::default(); + assert!(!mgr.is_blocked("anything")); + } + + // --- Proptest strategies --- + + use proptest::prelude::*; + + fn arb_provider_prefix() -> impl Strategy { + prop_oneof![ + Just("openai".to_string()), + Just("anthropic".to_string()), + Just("azure".to_string()), + Just("google".to_string()), + Just("cohere".to_string()), + ] + } + + fn arb_model_suffix() -> impl Strategy { + prop_oneof![ + Just("gpt-4o".to_string()), + Just("gpt-4o-mini".to_string()), + Just("claude-3".to_string()), + Just("gemini-pro".to_string()), + ] + } + + fn arb_model_id() -> impl Strategy { + (arb_provider_prefix(), arb_model_suffix()) + .prop_map(|(prefix, suffix)| format!("{}/{}", prefix, suffix)) + } + + fn arb_scope() -> impl Strategy { + prop_oneof![Just(BlockScope::Model), Just(BlockScope::Provider),] + } + + // Feature: retry-on-ratelimit, Property 15: Retry_After_State Scope Behavior + // **Validates: Requirements 11.5, 11.6, 11.7, 11.8, 12.9, 12.10, 13.10, 13.11** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 15 – Case 1: Model scope blocks only the exact model_id. + #[test] + fn prop_model_scope_blocks_exact_model_only( + model_id in arb_model_id(), + other_model_id in arb_model_id(), + retry_after in 1u64..300, + ) { + prop_assume!(model_id != other_model_id); + + let mgr = RetryAfterStateManager::new(); + // Record with the exact model_id (model scope records the full model ID) + mgr.record(&model_id, retry_after, 300); + + // The exact model should be blocked + prop_assert!( + mgr.is_model_blocked(&model_id, BlockScope::Model), + "Model {} should be blocked with Model scope after recording", + model_id + ); + + // A different model should NOT be blocked (even if same provider) + prop_assert!( + !mgr.is_model_blocked(&other_model_id, BlockScope::Model), + "Model {} should NOT be blocked when {} was recorded with Model scope", + other_model_id, model_id + ); + } + + /// Property 15 – Case 2: Provider scope blocks all models from the same provider. + #[test] + fn prop_provider_scope_blocks_all_same_provider_models( + provider in arb_provider_prefix(), + suffix1 in arb_model_suffix(), + suffix2 in arb_model_suffix(), + other_provider in arb_provider_prefix(), + other_suffix in arb_model_suffix(), + retry_after in 1u64..300, + ) { + let model1 = format!("{}/{}", provider, suffix1); + let model2 = format!("{}/{}", provider, suffix2); + let other_model = format!("{}/{}", other_provider, other_suffix); + prop_assume!(provider != other_provider); + + let mgr = RetryAfterStateManager::new(); + // Record at provider level (provider scope records the provider prefix) + mgr.record(&provider, retry_after, 300); + + // Both models from the same provider should be blocked + prop_assert!( + mgr.is_model_blocked(&model1, BlockScope::Provider), + "Model {} should be blocked with Provider scope after recording provider {}", + model1, provider + ); + prop_assert!( + mgr.is_model_blocked(&model2, BlockScope::Provider), + "Model {} should be blocked with Provider scope after recording provider {}", + model2, provider + ); + + // Model from a different provider should NOT be blocked + prop_assert!( + !mgr.is_model_blocked(&other_model, BlockScope::Provider), + "Model {} should NOT be blocked when provider {} was recorded", + other_model, provider + ); + } + + /// Property 15 – Case 3: Global state is visible across different "requests" + /// (same manager instance is shared). + #[test] + fn prop_global_state_shared_across_requests( + model_id in arb_model_id(), + scope in arb_scope(), + retry_after in 1u64..300, + ) { + let mgr = RetryAfterStateManager::new(); + + // Determine the identifier to record based on scope + let identifier = match scope { + BlockScope::Model => model_id.clone(), + BlockScope::Provider => extract_provider(&model_id).to_string(), + }; + mgr.record(&identifier, retry_after, 300); + + // Simulate "different requests" by checking from the same manager instance. + // Global state means any check against the same manager sees the block. + // Check 1 (simulating request A) + let blocked_a = mgr.is_model_blocked(&model_id, scope); + // Check 2 (simulating request B) + let blocked_b = mgr.is_model_blocked(&model_id, scope); + + prop_assert!( + blocked_a && blocked_b, + "Global state should be visible to all requests: request_a={}, request_b={}", + blocked_a, blocked_b + ); + } + + /// Property 15 – Case 4: Request-scoped state (HashMap) is isolated per request. + /// Two separate HashMaps don't share state. + #[test] + fn prop_request_scoped_state_isolated( + model_id in arb_model_id(), + retry_after in 1u64..300, + ) { + use std::collections::HashMap; + use std::time::Instant; + + // Simulate request-scoped state using separate HashMaps + // (as RequestContext.request_retry_after_state would be) + let mut request_a_state: HashMap = HashMap::new(); + let mut request_b_state: HashMap = HashMap::new(); + + // Request A records a Retry-After entry + let expiration = Instant::now() + Duration::from_secs(retry_after); + request_a_state.insert(model_id.clone(), expiration); + + // Request A should see the block + let a_blocked = request_a_state + .get(&model_id) + .map_or(false, |exp| Instant::now() < *exp); + + // Request B should NOT see the block (separate HashMap) + let b_blocked = request_b_state + .get(&model_id) + .map_or(false, |exp| Instant::now() < *exp); + + prop_assert!( + a_blocked, + "Request A should see its own block for {}", + model_id + ); + prop_assert!( + !b_blocked, + "Request B should NOT see Request A's block for {}", + model_id + ); + + // Recording in request B should not affect request A + let expiration_b = Instant::now() + Duration::from_secs(retry_after); + request_b_state.insert(model_id.clone(), expiration_b); + + // Both should now be blocked independently + let a_still_blocked = request_a_state + .get(&model_id) + .map_or(false, |exp| Instant::now() < *exp); + let b_now_blocked = request_b_state + .get(&model_id) + .map_or(false, |exp| Instant::now() < *exp); + + prop_assert!(a_still_blocked, "Request A should still be blocked"); + prop_assert!(b_now_blocked, "Request B should now be blocked independently"); + } + } + + // Feature: retry-on-ratelimit, Property 16: Retry_After_State Max Expiration Update + // **Validates: Requirements 12.11** + proptest! { + #![proptest_config(ProptestConfig::with_cases(100))] + + /// Property 16: Recording multiple Retry-After values for the same identifier + /// should result in the expiration reflecting the maximum value, not the most recent. + #[test] + fn prop_max_expiration_update( + identifier in arb_model_id(), + // Generate 2..=10 Retry-After values, each between 1 and 600 seconds + retry_after_values in prop::collection::vec(1u64..=600, 2..=10), + max_cap in 300u64..=600, + ) { + let mgr = RetryAfterStateManager::new(); + + // Record all values for the same identifier + for &val in &retry_after_values { + mgr.record(&identifier, val, max_cap); + } + + // The effective maximum is the max of all capped values + let effective_max = retry_after_values + .iter() + .map(|&v| v.min(max_cap)) + .max() + .unwrap(); + + // The remaining block duration should be close to the effective maximum + let remaining = mgr.remaining_block_duration(&identifier); + prop_assert!( + remaining.is_some(), + "Identifier {} should still be blocked after recording {} values (effective_max={}s)", + identifier, retry_after_values.len(), effective_max + ); + + let remaining_secs = remaining.unwrap().as_secs(); + + // The remaining duration should be within a reasonable tolerance of the + // effective maximum (allow up to 2 seconds for test execution time). + // It must be at least (effective_max - 2) to prove the max won. + prop_assert!( + remaining_secs >= effective_max.saturating_sub(2), + "Remaining {}s should reflect the max ({}s), not a smaller value. Values: {:?}", + remaining_secs, effective_max, retry_after_values + ); + + // It should not exceed the effective max (plus small tolerance for timing) + prop_assert!( + remaining_secs <= effective_max + 1, + "Remaining {}s should not exceed effective max {}s + tolerance. Values: {:?}", + remaining_secs, effective_max, retry_after_values + ); + } + } +} diff --git a/crates/common/src/retry/validation.rs b/crates/common/src/retry/validation.rs index e8bbf6a1..1d5678e9 100644 --- a/crates/common/src/retry/validation.rs +++ b/crates/common/src/retry/validation.rs @@ -311,3 +311,794 @@ impl ConfigValidator { } } +#[cfg(test)] +mod tests { + use super::*; + use crate::configuration::{ + ApplyTo, BackoffConfig, BackoffApplyTo, BlockScope, HighLatencyConfig, + LatencyMeasure, RetryAfterHandlingConfig, + RetryPolicy, RetryStrategy, StatusCodeConfig, StatusCodeEntry, + TimeoutRetryConfig, + }; + use proptest::prelude::*; + + fn make_provider(model: &str, policy: Option) -> LlmProvider { + LlmProvider { + model: Some(model.to_string()), + retry_policy: policy, + ..LlmProvider::default() + } + } + + fn basic_policy() -> RetryPolicy { + RetryPolicy { + fallback_models: vec![], + default_strategy: RetryStrategy::DifferentProvider, + default_max_attempts: 2, + on_status_codes: vec![], + on_timeout: None, + on_high_latency: None, + backoff: None, + retry_after_handling: None, + max_retry_duration_ms: None, + } + } + + #[test] + fn test_valid_basic_policy_no_errors() { + let providers = vec![ + make_provider("openai/gpt-4o", Some(basic_policy())), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + } + + #[test] + fn test_no_retry_policy_skipped() { + let providers = vec![make_provider("openai/gpt-4o", None)]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + assert!(result.unwrap().is_empty()); + } + + #[test] + fn test_status_code_out_of_range() { + let mut policy = basic_policy(); + policy.on_status_codes = vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(600)], + strategy: RetryStrategy::SameModel, + max_attempts: 2, + }]; + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!(e, ValidationError::StatusCodeOutOfRange { code: 600, .. }))); + } + + #[test] + fn test_status_code_range_inverted() { + let mut policy = basic_policy(); + policy.on_status_codes = vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Range("504-502".to_string())], + strategy: RetryStrategy::SameModel, + max_attempts: 2, + }]; + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!(e, ValidationError::StatusCodeRangeInverted { .. }))); + } + + #[test] + fn test_backoff_max_ms_not_greater_than_base_ms() { + let mut policy = basic_policy(); + policy.backoff = Some(BackoffConfig { + apply_to: BackoffApplyTo::SameModel, + base_ms: 5000, + max_ms: 5000, + jitter: true, + }); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!(e, ValidationError::MaxMsNotGreaterThanBaseMs { .. }))); + } + + #[test] + fn test_backoff_zero_base_ms() { + let mut policy = basic_policy(); + policy.backoff = Some(BackoffConfig { + apply_to: BackoffApplyTo::SameModel, + base_ms: 0, + max_ms: 5000, + jitter: true, + }); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!(e, ValidationError::NonPositiveValue { field, .. } if field == "backoff.base_ms"))); + } + + #[test] + fn test_max_retry_duration_ms_zero() { + let mut policy = basic_policy(); + policy.max_retry_duration_ms = Some(0); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!(e, ValidationError::NonPositiveValue { field, .. } if field == "max_retry_duration_ms"))); + } + + #[test] + fn test_single_provider_failover_warning() { + let policy = basic_policy(); // default_strategy is DifferentProvider + let providers = vec![make_provider("openai/gpt-4o", Some(policy))]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + let warnings = result.unwrap(); + assert!(warnings.iter().any(|w| matches!(w, ValidationWarning::SingleProviderWithFailover { .. }))); + } + + #[test] + fn test_overlapping_status_codes_warning() { + let mut policy = basic_policy(); + policy.on_status_codes = vec![ + StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::SameModel, + max_attempts: 2, + }, + StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429)], + strategy: RetryStrategy::DifferentProvider, + max_attempts: 3, + }, + ]; + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + let warnings = result.unwrap(); + assert!(warnings.iter().any(|w| matches!(w, ValidationWarning::OverlappingStatusCodes { code: 429, .. }))); + } + + #[test] + fn test_backoff_apply_to_mismatch_warning() { + let mut policy = basic_policy(); + policy.default_strategy = RetryStrategy::DifferentProvider; + policy.backoff = Some(BackoffConfig { + apply_to: BackoffApplyTo::SameModel, + base_ms: 100, + max_ms: 5000, + jitter: true, + }); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + let warnings = result.unwrap(); + assert!(warnings.iter().any(|w| matches!(w, ValidationWarning::BackoffApplyToMismatch { .. }))); + } + + #[test] + fn test_fallback_model_not_in_provider_list_warning() { + let mut policy = basic_policy(); + policy.fallback_models = vec!["nonexistent/model".to_string()]; + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + let warnings = result.unwrap(); + assert!(warnings.iter().any(|w| matches!(w, ValidationWarning::FallbackModelNotInProviderList { fallback, .. } if fallback == "nonexistent/model"))); + } + + #[test] + fn test_expand_status_codes_mixed() { + let codes = vec![ + StatusCodeEntry::Single(429), + StatusCodeEntry::Range("502-504".to_string()), + StatusCodeEntry::Single(526), + ]; + let result = ConfigValidator::expand_status_codes(&codes); + assert!(result.is_ok()); + let expanded = result.unwrap(); + assert_eq!(expanded, vec![429, 502, 503, 504, 526]); + } + + #[test] + fn test_valid_range_expansion() { + let codes = vec![StatusCodeEntry::Range("500-503".to_string())]; + let result = ConfigValidator::expand_status_codes(&codes); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), vec![500, 501, 502, 503]); + } + + #[test] + fn test_valid_policy_with_backoff_and_status_codes() { + let mut policy = basic_policy(); + policy.default_strategy = RetryStrategy::SameModel; + policy.on_status_codes = vec![ + StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(429), StatusCodeEntry::Range("502-504".to_string())], + strategy: RetryStrategy::SameModel, + max_attempts: 3, + }, + ]; + policy.backoff = Some(BackoffConfig { + apply_to: BackoffApplyTo::SameModel, + base_ms: 100, + max_ms: 5000, + jitter: true, + }); + policy.max_retry_duration_ms = Some(30000); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + assert!(result.unwrap().is_empty()); + } + + // ── P1 Validation Tests ─────────────────────────────────────────────── + + #[test] + fn test_on_timeout_zero_max_attempts_rejected() { + let mut policy = basic_policy(); + policy.on_timeout = Some(TimeoutRetryConfig { + strategy: RetryStrategy::DifferentProvider, + max_attempts: 0, + }); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!( + e, + ValidationError::NonPositiveValue { field, .. } if field == "on_timeout.max_attempts" + ))); + } + + #[test] + fn test_on_timeout_valid_max_attempts_accepted() { + let mut policy = basic_policy(); + policy.on_timeout = Some(TimeoutRetryConfig { + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + }); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + } + + #[test] + fn test_retry_after_handling_zero_max_seconds_rejected() { + let mut policy = basic_policy(); + policy.retry_after_handling = Some(RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: crate::configuration::ApplyTo::Global, + max_retry_after_seconds: 0, + }); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!( + e, + ValidationError::NonPositiveValue { field, .. } + if field == "retry_after_handling.max_retry_after_seconds" + ))); + } + + #[test] + fn test_retry_after_handling_valid_max_seconds_accepted() { + let mut policy = basic_policy(); + policy.retry_after_handling = Some(RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: crate::configuration::ApplyTo::Global, + max_retry_after_seconds: 300, + }); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + } + + #[test] + fn test_fallback_model_empty_string_rejected() { + let mut policy = basic_policy(); + policy.fallback_models = vec!["".to_string()]; + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!( + e, + ValidationError::InvalidFallbackModel { fallback, .. } if fallback.is_empty() + ))); + } + + #[test] + fn test_fallback_model_no_slash_rejected() { + let mut policy = basic_policy(); + policy.fallback_models = vec!["just-a-model-name".to_string()]; + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!( + e, + ValidationError::InvalidFallbackModel { fallback, .. } if fallback == "just-a-model-name" + ))); + } + + #[test] + fn test_fallback_model_valid_format_accepted() { + let mut policy = basic_policy(); + policy.fallback_models = vec!["anthropic/claude-3".to_string()]; + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + } + + #[test] + fn test_provider_scope_ra_with_same_model_strategy_warning() { + let mut policy = basic_policy(); + policy.default_strategy = RetryStrategy::SameModel; + policy.retry_after_handling = Some(RetryAfterHandlingConfig { + scope: BlockScope::Provider, + apply_to: crate::configuration::ApplyTo::Global, + max_retry_after_seconds: 300, + }); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + let warnings = result.unwrap(); + assert!(warnings.iter().any(|w| matches!( + w, + ValidationWarning::ProviderScopeWithSameModel { .. } + ))); + } + + #[test] + fn test_model_scope_ra_with_same_model_no_warning() { + let mut policy = basic_policy(); + policy.default_strategy = RetryStrategy::SameModel; + policy.retry_after_handling = Some(RetryAfterHandlingConfig { + scope: BlockScope::Model, + apply_to: crate::configuration::ApplyTo::Global, + max_retry_after_seconds: 300, + }); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + let warnings = result.unwrap(); + assert!(!warnings.iter().any(|w| matches!( + w, + ValidationWarning::ProviderScopeWithSameModel { .. } + ))); + } + + // ── P2 Validation Tests ─────────────────────────────────────────────── + + fn hl_config_valid() -> HighLatencyConfig { + HighLatencyConfig { + threshold_ms: 5000, + measure: LatencyMeasure::Ttfb, + min_triggers: 1, + trigger_window_seconds: None, + strategy: RetryStrategy::DifferentProvider, + max_attempts: 2, + block_duration_seconds: 300, + scope: BlockScope::Model, + apply_to: ApplyTo::Global, + } + } + + #[test] + fn test_on_high_latency_valid_config_accepted() { + let mut policy = basic_policy(); + policy.on_high_latency = Some(hl_config_valid()); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + } + + #[test] + fn test_on_high_latency_zero_threshold_ms_rejected() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.threshold_ms = 0; + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!( + e, + ValidationError::NonPositiveValue { field, .. } + if field == "on_high_latency.threshold_ms" + ))); + } + + #[test] + fn test_on_high_latency_zero_max_attempts_rejected() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.max_attempts = 0; + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!( + e, + ValidationError::NonPositiveValue { field, .. } + if field == "on_high_latency.max_attempts" + ))); + } + + #[test] + fn test_on_high_latency_zero_block_duration_rejected() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.block_duration_seconds = 0; + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!( + e, + ValidationError::NonPositiveValue { field, .. } + if field == "on_high_latency.block_duration_seconds" + ))); + } + + #[test] + fn test_on_high_latency_min_triggers_gt1_without_window_rejected() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.min_triggers = 3; + hl.trigger_window_seconds = None; + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!( + e, + ValidationError::LatencyMissingTriggerWindow { .. } + ))); + } + + #[test] + fn test_on_high_latency_min_triggers_gt1_with_window_accepted() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.min_triggers = 3; + hl.trigger_window_seconds = Some(60); + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + } + + #[test] + fn test_on_high_latency_zero_trigger_window_rejected() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.trigger_window_seconds = Some(0); + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_err()); + let errs = result.unwrap_err(); + assert!(errs.iter().any(|e| matches!( + e, + ValidationError::NonPositiveTriggerWindow { .. } + ))); + } + + #[test] + fn test_on_high_latency_provider_scope_same_model_warning() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.scope = BlockScope::Provider; + hl.strategy = RetryStrategy::SameModel; + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + let warnings = result.unwrap(); + assert!(warnings.iter().any(|w| matches!( + w, + ValidationWarning::LatencyScopeStrategyMismatch { .. } + ))); + } + + #[test] + fn test_on_high_latency_model_scope_same_model_no_warning() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.scope = BlockScope::Model; + hl.strategy = RetryStrategy::SameModel; + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + let warnings = result.unwrap(); + assert!(!warnings.iter().any(|w| matches!( + w, + ValidationWarning::LatencyScopeStrategyMismatch { .. } + ))); + } + + #[test] + fn test_on_high_latency_threshold_below_1000_warning() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.threshold_ms = 500; + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + let warnings = result.unwrap(); + assert!(warnings.iter().any(|w| matches!( + w, + ValidationWarning::AggressiveLatencyThreshold { threshold_ms: 500, .. } + ))); + } + + #[test] + fn test_on_high_latency_threshold_1000_no_warning() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.threshold_ms = 1000; + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + let warnings = result.unwrap(); + assert!(!warnings.iter().any(|w| matches!( + w, + ValidationWarning::AggressiveLatencyThreshold { .. } + ))); + } + + #[test] + fn test_on_high_latency_total_measure_accepted() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.measure = LatencyMeasure::Total; + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + } + + #[test] + fn test_on_high_latency_request_apply_to_accepted() { + let mut policy = basic_policy(); + let mut hl = hl_config_valid(); + hl.apply_to = ApplyTo::Request; + policy.on_high_latency = Some(hl); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + assert!(result.is_ok()); + } + + // ── Strategies for invalid config generation ─────────────────────────── + + /// Generates a status code outside the valid 100-599 range. + fn arb_out_of_range_code() -> impl Strategy { + prop_oneof![ + (0u16..100u16), // below 100 + (600u16..=u16::MAX), // above 599 + ] + } + + /// Generates a range string where start > end (both within valid range). + fn arb_inverted_range() -> impl Strategy { + (101u16..=599u16).prop_flat_map(|start| { + (100u16..start).prop_map(move |end| format!("{}-{}", start, end)) + }) + } + + /// Generates a backoff config where max_ms <= base_ms. + fn arb_backoff_max_lte_base() -> impl Strategy { + (1u64..=10000u64).prop_flat_map(|base_ms| { + (0u64..=base_ms).prop_map(move |max_ms| BackoffConfig { + apply_to: BackoffApplyTo::Global, + base_ms, + max_ms, + jitter: true, + }) + }) + } + + /// Generates a backoff config where base_ms = 0. + fn arb_backoff_zero_base() -> impl Strategy { + (1u64..=10000u64).prop_map(|max_ms| BackoffConfig { + apply_to: BackoffApplyTo::Global, + base_ms: 0, + max_ms, + jitter: true, + }) + } + + // Feature: retry-on-ratelimit, Property 3: Invalid Configuration Rejected + // **Validates: Requirements 8.27** + proptest! { + #![proptest_config(proptest::prelude::ProptestConfig::with_cases(100))] + + /// Property 3 – Case 1: Status codes outside 100-599 are rejected. + #[test] + fn prop_invalid_status_code_out_of_range(code in arb_out_of_range_code()) { + let mut policy = basic_policy(); + policy.on_status_codes = vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Single(code)], + strategy: RetryStrategy::SameModel, + max_attempts: 2, + }]; + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + prop_assert!(result.is_err(), "Expected Err for out-of-range code {}", code); + } + + /// Property 3 – Case 2: Range strings with start > end are rejected. + #[test] + fn prop_invalid_range_start_gt_end(range in arb_inverted_range()) { + let mut policy = basic_policy(); + policy.on_status_codes = vec![StatusCodeConfig { + codes: vec![StatusCodeEntry::Range(range.clone())], + strategy: RetryStrategy::SameModel, + max_attempts: 2, + }]; + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + prop_assert!(result.is_err(), "Expected Err for inverted range {}", range); + } + + /// Property 3 – Case 3: Backoff with max_ms <= base_ms is rejected. + #[test] + fn prop_invalid_backoff_max_lte_base(backoff in arb_backoff_max_lte_base()) { + let mut policy = basic_policy(); + policy.backoff = Some(backoff.clone()); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + prop_assert!( + result.is_err(), + "Expected Err for max_ms ({}) <= base_ms ({})", + backoff.max_ms, backoff.base_ms + ); + } + + /// Property 3 – Case 4: Backoff with base_ms = 0 is rejected. + #[test] + fn prop_invalid_backoff_zero_base(backoff in arb_backoff_zero_base()) { + let mut policy = basic_policy(); + policy.backoff = Some(backoff); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + prop_assert!(result.is_err(), "Expected Err for base_ms = 0"); + } + + /// Property 3 – Case 5: max_retry_duration_ms = 0 is rejected. + #[test] + fn prop_invalid_max_retry_duration_zero(_dummy in Just(())) { + let mut policy = basic_policy(); + policy.max_retry_duration_ms = Some(0); + let providers = vec![ + make_provider("openai/gpt-4o", Some(policy)), + make_provider("anthropic/claude-3", None), + ]; + let result = ConfigValidator::validate_retry_policies(&providers); + prop_assert!(result.is_err(), "Expected Err for max_retry_duration_ms = 0"); + } + } +} diff --git a/tests/e2e/configs/retry_it10_timeout_triggers_retry.yaml b/tests/e2e/configs/retry_it10_timeout_triggers_retry.yaml new file mode 100644 index 00000000..22a340d1 --- /dev/null +++ b/tests/e2e/configs/retry_it10_timeout_triggers_retry.yaml @@ -0,0 +1,27 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + fallback_models: [anthropic/claude-3-5-sonnet] + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + on_timeout: + strategy: "different_provider" + max_attempts: 2 + + - model: anthropic/claude-3-5-sonnet + base_url: http://host.docker.internal:${MOCK_SECONDARY_PORT} + access_key: test-key-secondary diff --git a/tests/e2e/configs/retry_it11_high_latency_failover.yaml b/tests/e2e/configs/retry_it11_high_latency_failover.yaml new file mode 100644 index 00000000..1dc8a7e2 --- /dev/null +++ b/tests/e2e/configs/retry_it11_high_latency_failover.yaml @@ -0,0 +1,33 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + fallback_models: [anthropic/claude-3-5-sonnet] + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + on_high_latency: + threshold_ms: 1000 + measure: "total" + min_triggers: 1 + strategy: "different_provider" + max_attempts: 2 + block_duration_seconds: 60 + scope: "model" + apply_to: "global" + + - model: anthropic/claude-3-5-sonnet + base_url: http://host.docker.internal:${MOCK_SECONDARY_PORT} + access_key: test-key-secondary diff --git a/tests/e2e/configs/retry_it12_streaming.yaml b/tests/e2e/configs/retry_it12_streaming.yaml new file mode 100644 index 00000000..f1933fa0 --- /dev/null +++ b/tests/e2e/configs/retry_it12_streaming.yaml @@ -0,0 +1,23 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + + - model: anthropic/claude-3-5-sonnet + base_url: http://host.docker.internal:${MOCK_SECONDARY_PORT} + access_key: test-key-secondary diff --git a/tests/e2e/configs/retry_it13_body_preserved.yaml b/tests/e2e/configs/retry_it13_body_preserved.yaml new file mode 100644 index 00000000..f1933fa0 --- /dev/null +++ b/tests/e2e/configs/retry_it13_body_preserved.yaml @@ -0,0 +1,23 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + + - model: anthropic/claude-3-5-sonnet + base_url: http://host.docker.internal:${MOCK_SECONDARY_PORT} + access_key: test-key-secondary diff --git a/tests/e2e/configs/retry_it1_basic_429.yaml b/tests/e2e/configs/retry_it1_basic_429.yaml new file mode 100644 index 00000000..f1933fa0 --- /dev/null +++ b/tests/e2e/configs/retry_it1_basic_429.yaml @@ -0,0 +1,23 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + + - model: anthropic/claude-3-5-sonnet + base_url: http://host.docker.internal:${MOCK_SECONDARY_PORT} + access_key: test-key-secondary diff --git a/tests/e2e/configs/retry_it2_503_different_provider.yaml b/tests/e2e/configs/retry_it2_503_different_provider.yaml new file mode 100644 index 00000000..38fe2edb --- /dev/null +++ b/tests/e2e/configs/retry_it2_503_different_provider.yaml @@ -0,0 +1,23 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [503] + strategy: "different_provider" + max_attempts: 2 + + - model: anthropic/claude-3-5-sonnet + base_url: http://host.docker.internal:${MOCK_SECONDARY_PORT} + access_key: test-key-secondary diff --git a/tests/e2e/configs/retry_it3_all_exhausted.yaml b/tests/e2e/configs/retry_it3_all_exhausted.yaml new file mode 100644 index 00000000..f1933fa0 --- /dev/null +++ b/tests/e2e/configs/retry_it3_all_exhausted.yaml @@ -0,0 +1,23 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + + - model: anthropic/claude-3-5-sonnet + base_url: http://host.docker.internal:${MOCK_SECONDARY_PORT} + access_key: test-key-secondary diff --git a/tests/e2e/configs/retry_it4_no_retry_policy.yaml b/tests/e2e/configs/retry_it4_no_retry_policy.yaml new file mode 100644 index 00000000..26bf31a6 --- /dev/null +++ b/tests/e2e/configs/retry_it4_no_retry_policy.yaml @@ -0,0 +1,17 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + # No retry_policy — errors should be returned directly to client + + - model: anthropic/claude-3-5-sonnet + base_url: http://host.docker.internal:${MOCK_SECONDARY_PORT} + access_key: test-key-secondary diff --git a/tests/e2e/configs/retry_it5_max_attempts.yaml b/tests/e2e/configs/retry_it5_max_attempts.yaml new file mode 100644 index 00000000..f1cfa815 --- /dev/null +++ b/tests/e2e/configs/retry_it5_max_attempts.yaml @@ -0,0 +1,27 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + default_strategy: "different_provider" + default_max_attempts: 1 + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 1 + + - model: anthropic/claude-3-5-sonnet + base_url: http://host.docker.internal:${MOCK_SECONDARY_PORT} + access_key: test-key-secondary + + - model: mistral/mistral-large + base_url: http://host.docker.internal:${MOCK_TERTIARY_PORT} + access_key: test-key-tertiary diff --git a/tests/e2e/configs/retry_it6_backoff_delay.yaml b/tests/e2e/configs/retry_it6_backoff_delay.yaml new file mode 100644 index 00000000..e7ec474c --- /dev/null +++ b/tests/e2e/configs/retry_it6_backoff_delay.yaml @@ -0,0 +1,24 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + default_strategy: "same_model" + default_max_attempts: 3 + on_status_codes: + - codes: [429] + strategy: "same_model" + max_attempts: 3 + backoff: + apply_to: "same_model" + base_ms: 500 + max_ms: 5000 + jitter: false diff --git a/tests/e2e/configs/retry_it7_fallback_priority.yaml b/tests/e2e/configs/retry_it7_fallback_priority.yaml new file mode 100644 index 00000000..e5bee0c5 --- /dev/null +++ b/tests/e2e/configs/retry_it7_fallback_priority.yaml @@ -0,0 +1,28 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + fallback_models: [anthropic/claude-3-5-sonnet, mistral/mistral-large] + default_strategy: "different_provider" + default_max_attempts: 3 + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 3 + + - model: anthropic/claude-3-5-sonnet + base_url: http://host.docker.internal:${MOCK_FALLBACK1_PORT} + access_key: test-key-fallback1 + + - model: mistral/mistral-large + base_url: http://host.docker.internal:${MOCK_FALLBACK2_PORT} + access_key: test-key-fallback2 diff --git a/tests/e2e/configs/retry_it8_retry_after_honored.yaml b/tests/e2e/configs/retry_it8_retry_after_honored.yaml new file mode 100644 index 00000000..3088759d --- /dev/null +++ b/tests/e2e/configs/retry_it8_retry_after_honored.yaml @@ -0,0 +1,23 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + default_strategy: "same_model" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "same_model" + max_attempts: 2 + retry_after_handling: + scope: "model" + apply_to: "request" + max_retry_after_seconds: 300 diff --git a/tests/e2e/configs/retry_it9_retry_after_blocks_selection.yaml b/tests/e2e/configs/retry_it9_retry_after_blocks_selection.yaml new file mode 100644 index 00000000..ef3d7ad7 --- /dev/null +++ b/tests/e2e/configs/retry_it9_retry_after_blocks_selection.yaml @@ -0,0 +1,36 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + - model: openai/gpt-4o + base_url: http://host.docker.internal:${MOCK_PRIMARY_PORT} + access_key: test-key-primary + default: true + retry_policy: + fallback_models: [anthropic/claude-3-5-sonnet] + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 + retry_after_handling: + scope: "model" + apply_to: "global" + max_retry_after_seconds: 300 + + - model: anthropic/claude-3-5-sonnet + base_url: http://host.docker.internal:${MOCK_SECONDARY_PORT} + access_key: test-key-secondary + default: false + retry_policy: + default_strategy: "different_provider" + default_max_attempts: 2 + on_status_codes: + - codes: [429] + strategy: "different_provider" + max_attempts: 2 diff --git a/tests/e2e/test_retry_integration.py b/tests/e2e/test_retry_integration.py new file mode 100644 index 00000000..a93ffb16 --- /dev/null +++ b/tests/e2e/test_retry_integration.py @@ -0,0 +1,1435 @@ +""" +Integration tests for retry-on-ratelimit feature (P0). + +Tests IT-1 through IT-6, IT-12, IT-13 validate end-to-end retry behavior +through the real Plano gateway using Python mock HTTP servers as upstream providers. + +Each test: + 1. Starts mock upstream servers on ephemeral ports + 2. Writes a YAML config pointing the gateway at those mock ports + 3. Starts the gateway via `planoai up` + 4. Sends requests and asserts on response status/body/timing + 5. Tears down the gateway via `planoai down` +""" + +import json +import logging +import os +import subprocess +import sys +import tempfile +import threading +import time +from http.server import HTTPServer, BaseHTTPRequestHandler +from typing import Optional + +import pytest +import requests + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) + +GATEWAY_BASE_URL = "http://localhost:12000" +GATEWAY_CHAT_URL = f"{GATEWAY_BASE_URL}/v1/chat/completions" +CONFIGS_DIR = os.path.join(os.path.dirname(__file__), "configs") + +# Standard OpenAI-compatible success response body +SUCCESS_RESPONSE = json.dumps({ + "id": "chatcmpl-test-001", + "object": "chat.completion", + "created": 1700000000, + "model": "mock-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello from mock provider!", + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, +}) + +# Standard chat request body +CHAT_REQUEST_BODY = { + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], +} + + +# --------------------------------------------------------------------------- +# Mock upstream server infrastructure +# --------------------------------------------------------------------------- + +class MockUpstreamHandler(BaseHTTPRequestHandler): + """ + Configurable mock HTTP handler that returns responses from a per-server queue. + + Each server instance has a response_queue (list of tuples): + (status_code, headers_dict, body_string) + + Responses are consumed in order. When the queue is exhausted, the last + response is repeated. The handler also records all received requests for + later assertion. + """ + + # These are set per-server-instance via the factory function below. + response_queue: list = [] + received_requests: list = [] + call_count: int = 0 + lock: threading.Lock = threading.Lock() + + def do_POST(self): + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) if content_length > 0 else b"" + + with self.__class__.lock: + self.__class__.call_count += 1 + self.__class__.received_requests.append({ + "path": self.path, + "headers": dict(self.headers), + "body": body.decode("utf-8", errors="replace"), + }) + idx = min( + self.__class__.call_count - 1, + len(self.__class__.response_queue) - 1, + ) + status_code, headers, response_body = self.__class__.response_queue[idx] + + self.send_response(status_code) + for key, value in headers.items(): + self.send_header(key, value) + self.send_header("Content-Type", "application/json") + self.end_headers() + if isinstance(response_body, str): + response_body = response_body.encode("utf-8") + self.wfile.write(response_body) + + def do_GET(self): + """Handle health checks or other GET requests.""" + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b'{"status": "ok"}') + + def log_message(self, format, *args): + """Suppress default request logging to reduce noise.""" + pass + + +def create_mock_handler_class(response_queue: list) -> type: + """ + Create a new handler class with its own response queue and state. + This avoids shared state between different mock servers. + """ + class Handler(MockUpstreamHandler): + pass + + Handler.response_queue = list(response_queue) + Handler.received_requests = [] + Handler.call_count = 0 + Handler.lock = threading.Lock() + return Handler + + +class MockServer: + """Manages a mock HTTP server running in a background thread.""" + + def __init__(self, response_queue: list): + self.handler_class = create_mock_handler_class(response_queue) + self.server = HTTPServer(("0.0.0.0", 0), self.handler_class) + self.port = self.server.server_address[1] + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + + def start(self): + self.thread.start() + logger.info(f"Mock server started on port {self.port}") + + def stop(self): + self.server.shutdown() + self.thread.join(timeout=5) + logger.info(f"Mock server stopped on port {self.port}") + + @property + def call_count(self) -> int: + return self.handler_class.call_count + + @property + def received_requests(self) -> list: + return self.handler_class.received_requests + + +# --------------------------------------------------------------------------- +# Gateway lifecycle helpers +# --------------------------------------------------------------------------- + +def write_config(template_name: str, substitutions: dict) -> str: + """ + Read a config template from configs/ dir, apply port substitutions, + and write to a temp file. Returns the path to the temp config file. + """ + template_path = os.path.join(CONFIGS_DIR, template_name) + with open(template_path, "r") as f: + content = f.read() + + for key, value in substitutions.items(): + content = content.replace(f"${{{key}}}", str(value)) + + # Write to a temp file in the e2e directory so planoai can find it + fd, config_path = tempfile.mkstemp(suffix=".yaml", prefix="retry_test_") + with os.fdopen(fd, "w") as f: + f.write(content) + + logger.info(f"Wrote test config to {config_path}") + return config_path + + +def gateway_up(config_path: str, timeout: int = 30): + """Start the Plano gateway with the given config. Waits for health.""" + logger.info(f"Starting gateway with config: {config_path}") + subprocess.run( + ["planoai", "down", "--docker"], + capture_output=True, + timeout=30, + ) + result = subprocess.run( + ["planoai", "up", "--docker", config_path], + capture_output=True, + text=True, + timeout=60, + ) + if result.returncode != 0: + logger.error(f"planoai up failed: {result.stderr}") + raise RuntimeError(f"planoai up failed: {result.stderr}") + + # Wait for gateway to be healthy + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(f"{GATEWAY_BASE_URL}/healthz", timeout=2) + if resp.status_code == 200: + logger.info("Gateway is healthy") + return + except requests.ConnectionError: + pass + time.sleep(1) + + raise RuntimeError(f"Gateway did not become healthy within {timeout}s") + + +def gateway_down(): + """Stop the Plano gateway.""" + logger.info("Stopping gateway") + subprocess.run( + ["planoai", "down", "--docker"], + capture_output=True, + timeout=30, + ) + + +def make_error_response(status_code: int, message: str = "error") -> str: + """Create a JSON error response body.""" + return json.dumps({ + "error": { + "message": message, + "type": "server_error", + "code": str(status_code), + } + }) + + +# --------------------------------------------------------------------------- +# Streaming helpers +# --------------------------------------------------------------------------- + +STREAMING_SUCCESS_CHUNKS = [ + 'data: {"id":"chatcmpl-stream-001","object":"chat.completion.chunk","created":1700000000,"model":"mock-model","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]}\n\n', + 'data: {"id":"chatcmpl-stream-001","object":"chat.completion.chunk","created":1700000000,"model":"mock-model","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}\n\n', + 'data: {"id":"chatcmpl-stream-001","object":"chat.completion.chunk","created":1700000000,"model":"mock-model","choices":[{"index":0,"delta":{"content":" stream!"},"finish_reason":null}]}\n\n', + 'data: {"id":"chatcmpl-stream-001","object":"chat.completion.chunk","created":1700000000,"model":"mock-model","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}\n\n', + "data: [DONE]\n\n", +] + + +class StreamingMockHandler(MockUpstreamHandler): + """Handler that returns SSE streaming responses.""" + pass + + +def create_streaming_handler_class( + response_queue: list, + streaming_chunks: Optional[list] = None, +) -> type: + """ + Create a handler class that can return streaming SSE responses. + + response_queue entries can include a special "STREAM" body marker + to trigger streaming mode with the provided chunks. + """ + chunks = streaming_chunks or STREAMING_SUCCESS_CHUNKS + + class Handler(StreamingMockHandler): + pass + + Handler.response_queue = list(response_queue) + Handler.received_requests = [] + Handler.call_count = 0 + Handler.lock = threading.Lock() + + original_do_post = Handler.do_POST + + def streaming_do_post(self): + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) if content_length > 0 else b"" + + with Handler.lock: + Handler.call_count += 1 + Handler.received_requests.append({ + "path": self.path, + "headers": dict(self.headers), + "body": body.decode("utf-8", errors="replace"), + }) + idx = min(Handler.call_count - 1, len(Handler.response_queue) - 1) + status_code, headers, response_body = Handler.response_queue[idx] + + if response_body == "STREAM": + self.send_response(status_code) + for key, value in headers.items(): + self.send_header(key, value) + self.send_header("Content-Type", "text/event-stream") + self.send_header("Transfer-Encoding", "chunked") + self.end_headers() + for chunk in chunks: + self.wfile.write(chunk.encode("utf-8")) + self.wfile.flush() + time.sleep(0.05) + else: + self.send_response(status_code) + for key, value in headers.items(): + self.send_header(key, value) + self.send_header("Content-Type", "application/json") + self.end_headers() + if isinstance(response_body, str): + response_body = response_body.encode("utf-8") + self.wfile.write(response_body) + + Handler.do_POST = streaming_do_post + return Handler + + +class StreamingMockServer: + """Mock server that supports streaming responses.""" + + def __init__(self, response_queue: list, streaming_chunks: Optional[list] = None): + self.handler_class = create_streaming_handler_class( + response_queue, streaming_chunks + ) + self.server = HTTPServer(("0.0.0.0", 0), self.handler_class) + self.port = self.server.server_address[1] + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + + def start(self): + self.thread.start() + logger.info(f"Streaming mock server started on port {self.port}") + + def stop(self): + self.server.shutdown() + self.thread.join(timeout=5) + + @property + def call_count(self) -> int: + return self.handler_class.call_count + + @property + def received_requests(self) -> list: + return self.handler_class.received_requests + + +# --------------------------------------------------------------------------- +# Body-echo handler for IT-13 +# --------------------------------------------------------------------------- + +def create_echo_handler_class(response_queue: list) -> type: + """ + Create a handler that echoes the received request body back in the + response, wrapped in a valid chat completion response. + The response_queue controls status codes — when the status is 200, + the handler echoes the body; otherwise it returns the queued response. + """ + + class Handler(MockUpstreamHandler): + pass + + Handler.response_queue = list(response_queue) + Handler.received_requests = [] + Handler.call_count = 0 + Handler.lock = threading.Lock() + + def echo_do_post(self): + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) if content_length > 0 else b"" + + with Handler.lock: + Handler.call_count += 1 + Handler.received_requests.append({ + "path": self.path, + "headers": dict(self.headers), + "body": body.decode("utf-8", errors="replace"), + }) + idx = min(Handler.call_count - 1, len(Handler.response_queue) - 1) + status_code, headers, response_body = Handler.response_queue[idx] + + if status_code == 200: + # Echo the received body inside a chat completion response + echo_response = json.dumps({ + "id": "chatcmpl-echo-001", + "object": "chat.completion", + "created": 1700000000, + "model": "echo-model", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": body.decode("utf-8", errors="replace"), + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + }, + }) + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(echo_response.encode("utf-8")) + else: + self.send_response(status_code) + for key, value in headers.items(): + self.send_header(key, value) + self.send_header("Content-Type", "application/json") + self.end_headers() + if isinstance(response_body, str): + response_body = response_body.encode("utf-8") + self.wfile.write(response_body) + + Handler.do_POST = echo_do_post + return Handler + + +class EchoMockServer: + """Mock server that echoes request body on 200 responses.""" + + def __init__(self, response_queue: list): + self.handler_class = create_echo_handler_class(response_queue) + self.server = HTTPServer(("0.0.0.0", 0), self.handler_class) + self.port = self.server.server_address[1] + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + + def start(self): + self.thread.start() + logger.info(f"Echo mock server started on port {self.port}") + + def stop(self): + self.server.shutdown() + self.thread.join(timeout=5) + + @property + def call_count(self) -> int: + return self.handler_class.call_count + + @property + def received_requests(self) -> list: + return self.handler_class.received_requests + + +# --------------------------------------------------------------------------- +# Delayed-response handler for IT-10 (timeout triggers retry) +# --------------------------------------------------------------------------- + +def create_delayed_handler_class(response_queue: list, delay_seconds: float) -> type: + """ + Create a handler class that delays its response by *delay_seconds* before + sending the queued response. Used to simulate upstream timeouts. + """ + + class Handler(MockUpstreamHandler): + pass + + Handler.response_queue = list(response_queue) + Handler.received_requests = [] + Handler.call_count = 0 + Handler.lock = threading.Lock() + + def delayed_do_post(self): + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) if content_length > 0 else b"" + + with Handler.lock: + Handler.call_count += 1 + Handler.received_requests.append({ + "path": self.path, + "headers": dict(self.headers), + "body": body.decode("utf-8", errors="replace"), + }) + idx = min(Handler.call_count - 1, len(Handler.response_queue) - 1) + status_code, headers, response_body = Handler.response_queue[idx] + + # Delay before responding — gateway should time out before this completes + time.sleep(delay_seconds) + + self.send_response(status_code) + for key, value in headers.items(): + self.send_header(key, value) + self.send_header("Content-Type", "application/json") + self.end_headers() + if isinstance(response_body, str): + response_body = response_body.encode("utf-8") + self.wfile.write(response_body) + + Handler.do_POST = delayed_do_post + return Handler + + +class DelayedMockServer: + """Mock server that delays responses to simulate slow upstreams / timeouts.""" + + def __init__(self, response_queue: list, delay_seconds: float): + self.handler_class = create_delayed_handler_class( + response_queue, delay_seconds + ) + self.server = HTTPServer(("0.0.0.0", 0), self.handler_class) + self.port = self.server.server_address[1] + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + + def start(self): + self.thread.start() + logger.info(f"Delayed mock server started on port {self.port} ") + + def stop(self): + self.server.shutdown() + self.thread.join(timeout=5) + + @property + def call_count(self) -> int: + return self.handler_class.call_count + + @property + def received_requests(self) -> list: + return self.handler_class.received_requests + + +# =========================================================================== +# Integration Tests +# =========================================================================== + + +class TestRetryIntegration: + """ + P0 integration tests for retry-on-ratelimit feature. + + These tests require the full gateway infrastructure (Docker, planoai CLI). + Each test starts mock servers, configures the gateway, sends requests, + and validates retry behavior end-to-end. + """ + + def test_it1_basic_retry_on_429(self): + """ + IT-1: Basic retry on 429. + + Primary mock returns 429, secondary returns 200. + Assert client gets 200 from the secondary provider. + """ + # Setup mock servers + primary = MockServer([ + (429, {}, make_error_response(429, "Rate limit exceeded")), + ]) + secondary = MockServer([ + (200, {}, SUCCESS_RESPONSE), + ]) + primary.start() + secondary.start() + config_path = None + + try: + # Write config with actual ports + config_path = write_config("retry_it1_basic_429.yaml", { + "MOCK_PRIMARY_PORT": primary.port, + "MOCK_SECONDARY_PORT": secondary.port, + }) + + # Start gateway + gateway_up(config_path) + + # Send request + resp = requests.post( + GATEWAY_CHAT_URL, + json=CHAT_REQUEST_BODY, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + + # Assert: client gets 200 from secondary + assert resp.status_code == 200, ( + f"Expected 200 but got {resp.status_code}: {resp.text}" + ) + body = resp.json() + assert "choices" in body + assert body["choices"][0]["message"]["content"] == "Hello from mock provider!" + + # Assert: primary was called (got 429), secondary was called (returned 200) + assert primary.call_count >= 1, "Primary should have been called" + assert secondary.call_count >= 1, "Secondary should have been called" + + finally: + gateway_down() + primary.stop() + secondary.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + def test_it2_retry_on_503_different_provider(self): + """ + IT-2: Retry on 503 with different_provider strategy. + + Primary returns 503, secondary returns 200. + Assert client gets 200 from the secondary provider. + """ + primary = MockServer([ + (503, {}, make_error_response(503, "Service Unavailable")), + ]) + secondary = MockServer([ + (200, {}, SUCCESS_RESPONSE), + ]) + primary.start() + secondary.start() + config_path = None + + try: + config_path = write_config("retry_it2_503_different_provider.yaml", { + "MOCK_PRIMARY_PORT": primary.port, + "MOCK_SECONDARY_PORT": secondary.port, + }) + gateway_up(config_path) + + resp = requests.post( + GATEWAY_CHAT_URL, + json=CHAT_REQUEST_BODY, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + + assert resp.status_code == 200, ( + f"Expected 200 but got {resp.status_code}: {resp.text}" + ) + body = resp.json() + assert "choices" in body + assert primary.call_count >= 1 + assert secondary.call_count >= 1 + + finally: + gateway_down() + primary.stop() + secondary.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + def test_it3_all_retries_exhausted(self): + """ + IT-3: All retries exhausted. + + All mock providers return 429. + Assert client gets an error response with attempts list and total_attempts. + """ + primary = MockServer([ + (429, {}, make_error_response(429, "Rate limit exceeded")), + ]) + secondary = MockServer([ + (429, {}, make_error_response(429, "Rate limit exceeded")), + ]) + primary.start() + secondary.start() + config_path = None + + try: + config_path = write_config("retry_it3_all_exhausted.yaml", { + "MOCK_PRIMARY_PORT": primary.port, + "MOCK_SECONDARY_PORT": secondary.port, + }) + gateway_up(config_path) + + resp = requests.post( + GATEWAY_CHAT_URL, + json=CHAT_REQUEST_BODY, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + + # Should get an error response (429 or the gateway's retry_exhausted error) + assert resp.status_code >= 400, ( + f"Expected error status but got {resp.status_code}" + ) + body = resp.json() + + # The error response should contain retry attempt details + error = body.get("error", {}) + assert error.get("type") == "retry_exhausted", ( + f"Expected retry_exhausted error type, got: {error}" + ) + assert "attempts" in error, "Error should contain attempts list" + assert "total_attempts" in error, "Error should contain total_attempts" + assert error["total_attempts"] >= 2, ( + f"Expected at least 2 total attempts, got {error['total_attempts']}" + ) + + finally: + gateway_down() + primary.stop() + secondary.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + def test_it4_no_retry_policy_no_retry(self): + """ + IT-4: No retry_policy → no retry. + + Primary returns 429 with no retry_policy configured. + Assert client gets 429 directly (no retry to secondary). + """ + primary = MockServer([ + (429, {}, make_error_response(429, "Rate limit exceeded")), + ]) + secondary = MockServer([ + (200, {}, SUCCESS_RESPONSE), + ]) + primary.start() + secondary.start() + config_path = None + + try: + config_path = write_config("retry_it4_no_retry_policy.yaml", { + "MOCK_PRIMARY_PORT": primary.port, + "MOCK_SECONDARY_PORT": secondary.port, + }) + gateway_up(config_path) + + resp = requests.post( + GATEWAY_CHAT_URL, + json=CHAT_REQUEST_BODY, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + + # Should get 429 directly — no retry + assert resp.status_code == 429, ( + f"Expected 429 but got {resp.status_code}: {resp.text}" + ) + + # Secondary should NOT have been called + assert secondary.call_count == 0, ( + f"Secondary should not be called without retry_policy, " + f"but was called {secondary.call_count} times" + ) + + finally: + gateway_down() + primary.stop() + secondary.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + def test_it5_max_attempts_respected(self): + """ + IT-5: max_attempts respected. + + Primary returns 429, max_attempts: 1. + Assert only 1 retry attempt is made, then error is returned. + The secondary also returns 429 to ensure we see the exhaustion. + """ + primary = MockServer([ + (429, {}, make_error_response(429, "Rate limit exceeded")), + ]) + secondary = MockServer([ + (429, {}, make_error_response(429, "Rate limit exceeded")), + ]) + tertiary = MockServer([ + (200, {}, SUCCESS_RESPONSE), + ]) + primary.start() + secondary.start() + tertiary.start() + config_path = None + + try: + config_path = write_config("retry_it5_max_attempts.yaml", { + "MOCK_PRIMARY_PORT": primary.port, + "MOCK_SECONDARY_PORT": secondary.port, + "MOCK_TERTIARY_PORT": tertiary.port, + }) + gateway_up(config_path) + + resp = requests.post( + GATEWAY_CHAT_URL, + json=CHAT_REQUEST_BODY, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + + # With max_attempts: 1, only 1 retry should happen after the initial failure. + # Primary fails (429) → 1 retry to secondary (429) → exhausted. + # Tertiary should NOT be reached. + assert resp.status_code >= 400, ( + f"Expected error status but got {resp.status_code}" + ) + + assert tertiary.call_count == 0, ( + f"Tertiary should not be called with max_attempts=1, " + f"but was called {tertiary.call_count} times" + ) + + # Total calls: primary (1) + secondary (1 retry) = 2 + total_calls = primary.call_count + secondary.call_count + assert total_calls <= 2, ( + f"Expected at most 2 total calls (1 original + 1 retry), " + f"got {total_calls}" + ) + + finally: + gateway_down() + primary.stop() + secondary.stop() + tertiary.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + def test_it6_backoff_delay_observed(self): + """ + IT-6: Backoff delay observed. + + Configure same_model strategy with backoff (base_ms: 500, jitter: false). + Primary returns 429 twice, then 200 on third attempt. + Assert total response time includes backoff delays. + + With base_ms=500 and no jitter: + - Attempt 1: fail (429) + - Backoff: 500ms (500 * 2^0) + - Attempt 2: fail (429) + - Backoff: 1000ms (500 * 2^1) + - Attempt 3: success (200) + Total backoff >= 1500ms (500 + 1000) + """ + primary = MockServer([ + (429, {}, make_error_response(429, "Rate limit exceeded")), + (429, {}, make_error_response(429, "Rate limit exceeded")), + (200, {}, SUCCESS_RESPONSE), + ]) + primary.start() + config_path = None + + try: + config_path = write_config("retry_it6_backoff_delay.yaml", { + "MOCK_PRIMARY_PORT": primary.port, + }) + gateway_up(config_path) + + start_time = time.time() + resp = requests.post( + GATEWAY_CHAT_URL, + json=CHAT_REQUEST_BODY, + headers={"Authorization": "Bearer test-key"}, + timeout=60, + ) + elapsed = time.time() - start_time + + assert resp.status_code == 200, ( + f"Expected 200 but got {resp.status_code}: {resp.text}" + ) + + # With base_ms=500 and no jitter, backoff should be at least: + # 500ms (attempt 1→2) + 1000ms (attempt 2→3) = 1500ms + # Use a slightly lower threshold (1.0s) to account for timing variance + min_expected_delay = 1.0 # seconds + assert elapsed >= min_expected_delay, ( + f"Expected response time >= {min_expected_delay}s due to backoff, " + f"but got {elapsed:.2f}s" + ) + + # Primary should have been called 3 times + assert primary.call_count == 3, ( + f"Expected 3 calls to primary, got {primary.call_count}" + ) + + finally: + gateway_down() + primary.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + def test_it12_streaming_preserved_across_retry(self): + """ + IT-12: Streaming request preserved across retry. + + Primary returns 429, secondary returns 200 with SSE streaming. + Assert client receives a streamed response. + """ + # Primary always returns 429 + primary = MockServer([ + (429, {}, make_error_response(429, "Rate limit exceeded")), + ]) + # Secondary returns streaming 200 + secondary_handler = create_streaming_handler_class([ + (200, {}, "STREAM"), + ]) + secondary_server = HTTPServer(("0.0.0.0", 0), secondary_handler) + secondary_port = secondary_server.server_address[1] + secondary_thread = threading.Thread( + target=secondary_server.serve_forever, daemon=True + ) + + primary.start() + secondary_thread.start() + logger.info(f"Streaming secondary mock started on port {secondary_port}") + config_path = None + + try: + config_path = write_config("retry_it12_streaming.yaml", { + "MOCK_PRIMARY_PORT": primary.port, + "MOCK_SECONDARY_PORT": secondary_port, + }) + gateway_up(config_path) + + # Send a streaming request + streaming_body = dict(CHAT_REQUEST_BODY) + streaming_body["stream"] = True + + resp = requests.post( + GATEWAY_CHAT_URL, + json=streaming_body, + headers={"Authorization": "Bearer test-key"}, + stream=True, + timeout=30, + ) + + assert resp.status_code == 200, ( + f"Expected 200 but got {resp.status_code}: {resp.text}" + ) + + # Collect streamed chunks + chunks = [] + for line in resp.iter_lines(decode_unicode=True): + if line: + chunks.append(line) + + # Should have received SSE data chunks + assert len(chunks) > 0, "Should have received streaming chunks" + + # Verify at least one chunk contains "data:" prefix (SSE format) + data_chunks = [c for c in chunks if c.startswith("data:")] + assert len(data_chunks) > 0, ( + f"Expected SSE data chunks, got: {chunks}" + ) + + # Verify the stream contains expected content + content_found = False + for chunk in data_chunks: + if chunk == "data: [DONE]": + continue + try: + payload = json.loads(chunk[len("data: "):]) + delta = payload.get("choices", [{}])[0].get("delta", {}) + if delta.get("content"): + content_found = True + except (json.JSONDecodeError, IndexError): + pass + + assert content_found, "Should have received content in streaming chunks" + + # Primary should have been called (got 429) + assert primary.call_count >= 1, "Primary should have been called" + + finally: + gateway_down() + primary.stop() + secondary_server.shutdown() + secondary_thread.join(timeout=5) + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + def test_it13_request_body_preserved_across_retry(self): + """ + IT-13: Request body preserved across retry. + + Primary returns 429, secondary echoes the request body. + Assert the echoed body matches the original request. + """ + primary = MockServer([ + (429, {}, make_error_response(429, "Rate limit exceeded")), + ]) + # Secondary echoes the request body + echo_server = EchoMockServer([ + (200, {}, ""), # Status 200 triggers echo behavior + ]) + + primary.start() + echo_server.start() + config_path = None + + try: + config_path = write_config("retry_it13_body_preserved.yaml", { + "MOCK_PRIMARY_PORT": primary.port, + "MOCK_SECONDARY_PORT": echo_server.port, + }) + gateway_up(config_path) + + # Send request with a distinctive body + request_body = { + "model": "openai/gpt-4o", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me about retry mechanisms."}, + ], + "temperature": 0.7, + "max_tokens": 100, + } + + resp = requests.post( + GATEWAY_CHAT_URL, + json=request_body, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + + assert resp.status_code == 200, ( + f"Expected 200 but got {resp.status_code}: {resp.text}" + ) + + # The echo server received the request body — verify it was preserved + assert echo_server.call_count >= 1, "Echo server should have been called" + + # Parse the body that the echo server received + received_body_str = echo_server.received_requests[-1]["body"] + received_body = json.loads(received_body_str) + + # The gateway may modify the model field when routing to a different + # provider, but the messages and other fields should be preserved + assert received_body.get("messages") is not None, ( + "Messages should be preserved in the forwarded request" + ) + + # Verify the user message content is preserved + user_messages = [ + m for m in received_body["messages"] if m.get("role") == "user" + ] + assert len(user_messages) > 0, "User messages should be preserved" + assert user_messages[-1]["content"] == "Tell me about retry mechanisms.", ( + f"User message content should be preserved, got: {user_messages[-1]}" + ) + + # Primary should have been called (got 429) + assert primary.call_count >= 1, "Primary should have been called" + + finally: + gateway_down() + primary.stop() + echo_server.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + + # ----------------------------------------------------------------------- + # P1 Integration Tests (IT-7 through IT-10) + # ----------------------------------------------------------------------- + + def test_it7_fallback_models_priority(self): + """ + IT-7: Fallback models priority. + + Primary mock returns 429, fallback[0] returns 429, fallback[1] returns 200. + Assert client gets 200 from fallback[1] and providers are tried in the + order defined by fallback_models. + + Config: fallback_models: [anthropic/claude-3-5-sonnet, mistral/mistral-large] + """ + primary = MockServer([ + (429, {}, make_error_response(429, "Rate limit exceeded")), + ]) + fallback1 = MockServer([ + (429, {}, make_error_response(429, "Rate limit exceeded")), + ]) + fallback2 = MockServer([ + (200, {}, SUCCESS_RESPONSE), + ]) + primary.start() + fallback1.start() + fallback2.start() + config_path = None + + try: + config_path = write_config("retry_it7_fallback_priority.yaml", { + "MOCK_PRIMARY_PORT": primary.port, + "MOCK_FALLBACK1_PORT": fallback1.port, + "MOCK_FALLBACK2_PORT": fallback2.port, + }) + gateway_up(config_path) + + resp = requests.post( + GATEWAY_CHAT_URL, + json=CHAT_REQUEST_BODY, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + + # Assert: client gets 200 from fallback[1] + assert resp.status_code == 200, ( + f"Expected 200 but got {resp.status_code}: {resp.text}" + ) + body = resp.json() + assert "choices" in body + assert body["choices"][0]["message"]["content"] == "Hello from mock provider!" + + # Assert: providers tried in order — primary, fallback[0], fallback[1] + assert primary.call_count >= 1, "Primary should have been called first" + assert fallback1.call_count >= 1, ( + "Fallback[0] (anthropic/claude-3-5-sonnet) should have been tried " + "before fallback[1]" + ) + assert fallback2.call_count >= 1, ( + "Fallback[1] (mistral/mistral-large) should have been called" + ) + + finally: + gateway_down() + primary.stop() + fallback1.stop() + fallback2.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + def test_it8_retry_after_header_honored(self): + """ + IT-8: Retry-After header honored. + + Primary returns 429 + Retry-After: 2 on the first call, then 200 on the + second call (same_model strategy). Assert the total response time is + >= 2 seconds, proving the gateway waited for the Retry-After duration. + """ + primary = MockServer([ + (429, {"Retry-After": "2"}, make_error_response(429, "Rate limit exceeded")), + (200, {}, SUCCESS_RESPONSE), + ]) + primary.start() + config_path = None + + try: + config_path = write_config("retry_it8_retry_after_honored.yaml", { + "MOCK_PRIMARY_PORT": primary.port, + }) + gateway_up(config_path) + + start_time = time.time() + resp = requests.post( + GATEWAY_CHAT_URL, + json=CHAT_REQUEST_BODY, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + elapsed = time.time() - start_time + + # Assert: client gets 200 after the retry + assert resp.status_code == 200, ( + f"Expected 200 but got {resp.status_code}: {resp.text}" + ) + body = resp.json() + assert "choices" in body + + # Assert: total time >= 2 seconds (Retry-After: 2 was honored) + # Use a slightly lower threshold to account for timing variance + min_expected_delay = 1.8 # seconds + assert elapsed >= min_expected_delay, ( + f"Expected response time >= {min_expected_delay}s due to " + f"Retry-After: 2, but got {elapsed:.2f}s" + ) + + # Primary should have been called twice (429 then 200) + assert primary.call_count == 2, ( + f"Expected 2 calls to primary (429 + 200), got {primary.call_count}" + ) + + finally: + gateway_down() + primary.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + def test_it9_retry_after_blocks_initial_selection(self): + """ + IT-9: Retry-After blocks initial selection. + + First request: primary returns 429 + Retry-After: 60 and the gateway + retries to the secondary (which returns 200). + + Second request (sent within 60s): because the primary is globally + blocked by the Retry-After state, the gateway should route directly + to the alternative provider without hitting the primary again. + """ + # Primary: first call returns 429 + Retry-After: 60, subsequent calls + # return 200 (but should not be reached for the second request). + primary = MockServer([ + (429, {"Retry-After": "60"}, make_error_response(429, "Rate limit exceeded")), + (200, {}, SUCCESS_RESPONSE), + (200, {}, SUCCESS_RESPONSE), + ]) + secondary = MockServer([ + (200, {}, SUCCESS_RESPONSE), + (200, {}, SUCCESS_RESPONSE), + ]) + primary.start() + secondary.start() + config_path = None + + try: + config_path = write_config( + "retry_it9_retry_after_blocks_selection.yaml", + { + "MOCK_PRIMARY_PORT": primary.port, + "MOCK_SECONDARY_PORT": secondary.port, + }, + ) + gateway_up(config_path) + + # --- First request: triggers the Retry-After state --- + resp1 = requests.post( + GATEWAY_CHAT_URL, + json=CHAT_REQUEST_BODY, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + assert resp1.status_code == 200, ( + f"First request: expected 200 but got {resp1.status_code}: {resp1.text}" + ) + + primary_calls_after_first = primary.call_count + secondary_calls_after_first = secondary.call_count + + # Primary should have been called once (got 429), secondary once (got 200) + assert primary_calls_after_first >= 1, ( + "Primary should have been called for the first request" + ) + assert secondary_calls_after_first >= 1, ( + "Secondary should have been called as fallback for the first request" + ) + + # --- Second request: within the 60s Retry-After window --- + # The primary model should be blocked globally, so the gateway + # should route to the alternative provider directly. + resp2 = requests.post( + GATEWAY_CHAT_URL, + json={ + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "Second request"}], + }, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + assert resp2.status_code == 200, ( + f"Second request: expected 200 but got {resp2.status_code}: {resp2.text}" + ) + + # Assert: primary was NOT called again for the second request + # (it should still be blocked by the 60s Retry-After) + assert primary.call_count == primary_calls_after_first, ( + f"Primary should not have been called for the second request " + f"(blocked by Retry-After: 60). Calls before: " + f"{primary_calls_after_first}, after: {primary.call_count}" + ) + + # Assert: secondary handled the second request + assert secondary.call_count > secondary_calls_after_first, ( + f"Secondary should have handled the second request. " + f"Calls before: {secondary_calls_after_first}, " + f"after: {secondary.call_count}" + ) + + finally: + gateway_down() + primary.stop() + secondary.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + def test_it10_timeout_triggers_retry(self): + """ + IT-10: Timeout triggers retry. + + Primary mock delays its response beyond the gateway's request timeout. + Secondary returns 200 immediately. + Assert client gets 200 from the secondary provider. + """ + # Primary delays 120 seconds — well beyond any reasonable gateway timeout. + # The gateway should time out and retry to the secondary. + primary = DelayedMockServer( + response_queue=[ + (200, {}, SUCCESS_RESPONSE), + ], + delay_seconds=120, + ) + secondary = MockServer([ + (200, {}, SUCCESS_RESPONSE), + ]) + primary.start() + secondary.start() + config_path = None + + try: + config_path = write_config("retry_it10_timeout_triggers_retry.yaml", { + "MOCK_PRIMARY_PORT": primary.port, + "MOCK_SECONDARY_PORT": secondary.port, + }) + gateway_up(config_path) + + resp = requests.post( + GATEWAY_CHAT_URL, + json=CHAT_REQUEST_BODY, + headers={"Authorization": "Bearer test-key"}, + timeout=120, + ) + + # Assert: client gets 200 from the secondary + assert resp.status_code == 200, ( + f"Expected 200 but got {resp.status_code}: {resp.text}" + ) + body = resp.json() + assert "choices" in body + assert body["choices"][0]["message"]["content"] == "Hello from mock provider!" + + # Assert: primary was called (timed out), secondary was called (returned 200) + assert primary.call_count >= 1, ( + "Primary should have been called (and timed out)" + ) + assert secondary.call_count >= 1, ( + "Secondary should have been called after primary timed out" + ) + + finally: + gateway_down() + primary.stop() + secondary.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path) + + def test_it11_high_latency_proactive_failover(self): + """ + IT-11: High latency proactive failover. + + First request: primary mock delays response by ~1.5s (threshold_ms=1000 + + 500ms buffer) but completes with 200 OK. The client receives the slow + 200 response (completed responses are always delivered). However, the + gateway records a Latency_Block_State for the primary model. + + Second request: sent immediately after the first. Because the primary + is now latency-blocked (block_duration_seconds=60, min_triggers=1), + the gateway should route directly to the secondary provider. + + Config: on_high_latency with min_triggers: 1, threshold_ms: 1000, + block_duration_seconds: 60, measure: "total", scope: "model", + apply_to: "global". + """ + # Primary: delays 1.5s (exceeds 1000ms threshold), returns 200. + # Queue two responses in case the primary is called twice (it shouldn't + # be for the second request, but we need a response ready just in case). + primary = DelayedMockServer( + response_queue=[ + (200, {}, SUCCESS_RESPONSE), + (200, {}, SUCCESS_RESPONSE), + ], + delay_seconds=1.5, + ) + # Secondary: returns 200 immediately. + secondary = MockServer([ + (200, {}, SUCCESS_RESPONSE), + (200, {}, SUCCESS_RESPONSE), + ]) + primary.start() + secondary.start() + config_path = None + + try: + config_path = write_config( + "retry_it11_high_latency_failover.yaml", + { + "MOCK_PRIMARY_PORT": primary.port, + "MOCK_SECONDARY_PORT": secondary.port, + }, + ) + gateway_up(config_path) + + # --- First request: triggers the latency block --- + # The primary will respond with 200 after ~1.5s delay. + # Since the response completes, the client gets the 200 back, + # but the gateway should record a Latency_Block_State entry. + resp1 = requests.post( + GATEWAY_CHAT_URL, + json=CHAT_REQUEST_BODY, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + assert resp1.status_code == 200, ( + f"First request: expected 200 but got {resp1.status_code}: " + f"{resp1.text}" + ) + + primary_calls_after_first = primary.call_count + secondary_calls_after_first = secondary.call_count + + # Primary should have been called once (slow 200). + assert primary_calls_after_first >= 1, ( + "Primary should have been called for the first request" + ) + + # --- Second request: within the 60s latency block window --- + # The primary model should be latency-blocked globally, so the + # gateway should route to the secondary provider directly. + resp2 = requests.post( + GATEWAY_CHAT_URL, + json={ + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "Second request"}], + }, + headers={"Authorization": "Bearer test-key"}, + timeout=30, + ) + assert resp2.status_code == 200, ( + f"Second request: expected 200 but got {resp2.status_code}: " + f"{resp2.text}" + ) + + # Assert: primary was NOT called again for the second request + # (it should be latency-blocked for 60s after the slow first response). + assert primary.call_count == primary_calls_after_first, ( + f"Primary should not have been called for the second request " + f"(latency-blocked for 60s). Calls before: " + f"{primary_calls_after_first}, after: {primary.call_count}" + ) + + # Assert: secondary handled the second request. + assert secondary.call_count > secondary_calls_after_first, ( + f"Secondary should have handled the second request. " + f"Calls before: {secondary_calls_after_first}, " + f"after: {secondary.call_count}" + ) + + finally: + gateway_down() + primary.stop() + secondary.stop() + if config_path and os.path.exists(config_path): + os.unlink(config_path)