diff --git a/crates/brightstaff/src/router/orchestrator_model_v1.rs b/crates/brightstaff/src/router/orchestrator_model_v1.rs index 240d2f1f..693aacc2 100644 --- a/crates/brightstaff/src/router/orchestrator_model_v1.rs +++ b/crates/brightstaff/src/router/orchestrator_model_v1.rs @@ -10,6 +10,18 @@ use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError}; pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model +/// Hard cap on the number of recent messages considered when building the +/// routing prompt. Bounds prompt growth for long-running conversations and +/// acts as an outer guardrail before the token-budget loop runs. The most +/// recent `MAX_ROUTING_TURNS` filtered messages are kept; older turns are +/// dropped entirely. +pub const MAX_ROUTING_TURNS: usize = 16; + +/// Unicode ellipsis used to mark where content was trimmed out of a long +/// message. Helps signal to the downstream router model that the message was +/// truncated. +const TRIM_MARKER: &str = "…"; + /// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python struct SpacedJsonFormatter; @@ -176,10 +188,9 @@ impl OrchestratorModel for OrchestratorModelV1 { messages: &[Message], usage_preferences_from_request: &Option>, ) -> ChatCompletionsRequest { - // remove system prompt, tool calls, tool call response and messages without content - // if content is empty its likely a tool call - // when role == tool its tool call response - let messages_vec = messages + // Remove system/developer/tool messages and messages without extractable + // text (tool calls have no text content we can classify against). + let filtered: Vec<&Message> = messages .iter() .filter(|m| { m.role != Role::System @@ -187,10 +198,17 @@ impl OrchestratorModel for OrchestratorModelV1 { && m.role != Role::Tool && !m.content.extract_text().is_empty() }) - .collect::>(); + .collect(); - // Following code is to ensure that the conversation does not exceed max token length - // Note: we use a simple heuristic to estimate token count based on character length to optimize for performance + // Outer guardrail: only consider the last `MAX_ROUTING_TURNS` filtered + // messages when building the routing prompt. Keeps prompt growth + // predictable for long conversations regardless of per-message size. + let start = filtered.len().saturating_sub(MAX_ROUTING_TURNS); + let messages_vec: &[&Message] = &filtered[start..]; + + // Ensure the conversation does not exceed the configured token budget. + // We use `len() / TOKEN_LENGTH_DIVISOR` as a cheap token estimate to + // avoid running a real tokenizer on the hot path. let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR; let mut selected_messages_list_reversed: Vec = vec![]; for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() { @@ -206,16 +224,19 @@ impl OrchestratorModel for OrchestratorModelV1 { total = messages_vec.len(), "token count exceeds max, truncating conversation" ); - // If the message that overflows the budget is from the user we need - // to keep at least some of it so the orchestrator sees the latest - // user intent. Trim from the end of the message toward the - // beginning until we fit in the remaining token budget. + // If the overflow message is from the user we need to keep + // some of it so the orchestrator still sees the latest user + // intent. Use a middle-trim (head + ellipsis + tail): users + // often frame the task at the start AND put the actual ask + // at the end of a long pasted block, so preserving both is + // better than a head-only cut. The ellipsis also signals to + // the router model that content was dropped. if message.role == Role::User && remaining_tokens > 0 { let max_bytes = remaining_tokens.saturating_mul(TOKEN_LENGTH_DIVISOR); - let truncated = truncate_to_utf8_boundary(&message_text, max_bytes); + let truncated = trim_middle_utf8(&message_text, max_bytes); selected_messages_list_reversed.push(Message { role: Role::User, - content: Some(MessageContent::Text(truncated.to_string())), + content: Some(MessageContent::Text(truncated)), name: None, tool_calls: None, tool_call_id: None, @@ -416,18 +437,43 @@ fn fix_json_response(body: &str) -> String { body.replace("'", "\"").replace("\\n", "") } -/// Truncate `s` so that the returned slice is at most `max_bytes` bytes long -/// and ends on a UTF-8 character boundary. Keeps the beginning of the string -/// and drops characters from the end. -fn truncate_to_utf8_boundary(s: &str, max_bytes: usize) -> &str { +/// Truncate `s` so the result is at most `max_bytes` bytes long, keeping +/// roughly 60% from the start and 40% from the end, with a Unicode ellipsis +/// separating the two. All splits respect UTF-8 character boundaries. When +/// `max_bytes` is too small to fit the marker at all, falls back to a +/// head-only truncation. +fn trim_middle_utf8(s: &str, max_bytes: usize) -> String { if s.len() <= max_bytes { - return s; + return s.to_string(); } - let mut end = max_bytes; - while end > 0 && !s.is_char_boundary(end) { - end -= 1; + if max_bytes <= TRIM_MARKER.len() { + // Not enough room even for the marker — just keep the start. + let mut end = max_bytes; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + return s[..end].to_string(); } - &s[..end] + + let available = max_bytes - TRIM_MARKER.len(); + // Bias toward the start (60%) where task framing typically lives, while + // still preserving ~40% of the tail where the user's actual ask often + // appears after a long paste. + let mut start_len = available * 3 / 5; + while start_len > 0 && !s.is_char_boundary(start_len) { + start_len -= 1; + } + let end_len = available - start_len; + let mut end_start = s.len().saturating_sub(end_len); + while end_start < s.len() && !s.is_char_boundary(end_start) { + end_start += 1; + } + + let mut out = String::with_capacity(start_len + TRIM_MARKER.len() + (s.len() - end_start)); + out.push_str(&s[..start_len]); + out.push_str(TRIM_MARKER); + out.push_str(&s[end_start..]); + out } impl std::fmt::Debug for dyn OrchestratorModel { @@ -802,10 +848,10 @@ If no routes are needed, return an empty list for `route`. #[test] fn test_conversation_trim_upto_user_message() { - // With max_token_length=230 the older user message "given the image In - // style of Andy Warhol" overflows the remaining budget and is - // truncated from the end (chars dropped from the end of the string) - // until it fits. The newer assistant/user turns are preserved in full. + // With max_token_length=230, the older user message "given the image + // In style of Andy Warhol" overflows the remaining budget and gets + // middle-trimmed (head + ellipsis + tail) until it fits. Newer turns + // are kept in full. let expected_prompt = r#" You are a helpful assistant that selects the most suitable routes based on user intent. You are provided with a list of available routes enclosed within XML tags: @@ -818,7 +864,7 @@ You are also given the conversation context enclosed within 0, - "expected at least some of the user message to survive truncation" + "expected some of the user message to survive truncation" ); - // Sanity: the prompt still includes the routing prompt scaffolding. + // Head and tail of the message must both be preserved (that's the + // whole point of middle-trim over head-only). + assert!( + prompt.contains(head), + "head marker missing — head was not preserved" + ); + assert!( + prompt.contains(tail), + "tail marker missing — tail was not preserved" + ); + + // Trim marker must be present so the router model can see that + // content was omitted. + assert!( + prompt.contains(TRIM_MARKER), + "ellipsis trim marker missing from truncated prompt" + ); + + // Routing prompt scaffolding remains intact. assert!(prompt.contains("")); assert!(prompt.contains("")); } + #[test] + fn test_turn_cap_limits_routing_history() { + // The outer turn-cap guardrail should keep only the last + // `MAX_ROUTING_TURNS` filtered messages regardless of how long the + // conversation is. We build a conversation with alternating + // user/assistant turns tagged with their index and verify that only + // the tail of the conversation makes it into the prompt. + let orchestrations_str = r#" + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } + "#; + let agent_orchestrations = serde_json::from_str::< + HashMap>, + >(orchestrations_str) + .unwrap(); + + let orchestrator = + OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), usize::MAX); + + let mut conversation: Vec = Vec::new(); + let total_turns = MAX_ROUTING_TURNS * 2; // well past the cap + for i in 0..total_turns { + let role = if i % 2 == 0 { + Role::User + } else { + Role::Assistant + }; + conversation.push(Message { + role, + content: Some(MessageContent::Text(format!("turn-{i:03}"))), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + + let req = orchestrator.generate_request(&conversation, &None); + let prompt = req.messages[0].content.extract_text(); + + // The last MAX_ROUTING_TURNS messages (indexes total-cap..total) + // must all appear. + for i in (total_turns - MAX_ROUTING_TURNS)..total_turns { + let tag = format!("turn-{i:03}"); + assert!( + prompt.contains(&tag), + "expected recent turn tag {tag} to be present" + ); + } + + // And earlier turns (indexes 0..total-cap) must all be dropped. + for i in 0..(total_turns - MAX_ROUTING_TURNS) { + let tag = format!("turn-{i:03}"); + assert!( + !prompt.contains(&tag), + "old turn tag {tag} leaked past turn cap into the prompt" + ); + } + } + + #[test] + fn test_trim_middle_utf8_helper() { + // No-op when already small enough. + assert_eq!(trim_middle_utf8("hello", 100), "hello"); + assert_eq!(trim_middle_utf8("hello", 5), "hello"); + + // 60/40 split with ellipsis when too long. + let long = "a".repeat(20); + let out = trim_middle_utf8(&long, 10); + assert!(out.len() <= 10); + assert!(out.contains(TRIM_MARKER)); + // Exactly one ellipsis, rest are 'a's. + assert_eq!(out.matches(TRIM_MARKER).count(), 1); + assert!(out.chars().filter(|c| *c == 'a').count() > 0); + + // When max_bytes is smaller than the marker, falls back to + // head-only truncation (no marker). + let out = trim_middle_utf8("abcdefgh", 2); + assert_eq!(out, "ab"); + + // UTF-8 boundary safety: 2-byte chars. + let s = "é".repeat(50); // 100 bytes + let out = trim_middle_utf8(&s, 25); + assert!(out.len() <= 25); + // Must still be valid UTF-8 that only contains 'é' and the marker. + let ok = out.chars().all(|c| c == 'é' || c == '…'); + assert!(ok, "unexpected char in trimmed output: {out:?}"); + } + #[test] fn test_non_text_input() { let expected_prompt = r#"