mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
fix: truncate oversized user messages in orchestrator routing prompt (#895)
This commit is contained in:
parent
37600fd07a
commit
95a7beaab3
2 changed files with 423 additions and 39 deletions
|
|
@ -1,8 +1,14 @@
|
||||||
use hermesllm::apis::openai::ChatCompletionsResponse;
|
use hermesllm::apis::openai::ChatCompletionsResponse;
|
||||||
use hyper::header;
|
use hyper::header;
|
||||||
|
use serde::Deserialize;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
|
||||||
|
/// Max bytes of raw upstream body we include in a log message or error text
|
||||||
|
/// when the body is not a recognizable error envelope. Keeps logs from being
|
||||||
|
/// flooded by huge HTML error pages.
|
||||||
|
const RAW_BODY_LOG_LIMIT: usize = 512;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum HttpError {
|
pub enum HttpError {
|
||||||
#[error("Failed to send request: {0}")]
|
#[error("Failed to send request: {0}")]
|
||||||
|
|
@ -10,13 +16,64 @@ pub enum HttpError {
|
||||||
|
|
||||||
#[error("Failed to parse JSON response: {0}")]
|
#[error("Failed to parse JSON response: {0}")]
|
||||||
Json(serde_json::Error, String),
|
Json(serde_json::Error, String),
|
||||||
|
|
||||||
|
#[error("Upstream returned {status}: {message}")]
|
||||||
|
Upstream { status: u16, message: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shape of an OpenAI-style error response body, e.g.
|
||||||
|
/// `{"error": {"message": "...", "type": "...", "param": "...", "code": ...}}`.
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct UpstreamErrorEnvelope {
|
||||||
|
error: UpstreamErrorBody,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct UpstreamErrorBody {
|
||||||
|
message: String,
|
||||||
|
#[serde(default, rename = "type")]
|
||||||
|
err_type: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
param: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract a human-readable error message from an upstream response body.
|
||||||
|
/// Tries to parse an OpenAI-style `{"error": {"message": ...}}` envelope; if
|
||||||
|
/// that fails, falls back to the first `RAW_BODY_LOG_LIMIT` bytes of the raw
|
||||||
|
/// body (UTF-8 safe).
|
||||||
|
fn extract_upstream_error_message(body: &str) -> String {
|
||||||
|
if let Ok(env) = serde_json::from_str::<UpstreamErrorEnvelope>(body) {
|
||||||
|
let mut msg = env.error.message;
|
||||||
|
if let Some(param) = env.error.param {
|
||||||
|
msg.push_str(&format!(" (param={param})"));
|
||||||
|
}
|
||||||
|
if let Some(err_type) = env.error.err_type {
|
||||||
|
msg.push_str(&format!(" [type={err_type}]"));
|
||||||
|
}
|
||||||
|
return msg;
|
||||||
|
}
|
||||||
|
truncate_for_log(body).to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn truncate_for_log(s: &str) -> &str {
|
||||||
|
if s.len() <= RAW_BODY_LOG_LIMIT {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
let mut end = RAW_BODY_LOG_LIMIT;
|
||||||
|
while end > 0 && !s.is_char_boundary(end) {
|
||||||
|
end -= 1;
|
||||||
|
}
|
||||||
|
&s[..end]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sends a POST request to the given URL and extracts the text content
|
/// Sends a POST request to the given URL and extracts the text content
|
||||||
/// from the first choice of the `ChatCompletionsResponse`.
|
/// from the first choice of the `ChatCompletionsResponse`.
|
||||||
///
|
///
|
||||||
/// Returns `Some((content, elapsed))` on success, or `None` if the response
|
/// Returns `Some((content, elapsed))` on success, `None` if the response
|
||||||
/// had no choices or the first choice had no content.
|
/// had no choices or the first choice had no content. Returns
|
||||||
|
/// `HttpError::Upstream` for any non-2xx status, carrying a message
|
||||||
|
/// extracted from the OpenAI-style error envelope (or a truncated raw body
|
||||||
|
/// if the body is not in that shape).
|
||||||
pub async fn post_and_extract_content(
|
pub async fn post_and_extract_content(
|
||||||
client: &reqwest::Client,
|
client: &reqwest::Client,
|
||||||
url: &str,
|
url: &str,
|
||||||
|
|
@ -26,17 +83,36 @@ pub async fn post_and_extract_content(
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
|
|
||||||
let res = client.post(url).headers(headers).body(body).send().await?;
|
let res = client.post(url).headers(headers).body(body).send().await?;
|
||||||
|
let status = res.status();
|
||||||
|
|
||||||
let body = res.text().await?;
|
let body = res.text().await?;
|
||||||
let elapsed = start_time.elapsed();
|
let elapsed = start_time.elapsed();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let message = extract_upstream_error_message(&body);
|
||||||
|
warn!(
|
||||||
|
status = status.as_u16(),
|
||||||
|
message = %message,
|
||||||
|
body_size = body.len(),
|
||||||
|
"upstream returned error response"
|
||||||
|
);
|
||||||
|
return Err(HttpError::Upstream {
|
||||||
|
status: status.as_u16(),
|
||||||
|
message,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| {
|
let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| {
|
||||||
warn!(error = %err, body = %body, "failed to parse json response");
|
warn!(
|
||||||
|
error = %err,
|
||||||
|
body = %truncate_for_log(&body),
|
||||||
|
"failed to parse json response",
|
||||||
|
);
|
||||||
HttpError::Json(err, format!("Failed to parse JSON: {}", body))
|
HttpError::Json(err, format!("Failed to parse JSON: {}", body))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if response.choices.is_empty() {
|
if response.choices.is_empty() {
|
||||||
warn!(body = %body, "no choices in response");
|
warn!(body = %truncate_for_log(&body), "no choices in response");
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -46,3 +122,52 @@ pub async fn post_and_extract_content(
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|c| (c.clone(), elapsed)))
|
.map(|c| (c.clone(), elapsed)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extracts_message_from_openai_style_error_envelope() {
|
||||||
|
let body = r#"{"error":{"code":400,"message":"This model's maximum context length is 32768 tokens. However, you requested 0 output tokens and your prompt contains at least 32769 input tokens, for a total of at least 32769 tokens.","param":"input_tokens","type":"BadRequestError"}}"#;
|
||||||
|
let msg = extract_upstream_error_message(body);
|
||||||
|
assert!(
|
||||||
|
msg.starts_with("This model's maximum context length is 32768 tokens."),
|
||||||
|
"unexpected message: {msg}"
|
||||||
|
);
|
||||||
|
assert!(msg.contains("(param=input_tokens)"));
|
||||||
|
assert!(msg.contains("[type=BadRequestError]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extracts_message_without_optional_fields() {
|
||||||
|
let body = r#"{"error":{"message":"something broke"}}"#;
|
||||||
|
let msg = extract_upstream_error_message(body);
|
||||||
|
assert_eq!(msg, "something broke");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn falls_back_to_raw_body_when_not_error_envelope() {
|
||||||
|
let body = "<html><body>502 Bad Gateway</body></html>";
|
||||||
|
let msg = extract_upstream_error_message(body);
|
||||||
|
assert_eq!(msg, body);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn truncates_non_envelope_bodies_in_logs() {
|
||||||
|
let body = "x".repeat(RAW_BODY_LOG_LIMIT * 3);
|
||||||
|
let msg = extract_upstream_error_message(&body);
|
||||||
|
assert_eq!(msg.len(), RAW_BODY_LOG_LIMIT);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn truncate_for_log_respects_utf8_boundaries() {
|
||||||
|
// 2-byte characters; picking a length that would split mid-char.
|
||||||
|
let body = "é".repeat(RAW_BODY_LOG_LIMIT);
|
||||||
|
let out = truncate_for_log(&body);
|
||||||
|
// Should be a valid &str (implicit — would panic if we returned
|
||||||
|
// a non-boundary slice) and at most RAW_BODY_LOG_LIMIT bytes.
|
||||||
|
assert!(out.len() <= RAW_BODY_LOG_LIMIT);
|
||||||
|
assert!(out.chars().all(|c| c == 'é'));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,18 @@ use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
|
||||||
|
|
||||||
pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model
|
pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model
|
||||||
|
|
||||||
|
/// Hard cap on the number of recent messages considered when building the
|
||||||
|
/// routing prompt. Bounds prompt growth for long-running conversations and
|
||||||
|
/// acts as an outer guardrail before the token-budget loop runs. The most
|
||||||
|
/// recent `MAX_ROUTING_TURNS` filtered messages are kept; older turns are
|
||||||
|
/// dropped entirely.
|
||||||
|
pub const MAX_ROUTING_TURNS: usize = 16;
|
||||||
|
|
||||||
|
/// Unicode ellipsis used to mark where content was trimmed out of a long
|
||||||
|
/// message. Helps signal to the downstream router model that the message was
|
||||||
|
/// truncated.
|
||||||
|
const TRIM_MARKER: &str = "…";
|
||||||
|
|
||||||
/// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python
|
/// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python
|
||||||
struct SpacedJsonFormatter;
|
struct SpacedJsonFormatter;
|
||||||
|
|
||||||
|
|
@ -176,10 +188,9 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
||||||
messages: &[Message],
|
messages: &[Message],
|
||||||
usage_preferences_from_request: &Option<Vec<AgentUsagePreference>>,
|
usage_preferences_from_request: &Option<Vec<AgentUsagePreference>>,
|
||||||
) -> ChatCompletionsRequest {
|
) -> ChatCompletionsRequest {
|
||||||
// remove system prompt, tool calls, tool call response and messages without content
|
// Remove system/developer/tool messages and messages without extractable
|
||||||
// if content is empty its likely a tool call
|
// text (tool calls have no text content we can classify against).
|
||||||
// when role == tool its tool call response
|
let filtered: Vec<&Message> = messages
|
||||||
let messages_vec = messages
|
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|m| {
|
.filter(|m| {
|
||||||
m.role != Role::System
|
m.role != Role::System
|
||||||
|
|
@ -187,37 +198,72 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
||||||
&& m.role != Role::Tool
|
&& m.role != Role::Tool
|
||||||
&& !m.content.extract_text().is_empty()
|
&& !m.content.extract_text().is_empty()
|
||||||
})
|
})
|
||||||
.collect::<Vec<&Message>>();
|
.collect();
|
||||||
|
|
||||||
// Following code is to ensure that the conversation does not exceed max token length
|
// Outer guardrail: only consider the last `MAX_ROUTING_TURNS` filtered
|
||||||
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
|
// messages when building the routing prompt. Keeps prompt growth
|
||||||
|
// predictable for long conversations regardless of per-message size.
|
||||||
|
let start = filtered.len().saturating_sub(MAX_ROUTING_TURNS);
|
||||||
|
let messages_vec: &[&Message] = &filtered[start..];
|
||||||
|
|
||||||
|
// Ensure the conversation does not exceed the configured token budget.
|
||||||
|
// We use `len() / TOKEN_LENGTH_DIVISOR` as a cheap token estimate to
|
||||||
|
// avoid running a real tokenizer on the hot path.
|
||||||
let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
let mut selected_messages_list_reversed: Vec<Message> = vec![];
|
||||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
let message_text = message.content.extract_text();
|
||||||
token_count += message_token_count;
|
let message_token_count = message_text.len() / TOKEN_LENGTH_DIVISOR;
|
||||||
if token_count > self.max_token_length {
|
if token_count + message_token_count > self.max_token_length {
|
||||||
|
let remaining_tokens = self.max_token_length.saturating_sub(token_count);
|
||||||
debug!(
|
debug!(
|
||||||
token_count = token_count,
|
attempted_total_tokens = token_count + message_token_count,
|
||||||
max_tokens = self.max_token_length,
|
max_tokens = self.max_token_length,
|
||||||
|
remaining_tokens,
|
||||||
selected = selected_messsage_count,
|
selected = selected_messsage_count,
|
||||||
total = messages_vec.len(),
|
total = messages_vec.len(),
|
||||||
"token count exceeds max, truncating conversation"
|
"token count exceeds max, truncating conversation"
|
||||||
);
|
);
|
||||||
if message.role == Role::User {
|
// If the overflow message is from the user we need to keep
|
||||||
// If message that exceeds max token length is from user, we need to keep it
|
// some of it so the orchestrator still sees the latest user
|
||||||
selected_messages_list_reversed.push(message);
|
// intent. Use a middle-trim (head + ellipsis + tail): users
|
||||||
|
// often frame the task at the start AND put the actual ask
|
||||||
|
// at the end of a long pasted block, so preserving both is
|
||||||
|
// better than a head-only cut. The ellipsis also signals to
|
||||||
|
// the router model that content was dropped.
|
||||||
|
if message.role == Role::User && remaining_tokens > 0 {
|
||||||
|
let max_bytes = remaining_tokens.saturating_mul(TOKEN_LENGTH_DIVISOR);
|
||||||
|
let truncated = trim_middle_utf8(&message_text, max_bytes);
|
||||||
|
selected_messages_list_reversed.push(Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: Some(MessageContent::Text(truncated)),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// If we are here, it means that the message is within the max token length
|
token_count += message_token_count;
|
||||||
selected_messages_list_reversed.push(message);
|
selected_messages_list_reversed.push(Message {
|
||||||
|
role: message.role.clone(),
|
||||||
|
content: Some(MessageContent::Text(message_text)),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
if selected_messages_list_reversed.is_empty() {
|
if selected_messages_list_reversed.is_empty() {
|
||||||
debug!("no messages selected, using last message");
|
debug!("no messages selected, using last message");
|
||||||
if let Some(last_message) = messages_vec.last() {
|
if let Some(last_message) = messages_vec.last() {
|
||||||
selected_messages_list_reversed.push(last_message);
|
selected_messages_list_reversed.push(Message {
|
||||||
|
role: last_message.role.clone(),
|
||||||
|
content: Some(MessageContent::Text(last_message.content.extract_text())),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -237,22 +283,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reverse the selected messages to maintain the conversation order
|
// Reverse the selected messages to maintain the conversation order
|
||||||
let selected_conversation_list = selected_messages_list_reversed
|
let selected_conversation_list: Vec<Message> =
|
||||||
.iter()
|
selected_messages_list_reversed.into_iter().rev().collect();
|
||||||
.rev()
|
|
||||||
.map(|message| Message {
|
|
||||||
role: message.role.clone(),
|
|
||||||
content: Some(MessageContent::Text(
|
|
||||||
message
|
|
||||||
.content
|
|
||||||
.as_ref()
|
|
||||||
.map_or(String::new(), |c| c.to_string()),
|
|
||||||
)),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
})
|
|
||||||
.collect::<Vec<Message>>();
|
|
||||||
|
|
||||||
// Generate the orchestrator request message based on the usage preferences.
|
// Generate the orchestrator request message based on the usage preferences.
|
||||||
// If preferences are passed in request then we use them;
|
// If preferences are passed in request then we use them;
|
||||||
|
|
@ -405,6 +437,45 @@ fn fix_json_response(body: &str) -> String {
|
||||||
body.replace("'", "\"").replace("\\n", "")
|
body.replace("'", "\"").replace("\\n", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Truncate `s` so the result is at most `max_bytes` bytes long, keeping
|
||||||
|
/// roughly 60% from the start and 40% from the end, with a Unicode ellipsis
|
||||||
|
/// separating the two. All splits respect UTF-8 character boundaries. When
|
||||||
|
/// `max_bytes` is too small to fit the marker at all, falls back to a
|
||||||
|
/// head-only truncation.
|
||||||
|
fn trim_middle_utf8(s: &str, max_bytes: usize) -> String {
|
||||||
|
if s.len() <= max_bytes {
|
||||||
|
return s.to_string();
|
||||||
|
}
|
||||||
|
if max_bytes <= TRIM_MARKER.len() {
|
||||||
|
// Not enough room even for the marker — just keep the start.
|
||||||
|
let mut end = max_bytes;
|
||||||
|
while end > 0 && !s.is_char_boundary(end) {
|
||||||
|
end -= 1;
|
||||||
|
}
|
||||||
|
return s[..end].to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let available = max_bytes - TRIM_MARKER.len();
|
||||||
|
// Bias toward the start (60%) where task framing typically lives, while
|
||||||
|
// still preserving ~40% of the tail where the user's actual ask often
|
||||||
|
// appears after a long paste.
|
||||||
|
let mut start_len = available * 3 / 5;
|
||||||
|
while start_len > 0 && !s.is_char_boundary(start_len) {
|
||||||
|
start_len -= 1;
|
||||||
|
}
|
||||||
|
let end_len = available - start_len;
|
||||||
|
let mut end_start = s.len().saturating_sub(end_len);
|
||||||
|
while end_start < s.len() && !s.is_char_boundary(end_start) {
|
||||||
|
end_start += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut out = String::with_capacity(start_len + TRIM_MARKER.len() + (s.len() - end_start));
|
||||||
|
out.push_str(&s[..start_len]);
|
||||||
|
out.push_str(TRIM_MARKER);
|
||||||
|
out.push_str(&s[end_start..]);
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for dyn OrchestratorModel {
|
impl std::fmt::Debug for dyn OrchestratorModel {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "OrchestratorModel")
|
write!(f, "OrchestratorModel")
|
||||||
|
|
@ -777,6 +848,10 @@ If no routes are needed, return an empty list for `route`.
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_conversation_trim_upto_user_message() {
|
fn test_conversation_trim_upto_user_message() {
|
||||||
|
// With max_token_length=230, the older user message "given the image
|
||||||
|
// In style of Andy Warhol" overflows the remaining budget and gets
|
||||||
|
// middle-trimmed (head + ellipsis + tail) until it fits. Newer turns
|
||||||
|
// are kept in full.
|
||||||
let expected_prompt = r#"
|
let expected_prompt = r#"
|
||||||
You are a helpful assistant that selects the most suitable routes based on user intent.
|
You are a helpful assistant that selects the most suitable routes based on user intent.
|
||||||
You are provided with a list of available routes enclosed within <routes></routes> XML tags:
|
You are provided with a list of available routes enclosed within <routes></routes> XML tags:
|
||||||
|
|
@ -789,7 +864,7 @@ You are also given the conversation context enclosed within <conversation></conv
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "given the image In style of Andy Warhol"
|
"content": "given…rhol"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
|
@ -862,6 +937,190 @@ If no routes are needed, return an empty list for `route`.
|
||||||
assert_eq!(expected_prompt, prompt);
|
assert_eq!(expected_prompt, prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_huge_single_user_message_is_middle_trimmed() {
|
||||||
|
// Regression test for the case where a single, extremely large user
|
||||||
|
// message was being passed to the orchestrator verbatim and blowing
|
||||||
|
// past the upstream model's context window. The trimmer must now
|
||||||
|
// middle-trim (head + ellipsis + tail) the oversized message so the
|
||||||
|
// resulting request stays within the configured budget, and the
|
||||||
|
// trim marker must be present so the router model knows content
|
||||||
|
// was dropped.
|
||||||
|
let orchestrations_str = r#"
|
||||||
|
{
|
||||||
|
"gpt-4o": [
|
||||||
|
{"name": "Image generation", "description": "generating image"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"#;
|
||||||
|
let agent_orchestrations = serde_json::from_str::<
|
||||||
|
HashMap<String, Vec<OrchestrationPreference>>,
|
||||||
|
>(orchestrations_str)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let max_token_length = 2048;
|
||||||
|
let orchestrator = OrchestratorModelV1::new(
|
||||||
|
agent_orchestrations,
|
||||||
|
"test-model".to_string(),
|
||||||
|
max_token_length,
|
||||||
|
);
|
||||||
|
|
||||||
|
// ~500KB of content — same scale as the real payload that triggered
|
||||||
|
// the production upstream 400.
|
||||||
|
let head = "HEAD_MARKER_START ";
|
||||||
|
let tail = " TAIL_MARKER_END";
|
||||||
|
let filler = "A".repeat(500_000);
|
||||||
|
let huge_user_content = format!("{head}{filler}{tail}");
|
||||||
|
|
||||||
|
let conversation = vec![Message {
|
||||||
|
role: Role::User,
|
||||||
|
content: Some(MessageContent::Text(huge_user_content.clone())),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
}];
|
||||||
|
|
||||||
|
let req = orchestrator.generate_request(&conversation, &None);
|
||||||
|
let prompt = req.messages[0].content.extract_text();
|
||||||
|
|
||||||
|
// Prompt must stay bounded. Generous ceiling = budget-in-bytes +
|
||||||
|
// scaffolding + slack. Real result should be well under this.
|
||||||
|
let byte_ceiling = max_token_length * TOKEN_LENGTH_DIVISOR
|
||||||
|
+ ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len()
|
||||||
|
+ 1024;
|
||||||
|
assert!(
|
||||||
|
prompt.len() < byte_ceiling,
|
||||||
|
"prompt length {} exceeded ceiling {} — truncation did not apply",
|
||||||
|
prompt.len(),
|
||||||
|
byte_ceiling,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Not all 500k filler chars survive.
|
||||||
|
let a_count = prompt.chars().filter(|c| *c == 'A').count();
|
||||||
|
assert!(
|
||||||
|
a_count < filler.len(),
|
||||||
|
"expected user message to be truncated; all {} 'A's survived",
|
||||||
|
a_count
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
a_count > 0,
|
||||||
|
"expected some of the user message to survive truncation"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Head and tail of the message must both be preserved (that's the
|
||||||
|
// whole point of middle-trim over head-only).
|
||||||
|
assert!(
|
||||||
|
prompt.contains(head),
|
||||||
|
"head marker missing — head was not preserved"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
prompt.contains(tail),
|
||||||
|
"tail marker missing — tail was not preserved"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Trim marker must be present so the router model can see that
|
||||||
|
// content was omitted.
|
||||||
|
assert!(
|
||||||
|
prompt.contains(TRIM_MARKER),
|
||||||
|
"ellipsis trim marker missing from truncated prompt"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Routing prompt scaffolding remains intact.
|
||||||
|
assert!(prompt.contains("<conversation>"));
|
||||||
|
assert!(prompt.contains("<routes>"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_turn_cap_limits_routing_history() {
|
||||||
|
// The outer turn-cap guardrail should keep only the last
|
||||||
|
// `MAX_ROUTING_TURNS` filtered messages regardless of how long the
|
||||||
|
// conversation is. We build a conversation with alternating
|
||||||
|
// user/assistant turns tagged with their index and verify that only
|
||||||
|
// the tail of the conversation makes it into the prompt.
|
||||||
|
let orchestrations_str = r#"
|
||||||
|
{
|
||||||
|
"gpt-4o": [
|
||||||
|
{"name": "Image generation", "description": "generating image"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"#;
|
||||||
|
let agent_orchestrations = serde_json::from_str::<
|
||||||
|
HashMap<String, Vec<OrchestrationPreference>>,
|
||||||
|
>(orchestrations_str)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let orchestrator =
|
||||||
|
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), usize::MAX);
|
||||||
|
|
||||||
|
let mut conversation: Vec<Message> = Vec::new();
|
||||||
|
let total_turns = MAX_ROUTING_TURNS * 2; // well past the cap
|
||||||
|
for i in 0..total_turns {
|
||||||
|
let role = if i % 2 == 0 {
|
||||||
|
Role::User
|
||||||
|
} else {
|
||||||
|
Role::Assistant
|
||||||
|
};
|
||||||
|
conversation.push(Message {
|
||||||
|
role,
|
||||||
|
content: Some(MessageContent::Text(format!("turn-{i:03}"))),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let req = orchestrator.generate_request(&conversation, &None);
|
||||||
|
let prompt = req.messages[0].content.extract_text();
|
||||||
|
|
||||||
|
// The last MAX_ROUTING_TURNS messages (indexes total-cap..total)
|
||||||
|
// must all appear.
|
||||||
|
for i in (total_turns - MAX_ROUTING_TURNS)..total_turns {
|
||||||
|
let tag = format!("turn-{i:03}");
|
||||||
|
assert!(
|
||||||
|
prompt.contains(&tag),
|
||||||
|
"expected recent turn tag {tag} to be present"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// And earlier turns (indexes 0..total-cap) must all be dropped.
|
||||||
|
for i in 0..(total_turns - MAX_ROUTING_TURNS) {
|
||||||
|
let tag = format!("turn-{i:03}");
|
||||||
|
assert!(
|
||||||
|
!prompt.contains(&tag),
|
||||||
|
"old turn tag {tag} leaked past turn cap into the prompt"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_trim_middle_utf8_helper() {
|
||||||
|
// No-op when already small enough.
|
||||||
|
assert_eq!(trim_middle_utf8("hello", 100), "hello");
|
||||||
|
assert_eq!(trim_middle_utf8("hello", 5), "hello");
|
||||||
|
|
||||||
|
// 60/40 split with ellipsis when too long.
|
||||||
|
let long = "a".repeat(20);
|
||||||
|
let out = trim_middle_utf8(&long, 10);
|
||||||
|
assert!(out.len() <= 10);
|
||||||
|
assert!(out.contains(TRIM_MARKER));
|
||||||
|
// Exactly one ellipsis, rest are 'a's.
|
||||||
|
assert_eq!(out.matches(TRIM_MARKER).count(), 1);
|
||||||
|
assert!(out.chars().filter(|c| *c == 'a').count() > 0);
|
||||||
|
|
||||||
|
// When max_bytes is smaller than the marker, falls back to
|
||||||
|
// head-only truncation (no marker).
|
||||||
|
let out = trim_middle_utf8("abcdefgh", 2);
|
||||||
|
assert_eq!(out, "ab");
|
||||||
|
|
||||||
|
// UTF-8 boundary safety: 2-byte chars.
|
||||||
|
let s = "é".repeat(50); // 100 bytes
|
||||||
|
let out = trim_middle_utf8(&s, 25);
|
||||||
|
assert!(out.len() <= 25);
|
||||||
|
// Must still be valid UTF-8 that only contains 'é' and the marker.
|
||||||
|
let ok = out.chars().all(|c| c == 'é' || c == '…');
|
||||||
|
assert!(ok, "unexpected char in trimmed output: {out:?}");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_non_text_input() {
|
fn test_non_text_input() {
|
||||||
let expected_prompt = r#"
|
let expected_prompt = r#"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue