mirror of
https://github.com/katanemo/plano.git
synced 2026-05-11 16:52:41 +02:00
Merge origin/main into musa/chatgpt-subscription
This commit is contained in:
commit
6f67048c04
118 changed files with 11627 additions and 2194 deletions
1757
crates/Cargo.lock
generated
1757
crates/Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -26,6 +26,8 @@ opentelemetry-stdout = "0.31"
|
|||
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
|
||||
pretty_assertions = "1.4.1"
|
||||
rand = "0.9.2"
|
||||
lru = "0.12"
|
||||
redis = { version = "0.27", features = ["tokio-comp"] }
|
||||
reqwest = { version = "0.12.15", features = ["stream"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ use common::configuration::{Agent, FilterPipeline, Listener, ModelAlias, SpanAtt
|
|||
use common::llm_providers::LlmProviders;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::router::llm::RouterService;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::state::StateStorage;
|
||||
|
||||
|
|
@ -14,7 +13,6 @@ use crate::state::StateStorage;
|
|||
/// Instead of cloning 8+ individual `Arc`s per connection, a single
|
||||
/// `Arc<AppState>` is cloned once and passed to the request handler.
|
||||
pub struct AppState {
|
||||
pub router_service: Arc<RouterService>,
|
||||
pub orchestrator_service: Arc<OrchestratorService>,
|
||||
pub model_aliases: Option<HashMap<String, ModelAlias>>,
|
||||
pub llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
|
|
|
|||
|
|
@ -177,6 +177,7 @@ mod tests {
|
|||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
crate::router::orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ mod tests {
|
|||
"http://localhost:8080".to_string(),
|
||||
"test-model".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
crate::router::orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
))
|
||||
}
|
||||
|
||||
|
|
@ -147,8 +148,8 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_error_handling_flow() {
|
||||
let router_service = create_test_orchestrator_service();
|
||||
let agent_selector = AgentSelector::new(router_service);
|
||||
let orchestrator_service = create_test_orchestrator_service();
|
||||
let agent_selector = AgentSelector::new(orchestrator_service);
|
||||
|
||||
// Test listener not found
|
||||
let result = agent_selector.find_listener(Some("nonexistent"), &[]);
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ pub(crate) mod model_selection;
|
|||
|
||||
use crate::app_state::AppState;
|
||||
use crate::handlers::agents::pipeline::PipelineProcessor;
|
||||
use crate::handlers::extract_or_generate_traceparent;
|
||||
use crate::handlers::extract_request_id;
|
||||
use crate::handlers::full;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
|
|
@ -34,7 +33,8 @@ use crate::streaming::{
|
|||
ObservableStreamProcessor, StreamProcessor,
|
||||
};
|
||||
use crate::tracing::{
|
||||
collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name,
|
||||
collect_custom_trace_attributes, llm as tracing_llm, operation_component,
|
||||
plano as tracing_plano, set_service_name,
|
||||
};
|
||||
use model_selection::router_chat_get_upstream_model;
|
||||
|
||||
|
|
@ -92,22 +92,47 @@ async fn llm_chat_inner(
|
|||
}
|
||||
});
|
||||
|
||||
let traceparent = extract_or_generate_traceparent(&request_headers);
|
||||
|
||||
// Session pinning: extract session ID and check cache before routing
|
||||
let session_id: Option<String> = request_headers
|
||||
.get(MODEL_AFFINITY_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let pinned_model: Option<String> = if let Some(ref sid) = session_id {
|
||||
let tenant_id: Option<String> = state
|
||||
.orchestrator_service
|
||||
.tenant_header()
|
||||
.and_then(|hdr| request_headers.get(hdr))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let cached_route = if let Some(ref sid) = session_id {
|
||||
state
|
||||
.router_service
|
||||
.get_cached_route(sid)
|
||||
.orchestrator_service
|
||||
.get_cached_route(sid, tenant_id.as_deref())
|
||||
.await
|
||||
.map(|c| c.model_name)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (pinned_model, pinned_route_name): (Option<String>, Option<String>) = match cached_route {
|
||||
Some(c) => (Some(c.model_name), c.route_name),
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
// Record session id on the LLM span for the observability console.
|
||||
if let Some(ref sid) = session_id {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_plano::SESSION_ID,
|
||||
sid.clone(),
|
||||
));
|
||||
});
|
||||
}
|
||||
if let Some(ref route_name) = pinned_route_name {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_plano::ROUTE_NAME,
|
||||
route_name.clone(),
|
||||
));
|
||||
});
|
||||
}
|
||||
|
||||
let full_qualified_llm_provider_url = format!("{}{}", state.llm_provider_url, request_path);
|
||||
|
||||
|
|
@ -289,9 +314,8 @@ async fn llm_chat_inner(
|
|||
let routing_result = match async {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
router_chat_get_upstream_model(
|
||||
Arc::clone(&state.router_service),
|
||||
Arc::clone(&state.orchestrator_service),
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_preferences,
|
||||
|
|
@ -317,11 +341,22 @@ async fn llm_chat_inner(
|
|||
alias_resolved_model.clone()
|
||||
};
|
||||
|
||||
// Cache the routing decision so subsequent requests with the same session ID are pinned
|
||||
// Record route name on the LLM span (only when the orchestrator produced one).
|
||||
if let Some(ref rn) = route_name {
|
||||
if !rn.is_empty() && rn != "none" {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
tracing_plano::ROUTE_NAME,
|
||||
rn.clone(),
|
||||
));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref sid) = session_id {
|
||||
state
|
||||
.router_service
|
||||
.cache_route(sid.clone(), model.clone(), route_name)
|
||||
.orchestrator_service
|
||||
.cache_route(sid.clone(), tenant_id.as_deref(), model.clone(), route_name)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -678,6 +713,36 @@ async fn send_upstream(
|
|||
// Propagate upstream headers and status
|
||||
let response_headers = llm_response.headers().clone();
|
||||
let upstream_status = llm_response.status();
|
||||
|
||||
// Upstream routers (e.g. DigitalOcean Gradient) may return an
|
||||
// `x-model-router-selected-route` header indicating which task-level
|
||||
// route the request was classified into (e.g. "Code Generation"). Surface
|
||||
// it as `plano.route.name` so the obs console's Route hit % panel can
|
||||
// show the breakdown even when Plano's own orchestrator wasn't in the
|
||||
// routing path. Any value from Plano's orchestrator already set earlier
|
||||
// takes precedence — this only fires when the span doesn't already have
|
||||
// a route name.
|
||||
if let Some(upstream_route) = response_headers
|
||||
.get("x-model-router-selected-route")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
if !upstream_route.is_empty() {
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
crate::tracing::plano::ROUTE_NAME,
|
||||
upstream_route.to_string(),
|
||||
));
|
||||
});
|
||||
}
|
||||
}
|
||||
// Record the upstream HTTP status on the span for the obs console.
|
||||
get_active_span(|span| {
|
||||
span.set_attribute(opentelemetry::KeyValue::new(
|
||||
crate::tracing::http::STATUS_CODE,
|
||||
upstream_status.as_u16() as i64,
|
||||
));
|
||||
});
|
||||
|
||||
let mut response = Response::builder().status(upstream_status);
|
||||
if let Some(headers) = response.headers_mut() {
|
||||
for (name, value) in response_headers.iter() {
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ use hyper::StatusCode;
|
|||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm::RouterService;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::streaming::truncate_message;
|
||||
use crate::tracing::routing;
|
||||
|
||||
|
|
@ -37,9 +37,8 @@ impl RoutingError {
|
|||
/// * `Ok(RoutingResult)` - Contains the selected model name and span ID
|
||||
/// * `Err(RoutingError)` - Contains error details and optional span ID
|
||||
pub async fn router_chat_get_upstream_model(
|
||||
router_service: Arc<RouterService>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
client_request: ProviderRequestType,
|
||||
traceparent: &str,
|
||||
request_path: &str,
|
||||
request_id: &str,
|
||||
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
|
|
@ -99,11 +98,9 @@ pub async fn router_chat_get_upstream_model(
|
|||
// Capture start time for routing span
|
||||
let routing_start_time = std::time::Instant::now();
|
||||
|
||||
// Attempt to determine route using the router service
|
||||
let routing_result = router_service
|
||||
let routing_result = orchestrator_service
|
||||
.determine_route(
|
||||
&chat_request.messages,
|
||||
traceparent,
|
||||
inline_routing_preferences,
|
||||
request_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ use tracing::{debug, info, info_span, warn, Instrument};
|
|||
|
||||
use super::extract_or_generate_traceparent;
|
||||
use crate::handlers::llm::model_selection::router_chat_get_upstream_model;
|
||||
use crate::router::llm::RouterService;
|
||||
use crate::router::orchestrator::OrchestratorService;
|
||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||
|
||||
/// Extracts `routing_preferences` from a JSON body, returning the cleaned body bytes
|
||||
|
|
@ -60,7 +60,7 @@ struct RoutingDecisionResponse {
|
|||
|
||||
pub async fn routing_decision(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
request_path: String,
|
||||
span_attributes: &Option<SpanAttributes>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
|
@ -76,6 +76,12 @@ pub async fn routing_decision(
|
|||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let tenant_id: Option<String> = orchestrator_service
|
||||
.tenant_header()
|
||||
.and_then(|hdr| request_headers.get(hdr))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref());
|
||||
|
||||
let request_span = info_span!(
|
||||
|
|
@ -88,25 +94,28 @@ pub async fn routing_decision(
|
|||
|
||||
routing_decision_inner(
|
||||
request,
|
||||
router_service,
|
||||
orchestrator_service,
|
||||
request_id,
|
||||
request_path,
|
||||
request_headers,
|
||||
custom_attrs,
|
||||
session_id,
|
||||
tenant_id,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn routing_decision_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
orchestrator_service: Arc<OrchestratorService>,
|
||||
request_id: String,
|
||||
request_path: String,
|
||||
request_headers: hyper::HeaderMap,
|
||||
custom_attrs: std::collections::HashMap<String, String>,
|
||||
session_id: Option<String>,
|
||||
tenant_id: Option<String>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
opentelemetry::trace::get_active_span(|span| {
|
||||
|
|
@ -124,9 +133,11 @@ async fn routing_decision_inner(
|
|||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
// Session pinning: check cache before doing any routing work
|
||||
if let Some(ref sid) = session_id {
|
||||
if let Some(cached) = router_service.get_cached_route(sid).await {
|
||||
if let Some(cached) = orchestrator_service
|
||||
.get_cached_route(sid, tenant_id.as_deref())
|
||||
.await
|
||||
{
|
||||
info!(
|
||||
session_id = %sid,
|
||||
model = %cached.model_name,
|
||||
|
|
@ -190,9 +201,8 @@ async fn routing_decision_inner(
|
|||
};
|
||||
|
||||
let routing_result = router_chat_get_upstream_model(
|
||||
Arc::clone(&router_service),
|
||||
Arc::clone(&orchestrator_service),
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_preferences,
|
||||
|
|
@ -201,11 +211,11 @@ async fn routing_decision_inner(
|
|||
|
||||
match routing_result {
|
||||
Ok(result) => {
|
||||
// Cache the result if session_id is present
|
||||
if let Some(ref sid) = session_id {
|
||||
router_service
|
||||
orchestrator_service
|
||||
.cache_route(
|
||||
sid.clone(),
|
||||
tenant_id.as_deref(),
|
||||
result.model_name.clone(),
|
||||
result.route_name.clone(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
pub mod app_state;
|
||||
pub mod handlers;
|
||||
pub mod router;
|
||||
pub mod session_cache;
|
||||
pub mod signals;
|
||||
pub mod state;
|
||||
pub mod streaming;
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
|||
use brightstaff::handlers::llm::llm_chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::handlers::routing_service::routing_decision;
|
||||
use brightstaff::router::llm::RouterService;
|
||||
use brightstaff::router::model_metrics::ModelMetricsService;
|
||||
use brightstaff::router::orchestrator::OrchestratorService;
|
||||
use brightstaff::session_cache::init_session_cache;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::StateStorage;
|
||||
|
|
@ -36,8 +36,6 @@ use tokio::sync::RwLock;
|
|||
use tracing::{debug, info, warn};
|
||||
|
||||
const BIND_ADDRESS: &str = "0.0.0.0:9091";
|
||||
const DEFAULT_ROUTING_LLM_PROVIDER: &str = "arch-router";
|
||||
const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
|
||||
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
|
||||
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
|
||||
|
||||
|
|
@ -160,22 +158,8 @@ async fn init_app_state(
|
|||
|
||||
let overrides = config.overrides.clone().unwrap_or_default();
|
||||
|
||||
let routing_model_name: String = overrides
|
||||
.llm_routing_model
|
||||
.as_deref()
|
||||
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
||||
.unwrap_or(DEFAULT_ROUTING_MODEL_NAME)
|
||||
.to_string();
|
||||
|
||||
let routing_llm_provider = config
|
||||
.model_providers
|
||||
.iter()
|
||||
.find(|p| p.model.as_deref() == Some(routing_model_name.as_str()))
|
||||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||
|
||||
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
|
||||
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
|
||||
let session_cache = init_session_cache(config).await?;
|
||||
|
||||
// Validate that top-level routing_preferences requires v0.4.0+.
|
||||
let config_version = parse_semver(&config.version);
|
||||
|
|
@ -297,31 +281,17 @@ async fn init_app_state(
|
|||
}
|
||||
}
|
||||
|
||||
let router_service = Arc::new(RouterService::new(
|
||||
config.routing_preferences.clone(),
|
||||
metrics_service,
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
session_ttl_seconds,
|
||||
session_max_entries,
|
||||
));
|
||||
|
||||
// Spawn background task to clean up expired session cache entries every 5 minutes
|
||||
{
|
||||
let router_service = Arc::clone(&router_service);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
router_service.cleanup_expired_sessions().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
let session_tenant_header = config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.session_cache.as_ref())
|
||||
.and_then(|c| c.tenant_header.clone());
|
||||
|
||||
// Resolve model name: prefer llm_routing_model override, then agent_orchestration_model, then default.
|
||||
let orchestrator_model_name: String = overrides
|
||||
.agent_orchestration_model
|
||||
.llm_routing_model
|
||||
.as_deref()
|
||||
.or(overrides.agent_orchestration_model.as_deref())
|
||||
.map(|m| m.split_once('/').map(|(_, id)| id).unwrap_or(m))
|
||||
.unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME)
|
||||
.to_string();
|
||||
|
|
@ -333,10 +303,20 @@ async fn init_app_state(
|
|||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string());
|
||||
|
||||
let orchestrator_service = Arc::new(OrchestratorService::new(
|
||||
let orchestrator_max_tokens = overrides
|
||||
.orchestrator_model_context_length
|
||||
.unwrap_or(brightstaff::router::orchestrator_model_v1::MAX_TOKEN_LEN);
|
||||
|
||||
let orchestrator_service = Arc::new(OrchestratorService::with_routing(
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
orchestrator_model_name,
|
||||
orchestrator_llm_provider,
|
||||
config.routing_preferences.clone(),
|
||||
metrics_service,
|
||||
session_ttl_seconds,
|
||||
session_cache,
|
||||
session_tenant_header,
|
||||
orchestrator_max_tokens,
|
||||
));
|
||||
|
||||
let state_storage = init_state_storage(config).await?;
|
||||
|
|
@ -347,7 +327,6 @@ async fn init_app_state(
|
|||
.and_then(|tracing| tracing.span_attributes.clone());
|
||||
|
||||
Ok(AppState {
|
||||
router_service,
|
||||
orchestrator_service,
|
||||
model_aliases: config.model_aliases.clone(),
|
||||
llm_providers: Arc::new(RwLock::new(llm_providers)),
|
||||
|
|
@ -434,7 +413,7 @@ async fn route(
|
|||
) {
|
||||
return routing_decision(
|
||||
req,
|
||||
Arc::clone(&state.router_service),
|
||||
Arc::clone(&state.orchestrator_service),
|
||||
stripped,
|
||||
&state.span_attributes,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
use hermesllm::apis::openai::ChatCompletionsResponse;
|
||||
use hyper::header;
|
||||
use serde::Deserialize;
|
||||
use thiserror::Error;
|
||||
use tracing::warn;
|
||||
|
||||
/// Max bytes of raw upstream body we include in a log message or error text
|
||||
/// when the body is not a recognizable error envelope. Keeps logs from being
|
||||
/// flooded by huge HTML error pages.
|
||||
const RAW_BODY_LOG_LIMIT: usize = 512;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum HttpError {
|
||||
#[error("Failed to send request: {0}")]
|
||||
|
|
@ -10,13 +16,64 @@ pub enum HttpError {
|
|||
|
||||
#[error("Failed to parse JSON response: {0}")]
|
||||
Json(serde_json::Error, String),
|
||||
|
||||
#[error("Upstream returned {status}: {message}")]
|
||||
Upstream { status: u16, message: String },
|
||||
}
|
||||
|
||||
/// Shape of an OpenAI-style error response body, e.g.
|
||||
/// `{"error": {"message": "...", "type": "...", "param": "...", "code": ...}}`.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpstreamErrorEnvelope {
|
||||
error: UpstreamErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UpstreamErrorBody {
|
||||
message: String,
|
||||
#[serde(default, rename = "type")]
|
||||
err_type: Option<String>,
|
||||
#[serde(default)]
|
||||
param: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract a human-readable error message from an upstream response body.
|
||||
/// Tries to parse an OpenAI-style `{"error": {"message": ...}}` envelope; if
|
||||
/// that fails, falls back to the first `RAW_BODY_LOG_LIMIT` bytes of the raw
|
||||
/// body (UTF-8 safe).
|
||||
fn extract_upstream_error_message(body: &str) -> String {
|
||||
if let Ok(env) = serde_json::from_str::<UpstreamErrorEnvelope>(body) {
|
||||
let mut msg = env.error.message;
|
||||
if let Some(param) = env.error.param {
|
||||
msg.push_str(&format!(" (param={param})"));
|
||||
}
|
||||
if let Some(err_type) = env.error.err_type {
|
||||
msg.push_str(&format!(" [type={err_type}]"));
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
truncate_for_log(body).to_string()
|
||||
}
|
||||
|
||||
fn truncate_for_log(s: &str) -> &str {
|
||||
if s.len() <= RAW_BODY_LOG_LIMIT {
|
||||
return s;
|
||||
}
|
||||
let mut end = RAW_BODY_LOG_LIMIT;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
&s[..end]
|
||||
}
|
||||
|
||||
/// Sends a POST request to the given URL and extracts the text content
|
||||
/// from the first choice of the `ChatCompletionsResponse`.
|
||||
///
|
||||
/// Returns `Some((content, elapsed))` on success, or `None` if the response
|
||||
/// had no choices or the first choice had no content.
|
||||
/// Returns `Some((content, elapsed))` on success, `None` if the response
|
||||
/// had no choices or the first choice had no content. Returns
|
||||
/// `HttpError::Upstream` for any non-2xx status, carrying a message
|
||||
/// extracted from the OpenAI-style error envelope (or a truncated raw body
|
||||
/// if the body is not in that shape).
|
||||
pub async fn post_and_extract_content(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
|
|
@ -26,17 +83,36 @@ pub async fn post_and_extract_content(
|
|||
let start_time = std::time::Instant::now();
|
||||
|
||||
let res = client.post(url).headers(headers).body(body).send().await?;
|
||||
let status = res.status();
|
||||
|
||||
let body = res.text().await?;
|
||||
let elapsed = start_time.elapsed();
|
||||
|
||||
if !status.is_success() {
|
||||
let message = extract_upstream_error_message(&body);
|
||||
warn!(
|
||||
status = status.as_u16(),
|
||||
message = %message,
|
||||
body_size = body.len(),
|
||||
"upstream returned error response"
|
||||
);
|
||||
return Err(HttpError::Upstream {
|
||||
status: status.as_u16(),
|
||||
message,
|
||||
});
|
||||
}
|
||||
|
||||
let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| {
|
||||
warn!(error = %err, body = %body, "failed to parse json response");
|
||||
warn!(
|
||||
error = %err,
|
||||
body = %truncate_for_log(&body),
|
||||
"failed to parse json response",
|
||||
);
|
||||
HttpError::Json(err, format!("Failed to parse JSON: {}", body))
|
||||
})?;
|
||||
|
||||
if response.choices.is_empty() {
|
||||
warn!(body = %body, "no choices in response");
|
||||
warn!(body = %truncate_for_log(&body), "no choices in response");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
|
@ -46,3 +122,52 @@ pub async fn post_and_extract_content(
|
|||
.as_ref()
|
||||
.map(|c| (c.clone(), elapsed)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn extracts_message_from_openai_style_error_envelope() {
|
||||
let body = r#"{"error":{"code":400,"message":"This model's maximum context length is 32768 tokens. However, you requested 0 output tokens and your prompt contains at least 32769 input tokens, for a total of at least 32769 tokens.","param":"input_tokens","type":"BadRequestError"}}"#;
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert!(
|
||||
msg.starts_with("This model's maximum context length is 32768 tokens."),
|
||||
"unexpected message: {msg}"
|
||||
);
|
||||
assert!(msg.contains("(param=input_tokens)"));
|
||||
assert!(msg.contains("[type=BadRequestError]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_message_without_optional_fields() {
|
||||
let body = r#"{"error":{"message":"something broke"}}"#;
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert_eq!(msg, "something broke");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn falls_back_to_raw_body_when_not_error_envelope() {
|
||||
let body = "<html><body>502 Bad Gateway</body></html>";
|
||||
let msg = extract_upstream_error_message(body);
|
||||
assert_eq!(msg, body);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_non_envelope_bodies_in_logs() {
|
||||
let body = "x".repeat(RAW_BODY_LOG_LIMIT * 3);
|
||||
let msg = extract_upstream_error_message(&body);
|
||||
assert_eq!(msg.len(), RAW_BODY_LOG_LIMIT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncate_for_log_respects_utf8_boundaries() {
|
||||
// 2-byte characters; picking a length that would split mid-char.
|
||||
let body = "é".repeat(RAW_BODY_LOG_LIMIT);
|
||||
let out = truncate_for_log(&body);
|
||||
// Should be a valid &str (implicit — would panic if we returned
|
||||
// a non-boundary slice) and at most RAW_BODY_LOG_LIMIT bytes.
|
||||
assert!(out.len() <= RAW_BODY_LOG_LIMIT);
|
||||
assert!(out.chars().all(|c| c == 'é'));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,380 +0,0 @@
|
|||
use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant};
|
||||
|
||||
use common::{
|
||||
configuration::TopLevelRoutingPreference,
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER},
|
||||
};
|
||||
|
||||
use super::router_model::{ModelUsagePreference, RoutingPreference};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::http::{self, post_and_extract_content};
|
||||
use super::model_metrics::ModelMetricsService;
|
||||
use super::router_model::RouterModel;
|
||||
|
||||
use crate::router::router_model_v1;
|
||||
|
||||
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
|
||||
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedRoute {
|
||||
pub model_name: String,
|
||||
pub route_name: Option<String>,
|
||||
pub cached_at: Instant,
|
||||
}
|
||||
|
||||
pub struct RouterService {
|
||||
router_url: String,
|
||||
client: reqwest::Client,
|
||||
router_model: Arc<dyn RouterModel>,
|
||||
routing_provider_name: String,
|
||||
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
session_cache: RwLock<HashMap<String, CachedRoute>>,
|
||||
session_ttl: Duration,
|
||||
session_max_entries: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RoutingError {
|
||||
#[error(transparent)]
|
||||
Http(#[from] http::HttpError),
|
||||
|
||||
#[error("Router model error: {0}")]
|
||||
RouterModelError(#[from] super::router_model::RoutingModelError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingError>;
|
||||
|
||||
impl RouterService {
|
||||
pub fn new(
|
||||
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
router_url: String,
|
||||
routing_model_name: String,
|
||||
routing_provider_name: String,
|
||||
session_ttl_seconds: Option<u64>,
|
||||
session_max_entries: Option<usize>,
|
||||
) -> Self {
|
||||
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
|
||||
.map_or_else(HashMap::new, |prefs| {
|
||||
prefs.into_iter().map(|p| (p.name.clone(), p)).collect()
|
||||
});
|
||||
|
||||
// Build sentinel routes for RouterModelV1: route_name → first model.
|
||||
// RouterModelV1 uses this to build its prompt; RouterService overrides
|
||||
// the model selection via rank_models() after the route is determined.
|
||||
let sentinel_routes: HashMap<String, Vec<RoutingPreference>> = top_level_preferences
|
||||
.iter()
|
||||
.filter_map(|(name, pref)| {
|
||||
pref.models.first().map(|first_model| {
|
||||
(
|
||||
first_model.clone(),
|
||||
vec![RoutingPreference {
|
||||
name: name.clone(),
|
||||
description: pref.description.clone(),
|
||||
}],
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
|
||||
sentinel_routes,
|
||||
routing_model_name,
|
||||
router_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
let session_ttl =
|
||||
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
|
||||
let session_max_entries = session_max_entries
|
||||
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
|
||||
.min(MAX_SESSION_MAX_ENTRIES);
|
||||
|
||||
RouterService {
|
||||
router_url,
|
||||
client: reqwest::Client::new(),
|
||||
router_model,
|
||||
routing_provider_name,
|
||||
top_level_preferences,
|
||||
metrics_service,
|
||||
session_cache: RwLock::new(HashMap::new()),
|
||||
session_ttl,
|
||||
session_max_entries,
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up a cached routing decision by session ID.
|
||||
/// Returns None if not found or expired.
|
||||
pub async fn get_cached_route(&self, session_id: &str) -> Option<CachedRoute> {
|
||||
let cache = self.session_cache.read().await;
|
||||
if let Some(entry) = cache.get(session_id) {
|
||||
if entry.cached_at.elapsed() < self.session_ttl {
|
||||
return Some(entry.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Store a routing decision in the session cache.
|
||||
/// If at max capacity, evicts the oldest entry.
|
||||
pub async fn cache_route(
|
||||
&self,
|
||||
session_id: String,
|
||||
model_name: String,
|
||||
route_name: Option<String>,
|
||||
) {
|
||||
let mut cache = self.session_cache.write().await;
|
||||
if cache.len() >= self.session_max_entries && !cache.contains_key(&session_id) {
|
||||
if let Some(oldest_key) = cache
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.cached_at)
|
||||
.map(|(k, _)| k.clone())
|
||||
{
|
||||
cache.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
cache.insert(
|
||||
session_id,
|
||||
CachedRoute {
|
||||
model_name,
|
||||
route_name,
|
||||
cached_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Remove all expired entries from the session cache.
|
||||
pub async fn cleanup_expired_sessions(&self) {
|
||||
let mut cache = self.session_cache.write().await;
|
||||
let before = cache.len();
|
||||
cache.retain(|_, entry| entry.cached_at.elapsed() < self.session_ttl);
|
||||
let removed = before - cache.len();
|
||||
if removed > 0 {
|
||||
info!(
|
||||
removed = removed,
|
||||
remaining = cache.len(),
|
||||
"cleaned up expired session cache entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn determine_route(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
traceparent: &str,
|
||||
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
request_id: &str,
|
||||
) -> Result<Option<(String, Vec<String>)>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Build inline top-level map from request if present (inline overrides config).
|
||||
let inline_top_map: Option<HashMap<String, TopLevelRoutingPreference>> =
|
||||
inline_routing_preferences
|
||||
.map(|prefs| prefs.into_iter().map(|p| (p.name.clone(), p)).collect());
|
||||
|
||||
// No routing defined — skip the router call entirely.
|
||||
if inline_top_map.is_none() && self.top_level_preferences.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// For inline overrides, build synthetic ModelUsagePreference list so RouterModelV1
|
||||
// generates the correct prompt (route name + description pairs).
|
||||
// For config-level prefs the sentinel routes are already baked into RouterModelV1.
|
||||
let effective_usage_preferences: Option<Vec<ModelUsagePreference>> =
|
||||
inline_top_map.as_ref().map(|inline_map| {
|
||||
inline_map
|
||||
.values()
|
||||
.map(|p| ModelUsagePreference {
|
||||
model: p.models.first().cloned().unwrap_or_default(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: p.name.clone(),
|
||||
description: p.description.clone(),
|
||||
}],
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let router_request = self
|
||||
.router_model
|
||||
.generate_request(messages, &effective_usage_preferences);
|
||||
|
||||
debug!(
|
||||
model = %self.router_model.get_model_name(),
|
||||
endpoint = %self.router_url,
|
||||
"sending request to arch-router"
|
||||
);
|
||||
|
||||
let body = serde_json::to_string(&router_request)
|
||||
.map_err(super::router_model::RoutingModelError::from)?;
|
||||
debug!(body = %body, "arch router request");
|
||||
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
header::CONTENT_TYPE,
|
||||
header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
if let Ok(val) = header::HeaderValue::from_str(&self.routing_provider_name) {
|
||||
headers.insert(
|
||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||
val,
|
||||
);
|
||||
}
|
||||
if let Ok(val) = header::HeaderValue::from_str(traceparent) {
|
||||
headers.insert(header::HeaderName::from_static(TRACE_PARENT_HEADER), val);
|
||||
}
|
||||
if let Ok(val) = header::HeaderValue::from_str(request_id) {
|
||||
headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
|
||||
}
|
||||
headers.insert(
|
||||
header::HeaderName::from_static("model"),
|
||||
header::HeaderValue::from_static("arch-router"),
|
||||
);
|
||||
|
||||
let Some((content, elapsed)) =
|
||||
post_and_extract_content(&self.client, &self.router_url, headers, body).await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Parse the route name from the router response.
|
||||
let parsed = self
|
||||
.router_model
|
||||
.parse_response(&content, &effective_usage_preferences)?;
|
||||
|
||||
let result = if let Some((route_name, _sentinel)) = parsed {
|
||||
let top_pref = inline_top_map
|
||||
.as_ref()
|
||||
.and_then(|m| m.get(&route_name))
|
||||
.or_else(|| self.top_level_preferences.get(&route_name));
|
||||
|
||||
if let Some(pref) = top_pref {
|
||||
let ranked = match &self.metrics_service {
|
||||
Some(svc) => svc.rank_models(&pref.models, &pref.selection_policy).await,
|
||||
None => pref.models.clone(),
|
||||
};
|
||||
Some((route_name, ranked))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
info!(
|
||||
content = %content.replace("\n", "\\n"),
|
||||
selected_model = ?result,
|
||||
response_time_ms = elapsed.as_millis(),
|
||||
"arch-router determined route"
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_router_service(ttl_seconds: u64, max_entries: usize) -> RouterService {
|
||||
RouterService::new(
|
||||
None,
|
||||
None,
|
||||
"http://localhost:12001/v1/chat/completions".to_string(),
|
||||
"Arch-Router".to_string(),
|
||||
"arch-router".to_string(),
|
||||
Some(ttl_seconds),
|
||||
Some(max_entries),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_miss_returns_none() {
|
||||
let svc = make_router_service(600, 100);
|
||||
assert!(svc.get_cached_route("unknown-session").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_hit_returns_cached_route() {
|
||||
let svc = make_router_service(600, 100);
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
"gpt-4o".to_string(),
|
||||
Some("code".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let cached = svc.get_cached_route("s1").await.unwrap();
|
||||
assert_eq!(cached.model_name, "gpt-4o");
|
||||
assert_eq!(cached.route_name, Some("code".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_expired_entry_returns_none() {
|
||||
let svc = make_router_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
assert!(svc.get_cached_route("s1").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cleanup_removes_expired() {
|
||||
let svc = make_router_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), "claude".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cleanup_expired_sessions().await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert!(cache.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_evicts_oldest_when_full() {
|
||||
let svc = make_router_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), "model-a".to_string(), None)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
svc.cache_route("s2".to_string(), "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route("s3".to_string(), "model-c".to_string(), None)
|
||||
.await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert_eq!(cache.len(), 2);
|
||||
assert!(!cache.contains_key("s1"));
|
||||
assert!(cache.contains_key("s2"));
|
||||
assert!(cache.contains_key("s3"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_update_existing_session_does_not_evict() {
|
||||
let svc = make_router_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), "model-a".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
"model-a-updated".to_string(),
|
||||
Some("route".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert_eq!(cache.len(), 2);
|
||||
assert_eq!(cache.get("s1").unwrap().model_name, "model-a-updated");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,8 +1,5 @@
|
|||
pub(crate) mod http;
|
||||
pub mod llm;
|
||||
pub mod model_metrics;
|
||||
pub mod orchestrator;
|
||||
pub mod orchestrator_model;
|
||||
pub mod orchestrator_model_v1;
|
||||
pub mod router_model;
|
||||
pub mod router_model_v1;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use common::{
|
||||
configuration::{AgentUsagePreference, OrchestrationPreference},
|
||||
configuration::{AgentUsagePreference, OrchestrationPreference, TopLevelRoutingPreference},
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER},
|
||||
};
|
||||
use hermesllm::apis::openai::Message;
|
||||
|
|
@ -12,15 +12,26 @@ use thiserror::Error;
|
|||
use tracing::{debug, info};
|
||||
|
||||
use super::http::{self, post_and_extract_content};
|
||||
use super::model_metrics::ModelMetricsService;
|
||||
use super::orchestrator_model::OrchestratorModel;
|
||||
|
||||
use crate::router::orchestrator_model_v1;
|
||||
use crate::session_cache::SessionCache;
|
||||
|
||||
pub use crate::session_cache::CachedRoute;
|
||||
|
||||
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
|
||||
|
||||
pub struct OrchestratorService {
|
||||
orchestrator_url: String,
|
||||
client: reqwest::Client,
|
||||
orchestrator_model: Arc<dyn OrchestratorModel>,
|
||||
orchestrator_provider_name: String,
|
||||
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
session_cache: Option<Arc<dyn SessionCache>>,
|
||||
session_ttl: Duration,
|
||||
tenant_header: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
@ -39,13 +50,12 @@ impl OrchestratorService {
|
|||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
orchestrator_provider_name: String,
|
||||
max_token_length: usize,
|
||||
) -> Self {
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
|
||||
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model_name.clone(),
|
||||
orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
HashMap::new(),
|
||||
orchestration_model_name,
|
||||
max_token_length,
|
||||
));
|
||||
|
||||
OrchestratorService {
|
||||
|
|
@ -53,9 +63,182 @@ impl OrchestratorService {
|
|||
client: reqwest::Client::new(),
|
||||
orchestrator_model,
|
||||
orchestrator_provider_name,
|
||||
top_level_preferences: HashMap::new(),
|
||||
metrics_service: None,
|
||||
session_cache: None,
|
||||
session_ttl: Duration::from_secs(DEFAULT_SESSION_TTL_SECONDS),
|
||||
tenant_header: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn with_routing(
|
||||
orchestrator_url: String,
|
||||
orchestration_model_name: String,
|
||||
orchestrator_provider_name: String,
|
||||
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
session_ttl_seconds: Option<u64>,
|
||||
session_cache: Arc<dyn SessionCache>,
|
||||
tenant_header: Option<String>,
|
||||
max_token_length: usize,
|
||||
) -> Self {
|
||||
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
|
||||
.map_or_else(HashMap::new, |prefs| {
|
||||
prefs.into_iter().map(|p| (p.name.clone(), p)).collect()
|
||||
});
|
||||
|
||||
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
|
||||
HashMap::new(),
|
||||
orchestration_model_name,
|
||||
max_token_length,
|
||||
));
|
||||
|
||||
let session_ttl =
|
||||
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
|
||||
|
||||
OrchestratorService {
|
||||
orchestrator_url,
|
||||
client: reqwest::Client::new(),
|
||||
orchestrator_model,
|
||||
orchestrator_provider_name,
|
||||
top_level_preferences,
|
||||
metrics_service,
|
||||
session_cache: Some(session_cache),
|
||||
session_ttl,
|
||||
tenant_header,
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Session cache methods ----
|
||||
|
||||
#[must_use]
|
||||
pub fn tenant_header(&self) -> Option<&str> {
|
||||
self.tenant_header.as_deref()
|
||||
}
|
||||
|
||||
fn session_key<'a>(tenant_id: Option<&str>, session_id: &'a str) -> Cow<'a, str> {
|
||||
match tenant_id {
|
||||
Some(t) => Cow::Owned(format!("{t}:{session_id}")),
|
||||
None => Cow::Borrowed(session_id),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_cached_route(
|
||||
&self,
|
||||
session_id: &str,
|
||||
tenant_id: Option<&str>,
|
||||
) -> Option<CachedRoute> {
|
||||
let cache = self.session_cache.as_ref()?;
|
||||
cache.get(&Self::session_key(tenant_id, session_id)).await
|
||||
}
|
||||
|
||||
pub async fn cache_route(
|
||||
&self,
|
||||
session_id: String,
|
||||
tenant_id: Option<&str>,
|
||||
model_name: String,
|
||||
route_name: Option<String>,
|
||||
) {
|
||||
if let Some(ref cache) = self.session_cache {
|
||||
cache
|
||||
.put(
|
||||
&Self::session_key(tenant_id, &session_id),
|
||||
CachedRoute {
|
||||
model_name,
|
||||
route_name,
|
||||
},
|
||||
self.session_ttl,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
// ---- LLM routing ----
|
||||
|
||||
pub async fn determine_route(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
request_id: &str,
|
||||
) -> Result<Option<(String, Vec<String>)>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let inline_top_map: Option<HashMap<String, TopLevelRoutingPreference>> =
|
||||
inline_routing_preferences
|
||||
.map(|prefs| prefs.into_iter().map(|p| (p.name.clone(), p)).collect());
|
||||
|
||||
if inline_top_map.is_none() && self.top_level_preferences.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let effective_source = inline_top_map
|
||||
.as_ref()
|
||||
.unwrap_or(&self.top_level_preferences);
|
||||
|
||||
let effective_prefs: Vec<AgentUsagePreference> = effective_source
|
||||
.values()
|
||||
.map(|p| AgentUsagePreference {
|
||||
model: p.models.first().cloned().unwrap_or_default(),
|
||||
orchestration_preferences: vec![OrchestrationPreference {
|
||||
name: p.name.clone(),
|
||||
description: p.description.clone(),
|
||||
}],
|
||||
})
|
||||
.collect();
|
||||
|
||||
let orchestration_result = self
|
||||
.determine_orchestration(
|
||||
messages,
|
||||
Some(effective_prefs),
|
||||
Some(request_id.to_string()),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let result = if let Some(ref routes) = orchestration_result {
|
||||
if routes.len() > 1 {
|
||||
let all_routes: Vec<&str> = routes.iter().map(|(name, _)| name.as_str()).collect();
|
||||
info!(
|
||||
routes = ?all_routes,
|
||||
using = %all_routes.first().unwrap_or(&"none"),
|
||||
"plano-orchestrator detected multiple intents, using first"
|
||||
);
|
||||
}
|
||||
|
||||
if let Some((route_name, _)) = routes.first() {
|
||||
let top_pref = inline_top_map
|
||||
.as_ref()
|
||||
.and_then(|m| m.get(route_name))
|
||||
.or_else(|| self.top_level_preferences.get(route_name));
|
||||
|
||||
if let Some(pref) = top_pref {
|
||||
let ranked = match &self.metrics_service {
|
||||
Some(svc) => svc.rank_models(&pref.models, &pref.selection_policy).await,
|
||||
None => pref.models.clone(),
|
||||
};
|
||||
Some((route_name.clone(), ranked))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
info!(
|
||||
selected_model = ?result,
|
||||
"plano-orchestrator determined route"
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
// ---- Agent orchestration (existing) ----
|
||||
|
||||
pub async fn determine_orchestration(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
|
|
@ -80,12 +263,12 @@ impl OrchestratorService {
|
|||
debug!(
|
||||
model = %self.orchestrator_model.get_model_name(),
|
||||
endpoint = %self.orchestrator_url,
|
||||
"sending request to arch-orchestrator"
|
||||
"sending request to plano-orchestrator"
|
||||
);
|
||||
|
||||
let body = serde_json::to_string(&orchestrator_request)
|
||||
.map_err(super::orchestrator_model::OrchestratorModelError::from)?;
|
||||
debug!(body = %body, "arch orchestrator request");
|
||||
debug!(body = %body, "plano-orchestrator request");
|
||||
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(
|
||||
|
|
@ -98,7 +281,6 @@ impl OrchestratorService {
|
|||
.unwrap_or_else(|_| header::HeaderValue::from_static("plano-orchestrator")),
|
||||
);
|
||||
|
||||
// Inject OpenTelemetry trace context from current span
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
|
|
@ -130,9 +312,113 @@ impl OrchestratorService {
|
|||
content = %content.replace("\n", "\\n"),
|
||||
selected_routes = ?parsed,
|
||||
response_time_ms = elapsed.as_millis(),
|
||||
"arch-orchestrator determined routes"
|
||||
"plano-orchestrator determined routes"
|
||||
);
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::session_cache::memory::MemorySessionCache;
|
||||
|
||||
fn make_orchestrator_service(ttl_seconds: u64, max_entries: usize) -> OrchestratorService {
|
||||
let session_cache = Arc::new(MemorySessionCache::new(max_entries));
|
||||
OrchestratorService::with_routing(
|
||||
"http://localhost:12001/v1/chat/completions".to_string(),
|
||||
"Plano-Orchestrator".to_string(),
|
||||
"plano-orchestrator".to_string(),
|
||||
None,
|
||||
None,
|
||||
Some(ttl_seconds),
|
||||
session_cache,
|
||||
None,
|
||||
orchestrator_model_v1::MAX_TOKEN_LEN,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_miss_returns_none() {
|
||||
let svc = make_orchestrator_service(600, 100);
|
||||
assert!(svc
|
||||
.get_cached_route("unknown-session", None)
|
||||
.await
|
||||
.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_hit_returns_cached_route() {
|
||||
let svc = make_orchestrator_service(600, 100);
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
None,
|
||||
"gpt-4o".to_string(),
|
||||
Some("code".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let cached = svc.get_cached_route("s1", None).await.unwrap();
|
||||
assert_eq!(cached.model_name, "gpt-4o");
|
||||
assert_eq!(cached.route_name, Some("code".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_expired_entry_returns_none() {
|
||||
let svc = make_orchestrator_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_expired_entries_not_returned() {
|
||||
let svc = make_orchestrator_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), None, "claude".to_string(), None)
|
||||
.await;
|
||||
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
assert!(svc.get_cached_route("s2", None).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_evicts_oldest_when_full() {
|
||||
let svc = make_orchestrator_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), None, "model-a".to_string(), None)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
svc.cache_route("s2".to_string(), None, "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route("s3".to_string(), None, "model-c".to_string(), None)
|
||||
.await;
|
||||
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
assert!(svc.get_cached_route("s2", None).await.is_some());
|
||||
assert!(svc.get_cached_route("s3", None).await.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_update_existing_session_does_not_evict() {
|
||||
let svc = make_orchestrator_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), None, "model-a".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), None, "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
None,
|
||||
"model-a-updated".to_string(),
|
||||
Some("route".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let s1 = svc.get_cached_route("s1", None).await.unwrap();
|
||||
assert_eq!(s1.model_name, "model-a-updated");
|
||||
assert!(svc.get_cached_route("s2", None).await.is_some());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ pub enum OrchestratorModelError {
|
|||
pub type Result<T> = std::result::Result<T, OrchestratorModelError>;
|
||||
|
||||
/// OrchestratorModel trait for handling orchestration requests.
|
||||
/// Unlike RouterModel which returns a single route, OrchestratorModel
|
||||
/// can return multiple routes as the model output format is:
|
||||
/// Returns multiple routes as the model output format is:
|
||||
/// {"route": ["route_name_1", "route_name_2", ...]}
|
||||
pub trait OrchestratorModel: Send + Sync {
|
||||
fn generate_request(
|
||||
|
|
|
|||
|
|
@ -8,7 +8,19 @@ use tracing::{debug, warn};
|
|||
|
||||
use super::orchestrator_model::{OrchestratorModel, OrchestratorModelError};
|
||||
|
||||
pub const MAX_TOKEN_LEN: usize = 2048; // Default max token length for the orchestration model
|
||||
pub const MAX_TOKEN_LEN: usize = 8192; // Default max token length for the orchestration model
|
||||
|
||||
/// Hard cap on the number of recent messages considered when building the
|
||||
/// routing prompt. Bounds prompt growth for long-running conversations and
|
||||
/// acts as an outer guardrail before the token-budget loop runs. The most
|
||||
/// recent `MAX_ROUTING_TURNS` filtered messages are kept; older turns are
|
||||
/// dropped entirely.
|
||||
pub const MAX_ROUTING_TURNS: usize = 16;
|
||||
|
||||
/// Unicode ellipsis used to mark where content was trimmed out of a long
|
||||
/// message. Helps signal to the downstream router model that the message was
|
||||
/// truncated.
|
||||
const TRIM_MARKER: &str = "…";
|
||||
|
||||
/// Custom JSON formatter that produces spaced JSON (space after colons and commas), same as JSON in python
|
||||
struct SpacedJsonFormatter;
|
||||
|
|
@ -176,10 +188,9 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
messages: &[Message],
|
||||
usage_preferences_from_request: &Option<Vec<AgentUsagePreference>>,
|
||||
) -> ChatCompletionsRequest {
|
||||
// remove system prompt, tool calls, tool call response and messages without content
|
||||
// if content is empty its likely a tool call
|
||||
// when role == tool its tool call response
|
||||
let messages_vec = messages
|
||||
// Remove system/developer/tool messages and messages without extractable
|
||||
// text (tool calls have no text content we can classify against).
|
||||
let filtered: Vec<&Message> = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.role != Role::System
|
||||
|
|
@ -187,37 +198,72 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
&& m.role != Role::Tool
|
||||
&& !m.content.extract_text().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
.collect();
|
||||
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
|
||||
// Outer guardrail: only consider the last `MAX_ROUTING_TURNS` filtered
|
||||
// messages when building the routing prompt. Keeps prompt growth
|
||||
// predictable for long conversations regardless of per-message size.
|
||||
let start = filtered.len().saturating_sub(MAX_ROUTING_TURNS);
|
||||
let messages_vec: &[&Message] = &filtered[start..];
|
||||
|
||||
// Ensure the conversation does not exceed the configured token budget.
|
||||
// We use `len() / TOKEN_LENGTH_DIVISOR` as a cheap token estimate to
|
||||
// avoid running a real tokenizer on the hot path.
|
||||
let mut token_count = ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
let mut selected_messages_list_reversed: Vec<Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
let message_text = message.content.extract_text();
|
||||
let message_token_count = message_text.len() / TOKEN_LENGTH_DIVISOR;
|
||||
if token_count + message_token_count > self.max_token_length {
|
||||
let remaining_tokens = self.max_token_length.saturating_sub(token_count);
|
||||
debug!(
|
||||
token_count = token_count,
|
||||
attempted_total_tokens = token_count + message_token_count,
|
||||
max_tokens = self.max_token_length,
|
||||
remaining_tokens,
|
||||
selected = selected_messsage_count,
|
||||
total = messages_vec.len(),
|
||||
"token count exceeds max, truncating conversation"
|
||||
);
|
||||
if message.role == Role::User {
|
||||
// If message that exceeds max token length is from user, we need to keep it
|
||||
selected_messages_list_reversed.push(message);
|
||||
// If the overflow message is from the user we need to keep
|
||||
// some of it so the orchestrator still sees the latest user
|
||||
// intent. Use a middle-trim (head + ellipsis + tail): users
|
||||
// often frame the task at the start AND put the actual ask
|
||||
// at the end of a long pasted block, so preserving both is
|
||||
// better than a head-only cut. The ellipsis also signals to
|
||||
// the router model that content was dropped.
|
||||
if message.role == Role::User && remaining_tokens > 0 {
|
||||
let max_bytes = remaining_tokens.saturating_mul(TOKEN_LENGTH_DIVISOR);
|
||||
let truncated = trim_middle_utf8(&message_text, max_bytes);
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: Role::User,
|
||||
content: Some(MessageContent::Text(truncated)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
// If we are here, it means that the message is within the max token length
|
||||
selected_messages_list_reversed.push(message);
|
||||
token_count += message_token_count;
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: message.role.clone(),
|
||||
content: Some(MessageContent::Text(message_text)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
if selected_messages_list_reversed.is_empty() {
|
||||
debug!("no messages selected, using last message");
|
||||
if let Some(last_message) = messages_vec.last() {
|
||||
selected_messages_list_reversed.push(last_message);
|
||||
selected_messages_list_reversed.push(Message {
|
||||
role: last_message.role.clone(),
|
||||
content: Some(MessageContent::Text(last_message.content.extract_text())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -237,22 +283,8 @@ impl OrchestratorModel for OrchestratorModelV1 {
|
|||
}
|
||||
|
||||
// Reverse the selected messages to maintain the conversation order
|
||||
let selected_conversation_list = selected_messages_list_reversed
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|message| Message {
|
||||
role: message.role.clone(),
|
||||
content: Some(MessageContent::Text(
|
||||
message
|
||||
.content
|
||||
.as_ref()
|
||||
.map_or(String::new(), |c| c.to_string()),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
let selected_conversation_list: Vec<Message> =
|
||||
selected_messages_list_reversed.into_iter().rev().collect();
|
||||
|
||||
// Generate the orchestrator request message based on the usage preferences.
|
||||
// If preferences are passed in request then we use them;
|
||||
|
|
@ -405,6 +437,45 @@ fn fix_json_response(body: &str) -> String {
|
|||
body.replace("'", "\"").replace("\\n", "")
|
||||
}
|
||||
|
||||
/// Truncate `s` so the result is at most `max_bytes` bytes long, keeping
|
||||
/// roughly 60% from the start and 40% from the end, with a Unicode ellipsis
|
||||
/// separating the two. All splits respect UTF-8 character boundaries. When
|
||||
/// `max_bytes` is too small to fit the marker at all, falls back to a
|
||||
/// head-only truncation.
|
||||
fn trim_middle_utf8(s: &str, max_bytes: usize) -> String {
|
||||
if s.len() <= max_bytes {
|
||||
return s.to_string();
|
||||
}
|
||||
if max_bytes <= TRIM_MARKER.len() {
|
||||
// Not enough room even for the marker — just keep the start.
|
||||
let mut end = max_bytes;
|
||||
while end > 0 && !s.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
return s[..end].to_string();
|
||||
}
|
||||
|
||||
let available = max_bytes - TRIM_MARKER.len();
|
||||
// Bias toward the start (60%) where task framing typically lives, while
|
||||
// still preserving ~40% of the tail where the user's actual ask often
|
||||
// appears after a long paste.
|
||||
let mut start_len = available * 3 / 5;
|
||||
while start_len > 0 && !s.is_char_boundary(start_len) {
|
||||
start_len -= 1;
|
||||
}
|
||||
let end_len = available - start_len;
|
||||
let mut end_start = s.len().saturating_sub(end_len);
|
||||
while end_start < s.len() && !s.is_char_boundary(end_start) {
|
||||
end_start += 1;
|
||||
}
|
||||
|
||||
let mut out = String::with_capacity(start_len + TRIM_MARKER.len() + (s.len() - end_start));
|
||||
out.push_str(&s[..start_len]);
|
||||
out.push_str(TRIM_MARKER);
|
||||
out.push_str(&s[end_start..]);
|
||||
out
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn OrchestratorModel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "OrchestratorModel")
|
||||
|
|
@ -777,6 +848,10 @@ If no routes are needed, return an empty list for `route`.
|
|||
|
||||
#[test]
|
||||
fn test_conversation_trim_upto_user_message() {
|
||||
// With max_token_length=230, the older user message "given the image
|
||||
// In style of Andy Warhol" overflows the remaining budget and gets
|
||||
// middle-trimmed (head + ellipsis + tail) until it fits. Newer turns
|
||||
// are kept in full.
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant that selects the most suitable routes based on user intent.
|
||||
You are provided with a list of available routes enclosed within <routes></routes> XML tags:
|
||||
|
|
@ -789,7 +864,7 @@ You are also given the conversation context enclosed within <conversation></conv
|
|||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol"
|
||||
"content": "given…rhol"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
|
|
@ -862,6 +937,190 @@ If no routes are needed, return an empty list for `route`.
|
|||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_huge_single_user_message_is_middle_trimmed() {
|
||||
// Regression test for the case where a single, extremely large user
|
||||
// message was being passed to the orchestrator verbatim and blowing
|
||||
// past the upstream model's context window. The trimmer must now
|
||||
// middle-trim (head + ellipsis + tail) the oversized message so the
|
||||
// resulting request stays within the configured budget, and the
|
||||
// trim marker must be present so the router model knows content
|
||||
// was dropped.
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let max_token_length = 2048;
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
"test-model".to_string(),
|
||||
max_token_length,
|
||||
);
|
||||
|
||||
// ~500KB of content — same scale as the real payload that triggered
|
||||
// the production upstream 400.
|
||||
let head = "HEAD_MARKER_START ";
|
||||
let tail = " TAIL_MARKER_END";
|
||||
let filler = "A".repeat(500_000);
|
||||
let huge_user_content = format!("{head}{filler}{tail}");
|
||||
|
||||
let conversation = vec![Message {
|
||||
role: Role::User,
|
||||
content: Some(MessageContent::Text(huge_user_content.clone())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
// Prompt must stay bounded. Generous ceiling = budget-in-bytes +
|
||||
// scaffolding + slack. Real result should be well under this.
|
||||
let byte_ceiling = max_token_length * TOKEN_LENGTH_DIVISOR
|
||||
+ ARCH_ORCHESTRATOR_V1_SYSTEM_PROMPT.len()
|
||||
+ 1024;
|
||||
assert!(
|
||||
prompt.len() < byte_ceiling,
|
||||
"prompt length {} exceeded ceiling {} — truncation did not apply",
|
||||
prompt.len(),
|
||||
byte_ceiling,
|
||||
);
|
||||
|
||||
// Not all 500k filler chars survive.
|
||||
let a_count = prompt.chars().filter(|c| *c == 'A').count();
|
||||
assert!(
|
||||
a_count < filler.len(),
|
||||
"expected user message to be truncated; all {} 'A's survived",
|
||||
a_count
|
||||
);
|
||||
assert!(
|
||||
a_count > 0,
|
||||
"expected some of the user message to survive truncation"
|
||||
);
|
||||
|
||||
// Head and tail of the message must both be preserved (that's the
|
||||
// whole point of middle-trim over head-only).
|
||||
assert!(
|
||||
prompt.contains(head),
|
||||
"head marker missing — head was not preserved"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains(tail),
|
||||
"tail marker missing — tail was not preserved"
|
||||
);
|
||||
|
||||
// Trim marker must be present so the router model can see that
|
||||
// content was omitted.
|
||||
assert!(
|
||||
prompt.contains(TRIM_MARKER),
|
||||
"ellipsis trim marker missing from truncated prompt"
|
||||
);
|
||||
|
||||
// Routing prompt scaffolding remains intact.
|
||||
assert!(prompt.contains("<conversation>"));
|
||||
assert!(prompt.contains("<routes>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_turn_cap_limits_routing_history() {
|
||||
// The outer turn-cap guardrail should keep only the last
|
||||
// `MAX_ROUTING_TURNS` filtered messages regardless of how long the
|
||||
// conversation is. We build a conversation with alternating
|
||||
// user/assistant turns tagged with their index and verify that only
|
||||
// the tail of the conversation makes it into the prompt.
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(agent_orchestrations, "test-model".to_string(), usize::MAX);
|
||||
|
||||
let mut conversation: Vec<Message> = Vec::new();
|
||||
let total_turns = MAX_ROUTING_TURNS * 2; // well past the cap
|
||||
for i in 0..total_turns {
|
||||
let role = if i % 2 == 0 {
|
||||
Role::User
|
||||
} else {
|
||||
Role::Assistant
|
||||
};
|
||||
conversation.push(Message {
|
||||
role,
|
||||
content: Some(MessageContent::Text(format!("turn-{i:03}"))),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
}
|
||||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
// The last MAX_ROUTING_TURNS messages (indexes total-cap..total)
|
||||
// must all appear.
|
||||
for i in (total_turns - MAX_ROUTING_TURNS)..total_turns {
|
||||
let tag = format!("turn-{i:03}");
|
||||
assert!(
|
||||
prompt.contains(&tag),
|
||||
"expected recent turn tag {tag} to be present"
|
||||
);
|
||||
}
|
||||
|
||||
// And earlier turns (indexes 0..total-cap) must all be dropped.
|
||||
for i in 0..(total_turns - MAX_ROUTING_TURNS) {
|
||||
let tag = format!("turn-{i:03}");
|
||||
assert!(
|
||||
!prompt.contains(&tag),
|
||||
"old turn tag {tag} leaked past turn cap into the prompt"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trim_middle_utf8_helper() {
|
||||
// No-op when already small enough.
|
||||
assert_eq!(trim_middle_utf8("hello", 100), "hello");
|
||||
assert_eq!(trim_middle_utf8("hello", 5), "hello");
|
||||
|
||||
// 60/40 split with ellipsis when too long.
|
||||
let long = "a".repeat(20);
|
||||
let out = trim_middle_utf8(&long, 10);
|
||||
assert!(out.len() <= 10);
|
||||
assert!(out.contains(TRIM_MARKER));
|
||||
// Exactly one ellipsis, rest are 'a's.
|
||||
assert_eq!(out.matches(TRIM_MARKER).count(), 1);
|
||||
assert!(out.chars().filter(|c| *c == 'a').count() > 0);
|
||||
|
||||
// When max_bytes is smaller than the marker, falls back to
|
||||
// head-only truncation (no marker).
|
||||
let out = trim_middle_utf8("abcdefgh", 2);
|
||||
assert_eq!(out, "ab");
|
||||
|
||||
// UTF-8 boundary safety: 2-byte chars.
|
||||
let s = "é".repeat(50); // 100 bytes
|
||||
let out = trim_middle_utf8(&s, 25);
|
||||
assert!(out.len() <= 25);
|
||||
// Must still be valid UTF-8 that only contains 'é' and the marker.
|
||||
let ok = out.chars().all(|c| c == 'é' || c == '…');
|
||||
assert!(ok, "unexpected char in trimmed output: {out:?}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_text_input() {
|
||||
let expected_prompt = r#"
|
||||
|
|
|
|||
|
|
@ -1,39 +0,0 @@
|
|||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RoutingModelError {
|
||||
#[error("Failed to parse JSON: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
||||
|
||||
/// Internal route descriptor passed to the router model to build its prompt.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoutingPreference {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// Groups a model with its routing preferences (used internally by RouterModelV1).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelUsagePreference {
|
||||
pub model: String,
|
||||
pub routing_preferences: Vec<RoutingPreference>,
|
||||
}
|
||||
|
||||
pub trait RouterModel: Send + Sync {
|
||||
fn generate_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> ChatCompletionsRequest;
|
||||
fn parse_response(
|
||||
&self,
|
||||
content: &str,
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> Result<Option<(String, String)>>;
|
||||
fn get_model_name(&self) -> String;
|
||||
}
|
||||
|
|
@ -1,842 +0,0 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use super::router_model::{ModelUsagePreference, RoutingPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hermesllm::transforms::lib::ExtractText;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::router_model::{RouterModel, RoutingModelError};
|
||||
|
||||
pub const MAX_TOKEN_LEN: usize = 2048; // Default max token length for the routing model
|
||||
pub const ARCH_ROUTER_V1_SYSTEM_PROMPT: &str = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
{routes}
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
{conversation}
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
||||
pub struct RouterModelV1 {
|
||||
llm_route_json_str: String,
|
||||
llm_route_to_model_map: HashMap<String, String>,
|
||||
routing_model: String,
|
||||
max_token_length: usize,
|
||||
}
|
||||
impl RouterModelV1 {
|
||||
pub fn new(
|
||||
llm_routes: HashMap<String, Vec<RoutingPreference>>,
|
||||
routing_model: String,
|
||||
max_token_length: usize,
|
||||
) -> Self {
|
||||
let llm_route_values: Vec<RoutingPreference> =
|
||||
llm_routes.values().flatten().cloned().collect();
|
||||
let llm_route_json_str =
|
||||
serde_json::to_string(&llm_route_values).unwrap_or_else(|_| "[]".to_string());
|
||||
let llm_route_to_model_map: HashMap<String, String> = llm_routes
|
||||
.iter()
|
||||
.flat_map(|(model, prefs)| prefs.iter().map(|pref| (pref.name.clone(), model.clone())))
|
||||
.collect();
|
||||
|
||||
RouterModelV1 {
|
||||
routing_model,
|
||||
max_token_length,
|
||||
llm_route_json_str,
|
||||
llm_route_to_model_map,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct LlmRouterResponse {
|
||||
pub route: Option<String>,
|
||||
}
|
||||
|
||||
const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters
|
||||
|
||||
impl RouterModel for RouterModelV1 {
|
||||
fn generate_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences_from_request: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> ChatCompletionsRequest {
|
||||
// remove system prompt, tool calls, tool call response and messages without content
|
||||
// if content is empty its likely a tool call
|
||||
// when role == tool its tool call response
|
||||
let messages_vec = messages
|
||||
.iter()
|
||||
.filter(|m| {
|
||||
m.role != Role::System
|
||||
&& m.role != Role::Developer
|
||||
&& m.role != Role::Tool
|
||||
&& !m.content.extract_text().is_empty()
|
||||
})
|
||||
.collect::<Vec<&Message>>();
|
||||
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
// Note: we use a simple heuristic to estimate token count based on character length to optimize for performance
|
||||
let mut token_count = ARCH_ROUTER_V1_SYSTEM_PROMPT.len() / TOKEN_LENGTH_DIVISOR;
|
||||
let mut selected_messages_list_reversed: Vec<&Message> = vec![];
|
||||
for (selected_messsage_count, message) in messages_vec.iter().rev().enumerate() {
|
||||
let message_token_count = message.content.extract_text().len() / TOKEN_LENGTH_DIVISOR;
|
||||
token_count += message_token_count;
|
||||
if token_count > self.max_token_length {
|
||||
debug!(
|
||||
token_count = token_count,
|
||||
max_tokens = self.max_token_length,
|
||||
selected = selected_messsage_count,
|
||||
total = messages_vec.len(),
|
||||
"token count exceeds max, truncating conversation"
|
||||
);
|
||||
if message.role == Role::User {
|
||||
// If message that exceeds max token length is from user, we need to keep it
|
||||
selected_messages_list_reversed.push(message);
|
||||
}
|
||||
break;
|
||||
}
|
||||
// If we are here, it means that the message is within the max token length
|
||||
selected_messages_list_reversed.push(message);
|
||||
}
|
||||
|
||||
if selected_messages_list_reversed.is_empty() {
|
||||
debug!("no messages selected, using last message");
|
||||
if let Some(last_message) = messages_vec.last() {
|
||||
selected_messages_list_reversed.push(last_message);
|
||||
}
|
||||
}
|
||||
|
||||
// ensure that first and last selected message is from user
|
||||
if let Some(first_message) = selected_messages_list_reversed.first() {
|
||||
if first_message.role != Role::User {
|
||||
warn!("last message is not from user, may lead to incorrect routing");
|
||||
}
|
||||
}
|
||||
if let Some(last_message) = selected_messages_list_reversed.last() {
|
||||
if last_message.role != Role::User {
|
||||
warn!("first message is not from user, may lead to incorrect routing");
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse the selected messages to maintain the conversation order
|
||||
let selected_conversation_list = selected_messages_list_reversed
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|message| {
|
||||
Message {
|
||||
role: message.role.clone(),
|
||||
// we can unwrap here because we have already filtered out messages without content
|
||||
content: Some(MessageContent::Text(
|
||||
message
|
||||
.content
|
||||
.as_ref()
|
||||
.map_or(String::new(), |c| c.to_string()),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
|
||||
// Generate the router request message based on the usage preferences.
|
||||
// If preferences are passed in request then we use them otherwise we use the default routing model preferences.
|
||||
let router_message = match convert_to_router_preferences(usage_preferences_from_request) {
|
||||
Some(prefs) => generate_router_message(&prefs, &selected_conversation_list),
|
||||
None => generate_router_message(&self.llm_route_json_str, &selected_conversation_list),
|
||||
};
|
||||
|
||||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: Some(MessageContent::Text(router_message)),
|
||||
role: Role::User,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
temperature: Some(0.01),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_response(
|
||||
&self,
|
||||
content: &str,
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> Result<Option<(String, String)>> {
|
||||
if content.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
let router_resp_fixed = fix_json_response(content);
|
||||
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
|
||||
|
||||
let selected_route = router_response.route.unwrap_or_default().to_string();
|
||||
|
||||
if selected_route.is_empty() || selected_route == "other" {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(usage_preferences) = usage_preferences {
|
||||
// If usage preferences are defined, we need to find the model that matches the selected route
|
||||
let model_name: Option<String> = usage_preferences
|
||||
.iter()
|
||||
.map(|pref| {
|
||||
pref.routing_preferences
|
||||
.iter()
|
||||
.find(|routing_pref| routing_pref.name == selected_route)
|
||||
.map(|_| pref.model.clone())
|
||||
})
|
||||
.find_map(|model| model);
|
||||
|
||||
if let Some(model_name) = model_name {
|
||||
return Ok(Some((selected_route, model_name)));
|
||||
} else {
|
||||
warn!(
|
||||
route = %selected_route,
|
||||
preferences = ?usage_preferences,
|
||||
"no matching model found for route"
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
// If no usage preferences are passed in request then use the default routing model preferences
|
||||
if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() {
|
||||
return Ok(Some((selected_route, model)));
|
||||
}
|
||||
|
||||
warn!(
|
||||
route = %selected_route,
|
||||
preferences = ?self.llm_route_to_model_map,
|
||||
"no model found for route"
|
||||
);
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn get_model_name(&self) -> String {
|
||||
self.routing_model.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_router_message(prefs: &str, selected_conversation_list: &Vec<Message>) -> String {
|
||||
ARCH_ROUTER_V1_SYSTEM_PROMPT
|
||||
.replace("{routes}", prefs)
|
||||
.replace(
|
||||
"{conversation}",
|
||||
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn convert_to_router_preferences(
|
||||
prefs_from_request: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> Option<String> {
|
||||
if let Some(usage_preferences) = prefs_from_request {
|
||||
let routing_preferences = usage_preferences
|
||||
.iter()
|
||||
.flat_map(|pref| {
|
||||
pref.routing_preferences
|
||||
.iter()
|
||||
.map(|routing_pref| RoutingPreference {
|
||||
name: routing_pref.name.clone(),
|
||||
description: routing_pref.description.clone(),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<RoutingPreference>>();
|
||||
|
||||
return Some(serde_json::to_string(&routing_preferences).unwrap_or_default());
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn fix_json_response(body: &str) -> String {
|
||||
let mut updated_body = body.to_string();
|
||||
|
||||
updated_body = updated_body.replace("'", "\"");
|
||||
|
||||
if updated_body.contains("\\n") {
|
||||
updated_body = updated_body.replace("\\n", "");
|
||||
}
|
||||
|
||||
if updated_body.starts_with("```json") {
|
||||
updated_body = updated_body
|
||||
.strip_prefix("```json")
|
||||
.unwrap_or(&updated_body)
|
||||
.to_string();
|
||||
}
|
||||
|
||||
if updated_body.ends_with("```") {
|
||||
updated_body = updated_body
|
||||
.strip_suffix("```")
|
||||
.unwrap_or(&updated_body)
|
||||
.to_string();
|
||||
}
|
||||
|
||||
updated_body
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for dyn RouterModel {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "RouterModel")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_format() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_format_usage_preferences() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"code-generation","description":"generating new code snippets, functions, or boilerplate based on user prompts or requirements"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let usage_preferences = Some(vec![ModelUsagePreference {
|
||||
model: "claude/claude-3-7-sonnet".to_string(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: "code-generation".to_string(),
|
||||
description: "generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string(),
|
||||
}],
|
||||
}]);
|
||||
let req = router.generate_request(&conversation, &usage_preferences);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_exceed_token_count() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 235);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_exceed_token_count_large_single_message() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson and this is a very long message that exceeds the max token length of the routing model, so it should be truncated and only the last user message should be included in the conversation for routing."}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 200);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson and this is a very long message that exceeds the max token length of the routing model, so it should be truncated and only the last user message should be included in the conversation for routing."
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_trim_upto_user_message() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"given the image In style of Andy Warhol"},{"role":"assistant","content":"ok here is the image"},{"role":"user","content":"pls give me another image about Bart and Lisa"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 230);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "ok here is the image"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "pls give me another image about Bart and Lisa"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_text_input() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hi"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image.png"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skip_tool_call() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"What's the weather like in Tokyo?"},{"role":"assistant","content":"The current weather in Tokyo is 22°C and sunny."},{"role":"user","content":"What about in New York?"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Tokyo?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "toolcall-abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": "{ \"location\": \"Tokyo\" }"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "toolcall-abc123",
|
||||
"content": "{ \"temperature\": \"22°C\", \"condition\": \"Sunny\" }"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The current weather in Tokyo is 22°C and sunny."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What about in New York?"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
// expects conversation to look like this
|
||||
|
||||
// [
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "What's the weather like in Tokyo?"
|
||||
// },
|
||||
// {
|
||||
// "role": "assistant",
|
||||
// "content": "The current weather in Tokyo is 22°C and sunny."
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "What about in New York?"
|
||||
// }
|
||||
// ]
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req: ChatCompletionsRequest = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.extract_text();
|
||||
|
||||
assert_eq!(expected_prompt, prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
|
||||
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
|
||||
|
||||
// Case 1: Valid JSON with non-empty route
|
||||
let input = r#"{"route": "Image generation"}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
|
||||
// Case 2: Valid JSON with empty route
|
||||
let input = r#"{"route": ""}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 3: Valid JSON with null route
|
||||
let input = r#"{"route": null}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 4: JSON missing route field
|
||||
let input = r#"{}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 4.1: empty string
|
||||
let input = r#""#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, None);
|
||||
|
||||
// Case 5: Malformed JSON
|
||||
let input = r#"{"route": "route1""#; // missing closing }
|
||||
let result = router.parse_response(input, &None);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Case 6: Single quotes and \n in JSON
|
||||
let input = "{'route': 'Image generation'}\\n";
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
|
||||
// Case 7: Code block marker
|
||||
let input = "```json\n{\"route\": \"Image generation\"}\n```";
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
}
|
||||
}
|
||||
82
crates/brightstaff/src/session_cache/memory.rs
Normal file
82
crates/brightstaff/src/session_cache/memory.rs
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
use std::{
|
||||
num::NonZeroUsize,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use lru::LruCache;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::info;
|
||||
|
||||
use super::{CachedRoute, SessionCache};
|
||||
|
||||
type CacheStore = Mutex<LruCache<String, (CachedRoute, Instant, Duration)>>;
|
||||
|
||||
pub struct MemorySessionCache {
|
||||
store: Arc<CacheStore>,
|
||||
}
|
||||
|
||||
impl MemorySessionCache {
|
||||
pub fn new(max_entries: usize) -> Self {
|
||||
let capacity = NonZeroUsize::new(max_entries)
|
||||
.unwrap_or_else(|| NonZeroUsize::new(10_000).expect("10_000 is non-zero"));
|
||||
let store = Arc::new(Mutex::new(LruCache::new(capacity)));
|
||||
|
||||
// Spawn a background task to evict TTL-expired entries every 5 minutes.
|
||||
let store_clone = Arc::clone(&store);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(300));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
Self::evict_expired(&store_clone).await;
|
||||
}
|
||||
});
|
||||
|
||||
Self { store }
|
||||
}
|
||||
|
||||
async fn evict_expired(store: &CacheStore) {
|
||||
let mut cache = store.lock().await;
|
||||
let expired: Vec<String> = cache
|
||||
.iter()
|
||||
.filter(|(_, (_, inserted_at, ttl))| inserted_at.elapsed() >= *ttl)
|
||||
.map(|(k, _)| k.clone())
|
||||
.collect();
|
||||
let removed = expired.len();
|
||||
for key in &expired {
|
||||
cache.pop(key.as_str());
|
||||
}
|
||||
if removed > 0 {
|
||||
info!(
|
||||
removed = removed,
|
||||
remaining = cache.len(),
|
||||
"cleaned up expired session cache entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionCache for MemorySessionCache {
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute> {
|
||||
let mut cache = self.store.lock().await;
|
||||
if let Some((route, inserted_at, ttl)) = cache.get(key) {
|
||||
if inserted_at.elapsed() < *ttl {
|
||||
return Some(route.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration) {
|
||||
self.store
|
||||
.lock()
|
||||
.await
|
||||
.put(key.to_string(), (route, Instant::now(), ttl));
|
||||
}
|
||||
|
||||
async fn remove(&self, key: &str) {
|
||||
self.store.lock().await.pop(key);
|
||||
}
|
||||
}
|
||||
70
crates/brightstaff/src/session_cache/mod.rs
Normal file
70
crates/brightstaff/src/session_cache/mod.rs
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common::configuration::Configuration;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, info};
|
||||
|
||||
pub mod memory;
|
||||
pub mod redis;
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CachedRoute {
|
||||
pub model_name: String,
|
||||
pub route_name: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait SessionCache: Send + Sync {
|
||||
/// Look up a cached routing decision by key.
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute>;
|
||||
|
||||
/// Store a routing decision in the session cache with the given TTL.
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration);
|
||||
|
||||
/// Remove a cached routing decision by key.
|
||||
async fn remove(&self, key: &str);
|
||||
}
|
||||
|
||||
/// Initialize the session cache backend from config.
|
||||
/// Defaults to the in-memory backend when no `session_cache` block is configured.
|
||||
pub async fn init_session_cache(
|
||||
config: &Configuration,
|
||||
) -> Result<Arc<dyn SessionCache>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use common::configuration::SessionCacheType;
|
||||
|
||||
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
|
||||
|
||||
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
|
||||
let max_entries = session_max_entries
|
||||
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
|
||||
.min(MAX_SESSION_MAX_ENTRIES);
|
||||
|
||||
let cache_config = config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.session_cache.as_ref());
|
||||
|
||||
let cache_type = cache_config
|
||||
.map(|c| &c.cache_type)
|
||||
.unwrap_or(&SessionCacheType::Memory);
|
||||
|
||||
match cache_type {
|
||||
SessionCacheType::Memory => {
|
||||
info!(storage_type = "memory", "initialized session cache");
|
||||
Ok(Arc::new(memory::MemorySessionCache::new(max_entries)))
|
||||
}
|
||||
SessionCacheType::Redis => {
|
||||
let url = cache_config
|
||||
.and_then(|c| c.url.as_ref())
|
||||
.ok_or("session_cache.url is required when type is redis")?;
|
||||
debug!(storage_type = "redis", url = %url, "initializing session cache");
|
||||
let cache = redis::RedisSessionCache::new(url)
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to Redis session cache: {e}"))?;
|
||||
Ok(Arc::new(cache))
|
||||
}
|
||||
}
|
||||
}
|
||||
48
crates/brightstaff/src/session_cache/redis.rs
Normal file
48
crates/brightstaff/src/session_cache/redis.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use redis::aio::MultiplexedConnection;
|
||||
use redis::AsyncCommands;
|
||||
|
||||
use super::{CachedRoute, SessionCache};
|
||||
|
||||
const KEY_PREFIX: &str = "plano:affinity:";
|
||||
|
||||
pub struct RedisSessionCache {
|
||||
conn: MultiplexedConnection,
|
||||
}
|
||||
|
||||
impl RedisSessionCache {
|
||||
pub async fn new(url: &str) -> Result<Self, redis::RedisError> {
|
||||
let client = redis::Client::open(url)?;
|
||||
let conn = client.get_multiplexed_async_connection().await?;
|
||||
Ok(Self { conn })
|
||||
}
|
||||
|
||||
fn make_key(key: &str) -> String {
|
||||
format!("{KEY_PREFIX}{key}")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionCache for RedisSessionCache {
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute> {
|
||||
let mut conn = self.conn.clone();
|
||||
let value: Option<String> = conn.get(Self::make_key(key)).await.ok()?;
|
||||
value.and_then(|v| serde_json::from_str(&v).ok())
|
||||
}
|
||||
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration) {
|
||||
let mut conn = self.conn.clone();
|
||||
let Ok(json) = serde_json::to_string(&route) else {
|
||||
return;
|
||||
};
|
||||
let ttl_secs = ttl.as_secs().max(1);
|
||||
let _: Result<(), _> = conn.set_ex(Self::make_key(key), json, ttl_secs).await;
|
||||
}
|
||||
|
||||
async fn remove(&self, key: &str) {
|
||||
let mut conn = self.conn.clone();
|
||||
let _: Result<(), _> = conn.del(Self::make_key(key)).await;
|
||||
}
|
||||
}
|
||||
|
|
@ -1250,7 +1250,7 @@ impl TextBasedSignalAnalyzer {
|
|||
let mut repair_phrases = Vec::new();
|
||||
let mut user_turn_count = 0;
|
||||
|
||||
for (i, role, norm_msg) in normalized_messages {
|
||||
for (pos, (i, role, norm_msg)) in normalized_messages.iter().enumerate() {
|
||||
if *role != Role::User {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -1274,10 +1274,13 @@ impl TextBasedSignalAnalyzer {
|
|||
}
|
||||
}
|
||||
|
||||
// Only check for semantic similarity if no pattern matched
|
||||
if !found_in_turn && *i >= 2 {
|
||||
// Find previous user message
|
||||
for j in (0..*i).rev() {
|
||||
// Only check for semantic similarity if no pattern matched. Walk
|
||||
// backwards through the *normalized* list (not the original
|
||||
// conversation indices, which may be non-contiguous because
|
||||
// messages without extractable text are filtered out) to find the
|
||||
// most recent prior user message.
|
||||
if !found_in_turn && pos >= 1 {
|
||||
for j in (0..pos).rev() {
|
||||
let (_, prev_role, prev_norm_msg) = &normalized_messages[j];
|
||||
if *prev_role == Role::User {
|
||||
if self.is_similar_rephrase(norm_msg, prev_norm_msg) {
|
||||
|
|
@ -2199,6 +2202,68 @@ mod tests {
|
|||
println!("test_follow_up_detection took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_follow_up_does_not_panic_with_filtered_messages() {
|
||||
// Regression test: the preprocessing pipeline filters out messages
|
||||
// without extractable text (tool calls, tool results, empty content).
|
||||
// The stored tuple index `i` is the ORIGINAL-conversation index, so
|
||||
// once anything is filtered out, `i` no longer matches the position
|
||||
// inside `normalized_messages`. The old code used `*i` to index into
|
||||
// `normalized_messages`, which panicked with "index out of bounds"
|
||||
// when a user message appeared after filtered entries.
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: Some(hermesllm::apis::openai::MessageContent::Text(
|
||||
"first question".to_string(),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
// Assistant message with no text content (e.g. tool call) — filtered out.
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
// Tool-role message with no extractable text — filtered out.
|
||||
Message {
|
||||
role: Role::Tool,
|
||||
content: None,
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(hermesllm::apis::openai::MessageContent::Text(
|
||||
"some answer".to_string(),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
// Rephrased user turn — original index 4, but after filtering
|
||||
// only 3 messages remain in `normalized_messages` before it.
|
||||
Message {
|
||||
role: Role::User,
|
||||
content: Some(hermesllm::apis::openai::MessageContent::Text(
|
||||
"first question please".to_string(),
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Must not panic — exercises the full analyze pipeline.
|
||||
let _report = analyzer.analyze(&messages);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frustration_detection() {
|
||||
let start = Instant::now();
|
||||
|
|
|
|||
|
|
@ -16,10 +16,131 @@ use tracing_opentelemetry::OpenTelemetrySpanExt;
|
|||
use crate::handlers::agents::pipeline::{PipelineError, PipelineProcessor};
|
||||
|
||||
const STREAM_BUFFER_SIZE: usize = 16;
|
||||
/// Cap on accumulated response bytes kept for usage extraction.
|
||||
/// Most chat responses are well under this; pathological ones are dropped without
|
||||
/// affecting pass-through streaming to the client.
|
||||
const USAGE_BUFFER_MAX: usize = 2 * 1024 * 1024;
|
||||
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
||||
use crate::tracing::{llm, set_service_name, signals as signal_constants};
|
||||
use hermesllm::apis::openai::Message;
|
||||
|
||||
/// Parsed usage + resolved-model details from a provider response.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
struct ExtractedUsage {
|
||||
prompt_tokens: Option<i64>,
|
||||
completion_tokens: Option<i64>,
|
||||
total_tokens: Option<i64>,
|
||||
cached_input_tokens: Option<i64>,
|
||||
cache_creation_tokens: Option<i64>,
|
||||
reasoning_tokens: Option<i64>,
|
||||
/// The model the upstream actually used. For router aliases (e.g.
|
||||
/// `router:software-engineering`), this differs from the request model.
|
||||
resolved_model: Option<String>,
|
||||
}
|
||||
|
||||
impl ExtractedUsage {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.prompt_tokens.is_none()
|
||||
&& self.completion_tokens.is_none()
|
||||
&& self.total_tokens.is_none()
|
||||
&& self.resolved_model.is_none()
|
||||
}
|
||||
|
||||
fn from_json(value: &serde_json::Value) -> Self {
|
||||
let mut out = Self::default();
|
||||
if let Some(model) = value.get("model").and_then(|v| v.as_str()) {
|
||||
if !model.is_empty() {
|
||||
out.resolved_model = Some(model.to_string());
|
||||
}
|
||||
}
|
||||
if let Some(u) = value.get("usage") {
|
||||
// OpenAI-shape usage
|
||||
out.prompt_tokens = u.get("prompt_tokens").and_then(|v| v.as_i64());
|
||||
out.completion_tokens = u.get("completion_tokens").and_then(|v| v.as_i64());
|
||||
out.total_tokens = u.get("total_tokens").and_then(|v| v.as_i64());
|
||||
out.cached_input_tokens = u
|
||||
.get("prompt_tokens_details")
|
||||
.and_then(|d| d.get("cached_tokens"))
|
||||
.and_then(|v| v.as_i64());
|
||||
out.reasoning_tokens = u
|
||||
.get("completion_tokens_details")
|
||||
.and_then(|d| d.get("reasoning_tokens"))
|
||||
.and_then(|v| v.as_i64());
|
||||
|
||||
// Anthropic-shape fallbacks
|
||||
if out.prompt_tokens.is_none() {
|
||||
out.prompt_tokens = u.get("input_tokens").and_then(|v| v.as_i64());
|
||||
}
|
||||
if out.completion_tokens.is_none() {
|
||||
out.completion_tokens = u.get("output_tokens").and_then(|v| v.as_i64());
|
||||
}
|
||||
if out.total_tokens.is_none() {
|
||||
if let (Some(p), Some(c)) = (out.prompt_tokens, out.completion_tokens) {
|
||||
out.total_tokens = Some(p + c);
|
||||
}
|
||||
}
|
||||
if out.cached_input_tokens.is_none() {
|
||||
out.cached_input_tokens = u.get("cache_read_input_tokens").and_then(|v| v.as_i64());
|
||||
}
|
||||
if out.cached_input_tokens.is_none() {
|
||||
out.cached_input_tokens =
|
||||
u.get("cached_content_token_count").and_then(|v| v.as_i64());
|
||||
}
|
||||
out.cache_creation_tokens = u
|
||||
.get("cache_creation_input_tokens")
|
||||
.and_then(|v| v.as_i64());
|
||||
if out.reasoning_tokens.is_none() {
|
||||
out.reasoning_tokens = u.get("thoughts_token_count").and_then(|v| v.as_i64());
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to pull usage out of an accumulated response body.
|
||||
/// Handles both a single JSON object (non-streaming) and SSE streams where the
|
||||
/// final `data: {...}` event carries the `usage` field.
|
||||
fn extract_usage_from_bytes(buf: &[u8]) -> ExtractedUsage {
|
||||
if buf.is_empty() {
|
||||
return ExtractedUsage::default();
|
||||
}
|
||||
|
||||
// Fast path: full-body JSON (non-streaming).
|
||||
if let Ok(value) = serde_json::from_slice::<serde_json::Value>(buf) {
|
||||
let u = ExtractedUsage::from_json(&value);
|
||||
if !u.is_empty() {
|
||||
return u;
|
||||
}
|
||||
}
|
||||
|
||||
// SSE path: scan from the end for a `data:` line containing a usage object.
|
||||
let text = match std::str::from_utf8(buf) {
|
||||
Ok(t) => t,
|
||||
Err(_) => return ExtractedUsage::default(),
|
||||
};
|
||||
for line in text.lines().rev() {
|
||||
let trimmed = line.trim_start();
|
||||
let payload = match trimmed.strip_prefix("data:") {
|
||||
Some(p) => p.trim_start(),
|
||||
None => continue,
|
||||
};
|
||||
if payload == "[DONE]" || payload.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if !payload.contains("\"usage\"") {
|
||||
continue;
|
||||
}
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(payload) {
|
||||
let u = ExtractedUsage::from_json(&value);
|
||||
if !u.is_empty() {
|
||||
return u;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ExtractedUsage::default()
|
||||
}
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
pub trait StreamProcessor: Send + 'static {
|
||||
|
|
@ -60,6 +181,10 @@ pub struct ObservableStreamProcessor {
|
|||
start_time: Instant,
|
||||
time_to_first_token: Option<u128>,
|
||||
messages: Option<Vec<Message>>,
|
||||
/// Accumulated response bytes used only for best-effort usage extraction
|
||||
/// on `on_complete`. Capped at `USAGE_BUFFER_MAX`; excess chunks are dropped
|
||||
/// from the buffer (they still pass through to the client).
|
||||
response_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl ObservableStreamProcessor {
|
||||
|
|
@ -93,6 +218,7 @@ impl ObservableStreamProcessor {
|
|||
start_time,
|
||||
time_to_first_token: None,
|
||||
messages,
|
||||
response_buffer: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -101,6 +227,13 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
||||
self.total_bytes += chunk.len();
|
||||
self.chunk_count += 1;
|
||||
// Accumulate for best-effort usage extraction; drop further chunks once
|
||||
// the cap is reached so we don't retain huge response bodies in memory.
|
||||
if self.response_buffer.len() < USAGE_BUFFER_MAX {
|
||||
let remaining = USAGE_BUFFER_MAX - self.response_buffer.len();
|
||||
let take = chunk.len().min(remaining);
|
||||
self.response_buffer.extend_from_slice(&chunk[..take]);
|
||||
}
|
||||
Ok(Some(chunk))
|
||||
}
|
||||
|
||||
|
|
@ -124,6 +257,52 @@ impl StreamProcessor for ObservableStreamProcessor {
|
|||
);
|
||||
}
|
||||
|
||||
// Record total duration on the span for the observability console.
|
||||
let duration_ms = self.start_time.elapsed().as_millis() as i64;
|
||||
{
|
||||
let span = tracing::Span::current();
|
||||
let otel_context = span.context();
|
||||
let otel_span = otel_context.span();
|
||||
otel_span.set_attribute(KeyValue::new(llm::DURATION_MS, duration_ms));
|
||||
otel_span.set_attribute(KeyValue::new(llm::RESPONSE_BYTES, self.total_bytes as i64));
|
||||
}
|
||||
|
||||
// Best-effort usage extraction + emission (works for both streaming
|
||||
// SSE and non-streaming JSON responses that include a `usage` object).
|
||||
let usage = extract_usage_from_bytes(&self.response_buffer);
|
||||
if !usage.is_empty() {
|
||||
let span = tracing::Span::current();
|
||||
let otel_context = span.context();
|
||||
let otel_span = otel_context.span();
|
||||
if let Some(v) = usage.prompt_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::PROMPT_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.completion_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::COMPLETION_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.total_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::TOTAL_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.cached_input_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::CACHED_INPUT_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.cache_creation_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::CACHE_CREATION_TOKENS, v));
|
||||
}
|
||||
if let Some(v) = usage.reasoning_tokens {
|
||||
otel_span.set_attribute(KeyValue::new(llm::REASONING_TOKENS, v));
|
||||
}
|
||||
// Override `llm.model` with the model the upstream actually ran
|
||||
// (e.g. `openai-gpt-5.4` resolved from `router:software-engineering`).
|
||||
// Cost lookup keys off the real model, not the alias.
|
||||
if let Some(resolved) = usage.resolved_model.clone() {
|
||||
otel_span.set_attribute(KeyValue::new(llm::MODEL_NAME, resolved));
|
||||
}
|
||||
}
|
||||
// Release the buffered bytes early; nothing downstream needs them.
|
||||
self.response_buffer.clear();
|
||||
self.response_buffer.shrink_to_fit();
|
||||
|
||||
// Analyze signals if messages are available and record as span attributes
|
||||
if let Some(ref messages) = self.messages {
|
||||
let analyzer: Box<dyn SignalAnalyzer> = Box::new(TextBasedSignalAnalyzer::new());
|
||||
|
|
@ -404,3 +583,55 @@ pub fn truncate_message(message: &str, max_length: usize) -> String {
|
|||
message.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod usage_extraction_tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn non_streaming_openai_with_cached() {
|
||||
let body = br#"{"id":"x","model":"gpt-4o","choices":[],"usage":{"prompt_tokens":12,"completion_tokens":34,"total_tokens":46,"prompt_tokens_details":{"cached_tokens":5}}}"#;
|
||||
let u = extract_usage_from_bytes(body);
|
||||
assert_eq!(u.prompt_tokens, Some(12));
|
||||
assert_eq!(u.completion_tokens, Some(34));
|
||||
assert_eq!(u.total_tokens, Some(46));
|
||||
assert_eq!(u.cached_input_tokens, Some(5));
|
||||
assert_eq!(u.reasoning_tokens, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_streaming_anthropic_with_cache_creation() {
|
||||
let body = br#"{"id":"x","model":"claude","usage":{"input_tokens":100,"output_tokens":50,"cache_creation_input_tokens":20,"cache_read_input_tokens":30}}"#;
|
||||
let u = extract_usage_from_bytes(body);
|
||||
assert_eq!(u.prompt_tokens, Some(100));
|
||||
assert_eq!(u.completion_tokens, Some(50));
|
||||
assert_eq!(u.total_tokens, Some(150));
|
||||
assert_eq!(u.cached_input_tokens, Some(30));
|
||||
assert_eq!(u.cache_creation_tokens, Some(20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_openai_final_chunk_has_usage() {
|
||||
let sse = b"data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}
|
||||
|
||||
data: {\"choices\":[{\"delta\":{}, \"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3,\"total_tokens\":10}}
|
||||
|
||||
data: [DONE]
|
||||
|
||||
";
|
||||
let u = extract_usage_from_bytes(sse);
|
||||
assert_eq!(u.prompt_tokens, Some(7));
|
||||
assert_eq!(u.completion_tokens, Some(3));
|
||||
assert_eq!(u.total_tokens, Some(10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_returns_default() {
|
||||
assert!(extract_usage_from_bytes(b"").is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_usage_in_body_returns_default() {
|
||||
assert!(extract_usage_from_bytes(br#"{"ok":true}"#).is_empty());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,6 +80,18 @@ pub mod llm {
|
|||
/// Total tokens used (prompt + completion)
|
||||
pub const TOTAL_TOKENS: &str = "llm.usage.total_tokens";
|
||||
|
||||
/// Tokens served from a prompt cache read
|
||||
/// (OpenAI `prompt_tokens_details.cached_tokens`, Anthropic `cache_read_input_tokens`,
|
||||
/// Google `cached_content_token_count`)
|
||||
pub const CACHED_INPUT_TOKENS: &str = "llm.usage.cached_input_tokens";
|
||||
|
||||
/// Tokens used to write a prompt cache entry (Anthropic `cache_creation_input_tokens`)
|
||||
pub const CACHE_CREATION_TOKENS: &str = "llm.usage.cache_creation_tokens";
|
||||
|
||||
/// Reasoning tokens for reasoning models
|
||||
/// (OpenAI `completion_tokens_details.reasoning_tokens`, Google `thoughts_token_count`)
|
||||
pub const REASONING_TOKENS: &str = "llm.usage.reasoning_tokens";
|
||||
|
||||
/// Temperature parameter used
|
||||
pub const TEMPERATURE: &str = "llm.temperature";
|
||||
|
||||
|
|
@ -119,6 +131,22 @@ pub mod routing {
|
|||
pub const SELECTION_REASON: &str = "routing.selection_reason";
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - Plano-specific
|
||||
// =============================================================================
|
||||
|
||||
/// Attributes specific to Plano (session affinity, routing decisions).
|
||||
pub mod plano {
|
||||
/// Session identifier propagated via the `x-model-affinity` header.
|
||||
/// Absent when the client did not send the header.
|
||||
pub const SESSION_ID: &str = "plano.session_id";
|
||||
|
||||
/// Matched route name from routing (e.g. "code", "summarization",
|
||||
/// "software-engineering"). Absent when the client routed directly
|
||||
/// to a concrete model.
|
||||
pub const ROUTE_NAME: &str = "plano.route.name";
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Span Attributes - Error Handling
|
||||
// =============================================================================
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ mod init;
|
|||
mod service_name_exporter;
|
||||
|
||||
pub use constants::{
|
||||
error, http, llm, operation_component, routing, signals, OperationNameBuilder,
|
||||
error, http, llm, operation_component, plano, routing, signals, OperationNameBuilder,
|
||||
};
|
||||
pub use custom_attributes::collect_custom_trace_attributes;
|
||||
pub use init::init_tracer;
|
||||
|
|
|
|||
|
|
@ -7,12 +7,32 @@ use crate::api::open_ai::{
|
|||
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SessionCacheType {
|
||||
#[default]
|
||||
Memory,
|
||||
Redis,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionCacheConfig {
|
||||
#[serde(rename = "type", default)]
|
||||
pub cache_type: SessionCacheType,
|
||||
/// Redis URL, e.g. `redis://localhost:6379`. Required when `type` is `redis`.
|
||||
pub url: Option<String>,
|
||||
/// Optional HTTP header name whose value is used as a tenant prefix in the cache key.
|
||||
/// When set, keys are scoped as `plano:affinity:{tenant_id}:{session_id}`.
|
||||
pub tenant_header: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Routing {
|
||||
pub llm_provider: Option<String>,
|
||||
pub model: Option<String>,
|
||||
pub session_ttl_seconds: Option<u64>,
|
||||
pub session_max_entries: Option<usize>,
|
||||
pub session_cache: Option<SessionCacheConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -213,6 +233,7 @@ pub struct Overrides {
|
|||
pub use_agent_orchestrator: Option<bool>,
|
||||
pub llm_routing_model: Option<String>,
|
||||
pub agent_orchestration_model: Option<String>,
|
||||
pub orchestrator_model_context_length: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
@ -372,6 +393,8 @@ pub enum LlmProviderType {
|
|||
Plano,
|
||||
#[serde(rename = "chatgpt")]
|
||||
ChatGPT,
|
||||
#[serde(rename = "digitalocean")]
|
||||
DigitalOcean,
|
||||
}
|
||||
|
||||
impl Display for LlmProviderType {
|
||||
|
|
@ -394,6 +417,7 @@ impl Display for LlmProviderType {
|
|||
LlmProviderType::AmazonBedrock => write!(f, "amazon_bedrock"),
|
||||
LlmProviderType::Plano => write!(f, "plano"),
|
||||
LlmProviderType::ChatGPT => write!(f, "chatgpt"),
|
||||
LlmProviderType::DigitalOcean => write!(f, "digitalocean"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -714,13 +738,6 @@ mod test {
|
|||
internal: None,
|
||||
..Default::default()
|
||||
},
|
||||
LlmProvider {
|
||||
name: "arch-router".to_string(),
|
||||
provider_interface: LlmProviderType::Plano,
|
||||
model: Some("Arch-Router".to_string()),
|
||||
internal: Some(true),
|
||||
..Default::default()
|
||||
},
|
||||
LlmProvider {
|
||||
name: "plano-orchestrator".to_string(),
|
||||
provider_interface: LlmProviderType::Plano,
|
||||
|
|
@ -732,13 +749,10 @@ mod test {
|
|||
|
||||
let models = providers.into_models();
|
||||
|
||||
// Should only have 1 model: openai-gpt4
|
||||
assert_eq!(models.data.len(), 1);
|
||||
|
||||
// Verify internal models are excluded from /v1/models
|
||||
let model_ids: Vec<String> = models.data.iter().map(|m| m.id.clone()).collect();
|
||||
assert!(model_ids.contains(&"openai-gpt4".to_string()));
|
||||
assert!(!model_ids.contains(&"arch-router".to_string()));
|
||||
assert!(!model_ids.contains(&"plano-orchestrator".to_string()));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -435,6 +435,12 @@ impl TokenUsage for MessagesResponse {
|
|||
fn total_tokens(&self) -> usize {
|
||||
(self.usage.input_tokens + self.usage.output_tokens) as usize
|
||||
}
|
||||
fn cached_input_tokens(&self) -> Option<usize> {
|
||||
self.usage.cache_read_input_tokens.map(|t| t as usize)
|
||||
}
|
||||
fn cache_creation_tokens(&self) -> Option<usize> {
|
||||
self.usage.cache_creation_input_tokens.map(|t| t as usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for MessagesResponse {
|
||||
|
|
|
|||
|
|
@ -596,6 +596,18 @@ impl TokenUsage for Usage {
|
|||
fn total_tokens(&self) -> usize {
|
||||
self.total_tokens as usize
|
||||
}
|
||||
|
||||
fn cached_input_tokens(&self) -> Option<usize> {
|
||||
self.prompt_tokens_details
|
||||
.as_ref()
|
||||
.and_then(|d| d.cached_tokens.map(|t| t as usize))
|
||||
}
|
||||
|
||||
fn reasoning_tokens(&self) -> Option<usize> {
|
||||
self.completion_tokens_details
|
||||
.as_ref()
|
||||
.and_then(|d| d.reasoning_tokens.map(|t| t as usize))
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of ProviderRequest for ChatCompletionsRequest
|
||||
|
|
|
|||
|
|
@ -710,6 +710,18 @@ impl crate::providers::response::TokenUsage for ResponseUsage {
|
|||
fn total_tokens(&self) -> usize {
|
||||
self.total_tokens as usize
|
||||
}
|
||||
|
||||
fn cached_input_tokens(&self) -> Option<usize> {
|
||||
self.input_tokens_details
|
||||
.as_ref()
|
||||
.map(|d| d.cached_tokens.max(0) as usize)
|
||||
}
|
||||
|
||||
fn reasoning_tokens(&self) -> Option<usize> {
|
||||
self.output_tokens_details
|
||||
.as_ref()
|
||||
.map(|d| d.reasoning_tokens.max(0) as usize)
|
||||
}
|
||||
}
|
||||
|
||||
/// Token details
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
use crate::apis::anthropic::MessagesStreamEvent;
|
||||
use crate::apis::anthropic::{
|
||||
MessagesMessageDelta, MessagesStopReason, MessagesStreamEvent, MessagesUsage,
|
||||
};
|
||||
use crate::apis::streaming_shapes::sse::{SseEvent, SseStreamBufferTrait};
|
||||
use crate::providers::streaming_response::ProviderStreamResponseType;
|
||||
use log::warn;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// SSE Stream Buffer for Anthropic Messages API streaming.
|
||||
|
|
@ -11,13 +14,24 @@ use std::collections::HashSet;
|
|||
///
|
||||
/// When converting from OpenAI to Anthropic format, this buffer injects the required
|
||||
/// ContentBlockStart and ContentBlockStop events to maintain proper Anthropic protocol.
|
||||
///
|
||||
/// Guarantees (Anthropic Messages API contract):
|
||||
/// 1. `message_stop` is never emitted unless a matching `message_start` was emitted first.
|
||||
/// 2. `message_stop` is emitted at most once per stream (no double-close).
|
||||
/// 3. If upstream terminates with no content (empty/filtered/errored response), a
|
||||
/// minimal but well-formed envelope is synthesized so the client's state machine
|
||||
/// stays consistent.
|
||||
pub struct AnthropicMessagesStreamBuffer {
|
||||
/// Buffered SSE events ready to be written to wire
|
||||
buffered_events: Vec<SseEvent>,
|
||||
|
||||
/// Track if we've seen a message_start event
|
||||
/// Track if we've emitted a message_start event
|
||||
message_started: bool,
|
||||
|
||||
/// Track if we've emitted a terminal message_stop event (for idempotency /
|
||||
/// double-close protection).
|
||||
message_stopped: bool,
|
||||
|
||||
/// Track content block indices that have received ContentBlockStart events
|
||||
content_block_start_indices: HashSet<i32>,
|
||||
|
||||
|
|
@ -42,6 +56,7 @@ impl AnthropicMessagesStreamBuffer {
|
|||
Self {
|
||||
buffered_events: Vec::new(),
|
||||
message_started: false,
|
||||
message_stopped: false,
|
||||
content_block_start_indices: HashSet::new(),
|
||||
needs_content_block_stop: false,
|
||||
seen_message_delta: false,
|
||||
|
|
@ -49,6 +64,66 @@ impl AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
}
|
||||
|
||||
/// Inject a `message_start` event into the buffer if one hasn't been emitted yet.
|
||||
/// This is the single source of truth for opening a message — every handler
|
||||
/// that can legitimately be the first event on the wire must call this before
|
||||
/// pushing its own event.
|
||||
fn ensure_message_started(&mut self) {
|
||||
if self.message_started {
|
||||
return;
|
||||
}
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
|
||||
/// Inject a synthetic `message_delta` with `end_turn` / zero usage.
|
||||
/// Used when we must close a message but upstream never produced a terminal
|
||||
/// event (e.g. `[DONE]` arrives with no prior `finish_reason`).
|
||||
fn push_synthetic_message_delta(&mut self) {
|
||||
let event = MessagesStreamEvent::MessageDelta {
|
||||
delta: MessagesMessageDelta {
|
||||
stop_reason: MessagesStopReason::EndTurn,
|
||||
stop_sequence: None,
|
||||
},
|
||||
usage: MessagesUsage {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
};
|
||||
let sse_string: String = event.clone().into();
|
||||
self.buffered_events.push(SseEvent {
|
||||
data: None,
|
||||
event: Some("message_delta".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: Some(ProviderStreamResponseType::MessagesStreamEvent(event)),
|
||||
});
|
||||
self.seen_message_delta = true;
|
||||
}
|
||||
|
||||
/// Inject a `message_stop` event into the buffer, marking the stream as closed.
|
||||
/// Idempotent — subsequent calls are no-ops.
|
||||
fn push_message_stop(&mut self) {
|
||||
if self.message_stopped {
|
||||
return;
|
||||
}
|
||||
let message_stop = MessagesStreamEvent::MessageStop;
|
||||
let sse_string: String = message_stop.into();
|
||||
self.buffered_events.push(SseEvent {
|
||||
data: None,
|
||||
event: Some("message_stop".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
});
|
||||
self.message_stopped = true;
|
||||
self.seen_message_delta = false;
|
||||
}
|
||||
|
||||
/// Check if a content_block_start event has been sent for the given index
|
||||
fn has_content_block_start_been_sent(&self, index: i32) -> bool {
|
||||
self.content_block_start_indices.contains(&index)
|
||||
|
|
@ -149,6 +224,27 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
// We match on a reference first to determine the type, then move the event
|
||||
match &event.provider_stream_response {
|
||||
Some(ProviderStreamResponseType::MessagesStreamEvent(evt)) => {
|
||||
// If the message has already been closed, drop any trailing events
|
||||
// to avoid emitting data after `message_stop` (protocol violation).
|
||||
// This typically indicates a duplicate `[DONE]` from upstream or a
|
||||
// replay of previously-buffered bytes — worth surfacing so we can
|
||||
// spot misbehaving providers.
|
||||
if self.message_stopped {
|
||||
warn!(
|
||||
"anthropic stream buffer: dropping event after message_stop (variant={})",
|
||||
match evt {
|
||||
MessagesStreamEvent::MessageStart { .. } => "message_start",
|
||||
MessagesStreamEvent::ContentBlockStart { .. } => "content_block_start",
|
||||
MessagesStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => "content_block_stop",
|
||||
MessagesStreamEvent::MessageDelta { .. } => "message_delta",
|
||||
MessagesStreamEvent::MessageStop => "message_stop",
|
||||
MessagesStreamEvent::Ping => "ping",
|
||||
}
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
match evt {
|
||||
MessagesStreamEvent::MessageStart { .. } => {
|
||||
// Add the message_start event
|
||||
|
|
@ -157,14 +253,7 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
MessagesStreamEvent::ContentBlockStart { index, .. } => {
|
||||
let index = *index as i32;
|
||||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start =
|
||||
AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
self.ensure_message_started();
|
||||
|
||||
// Add the content_block_start event (from tool calls or other sources)
|
||||
self.buffered_events.push(event);
|
||||
|
|
@ -173,14 +262,7 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
MessagesStreamEvent::ContentBlockDelta { index, .. } => {
|
||||
let index = *index as i32;
|
||||
// Inject message_start if needed
|
||||
if !self.message_started {
|
||||
let model = self.model.as_deref().unwrap_or("unknown");
|
||||
let message_start =
|
||||
AnthropicMessagesStreamBuffer::create_message_start_event(model);
|
||||
self.buffered_events.push(message_start);
|
||||
self.message_started = true;
|
||||
}
|
||||
self.ensure_message_started();
|
||||
|
||||
// Check if ContentBlockStart was sent for this index
|
||||
if !self.has_content_block_start_been_sent(index) {
|
||||
|
|
@ -196,6 +278,11 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
self.buffered_events.push(event);
|
||||
}
|
||||
MessagesStreamEvent::MessageDelta { usage, .. } => {
|
||||
// `message_delta` is only meaningful inside an open message.
|
||||
// Upstream can send it with no prior content (empty completion,
|
||||
// content filter, etc.), so we must open a message first.
|
||||
self.ensure_message_started();
|
||||
|
||||
// Inject ContentBlockStop before message_delta
|
||||
if self.needs_content_block_stop {
|
||||
let content_block_stop =
|
||||
|
|
@ -230,15 +317,52 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
MessagesStreamEvent::ContentBlockStop { .. } => {
|
||||
// ContentBlockStop received from upstream (e.g., Bedrock)
|
||||
self.ensure_message_started();
|
||||
// Clear the flag so we don't inject another one
|
||||
self.needs_content_block_stop = false;
|
||||
self.buffered_events.push(event);
|
||||
}
|
||||
MessagesStreamEvent::MessageStop => {
|
||||
// MessageStop received from upstream (e.g., OpenAI via [DONE])
|
||||
// Clear the flag so we don't inject another one
|
||||
self.seen_message_delta = false;
|
||||
// MessageStop received from upstream (e.g., OpenAI via [DONE]).
|
||||
//
|
||||
// The Anthropic protocol requires the full envelope
|
||||
// message_start → [content blocks] → message_delta → message_stop
|
||||
// so we must not emit a bare `message_stop`. Synthesize whatever
|
||||
// is missing to keep the client's state machine consistent.
|
||||
self.ensure_message_started();
|
||||
|
||||
if self.needs_content_block_stop {
|
||||
let content_block_stop =
|
||||
AnthropicMessagesStreamBuffer::create_content_block_stop_event();
|
||||
self.buffered_events.push(content_block_stop);
|
||||
self.needs_content_block_stop = false;
|
||||
}
|
||||
|
||||
// If no message_delta has been emitted yet (empty/filtered upstream
|
||||
// response), synthesize a minimal one carrying `end_turn`.
|
||||
if !self.seen_message_delta {
|
||||
// If we also never opened a content block, open and close one
|
||||
// so clients that expect at least one block are happy.
|
||||
if self.content_block_start_indices.is_empty() {
|
||||
let content_block_start =
|
||||
AnthropicMessagesStreamBuffer::create_content_block_start_event(
|
||||
);
|
||||
self.buffered_events.push(content_block_start);
|
||||
self.set_content_block_start_sent(0);
|
||||
let content_block_stop =
|
||||
AnthropicMessagesStreamBuffer::create_content_block_stop_event(
|
||||
);
|
||||
self.buffered_events.push(content_block_stop);
|
||||
}
|
||||
self.push_synthetic_message_delta();
|
||||
}
|
||||
|
||||
// Push the upstream-provided message_stop and mark closed.
|
||||
// `push_message_stop` is idempotent but we want to reuse the
|
||||
// original SseEvent so raw passthrough semantics are preserved.
|
||||
self.buffered_events.push(event);
|
||||
self.message_stopped = true;
|
||||
self.seen_message_delta = false;
|
||||
}
|
||||
_ => {
|
||||
// Other Anthropic event types (Ping, etc.), just accumulate
|
||||
|
|
@ -254,24 +378,23 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer {
|
|||
}
|
||||
|
||||
fn to_bytes(&mut self) -> Vec<u8> {
|
||||
// Convert all accumulated events to bytes and clear buffer
|
||||
// Convert all accumulated events to bytes and clear buffer.
|
||||
//
|
||||
// NOTE: We do NOT inject ContentBlockStop here because it's injected when we see MessageDelta
|
||||
// or MessageStop. Injecting it here causes premature ContentBlockStop in the middle of streaming.
|
||||
|
||||
// Inject MessageStop after MessageDelta if we've seen one
|
||||
// This completes the Anthropic Messages API event sequence
|
||||
if self.seen_message_delta {
|
||||
let message_stop = MessagesStreamEvent::MessageStop;
|
||||
let sse_string: String = message_stop.into();
|
||||
let message_stop_event = SseEvent {
|
||||
data: None,
|
||||
event: Some("message_stop".to_string()),
|
||||
raw_line: sse_string.clone(),
|
||||
sse_transformed_lines: sse_string,
|
||||
provider_stream_response: None,
|
||||
};
|
||||
self.buffered_events.push(message_stop_event);
|
||||
self.seen_message_delta = false;
|
||||
//
|
||||
// Inject a synthetic `message_stop` only when:
|
||||
// 1. A `message_delta` has been seen (otherwise we'd violate the Anthropic
|
||||
// protocol by emitting `message_stop` without a preceding `message_delta`), AND
|
||||
// 2. We haven't already emitted `message_stop` (either synthetic from a
|
||||
// previous flush, or real from an upstream `[DONE]`).
|
||||
//
|
||||
// Without the `!message_stopped` guard, a stream whose `finish_reason` chunk
|
||||
// and `[DONE]` marker land in separate HTTP body chunks would receive two
|
||||
// `message_stop` events, triggering Claude Code's "Received message_stop
|
||||
// without a current message" error.
|
||||
if self.seen_message_delta && !self.message_stopped {
|
||||
self.push_message_stop();
|
||||
}
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
|
|
@ -615,4 +738,133 @@ data: [DONE]"#;
|
|||
println!("✓ Stop reason: tool_use");
|
||||
println!("✓ Proper Anthropic tool_use protocol\n");
|
||||
}
|
||||
|
||||
/// Regression test for:
|
||||
/// Claude Code CLI error: "Received message_stop without a current message"
|
||||
///
|
||||
/// Reproduces the *double-close* scenario: OpenAI's final `finish_reason`
|
||||
/// chunk and the `[DONE]` marker arrive in **separate** HTTP body chunks, so
|
||||
/// `to_bytes()` is called between them. Before the fix, this produced two
|
||||
/// `message_stop` events on the wire (one synthetic, one from `[DONE]`).
|
||||
#[test]
|
||||
fn test_openai_to_anthropic_emits_single_message_stop_across_chunk_boundary() {
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
// --- HTTP chunk 1: content + finish_reason (no [DONE] yet) -----------
|
||||
let chunk_1 = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"role":"assistant","content":"Hi"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}"#;
|
||||
|
||||
for raw in SseStreamIter::try_from(chunk_1.as_bytes()).unwrap() {
|
||||
let e = SseEvent::try_from((raw, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(e);
|
||||
}
|
||||
let out_1 = String::from_utf8(buffer.to_bytes()).unwrap();
|
||||
|
||||
// --- HTTP chunk 2: just the [DONE] marker ----------------------------
|
||||
let chunk_2 = "data: [DONE]";
|
||||
for raw in SseStreamIter::try_from(chunk_2.as_bytes()).unwrap() {
|
||||
let e = SseEvent::try_from((raw, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(e);
|
||||
}
|
||||
let out_2 = String::from_utf8(buffer.to_bytes()).unwrap();
|
||||
|
||||
let combined = format!("{}{}", out_1, out_2);
|
||||
let start_count = combined.matches("event: message_start").count();
|
||||
let stop_count = combined.matches("event: message_stop").count();
|
||||
|
||||
assert_eq!(
|
||||
start_count, 1,
|
||||
"Must emit exactly one message_start across chunks, got {start_count}. Output:\n{combined}"
|
||||
);
|
||||
assert_eq!(
|
||||
stop_count, 1,
|
||||
"Must emit exactly one message_stop across chunks (no double-close), got {stop_count}. Output:\n{combined}"
|
||||
);
|
||||
// Every message_stop must be preceded by a message_start earlier in the stream.
|
||||
let start_pos = combined.find("event: message_start").unwrap();
|
||||
let stop_pos = combined.find("event: message_stop").unwrap();
|
||||
assert!(
|
||||
start_pos < stop_pos,
|
||||
"message_start must come before message_stop. Output:\n{combined}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Regression test for:
|
||||
/// "Received message_stop without a current message" on empty upstream responses.
|
||||
///
|
||||
/// OpenAI returns only `[DONE]` with no content deltas and no `finish_reason`
|
||||
/// (this happens with content filters, truncated upstream streams, and some
|
||||
/// 5xx recoveries). Before the fix, the buffer emitted a bare `message_stop`
|
||||
/// with no preceding `message_start`. After the fix, it synthesizes a
|
||||
/// minimal but well-formed envelope.
|
||||
#[test]
|
||||
fn test_openai_done_only_stream_synthesizes_valid_envelope() {
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
let raw_input = "data: [DONE]";
|
||||
for raw in SseStreamIter::try_from(raw_input.as_bytes()).unwrap() {
|
||||
let e = SseEvent::try_from((raw, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(e);
|
||||
}
|
||||
let out = String::from_utf8(buffer.to_bytes()).unwrap();
|
||||
|
||||
assert!(
|
||||
out.contains("event: message_start"),
|
||||
"Empty upstream must still produce message_start. Output:\n{out}"
|
||||
);
|
||||
assert!(
|
||||
out.contains("event: message_delta"),
|
||||
"Empty upstream must produce a synthesized message_delta. Output:\n{out}"
|
||||
);
|
||||
assert_eq!(
|
||||
out.matches("event: message_stop").count(),
|
||||
1,
|
||||
"Empty upstream must produce exactly one message_stop. Output:\n{out}"
|
||||
);
|
||||
|
||||
// Protocol ordering: start < delta < stop.
|
||||
let p_start = out.find("event: message_start").unwrap();
|
||||
let p_delta = out.find("event: message_delta").unwrap();
|
||||
let p_stop = out.find("event: message_stop").unwrap();
|
||||
assert!(
|
||||
p_start < p_delta && p_delta < p_stop,
|
||||
"Bad ordering. Output:\n{out}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Regression test: events arriving after `message_stop` (e.g. a stray `[DONE]`
|
||||
/// echo, or late-arriving deltas from a racing upstream) must be dropped
|
||||
/// rather than written after the terminal frame.
|
||||
#[test]
|
||||
fn test_events_after_message_stop_are_dropped() {
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let mut buffer = AnthropicMessagesStreamBuffer::new();
|
||||
|
||||
let first = r#"data: {"id":"c1","object":"chat.completion.chunk","created":1,"model":"gpt-4o","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]"#;
|
||||
for raw in SseStreamIter::try_from(first.as_bytes()).unwrap() {
|
||||
let e = SseEvent::try_from((raw, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(e);
|
||||
}
|
||||
let _ = buffer.to_bytes();
|
||||
|
||||
// Simulate a duplicate / late `[DONE]` after the stream was already closed.
|
||||
let late = "data: [DONE]";
|
||||
for raw in SseStreamIter::try_from(late.as_bytes()).unwrap() {
|
||||
let e = SseEvent::try_from((raw, &client_api, &upstream_api)).unwrap();
|
||||
buffer.add_transformed_event(e);
|
||||
}
|
||||
let tail = String::from_utf8(buffer.to_bytes()).unwrap();
|
||||
assert!(
|
||||
tail.is_empty(),
|
||||
"No bytes should be emitted after message_stop, got: {tail:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ providers:
|
|||
anthropic:
|
||||
- anthropic/claude-sonnet-4-6
|
||||
- anthropic/claude-opus-4-6
|
||||
- anthropic/claude-opus-4-7
|
||||
- anthropic/claude-opus-4-5-20251101
|
||||
- anthropic/claude-opus-4-5
|
||||
- anthropic/claude-haiku-4-5-20251001
|
||||
|
|
@ -332,7 +333,53 @@ providers:
|
|||
- chatgpt/gpt-5.4
|
||||
- chatgpt/gpt-5.3-codex
|
||||
- chatgpt/gpt-5.2
|
||||
digitalocean:
|
||||
- digitalocean/openai-gpt-4.1
|
||||
- digitalocean/openai-gpt-4o
|
||||
- digitalocean/openai-gpt-4o-mini
|
||||
- digitalocean/openai-gpt-5
|
||||
- digitalocean/openai-gpt-5-mini
|
||||
- digitalocean/openai-gpt-5-nano
|
||||
- digitalocean/openai-gpt-5.1-codex-max
|
||||
- digitalocean/openai-gpt-5.2
|
||||
- digitalocean/openai-gpt-5.2-pro
|
||||
- digitalocean/openai-gpt-5.3-codex
|
||||
- digitalocean/openai-gpt-5.4
|
||||
- digitalocean/openai-gpt-5.4-mini
|
||||
- digitalocean/openai-gpt-5.4-nano
|
||||
- digitalocean/openai-gpt-5.4-pro
|
||||
- digitalocean/openai-gpt-oss-120b
|
||||
- digitalocean/openai-gpt-oss-20b
|
||||
- digitalocean/openai-o1
|
||||
- digitalocean/openai-o3
|
||||
- digitalocean/openai-o3-mini
|
||||
- digitalocean/anthropic-claude-4.1-opus
|
||||
- digitalocean/anthropic-claude-4.5-sonnet
|
||||
- digitalocean/anthropic-claude-4.6-sonnet
|
||||
- digitalocean/anthropic-claude-haiku-4.5
|
||||
- digitalocean/anthropic-claude-opus-4
|
||||
- digitalocean/anthropic-claude-opus-4.5
|
||||
- digitalocean/anthropic-claude-opus-4.6
|
||||
- digitalocean/anthropic-claude-opus-4.7
|
||||
- digitalocean/anthropic-claude-sonnet-4
|
||||
- digitalocean/alibaba-qwen3-32b
|
||||
- digitalocean/arcee-trinity-large-thinking
|
||||
- digitalocean/deepseek-3.2
|
||||
- digitalocean/deepseek-r1-distill-llama-70b
|
||||
- digitalocean/gemma-4-31B-it
|
||||
- digitalocean/glm-5
|
||||
- digitalocean/kimi-k2.5
|
||||
- digitalocean/llama3.3-70b-instruct
|
||||
- digitalocean/minimax-m2.5
|
||||
- digitalocean/nvidia-nemotron-3-super-120b
|
||||
- digitalocean/qwen3-coder-flash
|
||||
- digitalocean/qwen3.5-397b-a17b
|
||||
- digitalocean/all-mini-lm-l6-v2
|
||||
- digitalocean/gte-large-en-v1.5
|
||||
- digitalocean/multi-qa-mpnet-base-dot-v1
|
||||
- digitalocean/qwen3-embedding-0.6b
|
||||
- digitalocean/router:software-engineering
|
||||
metadata:
|
||||
total_providers: 12
|
||||
total_models: 319
|
||||
last_updated: 2026-04-03T23:14:46.956158+00:00
|
||||
total_providers: 13
|
||||
total_models: 364
|
||||
last_updated: 2026-04-20T00:00:00.000000+00:00
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ pub enum ProviderId {
|
|||
Qwen,
|
||||
AmazonBedrock,
|
||||
ChatGPT,
|
||||
DigitalOcean,
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for ProviderId {
|
||||
|
|
@ -73,6 +74,9 @@ impl TryFrom<&str> for ProviderId {
|
|||
"amazon_bedrock" => Ok(ProviderId::AmazonBedrock),
|
||||
"amazon" => Ok(ProviderId::AmazonBedrock), // alias
|
||||
"chatgpt" => Ok(ProviderId::ChatGPT),
|
||||
"digitalocean" => Ok(ProviderId::DigitalOcean),
|
||||
"do" => Ok(ProviderId::DigitalOcean), // alias
|
||||
"do_ai" => Ok(ProviderId::DigitalOcean), // alias
|
||||
_ => Err(format!("Unknown provider: {}", value)),
|
||||
}
|
||||
}
|
||||
|
|
@ -98,6 +102,7 @@ impl ProviderId {
|
|||
ProviderId::Zhipu => "z-ai",
|
||||
ProviderId::Qwen => "qwen",
|
||||
ProviderId::ChatGPT => "chatgpt",
|
||||
ProviderId::DigitalOcean => "digitalocean",
|
||||
_ => return Vec::new(),
|
||||
};
|
||||
|
||||
|
|
@ -152,7 +157,8 @@ impl ProviderId {
|
|||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen
|
||||
| ProviderId::ChatGPT,
|
||||
| ProviderId::ChatGPT
|
||||
| ProviderId::DigitalOcean,
|
||||
SupportedAPIsFromClient::AnthropicMessagesAPI(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
|
|
@ -172,7 +178,8 @@ impl ProviderId {
|
|||
| ProviderId::Moonshotai
|
||||
| ProviderId::Zhipu
|
||||
| ProviderId::Qwen
|
||||
| ProviderId::ChatGPT,
|
||||
| ProviderId::ChatGPT
|
||||
| ProviderId::DigitalOcean,
|
||||
SupportedAPIsFromClient::OpenAIChatCompletions(_),
|
||||
) => SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions),
|
||||
|
||||
|
|
@ -240,6 +247,7 @@ impl Display for ProviderId {
|
|||
ProviderId::Qwen => write!(f, "qwen"),
|
||||
ProviderId::AmazonBedrock => write!(f, "amazon_bedrock"),
|
||||
ProviderId::ChatGPT => write!(f, "chatgpt"),
|
||||
ProviderId::DigitalOcean => write!(f, "digitalocean"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,31 @@ pub trait TokenUsage {
|
|||
fn completion_tokens(&self) -> usize;
|
||||
fn prompt_tokens(&self) -> usize;
|
||||
fn total_tokens(&self) -> usize;
|
||||
/// Tokens served from a prompt cache read (OpenAI `prompt_tokens_details.cached_tokens`,
|
||||
/// Anthropic `cache_read_input_tokens`, Google `cached_content_token_count`).
|
||||
fn cached_input_tokens(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
/// Tokens used to write a cache entry (Anthropic `cache_creation_input_tokens`).
|
||||
fn cache_creation_tokens(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
/// Reasoning tokens for reasoning models (OpenAI `completion_tokens_details.reasoning_tokens`,
|
||||
/// Google `thoughts_token_count`).
|
||||
fn reasoning_tokens(&self) -> Option<usize> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Rich usage breakdown extracted from a provider response.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub struct UsageDetails {
|
||||
pub prompt_tokens: usize,
|
||||
pub completion_tokens: usize,
|
||||
pub total_tokens: usize,
|
||||
pub cached_input_tokens: Option<usize>,
|
||||
pub cache_creation_tokens: Option<usize>,
|
||||
pub reasoning_tokens: Option<usize>,
|
||||
}
|
||||
|
||||
pub trait ProviderResponse: Send + Sync {
|
||||
|
|
@ -34,6 +59,18 @@ pub trait ProviderResponse: Send + Sync {
|
|||
self.usage()
|
||||
.map(|u| (u.prompt_tokens(), u.completion_tokens(), u.total_tokens()))
|
||||
}
|
||||
|
||||
/// Extract a rich usage breakdown including cached/cache-creation/reasoning tokens.
|
||||
fn extract_usage_details(&self) -> Option<UsageDetails> {
|
||||
self.usage().map(|u| UsageDetails {
|
||||
prompt_tokens: u.prompt_tokens(),
|
||||
completion_tokens: u.completion_tokens(),
|
||||
total_tokens: u.total_tokens(),
|
||||
cached_input_tokens: u.cached_input_tokens(),
|
||||
cache_creation_tokens: u.cache_creation_tokens(),
|
||||
reasoning_tokens: u.reasoning_tokens(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderResponse for ProviderResponseType {
|
||||
|
|
|
|||
|
|
@ -177,24 +177,33 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
|
||||
if self.llm_provider().passthrough_auth == Some(true) {
|
||||
// Check if client provided an Authorization header
|
||||
if self.get_http_request_header("Authorization").is_none() {
|
||||
warn!(
|
||||
"request_id={}: passthrough_auth enabled but no authorization header present in client request",
|
||||
self.request_identifier()
|
||||
);
|
||||
} else {
|
||||
debug!(
|
||||
"request_id={}: preserving client authorization header for provider '{}'",
|
||||
self.request_identifier(),
|
||||
self.llm_provider().name
|
||||
);
|
||||
// Determine the credential to forward upstream. Either the client
|
||||
// supplied one (passthrough_auth) or it's configured on the provider.
|
||||
let credential: String = if self.llm_provider().passthrough_auth == Some(true) {
|
||||
// Client auth may arrive in either Anthropic-style (`x-api-key`)
|
||||
// or OpenAI-style (`Authorization: Bearer ...`). Accept both so
|
||||
// clients using Anthropic SDKs (which default to `x-api-key`)
|
||||
// work when the upstream is OpenAI-compatible, and vice versa.
|
||||
let authorization = self.get_http_request_header("Authorization");
|
||||
let x_api_key = self.get_http_request_header("x-api-key");
|
||||
match extract_client_credential(authorization.as_deref(), x_api_key.as_deref()) {
|
||||
Some(key) => {
|
||||
debug!(
|
||||
"request_id={}: forwarding client credential to provider '{}'",
|
||||
self.request_identifier(),
|
||||
self.llm_provider().name
|
||||
);
|
||||
key
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
"request_id={}: passthrough_auth enabled but no Authorization / x-api-key header present in client request",
|
||||
self.request_identifier()
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let llm_provider_api_key_value =
|
||||
} else {
|
||||
self.llm_provider()
|
||||
.access_key
|
||||
.as_ref()
|
||||
|
|
@ -203,15 +212,19 @@ impl StreamContext {
|
|||
"No access key configured for selected LLM Provider \"{}\"",
|
||||
self.llm_provider()
|
||||
),
|
||||
})?;
|
||||
})?
|
||||
.clone()
|
||||
};
|
||||
|
||||
// Set API-specific headers based on the resolved upstream API
|
||||
// Normalize the credential into whichever header the upstream expects.
|
||||
// This lets an Anthropic-SDK client reach an OpenAI-compatible upstream
|
||||
// (and vice versa) without the caller needing to know what format the
|
||||
// upstream uses.
|
||||
match self.resolved_api.as_ref() {
|
||||
Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(_)) => {
|
||||
// Anthropic API requires x-api-key and anthropic-version headers
|
||||
// Remove any existing Authorization header since Anthropic doesn't use it
|
||||
// Anthropic expects `x-api-key` + `anthropic-version`.
|
||||
self.remove_http_request_header("Authorization");
|
||||
self.set_http_request_header("x-api-key", Some(llm_provider_api_key_value));
|
||||
self.set_http_request_header("x-api-key", Some(&credential));
|
||||
self.set_http_request_header("anthropic-version", Some("2023-06-01"));
|
||||
}
|
||||
Some(
|
||||
|
|
@ -221,10 +234,9 @@ impl StreamContext {
|
|||
| SupportedUpstreamAPIs::OpenAIResponsesAPI(_),
|
||||
)
|
||||
| None => {
|
||||
// OpenAI and default: use Authorization Bearer token
|
||||
// Remove any existing x-api-key header since OpenAI doesn't use it
|
||||
// OpenAI (and default): `Authorization: Bearer ...`.
|
||||
self.remove_http_request_header("x-api-key");
|
||||
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
|
||||
let authorization_header_value = format!("Bearer {}", credential);
|
||||
self.set_http_request_header("Authorization", Some(&authorization_header_value));
|
||||
}
|
||||
}
|
||||
|
|
@ -1256,3 +1268,86 @@ fn current_time_ns() -> u128 {
|
|||
}
|
||||
|
||||
impl Context for StreamContext {}
|
||||
|
||||
/// Extract the credential a client sent in either an OpenAI-style
|
||||
/// `Authorization` header or an Anthropic-style `x-api-key` header.
|
||||
///
|
||||
/// Returns `None` when neither header is present or both are empty/whitespace.
|
||||
/// The `Bearer ` prefix on the `Authorization` value is stripped if present;
|
||||
/// otherwise the value is taken verbatim (some clients send a raw token).
|
||||
fn extract_client_credential(
|
||||
authorization: Option<&str>,
|
||||
x_api_key: Option<&str>,
|
||||
) -> Option<String> {
|
||||
// Strip the optional "Bearer " / "Bearer" prefix (case-sensitive, matches
|
||||
// OpenAI SDK behavior) and trim surrounding whitespace before validating
|
||||
// non-empty.
|
||||
let from_authorization = authorization
|
||||
.map(|v| {
|
||||
v.strip_prefix("Bearer ")
|
||||
.or_else(|| v.strip_prefix("Bearer"))
|
||||
.unwrap_or(v)
|
||||
.trim()
|
||||
.to_string()
|
||||
})
|
||||
.filter(|s| !s.is_empty());
|
||||
if from_authorization.is_some() {
|
||||
return from_authorization;
|
||||
}
|
||||
x_api_key
|
||||
.map(str::trim)
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::extract_client_credential;
|
||||
|
||||
#[test]
|
||||
fn authorization_bearer_strips_prefix() {
|
||||
assert_eq!(
|
||||
extract_client_credential(Some("Bearer sk-abc"), None),
|
||||
Some("sk-abc".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn authorization_raw_token_preserved() {
|
||||
// Some clients send the raw token without "Bearer " — accept it.
|
||||
assert_eq!(
|
||||
extract_client_credential(Some("sk-abc"), None),
|
||||
Some("sk-abc".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn x_api_key_used_when_authorization_absent() {
|
||||
assert_eq!(
|
||||
extract_client_credential(None, Some("sk-ant-api-key")),
|
||||
Some("sk-ant-api-key".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn authorization_wins_when_both_present() {
|
||||
// If a client is particularly exotic and sends both, prefer the
|
||||
// OpenAI-style Authorization header.
|
||||
assert_eq!(
|
||||
extract_client_credential(Some("Bearer openai-key"), Some("anthropic-key")),
|
||||
Some("openai-key".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_none_when_neither_present() {
|
||||
assert!(extract_client_credential(None, None).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_and_whitespace_headers_are_ignored() {
|
||||
assert!(extract_client_credential(Some(""), None).is_none());
|
||||
assert!(extract_client_credential(Some("Bearer "), None).is_none());
|
||||
assert!(extract_client_credential(Some(" "), Some(" ")).is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue