feat: preserve original JSON bytes for prompt cache compatibility

Avoid re-serializing request bodies when unnecessary to maintain
JSON key order, whitespace, and unknown fields — critical for
prompt cache prefix matching on providers like Anthropic.

- routing_service: only re-serialize when routing_preferences
  were actually removed from the body
- stream_context: replace model name at byte level instead of
  full deserialization/re-serialization cycle
- Strip provider prefix from model name (e.g. 'custom-aws/claude-opus-4-6'
  -> 'claude-opus-4-6') before sending upstream

Signed-off-by: Troy Mitchell <i@troy-y.org>
This commit is contained in:
Troy Mitchell 2026-04-28 16:19:41 +08:00
parent d29ed70c0a
commit c34ff5b5fd
2 changed files with 170 additions and 71 deletions

View file

@ -45,7 +45,14 @@ pub fn extract_routing_policy(
},
);
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
// Only re-serialize if we actually removed routing_preferences.
// Otherwise preserve the original bytes to maintain JSON key order,
// whitespace, and unknown fields — critical for prompt cache hits.
let bytes = if routing_preferences.is_some() {
Bytes::from(serde_json::to_vec(&json_body).unwrap())
} else {
Bytes::from(raw_bytes.to_vec())
};
Ok((bytes, routing_preferences))
}

View file

