diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index b63cb824..9fd439b1 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -415,6 +415,14 @@ properties: type: string model: type: string + session_ttl_seconds: + type: integer + minimum: 1 + description: TTL in seconds for session-pinned routing cache entries. Default 600 (10 minutes). + session_max_entries: + type: integer + minimum: 1 + description: Maximum number of session-pinned routing cache entries. Default 10000. additionalProperties: false state_storage: type: object diff --git a/crates/brightstaff/src/handlers/routing_service.rs b/crates/brightstaff/src/handlers/routing_service.rs index 4eae4685..82043dad 100644 --- a/crates/brightstaff/src/handlers/routing_service.rs +++ b/crates/brightstaff/src/handlers/routing_service.rs @@ -1,6 +1,6 @@ use bytes::Bytes; use common::configuration::{ModelUsagePreference, SpanAttributes}; -use common::consts::{REQUEST_ID_HEADER, TRACE_PARENT_HEADER}; +use common::consts::{REQUEST_ID_HEADER, SESSION_ID_HEADER, TRACE_PARENT_HEADER}; use common::errors::BrightStaffError; use hermesllm::clients::SupportedAPIsFromClient; use hermesllm::ProviderRequestType; @@ -66,6 +66,9 @@ struct RoutingDecisionResponse { model: String, route: Option, trace_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + session_id: Option, + pinned: bool, } pub async fn routing_decision( @@ -81,6 +84,11 @@ pub async fn routing_decision( .map(|s| s.to_string()) .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + let session_id: Option = request_headers + .get(SESSION_ID_HEADER) + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()); + let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref()); @@ -99,6 +107,7 @@ pub async fn routing_decision( request_path, request_headers, custom_attrs, + session_id, ) .instrument(request_span) .await @@ -111,6 +120,7 @@ async fn routing_decision_inner( request_path: String, request_headers: hyper::HeaderMap, custom_attrs: std::collections::HashMap, + session_id: Option, ) -> Result>, hyper::Error> { set_service_name(operation_component::ROUTING); opentelemetry::trace::get_active_span(|span| { @@ -144,6 +154,34 @@ async fn routing_decision_inner( .unwrap_or("unknown") .to_string(); + // Session pinning: check cache before doing any routing work + if let Some(ref sid) = session_id { + if let Some(cached) = router_service.get_cached_route(sid).await { + info!( + session_id = %sid, + model = %cached.model_name, + route = ?cached.route_name, + "returning pinned routing decision from cache" + ); + let response = RoutingDecisionResponse { + model: cached.model_name, + route: cached.route_name, + trace_id, + session_id: Some(sid.clone()), + pinned: true, + }; + let json = serde_json::to_string(&response).unwrap(); + let body = Full::new(Bytes::from(json)) + .map_err(|never| match never {}) + .boxed(); + return Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(body) + .unwrap()); + } + } + // Parse request body let raw_bytes = request.collect().await?.to_bytes(); @@ -182,7 +220,7 @@ async fn routing_decision_inner( // Call the existing routing logic with inline preferences let routing_result = router_chat_get_upstream_model( - router_service, + Arc::clone(&router_service), client_request, &traceparent, &request_path, @@ -193,10 +231,23 @@ async fn routing_decision_inner( match routing_result { Ok(result) => { + // Cache the result if session_id is present + if let Some(ref sid) = session_id { + router_service + .cache_route( + sid.clone(), + result.model_name.clone(), + result.route_name.clone(), + ) + .await; + } + let response = RoutingDecisionResponse { model: result.model_name, route: result.route_name, trace_id, + session_id, + pinned: false, }; info!( @@ -334,12 +385,16 @@ mod tests { model: "openai/gpt-4o".to_string(), route: Some("code_generation".to_string()), trace_id: "abc123".to_string(), + session_id: Some("sess-abc".to_string()), + pinned: true, }; let json = serde_json::to_string(&response).unwrap(); let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); assert_eq!(parsed["model"], "openai/gpt-4o"); assert_eq!(parsed["route"], "code_generation"); assert_eq!(parsed["trace_id"], "abc123"); + assert_eq!(parsed["session_id"], "sess-abc"); + assert_eq!(parsed["pinned"], true); } #[test] @@ -348,10 +403,14 @@ mod tests { model: "none".to_string(), route: None, trace_id: "abc123".to_string(), + session_id: None, + pinned: false, }; let json = serde_json::to_string(&response).unwrap(); let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); assert_eq!(parsed["model"], "none"); assert!(parsed["route"].is_null()); + assert!(parsed.get("session_id").is_none()); + assert_eq!(parsed["pinned"], false); } } diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 51c9127f..190379a0 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -102,13 +102,37 @@ async fn main() -> Result<(), Box> { .and_then(|r| r.model_provider.clone()) .unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string()); + let session_ttl_seconds = plano_config + .routing + .as_ref() + .and_then(|r| r.session_ttl_seconds); + + let session_max_entries = plano_config + .routing + .as_ref() + .and_then(|r| r.session_max_entries); + let router_service: Arc = Arc::new(RouterService::new( plano_config.model_providers.clone(), format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"), routing_model_name, routing_llm_provider, + session_ttl_seconds, + session_max_entries, )); + // Spawn background task to clean up expired session cache entries every 5 minutes + { + let router_service = Arc::clone(&router_service); + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(300)); + loop { + interval.tick().await; + router_service.cleanup_expired_sessions().await; + } + }); + } + let orchestrator_service: Arc = Arc::new(OrchestratorService::new( format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"), PLANO_ORCHESTRATOR_MODEL_NAME.to_string(), diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index ec3fe3ab..da2622b9 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant}; use common::{ configuration::{LlmProvider, ModelUsagePreference, RoutingPreference}, @@ -7,12 +7,23 @@ use common::{ use hermesllm::apis::openai::{ChatCompletionsResponse, Message}; use hyper::header; use thiserror::Error; +use tokio::sync::RwLock; use tracing::{debug, info, warn}; use crate::router::router_model_v1::{self}; use super::router_model::RouterModel; +const DEFAULT_SESSION_TTL_SECONDS: u64 = 600; +const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000; + +#[derive(Clone, Debug)] +pub struct CachedRoute { + pub model_name: String, + pub route_name: Option, + pub cached_at: Instant, +} + pub struct RouterService { router_url: String, client: reqwest::Client, @@ -20,6 +31,9 @@ pub struct RouterService { #[allow(dead_code)] routing_provider_name: String, llm_usage_defined: bool, + session_cache: RwLock>, + session_ttl: Duration, + session_max_entries: usize, } #[derive(Debug, Error)] @@ -42,6 +56,8 @@ impl RouterService { router_url: String, routing_model_name: String, routing_provider_name: String, + session_ttl_seconds: Option, + session_max_entries: Option, ) -> Self { let providers_with_usage = providers .iter() @@ -65,12 +81,75 @@ impl RouterService { router_model_v1::MAX_TOKEN_LEN, )); + let session_ttl = + Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS)); + let session_max_entries = session_max_entries.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES); + RouterService { router_url, client: reqwest::Client::new(), router_model, routing_provider_name, llm_usage_defined: !providers_with_usage.is_empty(), + session_cache: RwLock::new(HashMap::new()), + session_ttl, + session_max_entries, + } + } + + /// Look up a cached routing decision by session ID. + /// Returns None if not found or expired. + pub async fn get_cached_route(&self, session_id: &str) -> Option { + let cache = self.session_cache.read().await; + if let Some(entry) = cache.get(session_id) { + if entry.cached_at.elapsed() < self.session_ttl { + return Some(entry.clone()); + } + } + None + } + + /// Store a routing decision in the session cache. + /// If at max capacity, evicts the oldest entry. + pub async fn cache_route( + &self, + session_id: String, + model_name: String, + route_name: Option, + ) { + let mut cache = self.session_cache.write().await; + if cache.len() >= self.session_max_entries && !cache.contains_key(&session_id) { + // Evict the oldest entry + if let Some(oldest_key) = cache + .iter() + .min_by_key(|(_, v)| v.cached_at) + .map(|(k, _)| k.clone()) + { + cache.remove(&oldest_key); + } + } + cache.insert( + session_id, + CachedRoute { + model_name, + route_name, + cached_at: Instant::now(), + }, + ); + } + + /// Remove all expired entries from the session cache. + pub async fn cleanup_expired_sessions(&self) { + let mut cache = self.session_cache.write().await; + let before = cache.len(); + cache.retain(|_, entry| entry.cached_at.elapsed() < self.session_ttl); + let removed = before - cache.len(); + if removed > 0 { + info!( + removed = removed, + remaining = cache.len(), + "cleaned up expired session cache entries" + ); } } @@ -185,3 +264,105 @@ impl RouterService { } } } + +#[cfg(test)] +mod tests { + use super::*; + + fn make_router_service(ttl_seconds: u64, max_entries: usize) -> RouterService { + RouterService::new( + vec![], + "http://localhost:12001/v1/chat/completions".to_string(), + "Arch-Router".to_string(), + "arch-router".to_string(), + Some(ttl_seconds), + Some(max_entries), + ) + } + + #[tokio::test] + async fn test_cache_miss_returns_none() { + let svc = make_router_service(600, 100); + assert!(svc.get_cached_route("unknown-session").await.is_none()); + } + + #[tokio::test] + async fn test_cache_hit_returns_cached_route() { + let svc = make_router_service(600, 100); + svc.cache_route( + "s1".to_string(), + "gpt-4o".to_string(), + Some("code".to_string()), + ) + .await; + + let cached = svc.get_cached_route("s1").await.unwrap(); + assert_eq!(cached.model_name, "gpt-4o"); + assert_eq!(cached.route_name, Some("code".to_string())); + } + + #[tokio::test] + async fn test_cache_expired_entry_returns_none() { + let svc = make_router_service(0, 100); + svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None) + .await; + // TTL is 0 seconds, so the entry should be expired immediately + assert!(svc.get_cached_route("s1").await.is_none()); + } + + #[tokio::test] + async fn test_cleanup_removes_expired() { + let svc = make_router_service(0, 100); + svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None) + .await; + svc.cache_route("s2".to_string(), "claude".to_string(), None) + .await; + + svc.cleanup_expired_sessions().await; + + let cache = svc.session_cache.read().await; + assert!(cache.is_empty()); + } + + #[tokio::test] + async fn test_cache_evicts_oldest_when_full() { + let svc = make_router_service(600, 2); + svc.cache_route("s1".to_string(), "model-a".to_string(), None) + .await; + // Small delay so s2 has a later cached_at + tokio::time::sleep(Duration::from_millis(10)).await; + svc.cache_route("s2".to_string(), "model-b".to_string(), None) + .await; + + // Cache is full (2 entries). Adding s3 should evict s1 (oldest). + svc.cache_route("s3".to_string(), "model-c".to_string(), None) + .await; + + let cache = svc.session_cache.read().await; + assert_eq!(cache.len(), 2); + assert!(!cache.contains_key("s1")); + assert!(cache.contains_key("s2")); + assert!(cache.contains_key("s3")); + } + + #[tokio::test] + async fn test_cache_update_existing_session_does_not_evict() { + let svc = make_router_service(600, 2); + svc.cache_route("s1".to_string(), "model-a".to_string(), None) + .await; + svc.cache_route("s2".to_string(), "model-b".to_string(), None) + .await; + + // Updating s1 should not trigger eviction since key already exists + svc.cache_route( + "s1".to_string(), + "model-a-updated".to_string(), + Some("route".to_string()), + ) + .await; + + let cache = svc.session_cache.read().await; + assert_eq!(cache.len(), 2); + assert_eq!(cache.get("s1").unwrap().model_name, "model-a-updated"); + } +} diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index f4e2b7b4..ed7f61ad 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -11,6 +11,8 @@ use crate::api::open_ai::{ pub struct Routing { pub model_provider: Option, pub model: Option, + pub session_ttl_seconds: Option, + pub session_max_entries: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index cafc8e80..69fa5799 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -22,6 +22,7 @@ pub const X_ARCH_TOOL_CALL: &str = "x-arch-tool-call-message"; pub const X_ARCH_FC_MODEL_RESPONSE: &str = "x-arch-fc-model-response"; pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function"; pub const REQUEST_ID_HEADER: &str = "x-request-id"; +pub const SESSION_ID_HEADER: &str = "x-session-id"; pub const ENVOY_ORIGINAL_PATH_HEADER: &str = "x-envoy-original-path"; pub const TRACE_PARENT_HEADER: &str = "traceparent"; pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal"; diff --git a/demos/llm_routing/model_routing_service/README.md b/demos/llm_routing/model_routing_service/README.md index 85d56abf..761cb0e3 100644 --- a/demos/llm_routing/model_routing_service/README.md +++ b/demos/llm_routing/model_routing_service/README.md @@ -55,6 +55,63 @@ Response: The response tells you which model would handle this request and which route was matched, without actually making the LLM call. +## Session Pinning + +Send an `X-Session-Id` header to pin the routing decision for a session. Once a model is selected, all subsequent requests with the same session ID return the same model without re-running routing. + +```bash +# First call — runs routing, caches result +curl http://localhost:12000/routing/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "X-Session-Id: my-session-123" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Write a Python function for binary search"}] + }' +``` + +Response (first call): +```json +{ + "model": "anthropic/claude-sonnet-4-20250514", + "route": "code_generation", + "trace_id": "c16d1096c1af4a17abb48fb182918a88", + "session_id": "my-session-123", + "pinned": false +} +``` + +```bash +# Second call — same session, returns cached result +curl http://localhost:12000/routing/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "X-Session-Id: my-session-123" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Now explain merge sort"}] + }' +``` + +Response (pinned): +```json +{ + "model": "anthropic/claude-sonnet-4-20250514", + "route": "code_generation", + "trace_id": "a1b2c3d4e5f6...", + "session_id": "my-session-123", + "pinned": true +} +``` + +Session TTL and max cache size are configurable in `config.yaml`: +```yaml +routing: + session_ttl_seconds: 600 # default: 600 (10 minutes) + session_max_entries: 10000 # default: 10000 +``` + +Without the `X-Session-Id` header, routing runs fresh every time (no breaking change). + ## Demo Output ``` @@ -88,5 +145,33 @@ The response tells you which model would handle this request and which route was "trace_id": "26be822bbdf14a3ba19fe198e55ea4a9" } +--- 7. Session pinning - first call (fresh routing decision) --- +{ + "model": "anthropic/claude-sonnet-4-20250514", + "route": "code_generation", + "trace_id": "f1a2b3c4d5e6f7a8b9c0d1e2f3a4b5c6", + "session_id": "demo-session-001", + "pinned": false +} + +--- 8. Session pinning - second call (same session, pinned) --- + Notice: same model returned with "pinned": true, routing was skipped +{ + "model": "anthropic/claude-sonnet-4-20250514", + "route": "code_generation", + "trace_id": "a9b8c7d6e5f4a3b2c1d0e9f8a7b6c5d4", + "session_id": "demo-session-001", + "pinned": true +} + +--- 9. Different session gets its own fresh routing --- +{ + "model": "openai/gpt-4o", + "route": "complex_reasoning", + "trace_id": "1a2b3c4d5e6f7a8b9c0d1e2f3a4b5c6d", + "session_id": "demo-session-002", + "pinned": false +} + === Demo Complete === ``` diff --git a/demos/llm_routing/model_routing_service/demo.sh b/demos/llm_routing/model_routing_service/demo.sh index 0c3fdc5d..1e3d3b6c 100755 --- a/demos/llm_routing/model_routing_service/demo.sh +++ b/demos/llm_routing/model_routing_service/demo.sh @@ -117,4 +117,47 @@ curl -s "$PLANO_URL/routing/v1/messages" \ }' | python3 -m json.tool echo "" +# --- Example 7: Session pinning - first call (fresh routing) --- +echo "--- 7. Session pinning - first call (fresh routing decision) ---" +echo "" +curl -s "$PLANO_URL/routing/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "X-Session-Id: demo-session-001" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Write a Python function that implements binary search on a sorted array"} + ] + }' | python3 -m json.tool +echo "" + +# --- Example 8: Session pinning - second call (pinned result) --- +echo "--- 8. Session pinning - second call (same session, pinned) ---" +echo " Notice: same model returned with \"pinned\": true, routing was skipped" +echo "" +curl -s "$PLANO_URL/routing/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "X-Session-Id: demo-session-001" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Now explain how merge sort works and when to prefer it over quicksort"} + ] + }' | python3 -m json.tool +echo "" + +# --- Example 9: Different session gets fresh routing --- +echo "--- 9. Different session gets its own fresh routing ---" +echo "" +curl -s "$PLANO_URL/routing/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "X-Session-Id: demo-session-002" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Explain the trade-offs between microservices and monolithic architectures"} + ] + }' | python3 -m json.tool +echo "" + echo "=== Demo Complete ==="