mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
feat: preserve original JSON bytes for prompt cache compatibility
Avoid re-serializing request bodies when unnecessary to maintain JSON key order, whitespace, and unknown fields — critical for prompt cache prefix matching on providers like Anthropic. - routing_service: only re-serialize when routing_preferences were actually removed from the body - stream_context: replace model name at byte level instead of full deserialization/re-serialization cycle - Strip provider prefix from model name (e.g. 'custom-aws/claude-opus-4-6' -> 'claude-opus-4-6') before sending upstream Signed-off-by: Troy Mitchell <i@troy-y.org>
This commit is contained in:
parent
d29ed70c0a
commit
c34ff5b5fd
2 changed files with 170 additions and 71 deletions
|
|
@ -45,7 +45,14 @@ pub fn extract_routing_policy(
|
|||
},
|
||||
);
|
||||
|
||||
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
|
||||
// Only re-serialize if we actually removed routing_preferences.
|
||||
// Otherwise preserve the original bytes to maintain JSON key order,
|
||||
// whitespace, and unknown fields — critical for prompt cache hits.
|
||||
let bytes = if routing_preferences.is_some() {
|
||||
Bytes::from(serde_json::to_vec(&json_body).unwrap())
|
||||
} else {
|
||||
Bytes::from(raw_bytes.to_vec())
|
||||
};
|
||||
Ok((bytes, routing_preferences))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1023,8 +1023,15 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
// Set the resolved model using the trait method
|
||||
deserialized_client_request.set_model(resolved_model.clone());
|
||||
// Set the resolved model using the trait method.
|
||||
// Strip provider prefix (e.g., "custom-aws/claude-opus-4-6" -> "claude-opus-4-6")
|
||||
// so the upstream API receives only the model name it recognizes.
|
||||
let upstream_model = if let Some((_prefix, model_only)) = resolved_model.split_once('/') {
|
||||
model_only.to_string()
|
||||
} else {
|
||||
resolved_model.clone()
|
||||
};
|
||||
deserialized_client_request.set_model(upstream_model.clone());
|
||||
|
||||
// Extract user message for tracing
|
||||
self.user_message = deserialized_client_request.get_recent_user_message();
|
||||
|
|
@ -1056,82 +1063,93 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
// Convert chat completion request to llm provider specific request using provider interface
|
||||
let serialized_body_bytes_upstream = match self.resolved_api.as_ref() {
|
||||
Some(upstream) => {
|
||||
info!(
|
||||
"request_id={}: upstream transform, client_api={:?} -> upstream_api={:?}",
|
||||
self.request_identifier(),
|
||||
self.client_api,
|
||||
upstream
|
||||
);
|
||||
|
||||
match ProviderRequestType::try_from((deserialized_client_request, upstream)) {
|
||||
Ok(mut request) => {
|
||||
if let Err(e) =
|
||||
request.normalize_for_upstream(self.get_provider_id(), upstream)
|
||||
{
|
||||
warn!(
|
||||
"request_id={}: normalize_for_upstream failed: {}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
// Preserve original body bytes for prompt cache compatibility.
|
||||
// Only replace the "model" field value at the byte level instead of
|
||||
// deserializing + re-serializing, which destroys key order, whitespace,
|
||||
// and unknown fields — breaking prompt cache prefix matching.
|
||||
// Use upstream_model (prefix-stripped) so the upstream API receives
|
||||
// only the model name it recognizes.
|
||||
let original_model = model_requested.as_str();
|
||||
let serialized_body_bytes_upstream = if original_model != upstream_model.as_str() {
|
||||
match replace_json_model_value(&body_bytes, original_model, &upstream_model) {
|
||||
Some(patched) => {
|
||||
debug!(
|
||||
"request_id={}: byte-level model replacement '{}' -> '{}'",
|
||||
self.request_identifier(),
|
||||
original_model,
|
||||
upstream_model
|
||||
);
|
||||
patched
|
||||
}
|
||||
None => {
|
||||
// Fallback: full re-serialization if byte-level replacement fails
|
||||
warn!(
|
||||
"request_id={}: byte-level model replacement failed, falling back to re-serialization",
|
||||
self.request_identifier()
|
||||
);
|
||||
match self.resolved_api.as_ref() {
|
||||
Some(upstream) => {
|
||||
match ProviderRequestType::try_from((
|
||||
deserialized_client_request,
|
||||
upstream,
|
||||
)) {
|
||||
Ok(mut request) => {
|
||||
if let Err(e) = request
|
||||
.normalize_for_upstream(self.get_provider_id(), upstream)
|
||||
{
|
||||
warn!(
|
||||
"request_id={}: normalize_for_upstream failed: {}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(e.message),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
match request.to_bytes() {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!(
|
||||
"Request serialization error: {}",
|
||||
e
|
||||
)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!(
|
||||
"Provider request error: {}",
|
||||
e
|
||||
)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(e.message),
|
||||
ServerError::LogicError("No upstream API resolved".into()),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
debug!(
|
||||
"request_id={}: upstream request payload: {}",
|
||||
self.request_identifier(),
|
||||
String::from_utf8_lossy(&request.to_bytes().unwrap_or_default())
|
||||
);
|
||||
|
||||
match request.to_bytes() {
|
||||
Ok(bytes) => bytes,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"request_id={}: failed to serialize request body: {}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!(
|
||||
"Request serialization error: {}",
|
||||
e
|
||||
)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"request_id={}: failed to create provider request: {}",
|
||||
self.request_identifier(),
|
||||
e
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Provider request error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
"request_id={}: no upstream api resolved",
|
||||
self.request_identifier()
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError("No upstream API resolved".into()),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
"request_id={}: model unchanged, passing original body through",
|
||||
self.request_identifier()
|
||||
);
|
||||
body_bytes.clone()
|
||||
};
|
||||
|
||||
self.set_http_request_body(0, body_size, &serialized_body_bytes_upstream);
|
||||
|
|
@ -1260,6 +1278,80 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
/// Replace the value of the top-level `"model"` key in a JSON byte slice
|
||||
/// without re-serializing. Returns `Some(new_bytes)` on success, `None` if the
|
||||
/// pattern wasn't found (caller should fall back to full re-serialization).
|
||||
///
|
||||
/// This is intentionally simple and does NOT use regex (unavailable in WASM).
|
||||
/// It scans for `"model"` followed by `:` and a quoted string value, then
|
||||
/// splices in the new model name. Works for the common case where model values
|
||||
/// are simple strings like `"gpt-4o"` without JSON escapes.
|
||||
fn replace_json_model_value(body: &[u8], old_model: &str, new_model: &str) -> Option<Vec<u8>> {
|
||||
// Build the needle: `"model"` (we'll then skip whitespace + colon + whitespace + opening quote)
|
||||
let model_key = b"\"model\"";
|
||||
|
||||
// Find the position of `"model"` key
|
||||
let key_pos = find_bytes(body, model_key)?;
|
||||
|
||||
// After the key, skip whitespace, expect ':', skip whitespace, expect '"'
|
||||
let mut pos = key_pos + model_key.len();
|
||||
pos = skip_json_whitespace(body, pos);
|
||||
if body.get(pos)? != &b':' {
|
||||
return None;
|
||||
}
|
||||
pos += 1;
|
||||
pos = skip_json_whitespace(body, pos);
|
||||
if body.get(pos)? != &b'"' {
|
||||
return None;
|
||||
}
|
||||
let _value_start_quote = pos; // position of the opening '"'
|
||||
pos += 1;
|
||||
|
||||
// Find the closing quote (handle escaped quotes)
|
||||
let value_content_start = pos;
|
||||
loop {
|
||||
let ch = *body.get(pos)?;
|
||||
if ch == b'\\' {
|
||||
pos += 2; // skip escaped char
|
||||
continue;
|
||||
}
|
||||
if ch == b'"' {
|
||||
break;
|
||||
}
|
||||
pos += 1;
|
||||
}
|
||||
let value_content_end = pos; // position of closing '"'
|
||||
|
||||
// Verify the current value matches old_model
|
||||
let current_value = &body[value_content_start..value_content_end];
|
||||
if current_value != old_model.as_bytes() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Build new body: everything before value content + new model + everything after
|
||||
let mut result = Vec::with_capacity(body.len() + new_model.len() - old_model.len());
|
||||
result.extend_from_slice(&body[..value_content_start]);
|
||||
result.extend_from_slice(new_model.as_bytes());
|
||||
result.extend_from_slice(&body[value_content_end..]);
|
||||
Some(result)
|
||||
}
|
||||
|
||||
/// Find first occurrence of `needle` in `haystack`.
|
||||
fn find_bytes(haystack: &[u8], needle: &[u8]) -> Option<usize> {
|
||||
if needle.is_empty() || needle.len() > haystack.len() {
|
||||
return None;
|
||||
}
|
||||
(0..=haystack.len() - needle.len()).find(|&i| &haystack[i..i + needle.len()] == needle)
|
||||
}
|
||||
|
||||
/// Skip JSON whitespace (space, tab, newline, carriage return).
|
||||
fn skip_json_whitespace(data: &[u8], mut pos: usize) -> usize {
|
||||
while pos < data.len() && matches!(data[pos], b' ' | b'\t' | b'\n' | b'\r') {
|
||||
pos += 1;
|
||||
}
|
||||
pos
|
||||
}
|
||||
|
||||
fn current_time_ns() -> u128 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue