feat: head+tail trim with ellipsis and 16-turn cap for routing prompt

Replaces the previous head-only truncation of oversized user messages
with a middle-trim (head + ellipsis + tail) that preserves both the task
framing (start of message) and the actual ask (end of message) — a
common shape for long pasted content like code dumps or specs. The
unicode ellipsis also signals to the router model that content was
dropped, which can improve classification accuracy on truncated prompts.

Also adds an outer guardrail: only the last `MAX_ROUTING_TURNS` (16)
filtered messages are considered when building the routing prompt. This
bounds prompt growth for long conversations before the token-budget
loop runs, matching the approach HuggingFace chat-ui takes in its
arch-router client.

Tests:
- test_huge_single_user_message_is_middle_trimmed: regression test for
  the 500KB user message scenario. Verifies the prompt stays bounded,
  head + tail markers both survive, and the ellipsis is present.
- test_turn_cap_limits_routing_history: builds a 32-turn conversation
  and verifies only the last 16 make it into the prompt.
- test_trim_middle_utf8_helper: unit test for the helper covering the
  no-op path, the 60/40 split, the too-small-for-marker fallback, and
  UTF-8 boundary safety for multi-byte characters.
- Updated test_conversation_trim_upto_user_message to reflect the new
  middle-trim behavior.
This commit is contained in:
Adil Hafeez 2026-04-17 19:18:30 -07:00
parent c90b699c90
commit 42b7927122

View file

