mirror of
https://github.com/katanemo/plano.git
synced 2026-05-11 08:42:48 +02:00
Merge origin/main into musa/chatgpt-subscription
This commit is contained in:
commit
6f67048c04
118 changed files with 11627 additions and 2194 deletions
|
|
@ -26,6 +26,8 @@ opentelemetry-stdout = "0.31"
|
|||
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
|
||||
pretty_assertions = "1.4.1"
|
||||
rand = "0.9.2"
|
||||
lru = "0.12"
|
||||
redis = { version = "0.27", features = ["tokio-comp"] }
|
||||
reqwest = { version = "0.12.15", features = ["stream"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ use common::configuration::{Agent, FilterPipeline, Listener, ModelAlias, SpanAtt
|
|||
use common::llm_providers::LlmProviders;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::router::llm::RouterService;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::state::StateStorage;
|
||||
|
||||
|
|
@ -14,7 +13,6 @@ use crate::state::StateStorage;
|
|||
/// Instead of cloning 8+ individual `Arc`s per connection, a single
|
||||
/// `Arc<AppState>` is cloned once and passed to the request handler.
|
||||
pub struct AppState {
|
||||
pub router_service: Arc<RouterService>,
|
||||
pub orchestrator_service: Arc<OrchestratorService>,
|
||||
pub model_aliases: Option<HashMap<String, ModelAlias>>,
|
||||
pub llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
|
|
|
|||
|
|
@ -177,6 +177,7 @@ mod tests {
|
|||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
crate::router::orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ mod tests {
|
|||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
crate::router::orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
))
|
||||
}
|
||||
|
||||
|
|
@ -147,8 +148,8 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_error_handling_flow() {
|
||||
let router_service = create_test_orchestrator_service();
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let orchestrator_service = create_test_orchestrator_service();
|
||||
let agent_selector = AgentSelector::new(orchestrator_service);
|
||||
|
||||
// Test listener not found
|
||||
let result = agent_selector.find_listener(Some("nonexistent"), &[]);
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ pub(crate) mod model_selection;
|
|||
|
||||
use crate::app_state::AppState;
|
||||
use crate::handlers::agents::pipeline::PipelineProcessor;
|
||||
use crate::handlers::extract_or_generate_traceparent;
|
||||
use crate::handlers::extract_request_id;
|
||||
use crate::handlers::full;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
|
|
@ -34,7 +33,8 @@ use crate::streaming::{
|
|||
ObservableStreamProcessor, StreamProcessor,
|
||||
};
|
||||
use crate::tracing::{
|
||||
collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name,
|
||||
collect_custom_trace_attributes, llm as tracing_llm, operation_component,
|
||||
plano as tracing_plano, set_service_name,
|
||||
};
|
||||
use model_selection::router_chat_get_upstream_model;
|
||||
|
||||
|
|
@ -92,22 +92,47 @@ async fn llm_chat_inner(
|
|||
}
|
||||
});
|
||||
|
||||
let traceparent = extract_or_generate_traceparent(&request_headers);
|
||||
|
||||
// Session pinning: extract session ID and check cache before routing
|
||||
let session_id: Option<String> = request_headers
|
||||
.get(MODEL_AFFINITY_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let pinned_model: Option<String> = if let Some(ref sid) = session_id {
|
||||
let tenant_id: Option<String> = state
|
||||
.orchestrator_service
|
||||
.tenant_header()
|
||||
.and_then(|hdr| request_headers.get(hdr))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let cached_route = if let Some(ref sid) = session_id {
|
||||
state
|
||||
.router_service
|
||||
.get_cached_route(sid)
|
||||
.orchestrator_service
|
||||
.get_cached_route(sid, tenant_id.as_deref())
|
||||
.await
|
||||
.map(|c| c.model_name)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (pinned_model, pinned_route_name): (Option<String>, Option<String>) = match cached_route {
|
||||
Some(c) => (Some(c.model_name), c.route_name),
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
// Record session id on the LLM span for the observability console.
|
||||
if let Some(ref sid) = session_id {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_plano::SESSION_ID,
|
||||
sid.clone(),
|
||||
));
|
||||
});
|
||||
}
|
||||
if let Some(ref route_name) = pinned_route_name {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_plano::ROUTE_NAME,
|
||||
route_name.clone(),
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
let full_qualified_llm_provider_url = format!("{}{}", state.llm_provider_url, request_path);
|
||||
|
||||
|
|
@ -289,9 +314,8 @@ async fn llm_chat_inner(
|
|||
let routing_result = match async {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
router_chat_get_upstream_model(
|
||||
Arc::clone(&state.router_service),
|
||||
Arc::clone(&state.orchestrator_service),
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_preferences,
|
||||
|
|
@ -317,11 +341,22 @@ async fn llm_chat_inner(
|
|||
alias_resolved_model.clone()
|
||||
};
|
||||
|
||||
// Cache the routing decision so subsequent requests with the same session ID are pinned
|
||||
// Record route name on the LLM span (only when the orchestrator produced one).
|
||||
if let Some(ref rn) = route_name {
|
||||
if !rn.is_empty() && rn != "none" {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_plano::ROUTE_NAME,
|
||||
rn.clone(),
|
||||
));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref sid) = session_id {
|
||||
state
|
||||
.router_service
|
||||
.cache_route(sid.clone(), model.clone(), route_name)
|
||||
.orchestrator_service
|
||||
.cache_route(sid.clone(), tenant_id.as_deref(), model.clone(), route_name)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -678,6 +713,36 @@ async fn send_upstream(
|
|||
// Propagate upstream headers and status
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let upstream_status = llm_response.status();
|
||||
|
||||
// Upstream routers (e.g. DigitalOcean Gradient) may return an
|
||||
// `x-model-router-selected-route` header indicating which task-level
|
||||
// route the request was classified into (e.g. "Code Generation"). Surface
|
||||
// it as `plano.route.name` so the obs console's Route hit % panel can
|
||||
// show the breakdown even when Plano's own orchestrator wasn't in the
|
||||
// routing path. Any value from Plano's orchestrator already set earlier
|
||||
// takes precedence — this only fires when the span doesn't already have
|
||||
// a route name.
|
||||
if let Some(upstream_route) = response_headers
|
||||
.get("x-model-router-selected-route")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
if !upstream_route.is_empty() {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
crate::tracing::plano::ROUTE_NAME,
|
||||
upstream_route.to_string(),
|
||||
));
|
||||
});
|
||||
}
|
||||
}
|
||||
// Record the upstream HTTP status on the span for the obs console.
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
crate::tracing::http::STATUS_CODE,
|
||||
upstream_status.as_u16() as i64,
|
||||
));
|
||||
});
|
||||
|
||||
let mut response = Response::builder().status(upstream_status);
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
for (name, value) in response_headers.iter() {
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use hyper::StatusCode;
|
|||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm::RouterService;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::streaming::truncate_message;
|
||||
use crate::tracing::routing;
|
||||
|
||||
|
|
@ -37,9 +37,8 @@ impl RoutingError {
|
|||
/// * `Ok(RoutingResult)` - Contains the selected model name and span ID
|
||||
/// * `Err(RoutingError)` - Contains error details and optional span ID
|
||||
pub async fn router_chat_get_upstream_model(
|
||||
router_service: Arc<RouterService>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
client_request: ProviderRequestType,
|
||||
traceparent: &str,
|
||||
request_path: &str,
|
||||
request_id: &str,
|
||||
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
|
|
@ -99,11 +98,9 @@ pub async fn router_chat_get_upstream_model(
|
|||
// Capture start time for routing span
|
||||
let routing_start_time = std::time::Instant::now();
|
||||
|
||||
// Attempt to determine route using the router service
|
||||
let routing_result = router_service
|
||||
let routing_result = orchestrator_service
|
||||
.determine_route(
|
||||
&chat_request.messages,
|
||||
traceparent,
|
||||
inline_routing_preferences,
|
||||
request_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ use tracing::{debug, info, info_span, warn, Instrument};
|
|||
|
||||
use super::extract_or_generate_traceparent;
|
||||
use crate::handlers::llm::model_selection::router_chat_get_upstream_model;
|
||||
use crate::router::llm::RouterService;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||
|
||||
/// Extracts `routing_preferences` from a JSON body, returning the cleaned body bytes
|
||||
|
|
@ -60,7 +60,7 @@ struct RoutingDecisionResponse {
|
|||
|
||||
pub async fn routing_decision(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
request_path: String,
|
||||
span_attributes: &Option<SpanAttributes>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
|
@ -76,6 +76,12 @@ pub async fn routing_decision(
|
|||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let tenant_id: Option<String> = orchestrator_service
|
||||
.tenant_header()
|
||||
.and_then(|hdr| request_headers.get(hdr))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref());
|
||||
|
||||
let request_span = info_span!(
|
||||
|
|
@ -88,25 +94,28 @@ pub async fn routing_decision(
|
|||
|
||||
routing_decision_inner(
|
||||
request,
|
||||
router_service,
|
||||
orchestrator_service,
|
||||
request_id,
|
||||
request_path,
|
||||
request_headers,
|
||||
custom_attrs,
|
||||
session_id,
|
||||
tenant_id,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn routing_decision_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
request_id: String,
|
||||
request_path: String,
|
||||
request_headers: hyper::HeaderMap,
|
||||
custom_attrs: std::collections::HashMap<String, String>,
|
||||
session_id: Option<String>,
|
||||
tenant_id: Option<String>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
opentelemetry::trace::get_active_span(|span| {
|
||||
|
|
@ -124,9 +133,11 @@ async fn routing_decision_inner(
|
|||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
// Session pinning: check cache before doing any routing work
|
||||
if let Some(ref sid) = session_id {
|
||||
if let Some(cached) = router_service.get_cached_route(sid).await {
|
||||
if let Some(cached) = orchestrator_service
|
||||
.get_cached_route(sid, tenant_id.as_deref())
|
||||
.await
|
||||
{
|
||||
info!(
|
||||
session_id = %sid,
|
||||
model = %cached.model_name,
|
||||
|
|
@ -190,9 +201,8 @@ async fn routing_decision_inner(
|
|||
};
|
||||
|
||||
let routing_result = router_chat_get_upstream_model(
|
||||
Arc::clone(&router_service),
|
||||
Arc::clone(&orchestrator_service),
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_preferences,
|
||||
|
|
@ -201,11 +211,11 @@ async fn routing_decision_inner(
|
|||
|
||||
match routing_result {
|
||||
Ok(result) => {
|
||||
// Cache the result if session_id is present
|
||||
if let Some(ref sid) = session_id {
|
||||
router_service
|
||||
orchestrator_service
|
||||
.cache_route(
|
||||
sid.clone(),
|
||||
tenant_id.as_deref(),
|
||||
result.model_name.clone(),
|
||||
result.route_name.clone(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
pub mod app_state;
|
||||
pub mod handlers;
|
||||
pub mod router;
|
||||
pub mod session_cache;
|
||||
pub mod signals;
|
||||
pub mod state;
|
||||
pub mod streaming;
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
|||
use brightstaff::handlers::llm::llm_chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::handlers::routing_service::routing_decision;
|
||||
use brightstaff::router::llm::RouterService;
|
||||
use brightstaff::router::model_metrics::ModelMetricsService;
|
||||
use brightstaff::router::orchestrator::OrchestratorService;
|
||||
use brightstaff::session_cache::init_session_cache;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::StateStorage;
|
||||
|
|
@ -36,8 +36,6 @@ use tokio::sync::RwLock;
|
|||
use tracing::{debug, info, warn};
|
||||
|
||||
const BIND_ADDRESS: &str = "0.0.0.0:9091";
|
||||
const DEFAULT_ROUTING_LLM_PROVIDER: &str = "arch-router";
|
||||
const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
|
||||
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
|
||||
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
|
||||
|
||||
|
|
@ -160,22 +158,8 @@ async fn init_app_state(
|
|||
|
||||
let overrides = config.overrides.clone().unwrap_or_default();
|
||||
|
||||
let routing_model_name: String = overrides
|
||||
.llm_routing_model
|
||||
.as_deref()
|
||||
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
||||
.unwrap_or(DEFAULT_ROUTING_MODEL_NAME)
|
||||
.to_string();
|
||||
|
||||
let routing_llm_provider = config
|
||||
.model_providers
|
||||
.iter()
|
||||
.find(|p| p.model.as_deref() == Some(routing_model_name.as_str()))
|
||||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||
|
||||
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
|
||||
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
|
||||
let session_cache = init_session_cache(config).await?;
|
||||
|
||||
// Validate that top-level routing_preferences requires v0.4.0+.
|
||||
let config_version = parse_semver(&config.version);
|
||||
|
|
@ -297,31 +281,17 @@ async fn init_app_state(
|
|||
}
|
||||
}
|
||||
|
||||
let router_service = Arc::new(RouterService::new(
|
||||
config.routing_preferences.clone(),
|
||||
metrics_service,
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
session_ttl_seconds,
|
||||
session_max_entries,
|
||||
));
|
||||
|
||||
// Spawn background task to clean up expired session cache entries every 5 minutes
|
||||
{
|
||||
let router_service = Arc::clone(&router_service);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
router_service.cleanup_expired_sessions().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
let session_tenant_header = config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.session_cache.as_ref())
|
||||
.and_then(|c| c.tenant_header.clone());
|
||||
|
||||
// Resolve model name: prefer llm_routing_model override, then agent_orchestration_model, then default.
|
||||
let orchestrator_model_name: String = overrides
|
||||
.agent_orchestration_model
|
||||
.llm_routing_model
|
||||
.as_deref()
|
||||
.or(overrides.agent_orchestration_model.as_deref())
|
||||
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
||||
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
|
||||
.to_string();
|
||||
|
|
@ -333,10 +303,20 @@ async fn init_app_state(
|
|||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
|
||||
|
||||
let orchestrator_service = Arc::new(OrchestratorService::new(
|
||||
let orchestrator_max_tokens = overrides
|
||||
.orchestrator_model_context_length
|
||||
.unwrap_or(brightstaff::router::orchestrator_model_v1::MAX_TOKEN_LEN);
|
||||
|
||||
let orchestrator_service = Arc::new(OrchestratorService::with_routing(
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
orchestrator_model_name,
|
||||
orchestrator_llm_provider,
|
||||
config.routing_preferences.clone(),
|
||||
metrics_service,
|
||||
session_ttl_seconds,
|
||||
session_cache,
|
||||
session_tenant_header,
|
||||
orchestrator_max_tokens,
|
||||
));
|
||||
|
||||
let state_storage = init_state_storage(config).await?;
|
||||
|
|
@ -347,7 +327,6 @@ async fn init_app_state(
|
|||
.and_then(|tracing| tracing.span_attributes.clone());
|
||||
|
||||
Ok(AppState {
|
||||
router_service,
|
||||
orchestrator_service,
|
||||
model_aliases: config.model_aliases.clone(),
|
||||
llm_providers: Arc::new(RwLock::new(llm_providers)),
|
||||
|
|
@ -434,7 +413,7 @@ async fn route(
|
|||
) {
|
||||
return routing_decision(
|
||||
req,
|
||||
Arc::clone(&state.router_service),
|
||||
Arc::clone(&state.orchestrator_service),
|
||||
stripped,
|
||||
&state.span_attributes,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
use hermesllm::apis::openai::ChatCompletionsResponse;
|
||||
use hyper::header;
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
|
||||
/// Max bytes of raw upstream body we include in a log message or error text
|
||||
/// when the body is not a recognizable error envelope. Keeps logs from being
|
||||
/// flooded by huge HTML error pages.
|
||||
const RAW_BODY_LOG_LIMIT: usize = 512;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HttpError {
|
||||
#[error("Failed to send request: {0}")]
|
||||
|
|
@ -10,13 +16,64 @@ pub enum HttpError {
|
|||
|
||||
#[error("Failed to parse JSON response: {0}")]
|
||||
Json(serde_json::Error, String),
|
||||
|
||||
#[error("Upstream returned {status}: {message}")]
|
||||
Upstream { status: u16, message: String },
|
||||
}
|
||||
|
||||
/// Shape of an OpenAI-style error response body, e.g.
|
||||
/// `{"error": {"message": "...", "type": "...", "param": "...", "code": ...}}`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpstreamErrorEnvelope {
|
||||
error: UpstreamErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpstreamErrorBody {
|
||||
message: String,
|
||||
#[serde(default, rename = "type")]
|
||||
err_type: Option<String>,
|
||||
#[serde(default)]
|
||||
param: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract a human-readable error message from an upstream response body.
|
||||
/// Tries to parse an OpenAI-style `{"error": {"message": ...}}` envelope; if
|
||||
/// that fails, falls back to the first `RAW_BODY_LOG_LIMIT` bytes of the raw
|
||||
/// body (UTF-8 safe).
|
||||
fn extract_upstream_error_message(body: &str) -> String {
|
||||
if let Ok(env) = serde_json::from_str::<UpstreamErrorEnvelope>(body) {
|
||||
let mut msg = env.error.message;
|
||||
if let Some(param) = env.error.param {
|
||||
msg.push_str(&format!(" (param={param})"));
|
||||
}
|
||||
if let Some(err_type) = env.error.err_type {
|
||||
msg.push_str(&format!(" [type={err_type}]"));
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
truncate_for_log(body).to_string()
|
||||
}
|
||||
|
||||
fn truncate_for_log(s: &str) -> &str {
|
||||
if s.len() <= RAW_BODY_LOG_LIMIT {
|
||||
return s;
|
||||
}
|
||||
let mut end = RAW_BODY_LOG_LIMIT;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&s[..end]
|
||||
}
|
||||
|
||||
/// Sends a POST request to the given URL and extracts the text content
|
||||
/// from the first choice of the `ChatCompletionsResponse`.
|
||||
///
|
||||
/// Returns `Some((content, elapsed))` on success, or `None` if the response
|
||||
/// had no choices or the first choice had no content.
|
||||
/// Returns `Some((content, elapsed))` on success, `None` if the response
|
||||
/// had no choices or the first choice had no content. Returns
|
||||
/// `HttpError::Upstream` for any non-2xx status, carrying a message
|
||||
/// extracted from the OpenAI-style error envelope (or a truncated raw body
|
||||
/// if the body is not in that shape).
|
||||
pub async fn post_and_extract_content(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
|
|
@ -26,17 +83,36 @@ pub async fn post_and_extract_content(
|
|||
let start_time = std::time::Instant::now();
|
||||
|
||||
let res = client.post(url).headers(headers).body(body).send().await?;
|
||||
let status = res.status();
|
||||
|
||||
let body = res.text().await?;
|
||||
let elapsed = start_time.elapsed();
|
||||
|
||||
if !status.is_success() {
|
||||
let message = extract_upstream_error_message(&body);
|
||||
warn!(
|
||||
status = status.as_u16(),
|
||||
message = %message,
|
||||
body_size = body.len(),
|
||||
"upstream returned error response"
|
||||
);
|
||||
return Err(HttpError::Upstream {
|
||||
status: status.as_u16(),
|
||||
message,
|
||||
});
|
||||
}
|
||||
|
||||
let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| {
|
||||
warn!(error = %err, body = %body, "failed to parse json response");
|
||||
warn!(
|
||||
error = %err,
|
||||
body = %truncate_for_log(&body),
|
||||
"failed to parse json response",
|
||||
);
|
||||
HttpError::Json(err, format!("Failed to parse JSON: {}", body))
|
||||
})?;
|
||||
|
||||
if response.choices.is_empty() {
|
||||
warn!(body = %body, "no choices in response");
|
||||
warn!(body = %truncate_for_log(&body), "no choices in response");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
|
@ -46,3 +122,52 @@ pub async fn post_and_extract_content(
|
|||
.as_ref()
|
||||
.map(|c| (c.clone(), elapsed)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn extracts_message_from_openai_style_error_envelope() {
|
||||
let body = r#"{"error":{"code":400,"message":"This model's maximum context length is 32768 tokens. However, you requested 0 output tokens and your prompt contains at least 32769 input tokens, for a total of at least 32769 tokens.","param":"input_tokens","type":"BadRequestError"}}"#;
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert!(
|
||||
msg.starts_with("This model's maximum context length is 32768 tokens."),
|
||||
"unexpected message: {msg}"
|
||||
);
|
||||
assert!(msg.contains("(param=input_tokens)"));
|
||||
assert!(msg.contains("[type=BadRequestError]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_message_without_optional_fields() {
|
||||
let body = r#"{"error":{"message":"something broke"}}"#;
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert_eq!(msg, "something broke");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn falls_back_to_raw_body_when_not_error_envelope() {
|
||||
let body = "<html><body>502 Bad Gateway</body></html>";
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert_eq!(msg, body);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_non_envelope_bodies_in_logs() {
|
||||
let body = "x".repeat(RAW_BODY_LOG_LIMIT * 3);
|
||||
let msg = extract_upstream_error_message(&body);
|
||||
assert_eq!(msg.len(), RAW_BODY_LOG_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_for_log_respects_utf8_boundaries() {
|
||||
// 2-byte characters; picking a length that would split mid-char.
|
||||
let body = "é".repeat(RAW_BODY_LOG_LIMIT);
|
||||
let out = truncate_for_log(&body);
|
||||
// Should be a valid &str (implicit — would panic if we returned
|
||||
// a non-boundary slice) and at most RAW_BODY_LOG_LIMIT bytes.
|
||||
assert!(out.len() <= RAW_BODY_LOG_LIMIT);
|
||||
assert!(out.chars().all(|c| c == 'é'));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,380 +0,0 @@
|
|||
use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant};
|
||||
|
||||
use common::{
|
||||
configuration::TopLevelRoutingPreference,
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER},
|
||||
};
|
||||
|
||||
use super::router_model::{ModelUsagePreference, RoutingPreference};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::http::{self, post_and_extract_content};
|
||||
use super::model_metrics::ModelMetricsService;
|
||||
use super::router_model::RouterModel;
|
||||
|
||||
use crate::router::router_model_v1;
|
||||
|
||||
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
|
||||
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedRoute {
|
||||
pub model_name: String,
|
||||
pub route_name: Option<String>,
|
||||
pub cached_at: Instant,
|
||||
}
|
||||
|
||||
pub struct RouterService {
|
||||
router_url: String,
|
||||
client: reqwest::Client,
|
||||
router_model: Arc<dyn RouterModel>,
|
||||
routing_provider_name: String,
|
||||
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
session_cache: RwLock<HashMap<String, CachedRoute>>,
|
||||
session_ttl: Duration,
|
||||
session_max_entries: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RoutingError {
|
||||
#[error(transparent)]
|
||||
Http(#[from] http::HttpError),
|
||||
|
||||
#[error("Router model error: {0}")]
|
||||
RouterModelError(#[from] super::router_model::RoutingModelError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingError>;
|
||||
|
||||
impl RouterService {
|
||||
pub fn new(
|
||||
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
router_url: String,
|
||||
routing_model_name: String,
|
||||
routing_provider_name: String,
|
||||
session_ttl_seconds: Option<u64>,
|
||||
session_max_entries: Option<usize>,
|
||||
) -> Self {
|
||||
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
|
||||
.map_or_else(HashMap::new, |prefs| {
|
||||
prefs.into_iter().map(|p| (p.name.clone(), p)).collect()
|
||||
});
|
||||
|
||||
// Build sentinel routes for RouterModelV1: route_name → first model.
|
||||
// RouterModelV1 uses this to build its prompt; RouterService overrides
|
||||
// the model selection via rank_models() after the route is determined.
|
||||
let sentinel_routes: HashMap<String, Vec<RoutingPreference>> = top_level_preferences
|
||||
.iter()
|
||||
.filter_map(|(name, pref)| {
|
||||
pref.models.first().map(|first_model| {
|
||||
(
|
||||
first_model.clone(),
|
||||
vec![RoutingPreference {
|
||||
name: name.clone(),
|
||||
description: pref.description.clone(),
|
||||
}],
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
|
||||
sentinel_routes,
|
||||
routing_model_name,
|
||||
router_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
let session_ttl =
|
||||
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
|
||||
let session_max_entries = session_max_entries
|
||||
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
|
||||
.min(MAX_SESSION_MAX_ENTRIES);
|
||||
|
||||
RouterService {
|
||||
router_url,
|
||||
client: reqwest::Client::new(),
|
||||
router_model,
|
||||
routing_provider_name,
|
||||
top_level_preferences,
|
||||
metrics_service,
|
||||
session_cache: RwLock::new(HashMap::new()),
|
||||
session_ttl,
|
||||
session_max_entries,
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up a cached routing decision by session ID.
|
||||
/// Returns None if not found or expired.
|
||||
pub async fn get_cached_route(&self, session_id: &str) -> Option<CachedRoute> {
|
||||
let cache = self.session_cache.read().await;
|
||||
if let Some(entry) = cache.get(session_id) {
|
||||
if entry.cached_at.elapsed() < self.session_ttl {
|
||||
return Some(entry.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Store a routing decision in the session cache.
|
||||
/// If at max capacity, evicts the oldest entry.
|
||||
pub async fn cache_route(
|
||||
&self,
|
||||
session_id: String,
|
||||
model_name: String,
|
||||
route_name: Option<String>,
|
||||
) {
|
||||
let mut cache = self.session_cache.write().await;
|
||||
if cache.len() >= self.session_max_entries && !cache.contains_key(&session_id) {
|
||||
if let Some(oldest_key) = cache
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.cached_at)
|
||||
.map(|(k, _)| k.clone())
|
||||
{
|
||||
cache.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
cache.insert(
|
||||
session_id,
|
||||
CachedRoute {
|
||||
model_name,
|
||||
route_name,
|
||||
cached_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Remove all expired entries from the session cache.
|
||||
pub async fn cleanup_expired_sessions(&self) {
|
||||
let mut cache = self.session_cache.write().await;
|
||||
let before = cache.len();
|
||||
cache.retain(|_, entry| entry.cached_at.elapsed() < self.session_ttl);
|
||||
let removed = before - cache.len();
|
||||
if removed > 0 {
|
||||
info!(
|
||||
removed = removed,
|
||||
remaining = cache.len(),
|
||||
"cleaned up expired session cache entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn determine_route(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
traceparent: &str,
|
||||
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
request_id: &str,
|
||||
) -> Result<Option<(String, Vec<String>)>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Build inline top-level map from request if present (inline overrides config).
|
||||
let inline_top_map: Option<HashMap<String, TopLevelRoutingPreference>> =
|
||||
inline_routing_preferences
|
||||
.map(|prefs| prefs.into_iter().map(|p| (p.name.clone(), p)).collect());
|
||||
|
||||
// No routing defined — skip the router call entirely.
|
||||
if inline_top_map.is_none() && self.top_level_preferences.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// For inline overrides, build synthetic ModelUsagePreference list so RouterModelV1
|
||||
// generates the correct prompt (route name + description pairs).
|
||||
// For config-level prefs the sentinel routes are already baked into RouterModelV1.
|
||||
let effective_usage_preferences: Option<Vec<ModelUsagePreference>> =
|
||||
inline_top_map.as_ref().map(|inline_map| {
|
||||
inline_map
|
||||
.values()
|
||||
.map(|p| ModelUsagePreference {
|
||||
model: p.models.first().cloned().unwrap_or_default(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: p.name.clone(),
|
||||
description: p.description.clone(),
|
||||
}],
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let router_request = self
|
||||
.router_model
|
||||
.generate_request(messages, &effective_usage_preferences);
|
||||
|
||||
debug!(
|
||||
model = %self.router_model.get_model_name(),
|
||||
endpoint = %self.router_url,
|
||||
"sending request to arch-router"
|
||||
);
|
||||
|
||||
let body = serde_json::to_string(&router_request)
|
||||
.map_err(super::router_model::RoutingModelError::from)?;
|
||||
debug!(body = %body, "arch router request");
|
||||
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
if let Ok(val) = header::HeaderValue::from_str(&self.routing_provider_name) {
|
||||
headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
val,
|
||||
);
|
||||
}
|
||||
if let Ok(val) = header::HeaderValue::from_str(traceparent) {
|
||||
headers.insert(header::HeaderName::from_static(TRACE_PARENT_HEADER), val);
|
||||
}
|
||||
if let Ok(val) = header::HeaderValue::from_str(request_id) {
|
||||
headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
|
||||
}
|
||||
headers.insert(
|
||||
header::HeaderName::from_static("model"),
|
||||
header::HeaderValue::from_static("arch-router"),
|
||||
);
|
||||
|
||||
let Some((content, elapsed)) =
|
||||
post_and_extract_content(&self.client, &self.router_url, headers, body).await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Parse the route name from the router response.
|
||||
let parsed = self
|
||||
.router_model
|
||||
.parse_response(&content, &effective_usage_preferences)?;
|
||||
|
||||
let result = if let Some((route_name, _sentinel)) = parsed {
|
||||
let top_pref = inline_top_map
|
||||
.as_ref()
|
||||
.and_then(|m| m.get(&route_name))
|
||||
.or_else(|| self.top_level_preferences.get(&route_name));
|
||||
|
||||
if let Some(pref) = top_pref {
|
||||
let ranked = match &self.metrics_service {
|
||||
Some(svc) => svc.rank_models(&pref.models, &pref.selection_policy).await,
|
||||
None => pref.models.clone(),
|
||||
};
|
||||
Some((route_name, ranked))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
info!(
|
||||
content = %content.replace("\n", "\\n"),
|
||||
selected_model = ?result,
|
||||
response_time_ms = elapsed.as_millis(),
|
||||
"arch-router determined route"
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_router_service(ttl_seconds: u64, max_entries: usize) -> RouterService {
|
||||
RouterService::new(
|
||||
None,
|
||||
None,
|
||||
"http://localhost:12001/v1/chat/completions".to_string(),
|
||||
"Arch-Router".to_string(),
|
||||
"arch-router".to_string(),
|
||||
Some(ttl_seconds),
|
||||
Some(max_entries),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_miss_returns_none() {
|
||||
let svc = make_router_service(600, 100);
|
||||
assert!(svc.get_cached_route("unknown-session").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_hit_returns_cached_route() {
|
||||
let svc = make_router_service(600, 100);
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
"gpt-4o".to_string(),
|
||||
Some("code".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let cached = svc.get_cached_route("s1").await.unwrap();
|
||||
assert_eq!(cached.model_name, "gpt-4o");
|
||||
assert_eq!(cached.route_name, Some("code".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_expired_entry_returns_none() {
|
||||
let svc = make_router_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
assert!(svc.get_cached_route("s1").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cleanup_removes_expired() {
|
||||
let svc = make_router_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), "claude".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cleanup_expired_sessions().await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert!(cache.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_evicts_oldest_when_full() {
|
||||
let svc = make_router_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), "model-a".to_string(), None)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
svc.cache_route("s2".to_string(), "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route("s3".to_string(), "model-c".to_string(), None)
|
||||
.await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert_eq!(cache.len(), 2);
|
||||
assert!(!cache.contains_key("s1"));
|
||||
assert!(cache.contains_key("s2"));
|
||||
assert!(cache.contains_key("s3"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_update_existing_session_does_not_evict() {
|
||||
let svc = make_router_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), "model-a".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
"model-a-updated".to_string(),
|
||||
Some("route".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert_eq!(cache.len(), 2);
|
||||
assert_eq!(cache.get("s1").unwrap().model_name, "model-a-updated");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,8 +1,5 @@
|
|||
pub(crate) mod http;
|
||||
pub mod llm;
|
||||
pub mod model_metrics;
|
||||
pub mod orchestrator;
|
||||
pub mod orchestrator_model;
|
||||
pub mod orchestrator_model_v1;
|
||||
pub mod router_model;
|
||||
pub mod router_model_v1;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use common::{
|
||||
configuration::{AgentUsagePreference, OrchestrationPreference},
|
||||
configuration::{AgentUsagePreference, OrchestrationPreference, TopLevelRoutingPreference},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER},
|
||||
};
|
||||
use hermesllm::apis::openai::Message;
|
||||
|
|
@ -12,15 +12,26 @@ use thiserror::Error;
|
|||
use tracing::{debug, info};
|
||||
|
||||
use super::http::{self, post_and_extract_content};
|
||||
use super::model_metrics::ModelMetricsService;
|
||||
use super::orchestrator_model::OrchestratorModel;
|
||||
|
||||
use crate::router::orchestrator_model_v1;
|
||||
use crate::session_cache::SessionCache;
|
||||
|
||||
pub use crate::session_cache::CachedRoute;
|
||||
|
||||
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
|
||||
|
||||
pub struct OrchestratorService {
|
||||
orchestrator_url: String,
|
||||
client: reqwest::Client,
|
||||
orchestrator_model: Arc<dyn OrchestratorModel>,
|
||||
orchestrator_provider_name: String,
|
||||
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
session_cache: Option<Arc<dyn SessionCache>>,
|
||||
session_ttl: Duration,
|
||||
tenant_header: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
@ -39,13 +50,12 @@ impl OrchestratorService {
|
|||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
orchestrator_provider_name: String,
|
||||
max_token_length: usize,
|
||||
) -> Self {
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
|
||||
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model_name.clone(),
|
||||
orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
HashMap::new(),
|
||||
orchestration_model_name,
|
||||
max_token_length,
|
||||
));
|
||||
|
||||
OrchestratorService {
|
||||
|
|
@ -53,9 +63,182 @@ impl OrchestratorService {
|
|||
client: reqwest::Client::new(),
|
||||
orchestrator_model,
|
||||
orchestrator_provider_name,
|
||||
top_level_preferences: HashMap::new(),
|
||||
metrics_service: None,
|
||||
session_cache: None,
|
||||
session_ttl: Duration::from_secs(DEFAULT_SESSION_TTL_SECONDS),
|
||||
tenant_header: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn with_routing(
|
||||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
orchestrator_provider_name: String,
|
||||
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
session_ttl_seconds: Option<u64>,
|
||||
session_cache: Arc<dyn SessionCache>,
|
||||
tenant_header: Option<String>,
|
||||
max_token_length: usize,
|
||||
) -> Self {
|
||||
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
|
||||
.map_or_else(HashMap::new, |prefs| {
|
||||
prefs.into_iter().map(|p| (p.name.clone(), p)).collect()
|
||||
});
|
||||
|
||||
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
|
||||
HashMap::new(),
|
||||
orchestration_model_name,
|
||||
max_token_length,
|
||||
));
|
||||
|
||||
let session_ttl =
|
||||
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
|
||||
|
||||
OrchestratorService {
|
||||
orchestrator_url,
|
||||
client: reqwest::Client::new(),
|
||||
orchestrator_model,
|
||||
orchestrator_provider_name,
|
||||
top_level_preferences,
|
||||
metrics_service,
|
||||
session_cache: Some(session_cache),
|
||||
session_ttl,
|
||||
tenant_header,
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Session cache methods ----
|
||||
|
||||
#[must_use]
|
||||
pub fn tenant_header(&self) -> Option<&str> {
|
||||
self.tenant_header.as_deref()
|
||||
}
|
||||
|
||||
fn session_key<'a>(tenant_id: Option<&str>, session_id: &'a str) -> Cow<'a, str> {
|
||||
match tenant_id {
|
||||
Some(t) => Cow::Owned(format!("{t}:{session_id}")),
|
||||
None => Cow::Borrowed(session_id),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_cached_route(
|
||||
&self,
|
||||
session_id: &str,
|
||||
tenant_id: Option<&str>,
|
||||
) -> Option<CachedRoute> {
|
||||
let cache = self.session_cache.as_ref()?;
|
||||
cache.get(&Self::session_key(tenant_id, session_id)).await
|
||||
}
|
||||
|
||||
pub async fn cache_route(
|
||||
&self,
|
||||
session_id: String,
|
||||
tenant_id: Option<&str>,
|
||||
model_name: String,
|
||||
route_name: Option<String>,
|
||||
) {
|
||||
if let Some(ref cache) = self.session_cache {
|
||||
cache
|
||||
.put(
|
||||
&Self::session_key(tenant_id, &session_id),
|
||||
CachedRoute {
|
||||
model_name,
|
||||
route_name,
|
||||
},
|
||||
self.session_ttl,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// ---- LLM routing ----
|
||||
|
||||
pub async fn determine_route(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
request_id: &str,
|
||||
) -> Result<Option<(String, Vec<String>)>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let inline_top_map: Option<HashMap<String, TopLevelRoutingPreference>> =
|
||||
inline_routing_preferences
|
||||
.map(|prefs| prefs.into_iter().map(|p| (p.name.clone(), p)).collect());
|
||||
|
||||
if inline_top_map.is_none() && self.top_level_preferences.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let effective_source = inline_top_map
|
||||
.as_ref()
|
||||
.unwrap_or(&self.top_level_preferences);
|
||||
|
||||
let effective_prefs: Vec<AgentUsagePreference> = effective_source
|
||||
.values()
|
||||
.map(|p| AgentUsagePreference {
|
||||
model: p.models.first().cloned().unwrap_or_default(),
|
||||
orchestration_preferences: vec![OrchestrationPreference {
|
||||
name: p.name.clone(),
|
||||
description: p.description.clone(),
|
||||
}],
|
||||
})
|
||||
.collect();
|
||||
|
||||
let orchestration_result = self
|
||||
.determine_orchestration(
|
||||
messages,
|
||||
Some(effective_prefs),
|
||||
Some(request_id.to_string()),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let result = if let Some(ref routes) = orchestration_result {
|
||||
if routes.len() > 1 {
|
||||
let all_routes: Vec<&str> = routes.iter().map(|(name, _)| name.as_str()).collect();
|
||||
info!(
|
||||
routes = ?all_routes,
|
||||
using = %all_routes.first().unwrap_or(&"none"),
|
||||
"plano-orchestrator detected multiple intents, using first"
|
||||
);
|
||||
}
|
||||
|
||||
if let Some((route_name, _)) = routes.first() {
|
||||
let top_pref = inline_top_map
|
||||
.as_ref()
|
||||
.and_then(|m| m.get(route_name))
|
||||
.or_else(|| self.top_level_preferences.get(route_name));
|
||||
|
||||
if let Some(pref) = top_pref {
|
||||
let ranked = match &self.metrics_service {
|
||||
Some(svc) => svc.rank_models(&pref.models, &pref.selection_policy).await,
|
||||
None => pref.models.clone(),
|
||||
};
|
||||
Some((route_name.clone(), ranked))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
info!(
|
||||
selected_model = ?result,
|
||||
"plano-orchestrator determined route"
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
// ---- Agent orchestration (existing) ----
|
||||
|
||||
pub async fn determine_orchestration(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
|
|
@ -80,12 +263,12 @@ impl OrchestratorService {
|
|||
debug!(
|
||||
model = %self.orchestrator_model.get_model_name(),
|
||||
endpoint = %self.orchestrator_url,
|
||||
"sending request to arch-orchestrator"
|
||||
"sending request to plano-orchestrator"
|
||||
);
|
||||
|
||||
let body = serde_json::to_string(&orchestrator_request)
|
||||
.map_err(super::orchestrator_model::OrchestratorModelError::from)?;
|
||||
debug!(body = %body, "arch orchestrator request");
|
||||
debug!(body = %body, "plano-orchestrator request");
|
||||
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
|
|
@ -98,7 +281,6 @@ impl OrchestratorService {
|
|||
.unwrap_or_else(|_| header::HeaderValue::from_static("plano-orchestrator")),
|
||||
);
|
||||
|
||||
// Inject OpenTelemetry trace context from current span
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
|
|
@ -130,9 +312,113 @@ impl OrchestratorService {
|
|||
content = %content.replace("\n", "\\n"),
|
||||
selected_routes = ?parsed,
|
||||
response_time_ms = elapsed.as_millis(),
|
||||
"arch-orchestrator determined routes"
|
||||
"plano-orchestrator determined routes"
|
||||
);
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::session_cache::memory::MemorySessionCache;
|
||||
|
||||
fn make_orchestrator_service(ttl_seconds: u64, max_entries: usize) -> OrchestratorService {
|
||||
let session_cache = Arc::new(MemorySessionCache::new(max_entries));
|
||||
OrchestratorService::with_routing(
|
||||
"http://localhost:12001/v1/chat/completions".to_string(),
|
||||
"Plano-Orchestrator".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
None,
|
||||
None,
|
||||
Some(ttl_seconds),
|
||||
session_cache,
|
||||
None,
|
||||
orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_miss_returns_none() {
|
||||
let svc = make_orchestrator_service(600, 100);
|
||||
assert!(svc
|
||||
.get_cached_route("unknown-session", None)
|
||||
.await
|
||||
.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_hit_returns_cached_route() {
|
||||
let svc = make_orchestrator_service(600, 100);
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
None,
|
||||
"gpt-4o".to_string(),
|
||||
Some("code".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let cached = svc.get_cached_route("s1", None).await.unwrap();
|
||||
assert_eq!(cached.model_name, "gpt-4o");
|
||||
assert_eq!(cached.route_name, Some("code".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_expired_entry_returns_none() {
|
||||
let svc = make_orchestrator_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_expired_entries_not_returned() {
|
||||
let svc = make_orchestrator_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), None, "claude".to_string(), None)
|
||||
.await;
|
||||
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
assert!(svc.get_cached_route("s2", None).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_evicts_oldest_when_full() {
|
||||
let svc = make_orchestrator_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), None, "model-a".to_string(), None)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
svc.cache_route("s2".to_string(), None, "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route("s3".to_string(), None, "model-c".to_string(), None)
|
||||
.await;
|
||||
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
assert!(svc.get_cached_route("s2", None).await.is_some());
|
||||
assert!(svc.get_cached_route("s3", None).await.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_update_existing_session_does_not_evict() {
|
||||
let svc = make_orchestrator_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), None, "model-a".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), None, "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
None,
|
||||
"model-a-updated".to_string(),
|
||||
Some("route".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let s1 = svc.get_cached_route("s1", None).await.unwrap();
|
||||
assert_eq!(s1.model_name, "model-a-updated");
|
||||
assert!(svc.get_cached_route("s2", None).await.is_some());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ pub enum OrchestratorModelError {
|
|||
pub type Result<T> = std::result::Result<T, OrchestratorModelError>;
|
||||
|
||||
/// OrchestratorModel trait for handling orchestration requests.
|
||||
/// Unlike RouterModel which returns a single route, OrchestratorModel
|
||||
/// can return multiple routes as the model output format is:
|
||||
/// Returns multiple routes as the model output format is:
|
||||
/// {"route": ["route_name_1", "route_name_2", ...]}
|
||||
pub trait OrchestratorModel: Send + Sync {
|
||||
fn generate_request(
|
||||
|
|
|
|||
|
|
@ -8,7 +8,19 @@ use tracing::{debug, warn};
|
|||
|
||||
use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
|
||||
|
||||
pub const MAX_TOKEN_LEN: usize = 2048; // Default max token length for the orchestration model
|
||||
pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model
|
||||
|
||||
/// Hard cap on the number of recent messages considered when building the
|
||||
/// routing prompt. Bounds prompt growth for long-running conversations and
|
||||
/// acts as an outer guardrail before the token-budget loop runs. The most
|
||||
/// recent `MAX_ROUTING_TURNS` filtered messages are kept; older turns are
|
||||
/// dropped entirely.
|
||||
pub const MAX_ROUTING_TURNS: usize = 16;
|
||||
|
||||
/// Unicode ellipsis used to mark where content was trimmed out of a long
|
||||
/// message. Helps signal to the downstream router model that the message was
|
||||
/// truncated.
|
||||
const TRIM_MARKER: &str = "…";
|
||||
|
||||
/// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python
|
||||
struct SpacedJsonFormatter;
|
||||
|
|
@ -176,10 +188,9 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
messages: &[Message],
|
||||
usage_preferences_from_request: &Option<Vec<AgentUsagePreference>>,
|
||||
) -> ChatCompletionsRequest {
|
||||
// remove system prompt, tool calls, tool call response and messages without content
|
||||
// if content is empty its likely a tool call
|
||||
// when role == tool its tool call response
|
||||
let messages_vec = messages
|
||||
// Remove system/developer/tool messages and messages without extractable
|
||||
// text (tool calls have no text content we can classify against).
|
||||
let filtered: Vec<&Message> = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.role != Role::System
|
||||
|
|
@ -187,37 +198,72 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
&& m.role != Role::Tool
|
||||
&& !m.content.extract_text().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
.collect();
|
||||
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
|
||||
// Outer guardrail: only consider the last `MAX_ROUTING_TURNS` filtered
|
||||
// messages when building the routing prompt. Keeps prompt growth
|
||||
// predictable for long conversations regardless of per-message size.
|
||||
let start = filtered.len().saturating_sub(MAX_ROUTING_TURNS);
|
||||
let messages_vec: &[&Message] = &filtered[start..];
|
||||
|
||||
// Ensure the conversation does not exceed the configured token budget.
|
||||
// We use `len() / TOKEN_LENGTH_DIVISOR` as a cheap token estimate to
|
||||
// avoid running a real tokenizer on the hot path.
|
||||
let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
let mut selected_messages_list_reversed: Vec<Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
let message_text = message.content.extract_text();
|
||||
let message_token_count = message_text.len() / TOKEN_LENGTH_DIVISOR;
|
||||
if token_count + message_token_count > self.max_token_length {
|
||||
let remaining_tokens = self.max_token_length.saturating_sub(token_count);
|
||||
debug!(
|
||||
token_count = token_count,
|
||||
attempted_total_tokens = token_count + message_token_count,
|
||||
max_tokens = self.max_token_length,
|
||||
remaining_tokens,
|
||||
selected = selected_messsage_count,
|
||||
total = messages_vec.len(),
|
||||
"token count exceeds max, truncating conversation"
|
||||
);
|
||||
if message.role == Role::User {
|
||||
// If message that exceeds max token length is from user, we need to keep it
|
||||
selected_messages_list_reversed.push(message);
|
||||
// If the overflow message is from the user we need to keep
|
||||
// some of it so the orchestrator still sees the latest user
|
||||
// intent. Use a middle-trim (head + ellipsis + tail): users
|
||||
// often frame the task at the start AND put the actual ask
|
||||
// at the end of a long pasted block, so preserving both is
|
||||
// better than a head-only cut. The ellipsis also signals to
|
||||
// the router model that content was dropped.
|
||||
if message.role == Role::User && remaining_tokens > 0 {
|
||||
let max_bytes = remaining_tokens.saturating_mul(TOKEN_LENGTH_DIVISOR);
|
||||
let truncated = trim_middle_utf8(&message_text, max_bytes);
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: Role::User,
|
||||
content: Some(MessageContent::Text(truncated)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
// If we are here, it means that the message is within the max token length
|
||||
selected_messages_list_reversed.push(message);
|
||||
token_count += message_token_count;
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: message.role.clone(),
|
||||
content: Some(MessageContent::Text(message_text)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
if selected_messages_list_reversed.is_empty() {
|
||||
debug!("no messages selected, using last message");
|
||||
if let Some(last_message) = messages_vec.last() {
|
||||
selected_messages_list_reversed.push(last_message);
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: last_message.role.clone(),
|
||||
content: Some(MessageContent::Text(last_message.content.extract_text())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -237,22 +283,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
}
|
||||
|
||||
// Reverse the selected messages to maintain the conversation order
|
||||
let selected_conversation_list = selected_messages_list_reversed
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|message| Message {
|
||||
role: message.role.clone(),
|
||||
content: Some(MessageContent::Text(
|
||||
message
|
||||
.content
|
||||
.as_ref()
|
||||
.map_or(String::new(), |c| c.to_string()),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
let selected_conversation_list: Vec<Message> =
|
||||
selected_messages_list_reversed.into_iter().rev().collect();
|
||||
|
||||
// Generate the orchestrator request message based on the usage preferences.
|
||||
// If preferences are passed in request then we use them;
|
||||
|
|
@ -405,6 +437,45 @@ fn fix_json_response(body: &str) -> String {
|
|||
body.replace("'", "\"").replace("\\n", "")
|
||||
}
|
||||
|
||||
/// Truncate `s` so the result is at most `max_bytes` bytes long, keeping
|
||||
/// roughly 60% from the start and 40% from the end, with a Unicode ellipsis
|
||||
/// separating the two. All splits respect UTF-8 character boundaries. When
|
||||
/// `max_bytes` is too small to fit the marker at all, falls back to a
|
||||
/// head-only truncation.
|
||||
fn trim_middle_utf8(s: &str, max_bytes: usize) -> String {
|
||||
if s.len() <= max_bytes {
|
||||
return s.to_string();
|
||||
}
|
||||
if max_bytes <= TRIM_MARKER.len() {
|
||||
// Not enough room even for the marker — just keep the start.
|
||||
let mut end = max_bytes;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
return s[..end].to_string();
|
||||
}
|
||||
|
||||
let available = max_bytes - TRIM_MARKER.len();
|
||||
// Bias toward the start (60%) where task framing typically lives, while
|
||||
// still preserving ~40% of the tail where the user's actual ask often
|
||||
// appears after a long paste.
|
||||
let mut start_len = available * 3 / 5;
|
||||
while start_len > 0 && !s.is_char_boundary(start_len) {
|
||||
start_len -= 1;
|
||||
}
|
||||
let end_len = available - start_len;
|
||||
let mut end_start = s.len().saturating_sub(end_len);
|
||||
while end_start < s.len() && !s.is_char_boundary(end_start) {
|
||||
end_start += 1;
|
||||
}
|
||||
|
||||
let mut out = String::with_capacity(start_len + TRIM_MARKER.len() + (s.len() - end_start));
|
||||
out.push_str(&s[..start_len]);
|
||||
out.push_str(TRIM_MARKER);
|
||||
out.push_str(&s[end_start..]);
|
||||
out
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn OrchestratorModel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "OrchestratorModel")
|
||||
|
|
@ -777,6 +848,10 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
#[test]
|
||||
fn test_conversation_trim_upto_user_message() {
|
||||
// With max_token_length=230, the older user message "given the image
|
||||
// In style of Andy Warhol" overflows the remaining budget and gets
|
||||
// middle-trimmed (head + ellipsis + tail) until it fits. Newer turns
|
||||
// are kept in full.
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant that selects the most suitable routes based on user intent.
|
||||
You are provided with a list of available routes enclosed within <routes></routes> XML tags:
|
||||
|
|
@ -789,7 +864,7 @@ You are also given the conversation context enclosed within <conversation></conv
|
|||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol"
|
||||
"content": "given…rhol"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
|
|
@ -862,6 +937,190 @@ If no routes are needed, return an empty list for `route`.
|
|||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_huge_single_user_message_is_middle_trimmed() {
|
||||
// Regression test for the case where a single, extremely large user
|
||||
// message was being passed to the orchestrator verbatim and blowing
|
||||
// past the upstream model's context window. The trimmer must now
|
||||
// middle-trim (head + ellipsis + tail) the oversized message so the
|
||||
// resulting request stays within the configured budget, and the
|
||||
// trim marker must be present so the router model knows content
|
||||
// was dropped.
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let max_token_length = 2048;
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
"test-model".to_string(),
|
||||
max_token_length,
|
||||
);
|
||||
|
||||
// ~500KB of content — same scale as the real payload that triggered
|
||||
// the production upstream 400.
|
||||
let head = "HEAD_MARKER_START ";
|
||||
let tail = " TAIL_MARKER_END";
|
||||
let filler = "A".repeat(500_000);
|
||||
let huge_user_content = format!("{head}{filler}{tail}");
|
||||
|
||||
let conversation = vec![Message {
|
||||
role: Role::User,
|
||||
content: Some(MessageContent::Text(huge_user_content.clone())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
// Prompt must stay bounded. Generous ceiling = budget-in-bytes +
|
||||
// scaffolding + slack. Real result should be well under this.
|
||||
let byte_ceiling = max_token_length * TOKEN_LENGTH_DIVISOR
|
||||
+ ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len()
|
||||
+ 1024;
|
||||
assert!(
|
||||
prompt.len() < byte_ceiling,
|
||||
"prompt length {} exceeded ceiling {} — truncation did not apply",
|
||||
prompt.len(),
|
||||
byte_ceiling,
|
||||
);
|
||||
|
||||
// Not all 500k filler chars survive.
|
||||
let a_count = prompt.chars().filter(|c| *c == 'A').count();
|
||||
assert!(
|
||||
a_count < filler.len(),
|
||||
"expected user message to be truncated; all {} 'A's survived",
|
||||
a_count
|
||||
);
|
||||
assert!(
|
||||
a_count > 0,
|
||||
"expected some of the user message to survive truncation"
|
||||
);
|
||||
|
||||
// Head and tail of the message must both be preserved (that's the
|
||||
// whole point of middle-trim over head-only).
|
||||
assert!(
|
||||
prompt.contains(head),
|
||||
"head marker missing — head was not preserved"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains(tail),
|
||||
"tail marker missing — tail was not preserved"
|
||||
);
|
||||
|
||||
// Trim marker must be present so the router model can see that
|
||||
// content was omitted.
|
||||
assert!(
|
||||
prompt.contains(TRIM_MARKER),
|
||||
"ellipsis trim marker missing from truncated prompt"
|
||||
);
|
||||
|
||||
// Routing prompt scaffolding remains intact.
|
||||
assert!(prompt.contains("<conversation>"));
|
||||
assert!(prompt.contains("<routes>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_turn_cap_limits_routing_history() {
|
||||
// The outer turn-cap guardrail should keep only the last
|
||||
// `MAX_ROUTING_TURNS` filtered messages regardless of how long the
|
||||
// conversation is. We build a conversation with alternating
|
||||
// user/assistant turns tagged with their index and verify that only
|
||||
// the tail of the conversation makes it into the prompt.
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), usize::MAX);
|
||||
|
||||
let mut conversation: Vec<Message> = Vec::new();
|
||||
let total_turns = MAX_ROUTING_TURNS * 2; // well past the cap
|
||||
for i in 0..total_turns {
|
||||
let role = if i % 2 == 0 {
|
||||
Role::User
|
||||
} else {
|
||||
Role::Assistant
|
||||
};
|
||||
conversation.push(Message {
|
||||
role,
|
||||
content: Some(MessageContent::Text(format!("turn-{i:03}"))),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
// The last MAX_ROUTING_TURNS messages (indexes total-cap..total)
|
||||
// must all appear.
|
||||
for i in (total_turns - MAX_ROUTING_TURNS)..total_turns {
|
||||
let tag = format!("turn-{i:03}");
|
||||
assert!(
|
||||
prompt.contains(&tag),
|
||||
"expected recent turn tag {tag} to be present"
|
||||
);
|
||||
}
|
||||
|
||||
// And earlier turns (indexes 0..total-cap) must all be dropped.
|
||||
for i in 0..(total_turns - MAX_ROUTING_TURNS) {
|
||||
let tag = format!("turn-{i:03}");
|
||||
assert!(
|
||||
!prompt.contains(&tag),
|
||||
"old turn tag {tag} leaked past turn cap into the prompt"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trim_middle_utf8_helper() {
|
||||
// No-op when already small enough.
|
||||
assert_eq!(trim_middle_utf8("hello", 100), "hello");
|
||||
assert_eq!(trim_middle_utf8("hello", 5), "hello");
|
||||
|
||||
// 60/40 split with ellipsis when too long.
|
||||
let long = "a".repeat(20);
|
||||
let out = trim_middle_utf8(&long, 10);
|
||||
assert!(out.len() <= 10);
|
||||
assert!(out.contains(TRIM_MARKER));
|
||||
// Exactly one ellipsis, rest are 'a's.
|
||||
assert_eq!(out.matches(TRIM_MARKER).count(), 1);
|
||||
assert!(out.chars().filter(|c| *c == 'a').count() > 0);
|
||||
|
||||
// When max_bytes is smaller than the marker, falls back to
|
||||
// head-only truncation (no marker).
|
||||
let out = trim_middle_utf8("abcdefgh", 2);
|
||||
assert_eq!(out, "ab");
|
||||
|
||||
// UTF-8 boundary safety: 2-byte chars.
|
||||
let s = "é".repeat(50); // 100 bytes
|
||||
let out = trim_middle_utf8(&s, 25);
|
||||
assert!(out.len() <= 25);
|
||||
// Must still be valid UTF-8 that only contains 'é' and the marker.
|
||||
let ok = out.chars().all(|c| c == 'é' || c == '…');
|
||||
assert!(ok, "unexpected char in trimmed output: {out:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_text_input() {
|
||||
let expected_prompt = r#"
|
||||
|
|
|
|||
|
|
@ -1,39 +0,0 @@
|
|||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RoutingModelError {
|
||||
#[error("Failed to parse JSON: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
||||
|
||||
/// Internal route descriptor passed to the router model to build its prompt.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoutingPreference {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// Groups a model with its routing preferences (used internally by RouterModelV1).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelUsagePreference {
|
||||
pub model: String,
|
||||
pub routing_preferences: Vec<RoutingPreference>,
|
||||
}
|
||||
|
||||
pub trait RouterModel: Send + Sync {
|
||||
fn generate_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> ChatCompletionsRequest;
|
||||
fn parse_response(
|
||||
&self,
|
||||
content: &str,
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> Result<Option<(String, String)>>;
|
||||
fn get_model_name(&self) -> String;
|
||||
}
|
||||
|
|
@ -1,842 +0,0 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use super::router_model::{ModelUsagePreference, RoutingPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hermesllm::transforms::lib::ExtractText;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::router_model::{RouterModel, RoutingModelError};
|
||||
|
||||
pub const MAX_TOKEN_LEN: usize = 2048; // Default max token length for the routing model
|
||||
pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
{routes}
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
{conversation}
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
||||
pub struct RouterModelV1 {
|
||||
llm_route_json_str: String,
|
||||
llm_route_to_model_map: HashMap<String, String>,
|
||||
routing_model: String,
|
||||
max_token_length: usize,
|
||||
}
|
||||
impl RouterModelV1 {
|
||||
pub fn new(
|
||||
llm_routes: HashMap<String, Vec<RoutingPreference>>,
|
||||
routing_model: String,
|
||||
max_token_length: usize,
|
||||
) -> Self {
|
||||
let llm_route_values: Vec<RoutingPreference> =
|
||||
llm_routes.values().flatten().cloned().collect();
|
||||
let llm_route_json_str =
|
||||
serde_json::to_string(&llm_route_values).unwrap_or_else(|_| "[]".to_string());
|
||||
let llm_route_to_model_map: HashMap<String, String> = llm_routes
|
||||
.iter()
|
||||
.flat_map(|(model, prefs)| prefs.iter().map(|pref| (pref.name.clone(), model.clone())))
|
||||
.collect();
|
||||
|
||||
RouterModelV1 {
|
||||
routing_model,
|
||||
max_token_length,
|
||||
llm_route_json_str,
|
||||
llm_route_to_model_map,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct LlmRouterResponse {
|
||||
pub route: Option<String>,
|
||||
}
|
||||
|
||||
const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters
|
||||
|
||||
impl RouterModel for RouterModelV1 {
|
||||
fn generate_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences_from_request: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> ChatCompletionsRequest {
|
||||
// remove system prompt, tool calls, tool call response and messages without content
|
||||
// if content is empty its likely a tool call
|
||||
// when role == tool its tool call response
|
||||
let messages_vec = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.role != Role::System
|
||||
&& m.role != Role::Developer
|
||||
&& m.role != Role::Tool
|
||||
&& !m.content.extract_text().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
|
||||
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
debug!(
|
||||
token_count = token_count,
|
||||
max_tokens = self.max_token_length,
|
||||
selected = selected_messsage_count,
|
||||
total = messages_vec.len(),
|
||||
"token count exceeds max, truncating conversation"
|
||||
);
|
||||
if message.role == Role::User {
|
||||
// If message that exceeds max token length is from user, we need to keep it
|
||||
selected_messages_list_reversed.push(message);
|
||||
}
|
||||
break;
|
||||
}
|
||||
// If we are here, it means that the message is within the max token length
|
||||
selected_messages_list_reversed.push(message);
|
||||
}
|
||||
|
||||
if selected_messages_list_reversed.is_empty() {
|
||||
debug!("no messages selected, using last message");
|
||||
if let Some(last_message) = messages_vec.last() {
|
||||
selected_messages_list_reversed.push(last_message);
|
||||
}
|
||||
}
|
||||
|
||||
// ensure that first and last selected message is from user
|
||||
if let Some(first_message) = selected_messages_list_reversed.first() {
|
||||
if first_message.role != Role::User {
|
||||
warn!("last message is not from user, may lead to incorrect routing");
|
||||
}
|
||||
}
|
||||
if let Some(last_message) = selected_messages_list_reversed.last() {
|
||||
if last_message.role != Role::User {
|
||||
warn!("first message is not from user, may lead to incorrect routing");
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse the selected messages to maintain the conversation order
|
||||
let selected_conversation_list = selected_messages_list_reversed
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|message| {
|
||||
Message {
|
||||
role: message.role.clone(),
|
||||
// we can unwrap here because we have already filtered out messages without content
|
||||
content: Some(MessageContent::Text(
|
||||
message
|
||||
.content
|
||||
.as_ref()
|
||||
.map_or(String::new(), |c| c.to_string()),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
|
||||
// Generate the router request message based on the usage preferences.
|
||||
// If preferences are passed in request then we use them otherwise we use the default routing model preferences.
|
||||
let router_message = match convert_to_router_preferences(usage_preferences_from_request) {
|
||||
Some(prefs) => generate_router_message(&prefs, &selected_conversation_list),
|
||||
None => generate_router_message(&self.llm_route_json_str, &selected_conversation_list),
|
||||
};
|
||||
|
||||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: Some(MessageContent::Text(router_message)),
|
||||
role: Role::User,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
temperature: Some(0.01),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(
|
||||
&self,
|
||||
content: &str,
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> Result<Option<(String, String)>> {
|
||||
if content.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
let router_resp_fixed = fix_json_response(content);
|
||||
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
|
||||
|
||||
let selected_route = router_response.route.unwrap_or_default().to_string();
|
||||
|
||||
if selected_route.is_empty() || selected_route == "other" {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(usage_preferences) = usage_preferences {
|
||||
// If usage preferences are defined, we need to find the model that matches the selected route
|
||||
let model_name: Option<String> = usage_preferences
|
||||
.iter()
|
||||
.map(|pref| {
|
||||
pref.routing_preferences
|
||||
.iter()
|
||||
.find(|routing_pref| routing_pref.name == selected_route)
|
||||
.map(|_| pref.model.clone())
|
||||
})
|
||||
.find_map(|model| model);
|
||||
|
||||
if let Some(model_name) = model_name {
|
||||
return Ok(Some((selected_route, model_name)));
|
||||
} else {
|
||||
warn!(
|
||||
route = %selected_route,
|
||||
preferences = ?usage_preferences,
|
||||
"no matching model found for route"
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
// If no usage preferences are passed in request then use the default routing model preferences
|
||||
if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() {
|
||||
return Ok(Some((selected_route, model)));
|
||||
}
|
||||
|
||||
warn!(
|
||||
route = %selected_route,
|
||||
preferences = ?self.llm_route_to_model_map,
|
||||
"no model found for route"
|
||||
);
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn get_model_name(&self) -> String {
|
||||
self.routing_model.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_router_message(prefs: &str, selected_conversation_list: &Vec<Message>) -> String {
|
||||
ARCH_ROUTER_V1_SYSTEM_PROMPT
|
||||
.replace("{routes}", prefs)
|
||||
.replace(
|
||||
"{conversation}",
|
||||
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn convert_to_router_preferences(
|
||||
prefs_from_request: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> Option<String> {
|
||||
if let Some(usage_preferences) = prefs_from_request {
|
||||
let routing_preferences = usage_preferences
|
||||
.iter()
|
||||
.flat_map(|pref| {
|
||||
pref.routing_preferences
|
||||
.iter()
|
||||
.map(|routing_pref| RoutingPreference {
|
||||
name: routing_pref.name.clone(),
|
||||
description: routing_pref.description.clone(),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<RoutingPreference>>();
|
||||
|
||||
return Some(serde_json::to_string(&routing_preferences).unwrap_or_default());
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn fix_json_response(body: &str) -> String {
|
||||
let mut updated_body = body.to_string();
|
||||
|
||||
updated_body = updated_body.replace("'", "\"");
|
||||
|
||||
if updated_body.contains("\\n") {
|
||||
updated_body = updated_body.replace("\\n", "");
|
||||
}
|
||||
|
||||
if updated_body.starts_with("```json") {
|
||||
updated_body = updated_body
|
||||
.strip_prefix("```json")
|
||||
.unwrap_or(&updated_body)
|
||||
.to_string();
|
||||
}
|
||||
|
||||
if updated_body.ends_with("```") {
|
||||
updated_body = updated_body
|
||||
.strip_suffix("```")
|
||||
.unwrap_or(&updated_body)
|
||||
.to_string();
|
||||
}
|
||||
|
||||
updated_body
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn RouterModel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "RouterModel")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_format() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_format_usage_preferences() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"code-generation","description":"generating new code snippets, functions, or boilerplate based on user prompts or requirements"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let usage_preferences = Some(vec![ModelUsagePreference {
|
||||
model: "claude/claude-3-7-sonnet".to_string(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: "code-generation".to_string(),
|
||||
description: "generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string(),
|
||||
}],
|
||||
}]);
|
||||
let req = router.generate_request(&conversation, &usage_preferences);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_exceed_token_count() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 235);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_exceed_token_count_large_single_message() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson and this is a very long message that exceeds the max token length of the routing model, so it should be truncated and only the last user message should be included in the conversation for routing."}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 200);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson and this is a very long message that exceeds the max token length of the routing model, so it should be truncated and only the last user message should be included in the conversation for routing."
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_trim_upto_user_message() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"given the image In style of Andy Warhol"},{"role":"assistant","content":"ok here is the image"},{"role":"user","content":"pls give me another image about Bart and Lisa"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 230);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "ok here is the image"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "pls give me another image about Bart and Lisa"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_text_input() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hi"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image.png"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skip_tool_call() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"What's the weather like in Tokyo?"},{"role":"assistant","content":"The current weather in Tokyo is 22°C and sunny."},{"role":"user","content":"What about in New York?"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Tokyo?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "toolcall-abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{ \"location\": \"Tokyo\" }"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "toolcall-abc123",
|
||||
"content": "{ \"temperature\": \"22°C\", \"condition\": \"Sunny\" }"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The current weather in Tokyo is 22°C and sunny."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What about in New York?"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
// expects conversation to look like this
|
||||
|
||||
// [
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "What's the weather like in Tokyo?"
|
||||
// },
|
||||
// {
|
||||
// "role": "assistant",
|
||||
// "content": "The current weather in Tokyo is 22°C and sunny."
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "What about in New York?"
|
||||
// }
|
||||
// ]
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req: ChatCompletionsRequest = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
|
||||
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
|
||||
|
||||
// Case 1: Valid JSON with non-empty route
|
||||
let input = r#"{"route": "Image generation"}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
|
||||
// Case 2: Valid JSON with empty route
|
||||
let input = r#"{"route": ""}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 3: Valid JSON with null route
|
||||
let input = r#"{"route": null}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 4: JSON missing route field
|
||||
let input = r#"{}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 4.1: empty string
|
||||
let input = r#""#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 5: Malformed JSON
|
||||
let input = r#"{"route": "route1""#; // missing closing }
|
||||
let result = router.parse_response(input, &None);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Case 6: Single quotes and \n in JSON
|
||||
let input = "{'route': 'Image generation'}\\n";
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
|
||||
// Case 7: Code block marker
|
||||
let input = "```json\n{\"route\": \"Image generation\"}\n```";
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
}
|
||||
}
|
||||
82
crates/brightstaff/src/session_cache/memory.rs
Normal file
82
crates/brightstaff/src/session_cache/memory.rs
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
use std::{
|
||||
num::NonZeroUsize,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use lru::LruCache;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::info;
|
||||
|
||||
use super::{CachedRoute, SessionCache};
|
||||
|
||||
type CacheStore = Mutex<LruCache<String, (CachedRoute, Instant, Duration)>>;
|
||||
|
||||
pub struct MemorySessionCache {
|
||||
store: Arc<CacheStore>,
|
||||
}
|
||||
|
||||
impl MemorySessionCache {
|
||||
pub fn new(max_entries: usize) -> Self {
|
||||
let capacity = NonZeroUsize::new(max_entries)
|
||||
.unwrap_or_else(|| NonZeroUsize::new(10_000).expect("10_000 is non-zero"));
|
||||
let store = Arc::new(Mutex::new(LruCache::new(capacity)));
|
||||
|
||||
// Spawn a background task to evict TTL-expired entries every 5 minutes.
|
||||
let store_clone = Arc::clone(&store);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(300));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
Self::evict_expired(&store_clone).await;
|
||||
}
|
||||
});
|
||||
|
||||
Self { store }
|
||||
}
|
||||
|
||||
async fn evict_expired(store: &CacheStore) {
|
||||
let mut cache = store.lock().await;
|
||||
let expired: Vec<String> = cache
|
||||
.iter()
|
||||
.filter(|(_, (_, inserted_at, ttl))| inserted_at.elapsed() >= *ttl)
|
||||
.map(|(k, _)| k.clone())
|
||||
.collect();
|
||||
let removed = expired.len();
|
||||
for key in &expired {
|
||||
cache.pop(key.as_str());
|
||||
}
|
||||
if removed > 0 {
|
||||
info!(
|
||||
removed = removed,
|
||||
remaining = cache.len(),
|
||||
"cleaned up expired session cache entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionCache for MemorySessionCache {
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute> {
|
||||
let mut cache = self.store.lock().await;
|
||||
if let Some((route, inserted_at, ttl)) = cache.get(key) {
|
||||
if inserted_at.elapsed() < *ttl {
|
||||
return Some(route.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration) {
|
||||
self.store
|
||||
.lock()
|
||||
.await
|
||||
.put(key.to_string(), (route, Instant::now(), ttl));
|
||||
}
|
||||
|
||||
async fn remove(&self, key: &str) {
|
||||
self.store.lock().await.pop(key);
|
||||
}
|
||||
}
|
||||
70
crates/brightstaff/src/session_cache/mod.rs
Normal file
70
crates/brightstaff/src/session_cache/mod.rs
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common::configuration::Configuration;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, info};
|
||||
|
||||
pub mod memory;
|
||||
pub mod redis;
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CachedRoute {
|
||||
pub model_name: String,
|
||||
pub route_name: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait SessionCache: Send + Sync {
|
||||
/// Look up a cached routing decision by key.
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute>;
|
||||
|
||||
/// Store a routing decision in the session cache with the given TTL.
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration);
|
||||
|
||||
/// Remove a cached routing decision by key.
|
||||
async fn remove(&self, key: &str);
|
||||
}
|
||||
|
||||
/// Initialize the session cache backend from config.
|
||||
/// Defaults to the in-memory backend when no `session_cache` block is configured.
|
||||
pub async fn init_session_cache(
|
||||
config: &Configuration,
|
||||
) -> Result<Arc<dyn SessionCache>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use common::configuration::SessionCacheType;
|
||||
|
||||
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
|
||||
|
||||
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
|
||||
let max_entries = session_max_entries
|
||||
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
|
||||
.min(MAX_SESSION_MAX_ENTRIES);
|
||||
|
||||
let cache_config = config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.session_cache.as_ref());
|
||||
|
||||
let cache_type = cache_config
|
||||
.map(|c| &c.cache_type)
|
||||
.unwrap_or(&SessionCacheType::Memory);
|
||||
|
||||
match cache_type {
|
||||
SessionCacheType::Memory => {
|
||||
info!(storage_type = "memory", "initialized session cache");
|
||||
Ok(Arc::new(memory::MemorySessionCache::new(max_entries)))
|
||||
}
|
||||
SessionCacheType::Redis => {
|
||||
let url = cache_config
|
||||
.and_then(|c| c.url.as_ref())
|
||||
.ok_or("session_cache.url is required when type is redis")?;
|
||||
debug!(storage_type = "redis", url = %url, "initializing session cache");
|
||||
let cache = redis::RedisSessionCache::new(url)
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to Redis session cache: {e}"))?;
|
||||
Ok(Arc::new(cache))
|
||||
}
|
||||
}
|
||||
}
|
||||
48
crates/brightstaff/src/session_cache/redis.rs
Normal file
48
crates/brightstaff/src/session_cache/redis.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use redis::aio::MultiplexedConnection;
|
||||
use redis::AsyncCommands;
|
||||
|
||||
use super::{CachedRoute, SessionCache};
|
||||
|
||||
const KEY_PREFIX: &str = "plano:affinity:";
|
||||
|
||||
pub struct RedisSessionCache {
|
||||
conn: MultiplexedConnection,
|
||||
}
|
||||
|
||||
impl RedisSessionCache {
|
||||
pub async fn new(url: &str) -> Result<Self, redis::RedisError> {
|
||||
let client = redis::Client::open(url)?;
|
||||
let conn = client.get_multiplexed_async_connection().await?;
|
||||
Ok(Self { conn })
|
||||
}
|
||||
|
||||
fn make_key(key: &str) -> String {
|
||||
format!("{KEY_PREFIX}{key}")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionCache for RedisSessionCache {
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute> {
|
||||
let mut conn = self.conn.clone();
|
||||
let value: Option<String> = conn.get(Self::make_key(key)).await.ok()?;
|
||||
value.and_then(|v| serde_json::from_str(&v).ok())
|
||||
}
|
||||
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration) {
|
||||
let mut conn = self.conn.clone();
|
||||
let Ok(json) = serde_json::to_string(&route) else {
|
||||
return;
|
||||
};
|
||||
let ttl_secs = ttl.as_secs().max(1);
|
||||
let _: Result<(), _> = conn.set_ex(Self::make_key(key), json, ttl_secs).await;
|
||||
}
|
||||
|
||||
async fn remove(&self, key: &str) {
|
||||
let mut conn = self.conn.clone();
|
||||
let _: Result<(), _> = conn.del(Self::make_key(key)).await;
|
||||
}
|
||||
}
|
||||
|
|
@ -1250,7 +1250,7 @@ impl TextBasedSignalAnalyzer {
|
|||
let mut repair_phrases = Vec::new();
|
||||
let mut user_turn_count = 0;
|
||||
|
||||
for (i, role, norm_msg) in normalized_messages {
|
||||
for (pos, (i, role, norm_msg)) in normalized_messages.iter().enumerate() {
|
||||
if *role != Role::User {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -1274,10 +1274,13 @@ impl TextBasedSignalAnalyzer {
|
|||
}
|
||||
}
|
||||
|
||||
// Only check for semantic similarity if no pattern matched
|
||||
if !found_in_turn && *i >= 2 {
|
||||
// Find previous user message
|
||||
for j in (0..*i).rev() {
|
||||
// Only check for semantic similarity if no pattern matched. Walk
|
||||
// backwards through the *normalized* list (not the original
|
||||
// conversation indices, which may be non-contiguous because
|
||||
// messages without extractable text are filtered out) to find the
|
||||
// most recent prior user message.
|
||||
if !found_in_turn && pos >= 1 {
|
||||
for j in (0..pos).rev() {
|
||||
let (_, prev_role, prev_norm_msg) = &normalized_messages[j];
|
||||
if *prev_role == Role::User {
|
||||
if self.is_similar_rephrase(norm_msg, prev_norm_msg) {
|
||||
|
|
@ -2199,6 +2202,68 @@ mod tests {
|
|||
println!("test_follow_up_detection took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_follow_up_does_not_panic_with_filtered_messages() {
|
||||
// Regression test: the preprocessing pipeline filters out messages
|
||||
// without extractable text (tool calls, tool results, empty content).
|
||||
// The stored tuple index `i` is the ORIGINAL-conversation index, so
|
||||
// once anything is filtered out, `i` no longer matches the position
|
||||
// inside `normalized_messages`. The old code used `*i` to index into
|
||||
// `normalized_messages`, which panicked with "index out of bounds"
|
||||
// when a user message appeared after filtered entries.
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: Some(hermesllm::apis::openai::MessageContent::Text(
|
||||
"first question".to_string(),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
// Assistant message with no text content (e.g. tool call) — filtered out.
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
// Tool-role message with no extractable text — filtered out.
|
||||
Message {
|
||||
role: Role::Tool,
|
||||
content: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(hermesllm::apis::openai::MessageContent::Text(
|
||||
"some answer".to_string(),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
// Rephrased user turn — original index 4, but after filtering
|
||||
// only 3 messages remain in `normalized_messages` before it.
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: Some(hermesllm::apis::openai::MessageContent::Text(
|
||||
"first question please".to_string(),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Must not panic — exercises the full analyze pipeline.
|
||||
let _report = analyzer.analyze(&messages);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frustration_detection() {
|
||||
let start = Instant::now();
|
||||
|
|
|
|||
|
|
@ -16,10 +16,131 @@ use tracing_opentelemetry::OpenTelemetrySpanExt;
|
|||
use crate::handlers::agents::pipeline::{PipelineError, PipelineProcessor};
|
||||
|
||||
const STREAM_BUFFER_SIZE: usize = 16;
|
||||
/// Cap on accumulated response bytes kept for usage extraction.
|
||||
/// Most chat responses are well under this; pathological ones are dropped without
|
||||
/// affecting pass-through streaming to the client.
|
||||
const USAGE_BUFFER_MAX: usize = 2 * 1024 * 1024;
|
||||
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
||||
use crate::tracing::{llm, set_service_name, signals as signal_constants};
|
||||
use hermesllm::apis::openai::Message;
|
||||
|
||||
/// Parsed usage + resolved-model details from a provider response.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
struct ExtractedUsage {
|
||||
prompt_tokens: Option<i64>,
|
||||
completion_tokens: Option<i64>,
|
||||
total_tokens: Option<i64>,
|
||||
cached_input_tokens: Option<i64>,
|
||||
cache_creation_tokens: Option<i64>,
|
||||
reasoning_tokens: Option<i64>,
|
||||
/// The model the upstream actually used. For router aliases (e.g.
|
||||
/// `router:software-engineering`), this differs from the request model.
|
||||
resolved_model: Option<String>,
|
||||
}
|
||||
|
||||
impl ExtractedUsage {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.prompt_tokens.is_none()
|
||||
&& self.completion_tokens.is_none()
|
||||
&& self.total_tokens.is_none()
|
||||
&& self.resolved_model.is_none()
|
||||
}
|
||||
|
||||
fn from_json(value: &serde_json::Value) -> Self {
|
||||
let mut out = Self::default();
|
||||
if let Some(model) = value.get("model").and_then(|v| v.as_str()) {
|
||||
if !model.is_empty() {
|
||||
out.resolved_model = Some(model.to_string());
|
||||
}
|
||||
}
|
||||
if let Some(u) = value.get("usage") {
|
||||
// OpenAI-shape usage
|
||||
out.prompt_tokens = u.get("prompt_tokens").and_then(|v| v.as_i64());
|
||||
out.completion_tokens = u.get("completion_tokens").and_then(|v| v.as_i64());
|
||||
out.total_tokens = u.get("total_tokens").and_then(|v| v.as_i64());
|
||||
out.cached_input_tokens = u
|
||||
.get("prompt_tokens_details")
|
||||
.and_then(|d| d.get("cached_tokens"))
|
||||
.and_then(|v| v.as_i64());
|
||||
out.reasoning_tokens = u
|
||||
.get("completion_tokens_details")
|
||||
.and_then(|d| d.get("reasoning_tokens"))
|
||||
.and_then(|v| v.as_i64());
|
||||
|
||||
// Anthropic-shape fallbacks
|
||||
if out.prompt_tokens.is_none() {
|
||||
out.prompt_tokens = u.get("input_tokens").and_then(|v| v.as_i64());
|
||||
}
|
||||
if out.completion_tokens.is_none() {
|
||||
out.completion_tokens = u.get("output_tokens").and_then(|v| v.as_i64());
|
||||
}
|
||||
if out.total_tokens.is_none() {
|
||||
if let (Some(p), Some(c)) = (out.prompt_tokens, out.completion_tokens) {
|
||||
out.total_tokens = Some(p + c);
|
||||
}
|
||||
}
|
||||
if out.cached_input_tokens.is_none() {
|
||||
out.cached_input_tokens = u.get("cache_read_input_tokens").and_then(|v| v.as_i64());
|
||||
}
|
||||
if out.cached_input_tokens.is_none() {
|
||||
out.cached_input_tokens =
|
||||
u.get("cached_content_token_count").and_then(|v| v.as_i64());
|
||||
}
|
||||
out.cache_creation_tokens = u
|
||||
.get("cache_creation_input_tokens")
|
||||
.and_then(|v| v.as_i64());
|
||||
if out.reasoning_tokens.is_none() {
|
||||
out.reasoning_tokens = u.get("thoughts_token_count").and_then(|v| v.as_i64());
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to pull usage out of an accumulated response body.
|
||||
/// Handles both a single JSON object (non-streaming) and SSE streams where the
|
||||
/// final `data: {...}` event carries the `usage` field.
|
||||
fn extract_usage_from_bytes(buf: &[u8]) -> ExtractedUsage {
|
||||
if buf.is_empty() {
|
||||
return ExtractedUsage::default();
|
||||
}
|
||||
|
||||
// Fast path: full-body JSON (non-streaming).
|
||||
if let Ok(value) = serde_json::from_slice::<serde_json::Value>(buf) {
|
||||
let u = ExtractedUsage::from_json(&value);
|
||||
if !u.is_empty() {
|
||||
return u;
|
||||
}
|
||||
}
|
||||
|
||||
// SSE path: scan from the end for a `data:` line containing a usage object.
|
||||
let text = match std::str::from_utf8(buf) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return ExtractedUsage::default(),
|
||||
};
|
||||
for line in text.lines().rev() {
|
||||
let trimmed = line.trim_start();
|
||||
let payload = match trimmed.strip_prefix("data:") {
|
||||
Some(p) => p.trim_start(),
|
||||
None => continue,
|
||||
};
|
||||
if payload == "[DONE]" || payload.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if !payload.contains("\"usage\"") {
|
||||
continue;
|
||||
}
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(payload) {
|
||||
let u = ExtractedUsage::from_json(&value);
|
||||
if !u.is_empty() {
|
||||
return u;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ExtractedUsage::default()
|
||||
}
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
pub trait StreamProcessor: Send + 'static {
|
||||
|
|
@ -60,6 +181,10 @@ pub struct ObservableStreamProcessor {
|
|||
start_time: Instant,
|
||||
time_to_first_token: Option<u128>,
|
||||
messages: Option<Vec<Message>>,
|
||||
/// Accumulated response bytes used only for best-effort usage extraction
|
||||
/// on `on_complete`. Capped at `USAGE_BUFFER_MAX`; excess chunks are dropped
|
||||
/// from the buffer (they still pass through to the client).
|
||||
response_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl ObservableStreamProcessor {
|
||||
|
|
@ -93,6 +218,7 @@ impl ObservableStreamProcessor {
|
|||
start_time,
|
||||
time_to_first_token: None,
|
||||
messages,
|
||||
response_buffer: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -101,6 +227,13 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
||||
self.total_bytes += chunk.len();
|
||||
self.chunk_count += 1;
|
||||
// Accumulate for best-effort usage extraction; drop further chunks once
|
||||
// the cap is reached so we don't retain huge response bodies in memory.
|
||||
if self.response_buffer.len() < USAGE_BUFFER_MAX {
|
||||
let remaining = USAGE_BUFFER_MAX - self.response_buffer.len();
|
||||
let take = chunk.len().min(remaining);
|
||||
self.response_buffer.extend_from_slice(&chunk[..take]);
|
||||
}
|
||||
Ok(Some(chunk))
|
||||
}
|
||||
|
||||
|
|
@ -124,6 +257,52 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
);
|
||||
}
|
||||
|
||||
// Record total duration on the span for the observability console.
|
||||
let duration_ms = self.start_time.elapsed().as_millis() as i64;
|
||||
{
|
||||
let span = tracing::Span::current();
|
||||
let otel_context = span.context();
|
||||
let otel_span = otel_context.span();
|
||||
otel_span.set_attribute(KeyValue::new(llm::DURATION_MS, duration_ms));
|
||||
otel_span.set_attribute(KeyValue::new(llm::RESPONSE_BYTES, self.total_bytes as i64));
|
||||
}
|
||||
|
||||
// Best-effort usage extraction + emission (works for both streaming
|
||||
// SSE and non-streaming JSON responses that include a `usage` object).
|
||||
let usage = extract_usage_from_bytes(&self.response_buffer);
|
||||
if !usage.is_empty() {
|
||||
let span = tracing::Span::current();
|
||||
let otel_context = span.context();
|
||||
let otel_span = otel_context.span();
|
||||
if let Some(v) = usage.prompt_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::PROMPT_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.completion_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::COMPLETION_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.total_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::TOTAL_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.cached_input_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::CACHED_INPUT_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.cache_creation_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::CACHE_CREATION_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.reasoning_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::REASONING_TOKENS, v));
|
||||
}
|
||||
// Override `llm.model` with the model the upstream actually ran
|
||||
// (e.g. `openai-gpt-5.4` resolved from `router:software-engineering`).
|
||||
// Cost lookup keys off the real model, not the alias.
|
||||
if let Some(resolved) = usage.resolved_model.clone() {
|
||||
otel_span.set_attribute(KeyValue::new(llm::MODEL_NAME, resolved));
|
||||
}
|
||||
}
|
||||
// Release the buffered bytes early; nothing downstream needs them.
|
||||
self.response_buffer.clear();
|
||||
self.response_buffer.shrink_to_fit();
|
||||
|
||||
// Analyze signals if messages are available and record as span attributes
|
||||
if let Some(ref messages) = self.messages {
|
||||
let analyzer: Box<dyn SignalAnalyzer> = Box::new(TextBasedSignalAnalyzer::new());
|
||||
|
|
@ -404,3 +583,55 @@ pub fn truncate_message(message: &str, max_length: usize) -> String {
|
|||
message.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod usage_extraction_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn non_streaming_openai_with_cached() {
|
||||
let body = br#"{"id":"x","model":"gpt-4o","choices":[],"usage":{"prompt_tokens":12,"completion_tokens":34,"total_tokens":46,"prompt_tokens_details":{"cached_tokens":5}}}"#;
|
||||
let u = extract_usage_from_bytes(body);
|
||||
assert_eq!(u.prompt_tokens, Some(12));
|
||||
assert_eq!(u.completion_tokens, Some(34));
|
||||
assert_eq!(u.total_tokens, Some(46));
|
||||
assert_eq!(u.cached_input_tokens, Some(5));
|
||||
assert_eq!(u.reasoning_tokens, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_streaming_anthropic_with_cache_creation() {
|
||||
let body = br#"{"id":"x","model":"claude","usage":{"input_tokens":100,"output_tokens":50,"cache_creation_input_tokens":20,"cache_read_input_tokens":30}}"#;
|
||||
let u = extract_usage_from_bytes(body);
|
||||
assert_eq!(u.prompt_tokens, Some(100));
|
||||
assert_eq!(u.completion_tokens, Some(50));
|
||||
assert_eq!(u.total_tokens, Some(150));
|
||||
assert_eq!(u.cached_input_tokens, Some(30));
|
||||
assert_eq!(u.cache_creation_tokens, Some(20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_openai_final_chunk_has_usage() {
|
||||
let sse = b"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}
|
||||
|
||||
data: {\"choices\":[{\"delta\":{}, \"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3,\"total_tokens\":10}}
|
||||
|
||||
data: [DONE]
|
||||
|
||||
";
|
||||
let u = extract_usage_from_bytes(sse);
|
||||
assert_eq!(u.prompt_tokens, Some(7));
|
||||
assert_eq!(u.completion_tokens, Some(3));
|
||||
assert_eq!(u.total_tokens, Some(10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_returns_default() {
|
||||
assert!(extract_usage_from_bytes(b"").is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_usage_in_body_returns_default() {
|
||||
assert!(extract_usage_from_bytes(br#"{"ok":true}"#).is_empty());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,6 +80,18 @@ pub mod llm {
|
|||
/// Total tokens used (prompt + completion)
|
||||
pub const TOTAL_TOKENS: &str = "llm.usage.total_tokens";
|
||||
|
||||
/// Tokens served from a prompt cache read
|
||||
/// (OpenAI `prompt_tokens_details.cached_tokens`, Anthropic `cache_read_input_tokens`,
|
||||
/// Google `cached_content_token_count`)
|
||||
pub const CACHED_INPUT_TOKENS: &str = "llm.usage.cached_input_tokens";
|
||||
|
||||
/// Tokens used to write a prompt cache entry (Anthropic `cache_creation_input_tokens`)
|
||||
pub const CACHE_CREATION_TOKENS: &str = "llm.usage.cache_creation_tokens";
|
||||
|
||||
/// Reasoning tokens for reasoning models
|
||||
/// (OpenAI `completion_tokens_details.reasoning_tokens`, Google `thoughts_token_count`)
|
||||
pub const REASONING_TOKENS: &str = "llm.usage.reasoning_tokens";
|
||||
|
||||
/// Temperature parameter used
|
||||
pub const TEMPERATURE: &str = "llm.temperature";
|
||||
|
||||
|
|
@ -119,6 +131,22 @@ pub mod routing {
|
|||
pub const SELECTION_REASON: &str = "routing.selection_reason";
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - Plano-specific
|
||||
// =============================================================================
|
||||
|
||||
/// Attributes specific to Plano (session affinity, routing decisions).
|
||||
pub mod plano {
|
||||
/// Session identifier propagated via the `x-model-affinity` header.
|
||||
/// Absent when the client did not send the header.
|
||||
pub const SESSION_ID: &str = "plano.session_id";
|
||||
|
||||
/// Matched route name from routing (e.g. "code", "summarization",
|
||||
/// "software-engineering"). Absent when the client routed directly
|
||||
/// to a concrete model.
|
||||
pub const ROUTE_NAME: &str = "plano.route.name";
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - Error Handling
|
||||
// =============================================================================
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ mod init;
|
|||
mod service_name_exporter;
|
||||
|
||||
pub use constants::{
|
||||
error, http, llm, operation_component, routing, signals, OperationNameBuilder,
|
||||
error, http, llm, operation_component, plano, routing, signals, OperationNameBuilder,
|
||||
};
|
||||
pub use custom_attributes::collect_custom_trace_attributes;
|
||||
pub use init::init_tracer;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue