From c34ff5b5fd6ecdae479337aed29a2b728a54cdf0 Mon Sep 17 00:00:00 2001 From: Troy Mitchell Date: Tue, 28 Apr 2026 16:19:41 +0800 Subject: [PATCH] feat: preserve original JSON bytes for prompt cache compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Avoid re-serializing request bodies when unnecessary to maintain JSON key order, whitespace, and unknown fields — critical for prompt cache prefix matching on providers like Anthropic. - routing_service: only re-serialize when routing_preferences were actually removed from the body - stream_context: replace model name at byte level instead of full deserialization/re-serialization cycle - Strip provider prefix from model name (e.g. 'custom-aws/claude-opus-4-6' -> 'claude-opus-4-6') before sending upstream Signed-off-by: Troy Mitchell --- .../src/handlers/routing_service.rs | 9 +- crates/llm_gateway/src/stream_context.rs | 232 ++++++++++++------ 2 files changed, 170 insertions(+), 71 deletions(-) diff --git a/crates/brightstaff/src/handlers/routing_service.rs b/crates/brightstaff/src/handlers/routing_service.rs index b93b1422..ed724620 100644 --- a/crates/brightstaff/src/handlers/routing_service.rs +++ b/crates/brightstaff/src/handlers/routing_service.rs @@ -45,7 +45,14 @@ pub fn extract_routing_policy( }, ); - let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap()); + // Only re-serialize if we actually removed routing_preferences. + // Otherwise preserve the original bytes to maintain JSON key order, + // whitespace, and unknown fields — critical for prompt cache hits. + let bytes = if routing_preferences.is_some() { + Bytes::from(serde_json::to_vec(&json_body).unwrap()) + } else { + Bytes::from(raw_bytes.to_vec()) + }; Ok((bytes, routing_preferences)) } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index fa9964dd..252a2511 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1023,8 +1023,15 @@ impl HttpContext for StreamContext { } }; - // Set the resolved model using the trait method - deserialized_client_request.set_model(resolved_model.clone()); + // Set the resolved model using the trait method. + // Strip provider prefix (e.g., "custom-aws/claude-opus-4-6" -> "claude-opus-4-6") + // so the upstream API receives only the model name it recognizes. + let upstream_model = if let Some((_prefix, model_only)) = resolved_model.split_once('/') { + model_only.to_string() + } else { + resolved_model.clone() + }; + deserialized_client_request.set_model(upstream_model.clone()); // Extract user message for tracing self.user_message = deserialized_client_request.get_recent_user_message(); @@ -1056,82 +1063,93 @@ impl HttpContext for StreamContext { return Action::Continue; } - // Convert chat completion request to llm provider specific request using provider interface - let serialized_body_bytes_upstream = match self.resolved_api.as_ref() { - Some(upstream) => { - info!( - "request_id={}: upstream transform, client_api={:?} -> upstream_api={:?}", - self.request_identifier(), - self.client_api, - upstream - ); - - match ProviderRequestType::try_from((deserialized_client_request, upstream)) { - Ok(mut request) => { - if let Err(e) = - request.normalize_for_upstream(self.get_provider_id(), upstream) - { - warn!( - "request_id={}: normalize_for_upstream failed: {}", - self.request_identifier(), - e - ); + // Preserve original body bytes for prompt cache compatibility. + // Only replace the "model" field value at the byte level instead of + // deserializing + re-serializing, which destroys key order, whitespace, + // and unknown fields — breaking prompt cache prefix matching. + // Use upstream_model (prefix-stripped) so the upstream API receives + // only the model name it recognizes. + let original_model = model_requested.as_str(); + let serialized_body_bytes_upstream = if original_model != upstream_model.as_str() { + match replace_json_model_value(&body_bytes, original_model, &upstream_model) { + Some(patched) => { + debug!( + "request_id={}: byte-level model replacement '{}' -> '{}'", + self.request_identifier(), + original_model, + upstream_model + ); + patched + } + None => { + // Fallback: full re-serialization if byte-level replacement fails + warn!( + "request_id={}: byte-level model replacement failed, falling back to re-serialization", + self.request_identifier() + ); + match self.resolved_api.as_ref() { + Some(upstream) => { + match ProviderRequestType::try_from(( + deserialized_client_request, + upstream, + )) { + Ok(mut request) => { + if let Err(e) = request + .normalize_for_upstream(self.get_provider_id(), upstream) + { + warn!( + "request_id={}: normalize_for_upstream failed: {}", + self.request_identifier(), + e + ); + self.send_server_error( + ServerError::LogicError(e.message), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + match request.to_bytes() { + Ok(bytes) => bytes, + Err(e) => { + self.send_server_error( + ServerError::LogicError(format!( + "Request serialization error: {}", + e + )), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + } + } + Err(e) => { + self.send_server_error( + ServerError::LogicError(format!( + "Provider request error: {}", + e + )), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + } + } + None => { self.send_server_error( - ServerError::LogicError(e.message), + ServerError::LogicError("No upstream API resolved".into()), Some(StatusCode::BAD_REQUEST), ); return Action::Pause; } - debug!( - "request_id={}: upstream request payload: {}", - self.request_identifier(), - String::from_utf8_lossy(&request.to_bytes().unwrap_or_default()) - ); - - match request.to_bytes() { - Ok(bytes) => bytes, - Err(e) => { - warn!( - "request_id={}: failed to serialize request body: {}", - self.request_identifier(), - e - ); - self.send_server_error( - ServerError::LogicError(format!( - "Request serialization error: {}", - e - )), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - } - } - Err(e) => { - warn!( - "request_id={}: failed to create provider request: {}", - self.request_identifier(), - e - ); - self.send_server_error( - ServerError::LogicError(format!("Provider request error: {}", e)), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; } } } - None => { - warn!( - "request_id={}: no upstream api resolved", - self.request_identifier() - ); - self.send_server_error( - ServerError::LogicError("No upstream API resolved".into()), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } + } else { + debug!( + "request_id={}: model unchanged, passing original body through", + self.request_identifier() + ); + body_bytes.clone() }; self.set_http_request_body(0, body_size, &serialized_body_bytes_upstream); @@ -1260,6 +1278,80 @@ impl HttpContext for StreamContext { } } +/// Replace the value of the top-level `"model"` key in a JSON byte slice +/// without re-serializing. Returns `Some(new_bytes)` on success, `None` if the +/// pattern wasn't found (caller should fall back to full re-serialization). +/// +/// This is intentionally simple and does NOT use regex (unavailable in WASM). +/// It scans for `"model"` followed by `:` and a quoted string value, then +/// splices in the new model name. Works for the common case where model values +/// are simple strings like `"gpt-4o"` without JSON escapes. +fn replace_json_model_value(body: &[u8], old_model: &str, new_model: &str) -> Option> { + // Build the needle: `"model"` (we'll then skip whitespace + colon + whitespace + opening quote) + let model_key = b"\"model\""; + + // Find the position of `"model"` key + let key_pos = find_bytes(body, model_key)?; + + // After the key, skip whitespace, expect ':', skip whitespace, expect '"' + let mut pos = key_pos + model_key.len(); + pos = skip_json_whitespace(body, pos); + if body.get(pos)? != &b':' { + return None; + } + pos += 1; + pos = skip_json_whitespace(body, pos); + if body.get(pos)? != &b'"' { + return None; + } + let _value_start_quote = pos; // position of the opening '"' + pos += 1; + + // Find the closing quote (handle escaped quotes) + let value_content_start = pos; + loop { + let ch = *body.get(pos)?; + if ch == b'\\' { + pos += 2; // skip escaped char + continue; + } + if ch == b'"' { + break; + } + pos += 1; + } + let value_content_end = pos; // position of closing '"' + + // Verify the current value matches old_model + let current_value = &body[value_content_start..value_content_end]; + if current_value != old_model.as_bytes() { + return None; + } + + // Build new body: everything before value content + new model + everything after + let mut result = Vec::with_capacity(body.len() + new_model.len() - old_model.len()); + result.extend_from_slice(&body[..value_content_start]); + result.extend_from_slice(new_model.as_bytes()); + result.extend_from_slice(&body[value_content_end..]); + Some(result) +} + +/// Find first occurrence of `needle` in `haystack`. +fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option { + if needle.is_empty() || needle.len() > haystack.len() { + return None; + } + (0..=haystack.len() - needle.len()).find(|&i| &haystack[i..i + needle.len()] == needle) +} + +/// Skip JSON whitespace (space, tab, newline, carriage return). +fn skip_json_whitespace(data: &[u8], mut pos: usize) -> usize { + while pos < data.len() && matches!(data[pos], b' ' | b'\t' | b'\n' | b'\r') { + pos += 1; + } + pos +} + fn current_time_ns() -> u128 { SystemTime::now() .duration_since(UNIX_EPOCH)