diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index 891a4005..95a2e5cc 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -454,6 +454,11 @@ properties: url: type: string description: Redis URL, e.g. redis://localhost:6379. Required when type is redis. + tenant_header: + type: string + description: > + Optional HTTP header name whose value is used as a tenant prefix in the cache key. + When set, keys are scoped as plano:affinity:{tenant_id}:{session_id}. additionalProperties: false additionalProperties: false state_storage: diff --git a/crates/brightstaff/src/handlers/llm/mod.rs b/crates/brightstaff/src/handlers/llm/mod.rs index 80455cfb..5e108c56 100644 --- a/crates/brightstaff/src/handlers/llm/mod.rs +++ b/crates/brightstaff/src/handlers/llm/mod.rs @@ -99,10 +99,16 @@ async fn llm_chat_inner( .get(MODEL_AFFINITY_HEADER) .and_then(|h| h.to_str().ok()) .map(|s| s.to_string()); + let tenant_id: Option = state + .router_service + .tenant_header() + .and_then(|hdr| request_headers.get(hdr)) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); let pinned_model: Option = if let Some(ref sid) = session_id { state .router_service - .get_cached_route(sid) + .get_cached_route(sid, tenant_id.as_deref()) .await .map(|c| c.model_name) } else { @@ -313,7 +319,7 @@ async fn llm_chat_inner( if let Some(ref sid) = session_id { state .router_service - .cache_route(sid.clone(), model.clone(), route_name) + .cache_route(sid.clone(), tenant_id.as_deref(), model.clone(), route_name) .await; } diff --git a/crates/brightstaff/src/handlers/routing_service.rs b/crates/brightstaff/src/handlers/routing_service.rs index d09afe21..3365b6e9 100644 --- a/crates/brightstaff/src/handlers/routing_service.rs +++ b/crates/brightstaff/src/handlers/routing_service.rs @@ -76,6 +76,12 @@ pub async fn routing_decision( .and_then(|h| h.to_str().ok()) .map(|s| s.to_string()); + let tenant_id: Option = router_service + .tenant_header() + .and_then(|hdr| request_headers.get(hdr)) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref()); let request_span = info_span!( @@ -94,11 +100,13 @@ pub async fn routing_decision( request_headers, custom_attrs, session_id, + tenant_id, ) .instrument(request_span) .await } +#[allow(clippy::too_many_arguments)] async fn routing_decision_inner( request: Request, router_service: Arc, @@ -107,6 +115,7 @@ async fn routing_decision_inner( request_headers: hyper::HeaderMap, custom_attrs: std::collections::HashMap, session_id: Option, + tenant_id: Option, ) -> Result>, hyper::Error> { set_service_name(operation_component::ROUTING); opentelemetry::trace::get_active_span(|span| { @@ -126,7 +135,10 @@ async fn routing_decision_inner( // Session pinning: check cache before doing any routing work if let Some(ref sid) = session_id { - if let Some(cached) = router_service.get_cached_route(sid).await { + if let Some(cached) = router_service + .get_cached_route(sid, tenant_id.as_deref()) + .await + { info!( session_id = %sid, model = %cached.model_name, @@ -206,6 +218,7 @@ async fn routing_decision_inner( router_service .cache_route( sid.clone(), + tenant_id.as_deref(), result.model_name.clone(), result.route_name.clone(), ) diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 8eb46df7..73102a97 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -298,6 +298,12 @@ async fn init_app_state( } } + let session_tenant_header = config + .routing + .as_ref() + .and_then(|r| r.session_cache.as_ref()) + .and_then(|c| c.tenant_header.clone()); + let router_service = Arc::new(RouterService::new( config.routing_preferences.clone(), metrics_service, @@ -306,6 +312,7 @@ async fn init_app_state( routing_llm_provider, session_ttl_seconds, session_cache, + session_tenant_header, )); let orchestrator_model_name: String = overrides diff --git a/crates/brightstaff/src/router/llm.rs b/crates/brightstaff/src/router/llm.rs index 29385768..b1a74641 100644 --- a/crates/brightstaff/src/router/llm.rs +++ b/crates/brightstaff/src/router/llm.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration}; use common::{ configuration::TopLevelRoutingPreference, @@ -31,6 +31,7 @@ pub struct RouterService { metrics_service: Option>, session_cache: Arc, session_ttl: Duration, + tenant_header: Option, } #[derive(Debug, Error)] @@ -45,6 +46,7 @@ pub enum RoutingError { pub type Result = std::result::Result; impl RouterService { + #[allow(clippy::too_many_arguments)] pub fn new( top_level_prefs: Option>, metrics_service: Option>, @@ -53,6 +55,7 @@ impl RouterService { routing_provider_name: String, session_ttl_seconds: Option, session_cache: Arc, + tenant_header: Option, ) -> Self { let top_level_preferences: HashMap = top_level_prefs .map_or_else(HashMap::new, |prefs| { @@ -95,25 +98,48 @@ impl RouterService { metrics_service, session_cache, session_ttl, + tenant_header, + } + } + + /// Name of the HTTP header used to scope cache keys by tenant, if configured. + #[must_use] + pub fn tenant_header(&self) -> Option<&str> { + self.tenant_header.as_deref() + } + + /// Build the cache key, optionally scoped by tenant: `{tenant_id}:{session_id}` or `{session_id}`. + /// Returns a borrowed key when no tenant prefix is needed, avoiding an allocation. + fn session_key<'a>(tenant_id: Option<&str>, session_id: &'a str) -> Cow<'a, str> { + match tenant_id { + Some(t) => Cow::Owned(format!("{t}:{session_id}")), + None => Cow::Borrowed(session_id), } } /// Look up a cached routing decision by session ID. /// Returns None if not found or expired. - pub async fn get_cached_route(&self, session_id: &str) -> Option { - self.session_cache.get(session_id).await + pub async fn get_cached_route( + &self, + session_id: &str, + tenant_id: Option<&str>, + ) -> Option { + self.session_cache + .get(&Self::session_key(tenant_id, session_id)) + .await } /// Store a routing decision in the session cache. pub async fn cache_route( &self, session_id: String, + tenant_id: Option<&str>, model_name: String, route_name: Option, ) { self.session_cache .put( - &session_id, + &Self::session_key(tenant_id, &session_id), CachedRoute { model_name, route_name, @@ -123,16 +149,6 @@ impl RouterService { .await; } - /// Log a routing decision, used to surface affinity hits in structured logs. - pub fn log_affinity_hit(session_id: &str, model_name: &str, route_name: &Option) { - info!( - session_id = %session_id, - model = %model_name, - route = ?route_name, - "returning pinned routing decision from cache" - ); - } - pub async fn determine_route( &self, messages: &[Message], @@ -263,13 +279,17 @@ mod tests { "arch-router".to_string(), Some(ttl_seconds), session_cache, + None, ) } #[tokio::test] async fn test_cache_miss_returns_none() { let svc = make_router_service(600, 100); - assert!(svc.get_cached_route("unknown-session").await.is_none()); + assert!(svc + .get_cached_route("unknown-session", None) + .await + .is_none()); } #[tokio::test] @@ -277,12 +297,13 @@ mod tests { let svc = make_router_service(600, 100); svc.cache_route( "s1".to_string(), + None, "gpt-4o".to_string(), Some("code".to_string()), ) .await; - let cached = svc.get_cached_route("s1").await.unwrap(); + let cached = svc.get_cached_route("s1", None).await.unwrap(); assert_eq!(cached.model_name, "gpt-4o"); assert_eq!(cached.route_name, Some("code".to_string())); } @@ -290,60 +311,61 @@ mod tests { #[tokio::test] async fn test_cache_expired_entry_returns_none() { let svc = make_router_service(0, 100); - svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None) + svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None) .await; - assert!(svc.get_cached_route("s1").await.is_none()); + assert!(svc.get_cached_route("s1", None).await.is_none()); } #[tokio::test] async fn test_expired_entries_not_returned() { let svc = make_router_service(0, 100); - svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None) + svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None) .await; - svc.cache_route("s2".to_string(), "claude".to_string(), None) + svc.cache_route("s2".to_string(), None, "claude".to_string(), None) .await; // Entries with TTL=0 should be expired immediately - assert!(svc.get_cached_route("s1").await.is_none()); - assert!(svc.get_cached_route("s2").await.is_none()); + assert!(svc.get_cached_route("s1", None).await.is_none()); + assert!(svc.get_cached_route("s2", None).await.is_none()); } #[tokio::test] async fn test_cache_evicts_oldest_when_full() { let svc = make_router_service(600, 2); - svc.cache_route("s1".to_string(), "model-a".to_string(), None) + svc.cache_route("s1".to_string(), None, "model-a".to_string(), None) .await; tokio::time::sleep(Duration::from_millis(10)).await; - svc.cache_route("s2".to_string(), "model-b".to_string(), None) + svc.cache_route("s2".to_string(), None, "model-b".to_string(), None) .await; - svc.cache_route("s3".to_string(), "model-c".to_string(), None) + svc.cache_route("s3".to_string(), None, "model-c".to_string(), None) .await; // s1 should be evicted (oldest); s2 and s3 should remain - assert!(svc.get_cached_route("s1").await.is_none()); - assert!(svc.get_cached_route("s2").await.is_some()); - assert!(svc.get_cached_route("s3").await.is_some()); + assert!(svc.get_cached_route("s1", None).await.is_none()); + assert!(svc.get_cached_route("s2", None).await.is_some()); + assert!(svc.get_cached_route("s3", None).await.is_some()); } #[tokio::test] async fn test_cache_update_existing_session_does_not_evict() { let svc = make_router_service(600, 2); - svc.cache_route("s1".to_string(), "model-a".to_string(), None) + svc.cache_route("s1".to_string(), None, "model-a".to_string(), None) .await; - svc.cache_route("s2".to_string(), "model-b".to_string(), None) + svc.cache_route("s2".to_string(), None, "model-b".to_string(), None) .await; svc.cache_route( "s1".to_string(), + None, "model-a-updated".to_string(), Some("route".to_string()), ) .await; // Both sessions should still be present - let s1 = svc.get_cached_route("s1").await.unwrap(); + let s1 = svc.get_cached_route("s1", None).await.unwrap(); assert_eq!(s1.model_name, "model-a-updated"); - assert!(svc.get_cached_route("s2").await.is_some()); + assert!(svc.get_cached_route("s2", None).await.is_some()); } } diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index d4065d0d..10114274 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -21,6 +21,9 @@ pub struct SessionCacheConfig { pub cache_type: SessionCacheType, /// Redis URL, e.g. `redis://localhost:6379`. Required when `type` is `redis`. pub url: Option, + /// Optional HTTP header name whose value is used as a tenant prefix in the cache key. + /// When set, keys are scoped as `plano:affinity:{tenant_id}:{session_id}`. + pub tenant_header: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/docs/source/resources/includes/plano_config_full_reference.yaml b/docs/source/resources/includes/plano_config_full_reference.yaml index 049dba67..e9c89175 100644 --- a/docs/source/resources/includes/plano_config_full_reference.yaml +++ b/docs/source/resources/includes/plano_config_full_reference.yaml @@ -185,6 +185,7 @@ routing: type: memory # "memory" (default) or "redis" # url is required when type is "redis". Supports redis:// and rediss:// (TLS). # url: redis://localhost:6379 + # tenant_header: x-org-id # optional; when set, keys are scoped as plano:affinity:{tenant_id}:{session_id} # State storage for multi-turn conversation history state_storage: