fix: truncate oversized user messages in orchestrator routing prompt (#895)

This commit is contained in:
Adil Hafeez 2026-04-17 21:01:30 -07:00 committed by GitHub
parent 37600fd07a
commit 95a7beaab3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 423 additions and 39 deletions

View file

@ -1,8 +1,14 @@
use hermesllm::apis::openai::ChatCompletionsResponse; use hermesllm::apis::openai::ChatCompletionsResponse;
use hyper::header; use hyper::header;
use serde::Deserialize;
use thiserror::Error; use thiserror::Error;
use tracing::warn; use tracing::warn;
/// Max bytes of raw upstream body we include in a log message or error text
/// when the body is not a recognizable error envelope. Keeps logs from being
/// flooded by huge HTML error pages.
const RAW_BODY_LOG_LIMIT: usize = 512;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum HttpError { pub enum HttpError {
#[error("Failed to send request: {0}")] #[error("Failed to send request: {0}")]
@ -10,13 +16,64 @@ pub enum HttpError {
#[error("Failed to parse JSON response: {0}")] #[error("Failed to parse JSON response: {0}")]
Json(serde_json::Error, String), Json(serde_json::Error, String),
#[error("Upstream returned {status}: {message}")]
Upstream { status: u16, message: String },
}
/// Shape of an OpenAI-style error response body, e.g.
/// `{"error": {"message": "...", "type": "...", "param": "...", "code": ...}}`.
#[derive(Debug, Deserialize)]
struct UpstreamErrorEnvelope {
error: UpstreamErrorBody,
}
#[derive(Debug, Deserialize)]
struct UpstreamErrorBody {
message: String,
#[serde(default, rename = "type")]
err_type: Option<String>,
#[serde(default)]
param: Option<String>,
}
/// Extract a human-readable error message from an upstream response body.
/// Tries to parse an OpenAI-style `{"error": {"message": ...}}` envelope; if
/// that fails, falls back to the first `RAW_BODY_LOG_LIMIT` bytes of the raw
/// body (UTF-8 safe).
fn extract_upstream_error_message(body: &str) -> String {
if let Ok(env) = serde_json::from_str::<UpstreamErrorEnvelope>(body) {
let mut msg = env.error.message;
if let Some(param) = env.error.param {
msg.push_str(&format!(" (param={param})"));
}
if let Some(err_type) = env.error.err_type {
msg.push_str(&format!(" [type={err_type}]"));
}
return msg;
}
truncate_for_log(body).to_string()
}
fn truncate_for_log(s: &str) -> &str {
if s.len() <= RAW_BODY_LOG_LIMIT {
return s;
}
let mut end = RAW_BODY_LOG_LIMIT;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
} }
/// Sends a POST request to the given URL and extracts the text content /// Sends a POST request to the given URL and extracts the text content
/// from the first choice of the `ChatCompletionsResponse`. /// from the first choice of the `ChatCompletionsResponse`.
/// ///
/// Returns `Some((content, elapsed))` on success, or `None` if the response /// Returns `Some((content, elapsed))` on success, `None` if the response
/// had no choices or the first choice had no content. /// had no choices or the first choice had no content. Returns
/// `HttpError::Upstream` for any non-2xx status, carrying a message
/// extracted from the OpenAI-style error envelope (or a truncated raw body
/// if the body is not in that shape).
pub async fn post_and_extract_content( pub async fn post_and_extract_content(
client: &reqwest::Client, client: &reqwest::Client,
url: &str, url: &str,
@ -26,17 +83,36 @@ pub async fn post_and_extract_content(
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let res = client.post(url).headers(headers).body(body).send().await?; let res = client.post(url).headers(headers).body(body).send().await?;
let status = res.status();
let body = res.text().await?; let body = res.text().await?;
let elapsed = start_time.elapsed(); let elapsed = start_time.elapsed();
if !status.is_success() {
let message = extract_upstream_error_message(&body);
warn!(
status = status.as_u16(),
message = %message,
body_size = body.len(),
"upstream returned error response"
);
return Err(HttpError::Upstream {
status: status.as_u16(),
message,
});
}
let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| { let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| {
warn!(error = %err, body = %body, "failed to parse json response"); warn!(
error = %err,
body = %truncate_for_log(&body),
"failed to parse json response",
);
HttpError::Json(err, format!("Failed to parse JSON: {}", body)) HttpError::Json(err, format!("Failed to parse JSON: {}", body))
})?; })?;
if response.choices.is_empty() { if response.choices.is_empty() {
warn!(body = %body, "no choices in response"); warn!(body = %truncate_for_log(&body), "no choices in response");
return Ok(None); return Ok(None);
} }
@ -46,3 +122,52 @@ pub async fn post_and_extract_content(
.as_ref() .as_ref()
.map(|c| (c.clone(), elapsed))) .map(|c| (c.clone(), elapsed)))
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extracts_message_from_openai_style_error_envelope() {
let body = r#"{"error":{"code":400,"message":"This model's maximum context length is 32768 tokens. However, you requested 0 output tokens and your prompt contains at least 32769 input tokens, for a total of at least 32769 tokens.","param":"input_tokens","type":"BadRequestError"}}"#;
let msg = extract_upstream_error_message(body);
assert!(
msg.starts_with("This model's maximum context length is 32768 tokens."),
"unexpected message: {msg}"
);
assert!(msg.contains("(param=input_tokens)"));
assert!(msg.contains("[type=BadRequestError]"));
}
#[test]
fn extracts_message_without_optional_fields() {
let body = r#"{"error":{"message":"something broke"}}"#;
let msg = extract_upstream_error_message(body);
assert_eq!(msg, "something broke");
}
#[test]
fn falls_back_to_raw_body_when_not_error_envelope() {
let body = "<html><body>502 Bad Gateway</body></html>";
let msg = extract_upstream_error_message(body);
assert_eq!(msg, body);
}
#[test]
fn truncates_non_envelope_bodies_in_logs() {
let body = "x".repeat(RAW_BODY_LOG_LIMIT * 3);
let msg = extract_upstream_error_message(&body);
assert_eq!(msg.len(), RAW_BODY_LOG_LIMIT);
}
#[test]
fn truncate_for_log_respects_utf8_boundaries() {
// 2-byte characters; picking a length that would split mid-char.
let body = "é".repeat(RAW_BODY_LOG_LIMIT);
let out = truncate_for_log(&body);
// Should be a valid &str (implicit — would panic if we returned
// a non-boundary slice) and at most RAW_BODY_LOG_LIMIT bytes.
assert!(out.len() <= RAW_BODY_LOG_LIMIT);
assert!(out.chars().all(|c| c == 'é'));
}
}

View file

@ -10,6 +10,18 @@ use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model
/// Hard cap on the number of recent messages considered when building the
/// routing prompt. Bounds prompt growth for long-running conversations and
/// acts as an outer guardrail before the token-budget loop runs. The most
/// recent `MAX_ROUTING_TURNS` filtered messages are kept; older turns are
/// dropped entirely.
pub const MAX_ROUTING_TURNS: usize = 16;
/// Unicode ellipsis used to mark where content was trimmed out of a long
/// message. Helps signal to the downstream router model that the message was
/// truncated.
const TRIM_MARKER: &str = "";
/// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python /// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python
struct SpacedJsonFormatter; struct SpacedJsonFormatter;
@ -176,10 +188,9 @@ impl OrchestratorModel for OrchestratorModelV1 {
messages: &[Message], messages: &[Message],
usage_preferences_from_request: &Option<Vec<AgentUsagePreference>>, usage_preferences_from_request: &Option<Vec<AgentUsagePreference>>,
) -> ChatCompletionsRequest { ) -> ChatCompletionsRequest {
// remove system prompt, tool calls, tool call response and messages without content // Remove system/developer/tool messages and messages without extractable
// if content is empty its likely a tool call // text (tool calls have no text content we can classify against).
// when role == tool its tool call response let filtered: Vec<&Message> = messages
let messages_vec = messages
.iter() .iter()
.filter(|m| { .filter(|m| {
m.role != Role::System m.role != Role::System
@ -187,37 +198,72 @@ impl OrchestratorModel for OrchestratorModelV1 {
&& m.role != Role::Tool && m.role != Role::Tool
&& !m.content.extract_text().is_empty() && !m.content.extract_text().is_empty()
}) })
.collect::<Vec<&Message>>(); .collect();
// Following code is to ensure that the conversation does not exceed max token length // Outer guardrail: only consider the last `MAX_ROUTING_TURNS` filtered
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance // messages when building the routing prompt. Keeps prompt growth
// predictable for long conversations regardless of per-message size.
let start = filtered.len().saturating_sub(MAX_ROUTING_TURNS);
let messages_vec: &[&Message] = &filtered[start..];
// Ensure the conversation does not exceed the configured token budget.
// We use `len() / TOKEN_LENGTH_DIVISOR` as a cheap token estimate to
// avoid running a real tokenizer on the hot path.
let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
let mut selected_messages_list_reversed: Vec<&Message> = vec![]; let mut selected_messages_list_reversed: Vec<Message> = vec![];
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() { for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR; let message_text = message.content.extract_text();
token_count += message_token_count; let message_token_count = message_text.len() / TOKEN_LENGTH_DIVISOR;
if token_count > self.max_token_length { if token_count + message_token_count > self.max_token_length {
let remaining_tokens = self.max_token_length.saturating_sub(token_count);
debug!( debug!(
token_count = token_count, attempted_total_tokens = token_count + message_token_count,
max_tokens = self.max_token_length, max_tokens = self.max_token_length,
remaining_tokens,
selected = selected_messsage_count, selected = selected_messsage_count,
total = messages_vec.len(), total = messages_vec.len(),
"token count exceeds max, truncating conversation" "token count exceeds max, truncating conversation"
); );
if message.role == Role::User { // If the overflow message is from the user we need to keep
// If message that exceeds max token length is from user, we need to keep it // some of it so the orchestrator still sees the latest user
selected_messages_list_reversed.push(message); // intent. Use a middle-trim (head + ellipsis + tail): users
// often frame the task at the start AND put the actual ask
// at the end of a long pasted block, so preserving both is
// better than a head-only cut. The ellipsis also signals to
// the router model that content was dropped.
if message.role == Role::User && remaining_tokens > 0 {
let max_bytes = remaining_tokens.saturating_mul(TOKEN_LENGTH_DIVISOR);
let truncated = trim_middle_utf8(&message_text, max_bytes);
selected_messages_list_reversed.push(Message {
role: Role::User,
content: Some(MessageContent::Text(truncated)),
name: None,
tool_calls: None,
tool_call_id: None,
});
} }
break; break;
} }
// If we are here, it means that the message is within the max token length token_count += message_token_count;
selected_messages_list_reversed.push(message); selected_messages_list_reversed.push(Message {
role: message.role.clone(),
content: Some(MessageContent::Text(message_text)),
name: None,
tool_calls: None,
tool_call_id: None,
});
} }
if selected_messages_list_reversed.is_empty() { if selected_messages_list_reversed.is_empty() {
debug!("no messages selected, using last message"); debug!("no messages selected, using last message");
if let Some(last_message) = messages_vec.last() { if let Some(last_message) = messages_vec.last() {
selected_messages_list_reversed.push(last_message); selected_messages_list_reversed.push(Message {
role: last_message.role.clone(),
content: Some(MessageContent::Text(last_message.content.extract_text())),
name: None,
tool_calls: None,
tool_call_id: None,
});
} }
} }
@ -237,22 +283,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
} }
// Reverse the selected messages to maintain the conversation order // Reverse the selected messages to maintain the conversation order
let selected_conversation_list = selected_messages_list_reversed let selected_conversation_list: Vec<Message> =
.iter() selected_messages_list_reversed.into_iter().rev().collect();
.rev()
.map(|message| Message {
role: message.role.clone(),
content: Some(MessageContent::Text(
message
.content
.as_ref()
.map_or(String::new(), |c| c.to_string()),
)),
name: None,
tool_calls: None,
tool_call_id: None,
})
.collect::<Vec<Message>>();
// Generate the orchestrator request message based on the usage preferences. // Generate the orchestrator request message based on the usage preferences.
// If preferences are passed in request then we use them; // If preferences are passed in request then we use them;
@ -405,6 +437,45 @@ fn fix_json_response(body: &str) -> String {
body.replace("'", "\"").replace("\\n", "") body.replace("'", "\"").replace("\\n", "")
} }
/// Truncate `s` so the result is at most `max_bytes` bytes long, keeping
/// roughly 60% from the start and 40% from the end, with a Unicode ellipsis
/// separating the two. All splits respect UTF-8 character boundaries. When
/// `max_bytes` is too small to fit the marker at all, falls back to a
/// head-only truncation.
fn trim_middle_utf8(s: &str, max_bytes: usize) -> String {
if s.len() <= max_bytes {
return s.to_string();
}
if max_bytes <= TRIM_MARKER.len() {
// Not enough room even for the marker — just keep the start.
let mut end = max_bytes;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
return s[..end].to_string();
}
let available = max_bytes - TRIM_MARKER.len();
// Bias toward the start (60%) where task framing typically lives, while
// still preserving ~40% of the tail where the user's actual ask often
// appears after a long paste.
let mut start_len = available * 3 / 5;
while start_len > 0 && !s.is_char_boundary(start_len) {
start_len -= 1;
}
let end_len = available - start_len;
let mut end_start = s.len().saturating_sub(end_len);
while end_start < s.len() && !s.is_char_boundary(end_start) {
end_start += 1;
}
let mut out = String::with_capacity(start_len + TRIM_MARKER.len() + (s.len() - end_start));
out.push_str(&s[..start_len]);
out.push_str(TRIM_MARKER);
out.push_str(&s[end_start..]);
out
}
impl std::fmt::Debug for dyn OrchestratorModel { impl std::fmt::Debug for dyn OrchestratorModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "OrchestratorModel") write!(f, "OrchestratorModel")
@ -777,6 +848,10 @@ If no routes are needed, return an empty list for `route`.
#[test] #[test]
fn test_conversation_trim_upto_user_message() { fn test_conversation_trim_upto_user_message() {
// With max_token_length=230, the older user message "given the image
// In style of Andy Warhol" overflows the remaining budget and gets
// middle-trimmed (head + ellipsis + tail) until it fits. Newer turns
// are kept in full.
let expected_prompt = r#" let expected_prompt = r#"
You are a helpful assistant that selects the most suitable routes based on user intent. You are a helpful assistant that selects the most suitable routes based on user intent.
You are provided with a list of available routes enclosed within <routes></routes> XML tags: You are provided with a list of available routes enclosed within <routes></routes> XML tags:
@ -789,7 +864,7 @@ You are also given the conversation context enclosed within <conversation></conv
[ [
{ {
"role": "user", "role": "user",
"content": "given the image In style of Andy Warhol" "content": "givenrhol"
}, },
{ {
"role": "assistant", "role": "assistant",
@ -862,6 +937,190 @@ If no routes are needed, return an empty list for `route`.
assert_eq!(expected_prompt, prompt); assert_eq!(expected_prompt, prompt);
} }
#[test]
fn test_huge_single_user_message_is_middle_trimmed() {
// Regression test for the case where a single, extremely large user
// message was being passed to the orchestrator verbatim and blowing
// past the upstream model's context window. The trimmer must now
// middle-trim (head + ellipsis + tail) the oversized message so the
// resulting request stays within the configured budget, and the
// trim marker must be present so the router model knows content
// was dropped.
let orchestrations_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let max_token_length = 2048;
let orchestrator = OrchestratorModelV1::new(
agent_orchestrations,
"test-model".to_string(),
max_token_length,
);
// ~500KB of content — same scale as the real payload that triggered
// the production upstream 400.
let head = "HEAD_MARKER_START ";
let tail = " TAIL_MARKER_END";
let filler = "A".repeat(500_000);
let huge_user_content = format!("{head}{filler}{tail}");
let conversation = vec![Message {
role: Role::User,
content: Some(MessageContent::Text(huge_user_content.clone())),
name: None,
tool_calls: None,
tool_call_id: None,
}];
let req = orchestrator.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text();
// Prompt must stay bounded. Generous ceiling = budget-in-bytes +
// scaffolding + slack. Real result should be well under this.
let byte_ceiling = max_token_length * TOKEN_LENGTH_DIVISOR
+ ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len()
+ 1024;
assert!(
prompt.len() < byte_ceiling,
"prompt length {} exceeded ceiling {} — truncation did not apply",
prompt.len(),
byte_ceiling,
);
// Not all 500k filler chars survive.
let a_count = prompt.chars().filter(|c| *c == 'A').count();
assert!(
a_count < filler.len(),
"expected user message to be truncated; all {} 'A's survived",
a_count
);
assert!(
a_count > 0,
"expected some of the user message to survive truncation"
);
// Head and tail of the message must both be preserved (that's the
// whole point of middle-trim over head-only).
assert!(
prompt.contains(head),
"head marker missing — head was not preserved"
);
assert!(
prompt.contains(tail),
"tail marker missing — tail was not preserved"
);
// Trim marker must be present so the router model can see that
// content was omitted.
assert!(
prompt.contains(TRIM_MARKER),
"ellipsis trim marker missing from truncated prompt"
);
// Routing prompt scaffolding remains intact.
assert!(prompt.contains("<conversation>"));
assert!(prompt.contains("<routes>"));
}
#[test]
fn test_turn_cap_limits_routing_history() {
// The outer turn-cap guardrail should keep only the last
// `MAX_ROUTING_TURNS` filtered messages regardless of how long the
// conversation is. We build a conversation with alternating
// user/assistant turns tagged with their index and verify that only
// the tail of the conversation makes it into the prompt.
let orchestrations_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), usize::MAX);
let mut conversation: Vec<Message> = Vec::new();
let total_turns = MAX_ROUTING_TURNS * 2; // well past the cap
for i in 0..total_turns {
let role = if i % 2 == 0 {
Role::User
} else {
Role::Assistant
};
conversation.push(Message {
role,
content: Some(MessageContent::Text(format!("turn-{i:03}"))),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
let req = orchestrator.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text();
// The last MAX_ROUTING_TURNS messages (indexes total-cap..total)
// must all appear.
for i in (total_turns - MAX_ROUTING_TURNS)..total_turns {
let tag = format!("turn-{i:03}");
assert!(
prompt.contains(&tag),
"expected recent turn tag {tag} to be present"
);
}
// And earlier turns (indexes 0..total-cap) must all be dropped.
for i in 0..(total_turns - MAX_ROUTING_TURNS) {
let tag = format!("turn-{i:03}");
assert!(
!prompt.contains(&tag),
"old turn tag {tag} leaked past turn cap into the prompt"
);
}
}
#[test]
fn test_trim_middle_utf8_helper() {
// No-op when already small enough.
assert_eq!(trim_middle_utf8("hello", 100), "hello");
assert_eq!(trim_middle_utf8("hello", 5), "hello");
// 60/40 split with ellipsis when too long.
let long = "a".repeat(20);
let out = trim_middle_utf8(&long, 10);
assert!(out.len() <= 10);
assert!(out.contains(TRIM_MARKER));
// Exactly one ellipsis, rest are 'a's.
assert_eq!(out.matches(TRIM_MARKER).count(), 1);
assert!(out.chars().filter(|c| *c == 'a').count() > 0);
// When max_bytes is smaller than the marker, falls back to
// head-only truncation (no marker).
let out = trim_middle_utf8("abcdefgh", 2);
assert_eq!(out, "ab");
// UTF-8 boundary safety: 2-byte chars.
let s = "é".repeat(50); // 100 bytes
let out = trim_middle_utf8(&s, 25);
assert!(out.len() <= 25);
// Must still be valid UTF-8 that only contains 'é' and the marker.
let ok = out.chars().all(|c| c == 'é' || c == '…');
assert!(ok, "unexpected char in trimmed output: {out:?}");
}
#[test] #[test]
fn test_non_text_input() { fn test_non_text_input() {
let expected_prompt = r#" let expected_prompt = r#"