plano/crates/common/src/retry/mod.rs
Troy Mitchell 5a2d0aa52e retry: add core retry types and module structure
Add the retry module with core type definitions including:
- RequestContext, RequestSignature for request deduplication
- RetryExhaustedError, AllProvidersExhaustedError for error handling
- AttemptError, AttemptErrorType for attempt tracking
- ValidationError, ValidationWarning for config validation
- Helper functions for provider extraction and hashing

Wire up pub mod retry in lib.rs.

Signed-off-by: Troy Mitchell <i@troy-y.org>
2026-04-28 17:05:03 +08:00

804 lines
31 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Instant;
use bytes::Bytes;
use hyper::HeaderMap;
use sha2::{Digest, Sha256};
use crate::configuration::{ApplyTo, LlmProvider, LlmProviderType};
// Sub-modules
pub mod backoff;
pub mod error_detector;
pub mod error_response;
pub mod latency_block_state;
pub mod latency_trigger;
pub mod orchestrator;
pub mod provider_selector;
pub mod retry_after_state;
pub mod validation;
// ── State Structs ──────────────────────────────────────────────────────────
/// In-memory Retry-After state entry.
#[derive(Debug, Clone)]
pub struct RetryAfterEntry {
pub identifier: String,
pub expires_at: Instant,
pub apply_to: ApplyTo,
}
/// In-memory Latency Block state entry.
#[derive(Debug, Clone)]
pub struct LatencyBlockEntry {
pub identifier: String,
pub expires_at: Instant,
pub measured_latency_ms: u64,
pub apply_to: ApplyTo,
}
/// Error accumulated from a single attempt.
#[derive(Debug, Clone)]
pub struct AttemptError {
pub model_id: String,
pub error_type: AttemptErrorType,
pub attempt_number: u32,
}
#[derive(Debug, Clone)]
pub enum AttemptErrorType {
HttpError { status_code: u16, body: Vec<u8> },
Timeout { duration_ms: u64 },
HighLatency { measured_ms: u64, threshold_ms: u64 },
}
/// Lightweight request signature for retry tracking.
/// The actual request body bytes are passed by reference from the handler scope
/// (as `&Bytes`) rather than cloned into this struct.
#[derive(Debug, Clone)]
pub struct RequestSignature {
/// SHA-256 hash of the original request body
pub body_hash: [u8; 32],
pub headers: HeaderMap,
pub streaming: bool,
pub original_model: String,
}
impl RequestSignature {
pub fn new(body: &[u8], headers: &HeaderMap, streaming: bool, original_model: String) -> Self {
let mut hasher = Sha256::new();
hasher.update(body);
let hash: [u8; 32] = hasher.finalize().into();
Self {
body_hash: hash,
headers: headers.clone(),
streaming,
original_model,
}
}
}
// ── Auth Header Constants ───────────────────────────────────────────────────
/// Headers that carry authentication credentials and must be sanitized
/// when forwarding requests to a different provider.
const AUTH_HEADERS: &[&str] = &["authorization", "x-api-key"];
/// Additional provider-specific headers that should be sanitized.
const PROVIDER_SPECIFIC_HEADERS: &[&str] = &["anthropic-version"];
/// Rebuild a request for a different target provider.
///
/// Updates the `model` field in the JSON body to match the target provider's
/// model name (without provider prefix), and applies the correct auth
/// credentials for the target provider. Sanitizes auth headers from the
/// original request to prevent credential leakage across providers.
///
/// Returns the updated body bytes and headers, or an error if the body
/// cannot be parsed as JSON.
pub fn rebuild_request_for_provider(
body: &Bytes,
target_provider: &LlmProvider,
original_headers: &HeaderMap,
) -> Result<(Bytes, HeaderMap), RebuildError> {
// Update the model field in the JSON body
let mut json_body: serde_json::Value =
serde_json::from_slice(body).map_err(|e| RebuildError::InvalidJson(e.to_string()))?;
// Extract model name without provider prefix (e.g., "openai/gpt-4o" -> "gpt-4o")
let target_model = target_provider
.model
.as_deref()
.or(Some(&target_provider.name))
.unwrap_or(&target_provider.name);
let model_name_only = if let Some((_, model)) = target_model.split_once('/') {
model
} else {
target_model
};
if let Some(obj) = json_body.as_object_mut() {
obj.insert(
"model".to_string(),
serde_json::Value::String(model_name_only.to_string()),
);
}
let updated_body = Bytes::from(
serde_json::to_vec(&json_body).map_err(|e| RebuildError::InvalidJson(e.to_string()))?,
);
// Sanitize and rebuild headers
let mut headers = sanitize_headers(original_headers);
apply_auth_headers(&mut headers, target_provider)?;
Ok((updated_body, headers))
}
/// Remove auth-related headers from the original request to prevent
/// credential leakage when forwarding to a different provider.
fn sanitize_headers(original: &HeaderMap) -> HeaderMap {
let mut headers = original.clone();
for header_name in AUTH_HEADERS.iter().chain(PROVIDER_SPECIFIC_HEADERS.iter()) {
headers.remove(*header_name);
}
headers
}
/// Apply the correct auth headers for the target provider.
fn apply_auth_headers(headers: &mut HeaderMap, provider: &LlmProvider) -> Result<(), RebuildError> {
// If passthrough_auth is enabled, don't set provider credentials
if provider.passthrough_auth == Some(true) {
return Ok(());
}
let access_key = provider
.access_key
.as_ref()
.ok_or_else(|| RebuildError::MissingAccessKey(provider.name.clone()))?;
match provider.provider_interface {
LlmProviderType::Anthropic => {
headers.insert(
hyper::header::HeaderName::from_static("x-api-key"),
hyper::header::HeaderValue::from_str(access_key)
.map_err(|_| RebuildError::InvalidHeaderValue("x-api-key".to_string()))?,
);
headers.insert(
hyper::header::HeaderName::from_static("anthropic-version"),
hyper::header::HeaderValue::from_static("2023-06-01"),
);
}
_ => {
// OpenAI-compatible providers use Authorization: Bearer <key>
let bearer = format!("Bearer {}", access_key);
headers.insert(
hyper::header::AUTHORIZATION,
hyper::header::HeaderValue::from_str(&bearer)
.map_err(|_| RebuildError::InvalidHeaderValue("authorization".to_string()))?,
);
}
}
Ok(())
}
/// Errors that can occur when rebuilding a request for a different provider.
#[derive(Debug, Clone, PartialEq)]
pub enum RebuildError {
/// The request body is not valid JSON.
InvalidJson(String),
/// The target provider has no access_key configured.
MissingAccessKey(String),
/// A header value could not be constructed.
InvalidHeaderValue(String),
}
impl std::fmt::Display for RebuildError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RebuildError::InvalidJson(e) => write!(f, "invalid JSON body: {}", e),
RebuildError::MissingAccessKey(name) => {
write!(f, "no access key configured for provider '{}'", name)
}
RebuildError::InvalidHeaderValue(header) => {
write!(f, "invalid header value for '{}'", header)
}
}
}
}
impl std::error::Error for RebuildError {}
/// Extended request context for retry tracking.
#[derive(Debug)]
pub struct RequestContext {
pub request_id: String,
pub attempted_providers: HashSet<String>,
pub retry_start_time: Option<Instant>,
pub attempt_number: u32,
/// Request-scoped Retry_After_State (when apply_to: "request")
pub request_retry_after_state: HashMap<String, Instant>,
/// Request-scoped Latency_Block_State (when apply_to: "request")
pub request_latency_block_state: HashMap<String, Instant>,
/// Request signature for tracking
pub request_signature: RequestSignature,
/// Accumulated errors from all attempts
pub errors: Vec<AttemptError>,
}
/// Bounded semaphore controlling the maximum number of concurrent in-flight
/// retry operations. Prevents OOM under high load by rejecting new retry
/// attempts when the limit is reached (fail-open: original request proceeds
/// without retry).
pub struct RetryGate {
pub semaphore: Arc<tokio::sync::Semaphore>,
}
impl RetryGate {
const DEFAULT_MAX_IN_FLIGHT: usize = 1000;
pub fn new(max_in_flight_retries: usize) -> Self {
Self {
semaphore: Arc::new(tokio::sync::Semaphore::new(max_in_flight_retries)),
}
}
pub fn try_acquire(&self) -> Option<tokio::sync::OwnedSemaphorePermit> {
self.semaphore.clone().try_acquire_owned().ok()
}
}
impl Default for RetryGate {
fn default() -> Self {
Self::new(Self::DEFAULT_MAX_IN_FLIGHT)
}
}
// ── Error Types ────────────────────────────────────────────────────────────
/// All retry attempts exhausted for a single provider's retry sequence.
#[derive(Debug)]
pub struct RetryExhaustedError {
/// All attempt errors accumulated during the retry sequence.
pub attempts: Vec<AttemptError>,
/// Maximum Retry-After value observed across all attempts (if any).
pub max_retry_after_seconds: Option<u64>,
/// Shortest remaining block duration among blocked candidates at exhaustion time.
pub shortest_remaining_block_seconds: Option<u64>,
/// Whether the retry budget (max_retry_duration_ms) was exceeded.
pub retry_budget_exhausted: bool,
}
/// All providers (including fallbacks) exhausted.
#[derive(Debug)]
pub struct AllProvidersExhaustedError {
/// Shortest remaining block duration among blocked candidates.
pub shortest_remaining_block_seconds: Option<u64>,
}
// ── Validation Types ───────────────────────────────────────────────────────
/// Configuration validation errors that prevent gateway startup.
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationError {
/// Backoff section present without required `apply_to` field.
BackoffMissingApplyTo { model: String },
/// `min_triggers > 1` without `trigger_window_seconds`.
LatencyMissingTriggerWindow { model: String },
/// Invalid strategy value.
InvalidStrategy { model: String, value: String },
/// Invalid `apply_to` value.
InvalidApplyTo { model: String, value: String },
/// Invalid `scope` value.
InvalidScope { model: String, value: String },
/// Status code outside 100599.
StatusCodeOutOfRange { model: String, code: u16 },
/// Range with start > end.
StatusCodeRangeInverted { model: String, range: String },
/// Invalid status code range format.
StatusCodeRangeInvalid { model: String, range: String },
/// `threshold_ms`, `block_duration_seconds`, `max_retry_after_seconds`,
/// `max_retry_duration_ms`, or `base_ms` not positive.
NonPositiveValue { model: String, field: String },
/// `trigger_window_seconds` not positive when specified.
NonPositiveTriggerWindow { model: String },
/// `max_ms` ≤ `base_ms` in backoff config.
MaxMsNotGreaterThanBaseMs {
model: String,
base_ms: u64,
max_ms: u64,
},
/// `max_attempts` is negative (represented as u32, so this catches zero if needed).
InvalidMaxAttempts { model: String, value: String },
/// Fallback model string is empty or doesn't contain a "/" separator.
InvalidFallbackModel { model: String, fallback: String },
}
/// Configuration validation warnings (gateway starts, warning logged).
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationWarning {
/// Single provider with failover strategy.
SingleProviderWithFailover { model: String, strategy: String },
/// Provider-scope Retry-After with same_model strategy.
ProviderScopeWithSameModel { model: String },
/// Backoff apply_to mismatch with default strategy.
BackoffApplyToMismatch {
model: String,
apply_to: String,
strategy: String,
},
/// Latency scope/strategy mismatch.
LatencyScopeStrategyMismatch { model: String },
/// Aggressive latency threshold (< 1000ms).
AggressiveLatencyThreshold { model: String, threshold_ms: u64 },
/// Fallback model not in Provider_List.
FallbackModelNotInProviderList { model: String, fallback: String },
/// Overlapping status codes across on_status_codes entries.
OverlappingStatusCodes { model: String, code: u16 },
}
#[cfg(test)]
mod tests {
use super::*;
use crate::configuration::{LlmProvider, LlmProviderType};
use bytes::Bytes;
use hyper::header::{HeaderMap, HeaderValue, AUTHORIZATION};
use proptest::prelude::*;
fn make_provider(name: &str, interface: LlmProviderType, key: Option<&str>) -> LlmProvider {
LlmProvider {
name: name.to_string(),
provider_interface: interface,
access_key: key.map(|k| k.to_string()),
model: Some(name.to_string()),
default: None,
stream: None,
endpoint: None,
port: None,
rate_limits: None,
usage: None,
cluster_name: None,
base_url_path_prefix: None,
internal: None,
passthrough_auth: None,
retry_policy: None,
headers: None,
}
}
// ── RequestSignature tests ─────────────────────────────────────────
#[test]
fn test_request_signature_computes_hash() {
let body = b"hello world";
let headers = HeaderMap::new();
let sig = RequestSignature::new(body, &headers, false, "openai/gpt-4o".to_string());
// SHA-256 of "hello world" is deterministic
let mut hasher = Sha256::new();
hasher.update(b"hello world");
let expected: [u8; 32] = hasher.finalize().into();
assert_eq!(sig.body_hash, expected);
assert!(!sig.streaming);
assert_eq!(sig.original_model, "openai/gpt-4o");
}
#[test]
fn test_request_signature_preserves_headers() {
let mut headers = HeaderMap::new();
headers.insert("x-custom", HeaderValue::from_static("value"));
let sig = RequestSignature::new(b"body", &headers, true, "model".to_string());
assert_eq!(sig.headers.get("x-custom").unwrap(), "value");
assert!(sig.streaming);
}
#[test]
fn test_request_signature_different_bodies_different_hashes() {
let headers = HeaderMap::new();
let sig1 = RequestSignature::new(b"body1", &headers, false, "m".to_string());
let sig2 = RequestSignature::new(b"body2", &headers, false, "m".to_string());
assert_ne!(sig1.body_hash, sig2.body_hash);
}
// ── RetryGate tests ────────────────────────────────────────────────
#[test]
fn test_retry_gate_default_permits() {
let gate = RetryGate::default();
// Should be able to acquire at least one permit
assert!(gate.try_acquire().is_some());
}
#[test]
fn test_retry_gate_exhaustion() {
let gate = RetryGate::new(1);
let permit = gate.try_acquire();
assert!(permit.is_some());
// Second acquire should fail (only 1 permit)
assert!(gate.try_acquire().is_none());
// Drop permit, should be able to acquire again
drop(permit);
assert!(gate.try_acquire().is_some());
}
#[test]
fn test_retry_gate_custom_capacity() {
let gate = RetryGate::new(3);
let _p1 = gate.try_acquire().unwrap();
let _p2 = gate.try_acquire().unwrap();
let _p3 = gate.try_acquire().unwrap();
assert!(gate.try_acquire().is_none());
}
// ── rebuild_request_for_provider tests ─────────────────────────────
#[test]
fn test_rebuild_updates_model_field() {
let body = Bytes::from(r#"{"model":"gpt-4o","messages":[]}"#);
let headers = HeaderMap::new();
let provider = make_provider(
"openai/gpt-4o-mini",
LlmProviderType::OpenAI,
Some("sk-test"),
);
let (new_body, _) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
let json: serde_json::Value = serde_json::from_slice(&new_body).unwrap();
assert_eq!(json["model"], "gpt-4o-mini");
}
#[test]
fn test_rebuild_preserves_other_fields() {
let body = Bytes::from(
r#"{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}],"temperature":0.7}"#,
);
let headers = HeaderMap::new();
let provider = make_provider(
"openai/gpt-4o-mini",
LlmProviderType::OpenAI,
Some("sk-test"),
);
let (new_body, _) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
let json: serde_json::Value = serde_json::from_slice(&new_body).unwrap();
assert_eq!(json["messages"][0]["role"], "user");
assert_eq!(json["messages"][0]["content"], "hi");
assert_eq!(json["temperature"], 0.7);
}
#[test]
fn test_rebuild_sets_openai_auth() {
let body = Bytes::from(r#"{"model":"old"}"#);
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-key"));
let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("sk-new"));
let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
assert_eq!(
new_headers.get(AUTHORIZATION).unwrap().to_str().unwrap(),
"Bearer sk-new"
);
assert!(new_headers.get("x-api-key").is_none());
}
#[test]
fn test_rebuild_sets_anthropic_auth() {
let body = Bytes::from(r#"{"model":"old"}"#);
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-key"));
let provider = make_provider(
"anthropic/claude-3-5-sonnet",
LlmProviderType::Anthropic,
Some("ant-key"),
);
let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
// Anthropic uses x-api-key, not Authorization
assert!(new_headers.get(AUTHORIZATION).is_none());
assert_eq!(
new_headers.get("x-api-key").unwrap().to_str().unwrap(),
"ant-key"
);
assert_eq!(
new_headers
.get("anthropic-version")
.unwrap()
.to_str()
.unwrap(),
"2023-06-01"
);
}
#[test]
fn test_rebuild_sanitizes_old_auth_headers() {
let body = Bytes::from(r#"{"model":"old"}"#);
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-key"));
headers.insert("x-api-key", HeaderValue::from_static("old-api-key"));
headers.insert("anthropic-version", HeaderValue::from_static("old-version"));
headers.insert("x-custom", HeaderValue::from_static("keep-me"));
let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("sk-new"));
let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
// Old x-api-key and anthropic-version should be removed
assert!(new_headers.get("anthropic-version").is_none());
// New auth should be set
assert_eq!(
new_headers.get(AUTHORIZATION).unwrap().to_str().unwrap(),
"Bearer sk-new"
);
// Custom headers preserved
assert_eq!(
new_headers.get("x-custom").unwrap().to_str().unwrap(),
"keep-me"
);
}
#[test]
fn test_rebuild_passthrough_auth_skips_credentials() {
let body = Bytes::from(r#"{"model":"old"}"#);
let mut headers = HeaderMap::new();
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer client-key"));
let mut provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("sk-new"));
provider.passthrough_auth = Some(true);
let (_, new_headers) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
// Auth headers are sanitized, and passthrough_auth means no new ones are set
assert!(new_headers.get(AUTHORIZATION).is_none());
}
#[test]
fn test_rebuild_missing_access_key_errors() {
let body = Bytes::from(r#"{"model":"old"}"#);
let headers = HeaderMap::new();
let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, None);
let result = rebuild_request_for_provider(&body, &provider, &headers);
assert!(matches!(result, Err(RebuildError::MissingAccessKey(_))));
}
#[test]
fn test_rebuild_invalid_json_errors() {
let body = Bytes::from("not json");
let headers = HeaderMap::new();
let provider = make_provider("openai/gpt-4o", LlmProviderType::OpenAI, Some("key"));
let result = rebuild_request_for_provider(&body, &provider, &headers);
assert!(matches!(result, Err(RebuildError::InvalidJson(_))));
}
#[test]
fn test_rebuild_model_without_provider_prefix() {
let body = Bytes::from(r#"{"model":"old"}"#);
let headers = HeaderMap::new();
let mut provider = make_provider("gpt-4o", LlmProviderType::OpenAI, Some("key"));
provider.model = Some("gpt-4o".to_string());
let (new_body, _) = rebuild_request_for_provider(&body, &provider, &headers).unwrap();
let json: serde_json::Value = serde_json::from_slice(&new_body).unwrap();
// No prefix to strip, model name used as-is
assert_eq!(json["model"], "gpt-4o");
}
// --- Proptest strategies ---
fn arb_provider_type() -> impl Strategy<Value = LlmProviderType> {
prop_oneof![
Just(LlmProviderType::OpenAI),
Just(LlmProviderType::Anthropic),
Just(LlmProviderType::Gemini),
Just(LlmProviderType::Deepseek),
]
}
fn arb_model_name() -> impl Strategy<Value = String> {
prop_oneof![
Just("openai/gpt-4o".to_string()),
Just("openai/gpt-4o-mini".to_string()),
Just("anthropic/claude-3-5-sonnet".to_string()),
Just("gemini/gemini-pro".to_string()),
Just("deepseek/deepseek-chat".to_string()),
]
}
fn arb_target_provider() -> impl Strategy<Value = LlmProvider> {
(arb_model_name(), arb_provider_type())
.prop_map(|(model, iface)| make_provider(&model, iface, Some("test-key-123")))
}
fn arb_message_content() -> impl Strategy<Value = String> {
"[a-zA-Z0-9 ]{1,50}"
}
fn arb_messages() -> impl Strategy<Value = Vec<serde_json::Value>> {
prop::collection::vec(
(
prop_oneof![Just("user"), Just("assistant"), Just("system")],
arb_message_content(),
)
.prop_map(|(role, content)| serde_json::json!({"role": role, "content": content})),
1..5,
)
}
fn arb_json_body() -> impl Strategy<Value = serde_json::Value> {
(
arb_model_name(),
arb_messages(),
prop::option::of(0.0f64..2.0),
prop::option::of(1u32..4096),
proptest::bool::ANY,
)
.prop_map(|(model, messages, temperature, max_tokens, stream)| {
let model_only = model.split('/').nth(1).unwrap_or(&model);
let mut obj = serde_json::json!({
"model": model_only,
"messages": messages,
});
if let Some(t) = temperature {
obj["temperature"] = serde_json::json!(t);
}
if let Some(mt) = max_tokens {
obj["max_tokens"] = serde_json::json!(mt);
}
if stream {
obj["stream"] = serde_json::json!(true);
}
obj
})
}
fn arb_custom_headers() -> impl Strategy<Value = Vec<(String, String)>> {
prop::collection::vec(
(
prop_oneof![
Just("x-request-id".to_string()),
Just("x-custom-header".to_string()),
Just("x-trace-id".to_string()),
Just("content-type".to_string()),
],
"[a-zA-Z0-9-]{1,30}",
),
0..4,
)
}
// Feature: retry-on-ratelimit, Property 14: Request Preservation Across Retries
// **Validates: Requirements 5.1, 5.2, 5.3, 5.4, 5.5, 3.15**
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
/// Property 14 The original body bytes are unchanged after rebuild (body is passed by reference).
/// The rebuilt body has the model field updated to the target provider's model.
/// All other JSON fields are preserved. The RequestSignature hash matches the original body hash.
/// Custom headers are preserved while auth headers are sanitized.
#[test]
fn prop_request_preservation_across_retries(
json_body in arb_json_body(),
custom_headers in arb_custom_headers(),
streaming in proptest::bool::ANY,
target_provider in arb_target_provider(),
) {
let body_bytes = serde_json::to_vec(&json_body).unwrap();
let body = Bytes::from(body_bytes.clone());
// Build original headers with custom + auth headers
let mut original_headers = HeaderMap::new();
for (name, value) in &custom_headers {
if let (Ok(hn), Ok(hv)) = (
hyper::header::HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(value),
) {
original_headers.insert(hn, hv);
}
}
// Add auth headers that should be sanitized
original_headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer old-secret"));
original_headers.insert("x-api-key", HeaderValue::from_static("old-api-key"));
let original_model = json_body["model"].as_str().unwrap_or("unknown").to_string();
// Create RequestSignature from original body
let sig = RequestSignature::new(&body, &original_headers, streaming, original_model.clone());
// Assert: body bytes are unchanged (passed by reference, not modified)
prop_assert_eq!(&body[..], &body_bytes[..], "Original body bytes must be unchanged");
// Assert: RequestSignature hash matches a fresh hash of the same body
let mut hasher = Sha256::new();
hasher.update(&body);
let expected_hash: [u8; 32] = hasher.finalize().into();
prop_assert_eq!(sig.body_hash, expected_hash, "RequestSignature hash must match original body hash");
// Assert: streaming flag preserved
prop_assert_eq!(sig.streaming, streaming, "Streaming flag must be preserved in signature");
// Rebuild for target provider
let result = rebuild_request_for_provider(&body, &target_provider, &original_headers);
prop_assert!(result.is_ok(), "rebuild_request_for_provider should succeed for valid JSON body");
let (rebuilt_body, rebuilt_headers) = result.unwrap();
// Parse rebuilt body
let rebuilt_json: serde_json::Value = serde_json::from_slice(&rebuilt_body).unwrap();
// Assert: model field updated to target provider's model (without prefix)
let target_model = target_provider.model.as_deref().unwrap_or(&target_provider.name);
let expected_model = target_model.split_once('/').map(|(_, m)| m).unwrap_or(target_model);
prop_assert_eq!(
rebuilt_json["model"].as_str().unwrap(),
expected_model,
"Model field must be updated to target provider's model"
);
// Assert: messages array preserved
prop_assert_eq!(
&rebuilt_json["messages"],
&json_body["messages"],
"Messages array must be preserved across rebuild"
);
// Assert: other JSON fields preserved (temperature, max_tokens, stream)
// The rebuild function does a JSON round-trip (deserialize → modify model → serialize),
// so we compare against a round-tripped version of the original to account for
// any f64 precision changes inherent to JSON serialization.
let original_round_tripped: serde_json::Value = serde_json::from_slice(
&serde_json::to_vec(&json_body).unwrap()
).unwrap();
for key in ["temperature", "max_tokens", "stream"] {
if let Some(original_val) = original_round_tripped.get(key) {
prop_assert_eq!(
&rebuilt_json[key],
original_val,
"Field '{}' must be preserved across rebuild",
key
);
}
}
// Assert: custom headers preserved (non-auth headers)
// Note: HeaderMap::insert overwrites, so only the last value for each name survives
let mut last_custom: std::collections::HashMap<String, String> = std::collections::HashMap::new();
for (name, value) in &custom_headers {
let lower = name.to_lowercase();
if lower == "authorization" || lower == "x-api-key" || lower == "anthropic-version" {
continue;
}
last_custom.insert(lower, value.clone());
}
for (name, value) in &last_custom {
if let Some(hv) = rebuilt_headers.get(name.as_str()) {
prop_assert_eq!(
hv.to_str().unwrap(),
value.as_str(),
"Custom header '{}' must be preserved",
name
);
}
}
// Assert: old auth headers are sanitized (not leaked to target provider)
// The old "Bearer old-secret" and "old-api-key" should NOT appear
if let Some(auth) = rebuilt_headers.get(AUTHORIZATION) {
prop_assert_ne!(
auth.to_str().unwrap(),
"Bearer old-secret",
"Old authorization header must be sanitized"
);
}
if let Some(api_key) = rebuilt_headers.get("x-api-key") {
prop_assert_ne!(
api_key.to_str().unwrap(),
"old-api-key",
"Old x-api-key header must be sanitized"
);
}
// Assert: original body is still unchanged after rebuild
prop_assert_eq!(&body[..], &body_bytes[..], "Original body bytes must remain unchanged after rebuild");
}
}
}