@ -10,6 +10,18 @@ use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model
/// Hard cap on the number of recent messages considered when building the
/// routing prompt. Bounds prompt growth for long-running conversations and
/// acts as an outer guardrail before the token-budget loop runs. The most
/// recent `MAX_ROUTING_TURNS` filtered messages are kept; older turns are
/// dropped entirely.
pub const MAX_ROUTING_TURNS: usize = 16;
/// Unicode ellipsis used to mark where content was trimmed out of a long
/// message. Helps signal to the downstream router model that the message was
/// truncated.
const TRIM_MARKER: &str = "";
/// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python /// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python
struct SpacedJsonFormatter; struct SpacedJsonFormatter;
@ -176,10 +188,9 @@ impl OrchestratorModel for OrchestratorModelV1 {
messages: &[Message], messages: &[Message],
usage_preferences_from_request: &Option<Vec<AgentUsagePreference>>, usage_preferences_from_request: &Option<Vec<AgentUsagePreference>>,
) -> ChatCompletionsRequest { ) -> ChatCompletionsRequest {
// remove system prompt, tool calls, tool call response and messages without content // Remove system/developer/tool messages and messages without extractable
// if content is empty its likely a tool call // text (tool calls have no text content we can classify against).
// when role == tool its tool call response let filtered: Vec<&Message> = messages
let messages_vec = messages
.iter() .iter()
.filter(|m| { .filter(|m| {
m.role != Role::System m.role != Role::System
@ -187,10 +198,17 @@ impl OrchestratorModel for OrchestratorModelV1 {
&& m.role != Role::Tool && m.role != Role::Tool
&& !m.content.extract_text().is_empty() && !m.content.extract_text().is_empty()
}) })
.collect::<Vec<&Message>>(); .collect();
// Following code is to ensure that the conversation does not exceed max token length // Outer guardrail: only consider the last `MAX_ROUTING_TURNS` filtered
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance // messages when building the routing prompt. Keeps prompt growth
// predictable for long conversations regardless of per-message size.
let start = filtered.len().saturating_sub(MAX_ROUTING_TURNS);
let messages_vec: &[&Message] = &filtered[start..];
// Ensure the conversation does not exceed the configured token budget.
// We use `len() / TOKEN_LENGTH_DIVISOR` as a cheap token estimate to
// avoid running a real tokenizer on the hot path.
let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
let mut selected_messages_list_reversed: Vec<Message> = vec![]; let mut selected_messages_list_reversed: Vec<Message> = vec![];
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() { for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
@ -206,16 +224,19 @@ impl OrchestratorModel for OrchestratorModelV1 {
total = messages_vec.len(), total = messages_vec.len(),
"token count exceeds max, truncating conversation" "token count exceeds max, truncating conversation"
); );
// If the message that overflows the budget is from the user we need // If the overflow message is from the user we need to keep
// to keep at least some of it so the orchestrator sees the latest // some of it so the orchestrator still sees the latest user
// user intent. Trim from the end of the message toward the // intent. Use a middle-trim (head + ellipsis + tail): users
// beginning until we fit in the remaining token budget. // often frame the task at the start AND put the actual ask
// at the end of a long pasted block, so preserving both is
// better than a head-only cut. The ellipsis also signals to
// the router model that content was dropped.
if message.role == Role::User && remaining_tokens > 0 { if message.role == Role::User && remaining_tokens > 0 {
let max_bytes = remaining_tokens.saturating_mul(TOKEN_LENGTH_DIVISOR); let max_bytes = remaining_tokens.saturating_mul(TOKEN_LENGTH_DIVISOR);
let truncated = truncate_to_utf8_boundary(&message_text, max_bytes); let truncated = trim_middle_utf8(&message_text, max_bytes);
selected_messages_list_reversed.push(Message { selected_messages_list_reversed.push(Message {
role: Role::User, role: Role::User,
content: Some(MessageContent::Text(truncated.to_string())), content: Some(MessageContent::Text(truncated)),
name: None, name: None,
tool_calls: None, tool_calls: None,
tool_call_id: None, tool_call_id: None,
@ -416,18 +437,43 @@ fn fix_json_response(body: &str) -> String {
body.replace("'", "\"").replace("\\n", "") body.replace("'", "\"").replace("\\n", "")
} }
/// Truncate `s` so that the returned slice is at most `max_bytes` bytes long /// Truncate `s` so the result is at most `max_bytes` bytes long, keeping
/// and ends on a UTF-8 character boundary. Keeps the beginning of the string /// roughly 60% from the start and 40% from the end, with a Unicode ellipsis
/// and drops characters from the end. /// separating the two. All splits respect UTF-8 character boundaries. When
fn truncate_to_utf8_boundary(s: &str, max_bytes: usize) -> &str { /// `max_bytes` is too small to fit the marker at all, falls back to a
/// head-only truncation.
fn trim_middle_utf8(s: &str, max_bytes: usize) -> String {
if s.len() <= max_bytes { if s.len() <= max_bytes {
return s; return s.to_string();
} }
if max_bytes <= TRIM_MARKER.len() {
// Not enough room even for the marker — just keep the start.
let mut end = max_bytes; let mut end = max_bytes;
while end > 0 && !s.is_char_boundary(end) { while end > 0 && !s.is_char_boundary(end) {
end -= 1; end -= 1;
} }
&s[..end] return s[..end].to_string();
}
let available = max_bytes - TRIM_MARKER.len();
// Bias toward the start (60%) where task framing typically lives, while
// still preserving ~40% of the tail where the user's actual ask often
// appears after a long paste.
let mut start_len = available * 3 / 5;
while start_len > 0 && !s.is_char_boundary(start_len) {
start_len -= 1;
}
let end_len = available - start_len;
let mut end_start = s.len().saturating_sub(end_len);
while end_start < s.len() && !s.is_char_boundary(end_start) {
end_start += 1;
}
let mut out = String::with_capacity(start_len + TRIM_MARKER.len() + (s.len() - end_start));
out.push_str(&s[..start_len]);
out.push_str(TRIM_MARKER);
out.push_str(&s[end_start..]);
out
} }
impl std::fmt::Debug for dyn OrchestratorModel { impl std::fmt::Debug for dyn OrchestratorModel {
@ -802,10 +848,10 @@ If no routes are needed, return an empty list for `route`.
#[test] #[test]
fn test_conversation_trim_upto_user_message() { fn test_conversation_trim_upto_user_message() {
// With max_token_length=230 the older user message "given the image In // With max_token_length=230, the older user message "given the image
// style of Andy Warhol" overflows the remaining budget and is // In style of Andy Warhol" overflows the remaining budget and gets
// truncated from the end (chars dropped from the end of the string) // middle-trimmed (head + ellipsis + tail) until it fits. Newer turns
// until it fits. The newer assistant/user turns are preserved in full. // are kept in full.
let expected_prompt = r#" let expected_prompt = r#"
You are a helpful assistant that selects the most suitable routes based on user intent. You are a helpful assistant that selects the most suitable routes based on user intent.
You are provided with a list of available routes enclosed within <routes></routes> XML tags: You are provided with a list of available routes enclosed within <routes></routes> XML tags:
@ -818,7 +864,7 @@ You are also given the conversation context enclosed within <conversation></conv
[ [
{ {
"role": "user", "role": "user",
"content": "given the im" "content": "given…rhol"
}, },
{ {
"role": "assistant", "role": "assistant",
@ -892,12 +938,14 @@ If no routes are needed, return an empty list for `route`.
} }
#[test] #[test]
fn test_huge_single_user_message_is_truncated() { fn test_huge_single_user_message_is_middle_trimmed() {
// Regression test for the case where a single, extremely large user // Regression test for the case where a single, extremely large user
// message was being passed through to the orchestrator verbatim, // message was being passed to the orchestrator verbatim and blowing
// blowing past the upstream model's context window. The trimmer must // past the upstream model's context window. The trimmer must now
// now truncate the oversized user message from the end until it fits // middle-trim (head + ellipsis + tail) the oversized message so the
// within the configured budget. // resulting request stays within the configured budget, and the
// trim marker must be present so the router model knows content
// was dropped.
let orchestrations_str = r#" let orchestrations_str = r#"
{ {
"gpt-4o": [ "gpt-4o": [
@ -917,9 +965,13 @@ If no routes are needed, return an empty list for `route`.
max_token_length, max_token_length,
); );
// ~500KB of content — similar in scale to the real payload that // ~500KB of content — same scale as the real payload that triggered
// triggered the upstream 400 "context length exceeded". // the production upstream 400.
let huge_user_content = "A".repeat(500_000); let head = "HEAD_MARKER_START ";
let tail = " TAIL_MARKER_END";
let filler = "A".repeat(500_000);
let huge_user_content = format!("{head}{filler}{tail}");
let conversation = vec![Message { let conversation = vec![Message {
role: Role::User, role: Role::User,
content: Some(MessageContent::Text(huge_user_content.clone())), content: Some(MessageContent::Text(huge_user_content.clone())),
@ -931,13 +983,11 @@ If no routes are needed, return an empty list for `route`.
let req = orchestrator.generate_request(&conversation, &None); let req = orchestrator.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text(); let prompt = req.messages[0].content.extract_text();
// Final prompt must be bounded. Use a generous ceiling: the configured // Prompt must stay bounded. Generous ceiling = budget-in-bytes +
// budget converted to bytes (tokens * divisor) plus the system prompt // scaffolding + slack. Real result should be well under this.
// and routes JSON overhead. In practice the result should be well
// under this ceiling.
let byte_ceiling = max_token_length * TOKEN_LENGTH_DIVISOR let byte_ceiling = max_token_length * TOKEN_LENGTH_DIVISOR
+ ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() + ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len()
+ 512; + 1024;
assert!( assert!(
prompt.len() < byte_ceiling, prompt.len() < byte_ceiling,
"prompt length {} exceeded ceiling {} — truncation did not apply", "prompt length {} exceeded ceiling {} — truncation did not apply",
@ -945,24 +995,132 @@ If no routes are needed, return an empty list for `route`.
byte_ceiling, byte_ceiling,
); );
// The oversized user message must have been truncated — i.e. not all // Not all 500k filler chars survive.
// 500k "A" characters made it through.
let a_count = prompt.chars().filter(|c| *c == 'A').count(); let a_count = prompt.chars().filter(|c| *c == 'A').count();
assert!( assert!(
a_count < huge_user_content.len(), a_count < filler.len(),
"expected user message to be truncated, but all {} 'A' chars survived", "expected user message to be truncated; all {} 'A's survived",
a_count a_count
); );
assert!( assert!(
a_count > 0, a_count > 0,
"expected at least some of the user message to survive truncation" "expected some of the user message to survive truncation"
); );
// Sanity: the prompt still includes the routing prompt scaffolding. // Head and tail of the message must both be preserved (that's the
// whole point of middle-trim over head-only).
assert!(
prompt.contains(head),
"head marker missing — head was not preserved"
);
assert!(
prompt.contains(tail),
"tail marker missing — tail was not preserved"
);
// Trim marker must be present so the router model can see that
// content was omitted.
assert!(
prompt.contains(TRIM_MARKER),
"ellipsis trim marker missing from truncated prompt"
);
// Routing prompt scaffolding remains intact.
assert!(prompt.contains("<conversation>")); assert!(prompt.contains("<conversation>"));
assert!(prompt.contains("<routes>")); assert!(prompt.contains("<routes>"));
} }
#[test]
fn test_turn_cap_limits_routing_history() {
// The outer turn-cap guardrail should keep only the last
// `MAX_ROUTING_TURNS` filtered messages regardless of how long the
// conversation is. We build a conversation with alternating
// user/assistant turns tagged with their index and verify that only
// the tail of the conversation makes it into the prompt.
let orchestrations_str = r#"
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let agent_orchestrations = serde_json::from_str::<
HashMap<String, Vec<OrchestrationPreference>>,
>(orchestrations_str)
.unwrap();
let orchestrator =
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), usize::MAX);
let mut conversation: Vec<Message> = Vec::new();
let total_turns = MAX_ROUTING_TURNS * 2; // well past the cap
for i in 0..total_turns {
let role = if i % 2 == 0 {
Role::User
} else {
Role::Assistant
};
conversation.push(Message {
role,
content: Some(MessageContent::Text(format!("turn-{i:03}"))),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
let req = orchestrator.generate_request(&conversation, &None);
let prompt = req.messages[0].content.extract_text();
// The last MAX_ROUTING_TURNS messages (indexes total-cap..total)
// must all appear.
for i in (total_turns - MAX_ROUTING_TURNS)..total_turns {
let tag = format!("turn-{i:03}");
assert!(
prompt.contains(&tag),
"expected recent turn tag {tag} to be present"
);
}
// And earlier turns (indexes 0..total-cap) must all be dropped.
for i in 0..(total_turns - MAX_ROUTING_TURNS) {
let tag = format!("turn-{i:03}");
assert!(
!prompt.contains(&tag),
"old turn tag {tag} leaked past turn cap into the prompt"
);
}
}
#[test]
fn test_trim_middle_utf8_helper() {
// No-op when already small enough.
assert_eq!(trim_middle_utf8("hello", 100), "hello");
assert_eq!(trim_middle_utf8("hello", 5), "hello");
// 60/40 split with ellipsis when too long.
let long = "a".repeat(20);
let out = trim_middle_utf8(&long, 10);
assert!(out.len() <= 10);
assert!(out.contains(TRIM_MARKER));
// Exactly one ellipsis, rest are 'a's.
assert_eq!(out.matches(TRIM_MARKER).count(), 1);
assert!(out.chars().filter(|c| *c == 'a').count() > 0);
// When max_bytes is smaller than the marker, falls back to
// head-only truncation (no marker).
let out = trim_middle_utf8("abcdefgh", 2);
assert_eq!(out, "ab");
// UTF-8 boundary safety: 2-byte chars.
let s = "é".repeat(50); // 100 bytes
let out = trim_middle_utf8(&s, 25);
assert!(out.len() <= 25);
// Must still be valid UTF-8 that only contains 'é' and the marker.
let ok = out.chars().all(|c| c == 'é' || c == '…');
assert!(ok, "unexpected char in trimmed output: {out:?}");
}
#[test] #[test]
fn test_non_text_input() { fn test_non_text_input() {
let expected_prompt = r#" let expected_prompt = r#"