@ -1023,8 +1023,15 @@ impl HttpContext for StreamContext {
}
};
// Set the resolved model using the trait method
deserialized_client_request.set_model(resolved_model.clone());
// Set the resolved model using the trait method.
// Strip provider prefix (e.g., "custom-aws/claude-opus-4-6" -> "claude-opus-4-6")
// so the upstream API receives only the model name it recognizes.
let upstream_model = if let Some((_prefix, model_only)) = resolved_model.split_once('/') {
model_only.to_string()
} else {
resolved_model.clone()
};
deserialized_client_request.set_model(upstream_model.clone());
// Extract user message for tracing
self.user_message = deserialized_client_request.get_recent_user_message();
@ -1056,82 +1063,93 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
// Convert chat completion request to llm provider specific request using provider interface
let serialized_body_bytes_upstream = match self.resolved_api.as_ref() {
Some(upstream) => {
info!(
"request_id={}: upstream transform, client_api={:?} -> upstream_api={:?}",
self.request_identifier(),
self.client_api,
upstream
);
match ProviderRequestType::try_from((deserialized_client_request, upstream)) {
Ok(mut request) => {
if let Err(e) =
request.normalize_for_upstream(self.get_provider_id(), upstream)
{
warn!(
"request_id={}: normalize_for_upstream failed: {}",
self.request_identifier(),
e
);
// Preserve original body bytes for prompt cache compatibility.
// Only replace the "model" field value at the byte level instead of
// deserializing + re-serializing, which destroys key order, whitespace,
// and unknown fields — breaking prompt cache prefix matching.
// Use upstream_model (prefix-stripped) so the upstream API receives
// only the model name it recognizes.
let original_model = model_requested.as_str();
let serialized_body_bytes_upstream = if original_model != upstream_model.as_str() {
match replace_json_model_value(&body_bytes, original_model, &upstream_model) {
Some(patched) => {
debug!(
"request_id={}: byte-level model replacement '{}' -> '{}'",
self.request_identifier(),
original_model,
upstream_model
);
patched
}
None => {
// Fallback: full re-serialization if byte-level replacement fails
warn!(
"request_id={}: byte-level model replacement failed, falling back to re-serialization",
self.request_identifier()
);
match self.resolved_api.as_ref() {
Some(upstream) => {
match ProviderRequestType::try_from((
deserialized_client_request,
upstream,
)) {
Ok(mut request) => {
if let Err(e) = request
.normalize_for_upstream(self.get_provider_id(), upstream)
{
warn!(
"request_id={}: normalize_for_upstream failed: {}",
self.request_identifier(),
e
);
self.send_server_error(
ServerError::LogicError(e.message),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
match request.to_bytes() {
Ok(bytes) => bytes,
Err(e) => {
self.send_server_error(
ServerError::LogicError(format!(
"Request serialization error: {}",
e
)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
}
}
Err(e) => {
self.send_server_error(
ServerError::LogicError(format!(
"Provider request error: {}",
e
)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
}
}
None => {
self.send_server_error(
ServerError::LogicError(e.message),
ServerError::LogicError("No upstream API resolved".into()),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
debug!(
"request_id={}: upstream request payload: {}",
self.request_identifier(),
String::from_utf8_lossy(&request.to_bytes().unwrap_or_default())
);
match request.to_bytes() {
Ok(bytes) => bytes,
Err(e) => {
warn!(
"request_id={}: failed to serialize request body: {}",
self.request_identifier(),
e
);
self.send_server_error(
ServerError::LogicError(format!(
"Request serialization error: {}",
e
)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
}
}
Err(e) => {
warn!(
"request_id={}: failed to create provider request: {}",
self.request_identifier(),
e
);
self.send_server_error(
ServerError::LogicError(format!("Provider request error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
}
}
None => {
warn!(
"request_id={}: no upstream api resolved",
self.request_identifier()
);
self.send_server_error(
ServerError::LogicError("No upstream API resolved".into()),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
} else {
debug!(
"request_id={}: model unchanged, passing original body through",
self.request_identifier()
);
body_bytes.clone()
};
self.set_http_request_body(0, body_size, &serialized_body_bytes_upstream);
@ -1260,6 +1278,80 @@ impl HttpContext for StreamContext {
}
}
/// Replace the value of the top-level `"model"` key in a JSON byte slice
/// without re-serializing. Returns `Some(new_bytes)` on success, `None` if the
/// pattern wasn't found (caller should fall back to full re-serialization).
///
/// This is intentionally simple and does NOT use regex (unavailable in WASM).
/// It scans for `"model"` followed by `:` and a quoted string value, then
/// splices in the new model name. Works for the common case where model values
/// are simple strings like `"gpt-4o"` without JSON escapes.
fn replace_json_model_value(body: &[u8], old_model: &str, new_model: &str) -> Option<Vec<u8>> {
// Build the needle: `"model"` (we'll then skip whitespace + colon + whitespace + opening quote)
let model_key = b"\"model\"";
// Find the position of `"model"` key
let key_pos = find_bytes(body, model_key)?;
// After the key, skip whitespace, expect ':', skip whitespace, expect '"'
let mut pos = key_pos + model_key.len();
pos = skip_json_whitespace(body, pos);
if body.get(pos)? != &b':' {
return None;
}
pos += 1;
pos = skip_json_whitespace(body, pos);
if body.get(pos)? != &b'"' {
return None;
}
let _value_start_quote = pos; // position of the opening '"'
pos += 1;
// Find the closing quote (handle escaped quotes)
let value_content_start = pos;
loop {
let ch = *body.get(pos)?;
if ch == b'\\' {
pos += 2; // skip escaped char
continue;
}
if ch == b'"' {
break;
}
pos += 1;
}
let value_content_end = pos; // position of closing '"'
// Verify the current value matches old_model
let current_value = &body[value_content_start..value_content_end];
if current_value != old_model.as_bytes() {
return None;
}
// Build new body: everything before value content + new model + everything after
let mut result = Vec::with_capacity(body.len() + new_model.len() - old_model.len());
result.extend_from_slice(&body[..value_content_start]);
result.extend_from_slice(new_model.as_bytes());
result.extend_from_slice(&body[value_content_end..]);
Some(result)
}
/// Find first occurrence of `needle` in `haystack`.
fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
if needle.is_empty() || needle.len() > haystack.len() {
return None;
}
(0..=haystack.len() - needle.len()).find(|&i| &haystack[i..i + needle.len()] == needle)
}
/// Skip JSON whitespace (space, tab, newline, carriage return).
fn skip_json_whitespace(data: &[u8], mut pos: usize) -> usize {
while pos < data.len() && matches!(data[pos], b' ' | b'\t' | b'\n' | b'\r') {
pos += 1;
}
pos
}
fn current_time_ns() -> u128 {
SystemTime::now()
.duration_since(UNIX_EPOCH)