diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index d681b089..891a4005 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -441,6 +441,20 @@ properties: minimum: 1 maximum: 10000 description: Maximum number of session-pinned routing cache entries. Default 10000. + session_cache: + type: object + properties: + type: + type: string + enum: + - memory + - redis + default: memory + description: Session cache backend. "memory" (default) is in-process; "redis" is shared across replicas. + url: + type: string + description: Redis URL, e.g. redis://localhost:6379. Required when type is redis. + additionalProperties: false additionalProperties: false state_storage: type: object diff --git a/crates/Cargo.lock b/crates/Cargo.lock index fbf817e7..fc43f035 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -68,6 +68,15 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +[[package]] +name = "arc-swap" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a3a1fd6f75306b68087b831f025c712524bcb19aad54e557b1129cfa0a2b207" +dependencies = [ + "rustversion", +] + [[package]] name = "assert-json-diff" version = "2.0.2" @@ -330,6 +339,7 @@ dependencies = [ "opentelemetry_sdk", "pretty_assertions", "rand 0.9.2", + "redis", "reqwest", "serde", "serde_json", @@ -428,7 +438,21 @@ version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", +] + +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", ] [[package]] @@ -710,7 +734,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1455,6 +1479,15 @@ dependencies = [ "serde", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -1721,6 +1754,16 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + [[package]] name = "num-conv" version = "0.2.0" @@ -2101,7 +2144,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", - "itertools", + "itertools 0.14.0", "proc-macro2", "quote", "syn 2.0.101", @@ -2169,7 +2212,7 @@ dependencies = [ "once_cell", "socket2", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2246,6 +2289,30 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "redis" +version = "0.27.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d8f99a4090c89cc489a94833c901ead69bfbf3877b4867d5482e321ee875bc" +dependencies = [ + "arc-swap", + "async-trait", + "bytes", + "combine", + "futures-util", + "itertools 0.13.0", + "itoa", + "num-bigint", + "percent-encoding", + "pin-project-lite", + "ryu", + "sha1_smol", + "socket2", + "tokio", + "tokio-util", + "url", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -2413,7 +2480,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2751,6 +2818,12 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "sha1_smol" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" + [[package]] name = "sha2" version = "0.10.9" @@ -2940,7 +3013,7 @@ dependencies = [ "getrandom 0.3.3", "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index 5d986ffa..78bc7037 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -26,6 +26,7 @@ opentelemetry-stdout = "0.31" opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] } pretty_assertions = "1.4.1" rand = "0.9.2" +redis = { version = "0.27", features = ["tokio-comp"] } reqwest = { version = "0.12.15", features = ["stream"] } serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" diff --git a/crates/brightstaff/src/lib.rs b/crates/brightstaff/src/lib.rs index b4ab82a9..a0ba5f43 100644 --- a/crates/brightstaff/src/lib.rs +++ b/crates/brightstaff/src/lib.rs @@ -1,6 +1,7 @@ pub mod app_state; pub mod handlers; pub mod router; +pub mod session_cache; pub mod signals; pub mod state; pub mod streaming; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index f179dc4b..b988fd74 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -8,6 +8,8 @@ use brightstaff::handlers::routing_service::routing_decision; use brightstaff::router::llm::RouterService; use brightstaff::router::model_metrics::ModelMetricsService; use brightstaff::router::orchestrator::OrchestratorService; +use brightstaff::session_cache::memory::MemorySessionCache; +use brightstaff::session_cache::SessionCache; use brightstaff::state::memory::MemoryConversationalStorage; use brightstaff::state::postgresql::PostgreSQLConversationStorage; use brightstaff::state::StateStorage; @@ -175,7 +177,7 @@ async fn init_app_state( .unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string()); let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds); - let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries); + let session_cache = init_session_cache(config).await?; // Validate that top-level routing_preferences requires v0.4.0+. let config_version = parse_semver(&config.version); @@ -304,7 +306,7 @@ async fn init_app_state( routing_model_name, routing_llm_provider, session_ttl_seconds, - session_max_entries, + session_cache, )); // Spawn background task to clean up expired session cache entries every 5 minutes @@ -361,6 +363,54 @@ async fn init_app_state( }) } +/// Initialize the session cache backend based on config. +/// Defaults to the in-memory backend when no `session_cache` is configured. +async fn init_session_cache( + config: &Configuration, +) -> Result, Box> { + use common::configuration::SessionCacheType; + use std::time::Duration; + + let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds); + let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries); + + const DEFAULT_SESSION_TTL_SECONDS: u64 = 600; + const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000; + const MAX_SESSION_MAX_ENTRIES: usize = 10_000; + + let ttl = Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS)); + let max_entries = session_max_entries + .unwrap_or(DEFAULT_SESSION_MAX_ENTRIES) + .min(MAX_SESSION_MAX_ENTRIES); + + let cache_config = config + .routing + .as_ref() + .and_then(|r| r.session_cache.as_ref()); + + let cache_type = cache_config + .map(|c| &c.cache_type) + .unwrap_or(&SessionCacheType::Memory); + + match cache_type { + SessionCacheType::Memory => { + info!(storage_type = "memory", "initialized session cache"); + Ok(Arc::new(MemorySessionCache::new(ttl, max_entries))) + } + SessionCacheType::Redis => { + use brightstaff::session_cache::redis::RedisSessionCache; + let url = cache_config + .and_then(|c| c.url.as_ref()) + .ok_or("session_cache.url is required when type is redis")?; + info!(storage_type = "redis", url = %url, "initializing session cache"); + let cache = RedisSessionCache::new(url) + .await + .map_err(|e| format!("failed to connect to Redis session cache: {e}"))?; + Ok(Arc::new(cache)) + } + } +} + /// Initialize the conversation state storage backend (if configured). async fn init_state_storage( config: &Configuration, diff --git a/crates/brightstaff/src/router/llm.rs b/crates/brightstaff/src/router/llm.rs index 5a208c6e..b40753c2 100644 --- a/crates/brightstaff/src/router/llm.rs +++ b/crates/brightstaff/src/router/llm.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use common::{ configuration::TopLevelRoutingPreference, @@ -9,7 +9,6 @@ use super::router_model::{ModelUsagePreference, RoutingPreference}; use hermesllm::apis::openai::Message; use hyper::header; use thiserror::Error; -use tokio::sync::RwLock; use tracing::{debug, info}; use super::http::{self, post_and_extract_content}; @@ -17,17 +16,11 @@ use super::model_metrics::ModelMetricsService; use super::router_model::RouterModel; use crate::router::router_model_v1; +use crate::session_cache::SessionCache; + +pub use crate::session_cache::CachedRoute; const DEFAULT_SESSION_TTL_SECONDS: u64 = 600; -const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000; -const MAX_SESSION_MAX_ENTRIES: usize = 10_000; - -#[derive(Clone, Debug)] -pub struct CachedRoute { - pub model_name: String, - pub route_name: Option, - pub cached_at: Instant, -} pub struct RouterService { router_url: String, @@ -36,9 +29,8 @@ pub struct RouterService { routing_provider_name: String, top_level_preferences: HashMap, metrics_service: Option>, - session_cache: RwLock>, + session_cache: Arc, session_ttl: Duration, - session_max_entries: usize, } #[derive(Debug, Error)] @@ -60,7 +52,7 @@ impl RouterService { routing_model_name: String, routing_provider_name: String, session_ttl_seconds: Option, - session_max_entries: Option, + session_cache: Arc, ) -> Self { let top_level_preferences: HashMap = top_level_prefs .map_or_else(HashMap::new, |prefs| { @@ -93,9 +85,6 @@ impl RouterService { let session_ttl = Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS)); - let session_max_entries = session_max_entries - .unwrap_or(DEFAULT_SESSION_MAX_ENTRIES) - .min(MAX_SESSION_MAX_ENTRIES); RouterService { router_url, @@ -104,65 +93,49 @@ impl RouterService { routing_provider_name, top_level_preferences, metrics_service, - session_cache: RwLock::new(HashMap::new()), + session_cache, session_ttl, - session_max_entries, } } /// Look up a cached routing decision by session ID. /// Returns None if not found or expired. pub async fn get_cached_route(&self, session_id: &str) -> Option { - let cache = self.session_cache.read().await; - if let Some(entry) = cache.get(session_id) { - if entry.cached_at.elapsed() < self.session_ttl { - return Some(entry.clone()); - } - } - None + self.session_cache.get(session_id).await } /// Store a routing decision in the session cache. - /// If at max capacity, evicts the oldest entry. pub async fn cache_route( &self, session_id: String, model_name: String, route_name: Option, ) { - let mut cache = self.session_cache.write().await; - if cache.len() >= self.session_max_entries && !cache.contains_key(&session_id) { - if let Some(oldest_key) = cache - .iter() - .min_by_key(|(_, v)| v.cached_at) - .map(|(k, _)| k.clone()) - { - cache.remove(&oldest_key); - } - } - cache.insert( - session_id, - CachedRoute { - model_name, - route_name, - cached_at: Instant::now(), - }, - ); + self.session_cache + .put( + &session_id, + CachedRoute { + model_name, + route_name, + }, + self.session_ttl, + ) + .await; } /// Remove all expired entries from the session cache. pub async fn cleanup_expired_sessions(&self) { - let mut cache = self.session_cache.write().await; - let before = cache.len(); - cache.retain(|_, entry| entry.cached_at.elapsed() < self.session_ttl); - let removed = before - cache.len(); - if removed > 0 { - info!( - removed = removed, - remaining = cache.len(), - "cleaned up expired session cache entries" - ); - } + self.session_cache.cleanup_expired().await; + } + + /// Log a routing decision, used to surface affinity hits in structured logs. + pub fn log_affinity_hit(session_id: &str, model_name: &str, route_name: &Option) { + info!( + session_id = %session_id, + model = %model_name, + route = ?route_name, + "returning pinned routing decision from cache" + ); } pub async fn determine_route( @@ -283,8 +256,11 @@ impl RouterService { #[cfg(test)] mod tests { use super::*; + use crate::session_cache::memory::MemorySessionCache; fn make_router_service(ttl_seconds: u64, max_entries: usize) -> RouterService { + let ttl = Duration::from_secs(ttl_seconds); + let session_cache = Arc::new(MemorySessionCache::new(ttl, max_entries)); RouterService::new( None, None, @@ -292,7 +268,7 @@ mod tests { "Arch-Router".to_string(), "arch-router".to_string(), Some(ttl_seconds), - Some(max_entries), + session_cache, ) } @@ -335,8 +311,9 @@ mod tests { svc.cleanup_expired_sessions().await; - let cache = svc.session_cache.read().await; - assert!(cache.is_empty()); + // After cleanup, both expired entries should be gone + assert!(svc.get_cached_route("s1").await.is_none()); + assert!(svc.get_cached_route("s2").await.is_none()); } #[tokio::test] @@ -351,11 +328,10 @@ mod tests { svc.cache_route("s3".to_string(), "model-c".to_string(), None) .await; - let cache = svc.session_cache.read().await; - assert_eq!(cache.len(), 2); - assert!(!cache.contains_key("s1")); - assert!(cache.contains_key("s2")); - assert!(cache.contains_key("s3")); + // s1 should be evicted (oldest); s2 and s3 should remain + assert!(svc.get_cached_route("s1").await.is_none()); + assert!(svc.get_cached_route("s2").await.is_some()); + assert!(svc.get_cached_route("s3").await.is_some()); } #[tokio::test] @@ -373,8 +349,9 @@ mod tests { ) .await; - let cache = svc.session_cache.read().await; - assert_eq!(cache.len(), 2); - assert_eq!(cache.get("s1").unwrap().model_name, "model-a-updated"); + // Both sessions should still be present + let s1 = svc.get_cached_route("s1").await.unwrap(); + assert_eq!(s1.model_name, "model-a-updated"); + assert!(svc.get_cached_route("s2").await.is_some()); } } diff --git a/crates/brightstaff/src/session_cache/memory.rs b/crates/brightstaff/src/session_cache/memory.rs new file mode 100644 index 00000000..fd0c5a1f --- /dev/null +++ b/crates/brightstaff/src/session_cache/memory.rs @@ -0,0 +1,73 @@ +use std::{ + collections::HashMap, + sync::Arc, + time::{Duration, Instant}, +}; + +use async_trait::async_trait; +use tokio::sync::RwLock; +use tracing::info; + +use super::{CachedRoute, SessionCache}; + +pub struct MemorySessionCache { + store: Arc>>, + ttl: Duration, + max_entries: usize, +} + +impl MemorySessionCache { + pub fn new(ttl: Duration, max_entries: usize) -> Self { + Self { + store: Arc::new(RwLock::new(HashMap::new())), + ttl, + max_entries, + } + } +} + +#[async_trait] +impl SessionCache for MemorySessionCache { + async fn get(&self, session_id: &str) -> Option { + let store = self.store.read().await; + if let Some((route, inserted_at)) = store.get(session_id) { + if inserted_at.elapsed() < self.ttl { + return Some(route.clone()); + } + } + None + } + + async fn put(&self, session_id: &str, route: CachedRoute, _ttl: Duration) { + let mut store = self.store.write().await; + if store.len() >= self.max_entries && !store.contains_key(session_id) { + if let Some(oldest_key) = store + .iter() + .min_by_key(|(_, (_, inserted_at))| *inserted_at) + .map(|(k, _)| k.clone()) + { + store.remove(&oldest_key); + } + } + store.insert(session_id.to_string(), (route, Instant::now())); + } + + async fn remove(&self, session_id: &str) { + self.store.write().await.remove(session_id); + } + + async fn cleanup_expired(&self) { + let ttl = self.ttl; + let mut store = self.store.write().await; + let before = store.len(); + store.retain(|_, (_, inserted_at)| inserted_at.elapsed() < ttl); + let removed = before - store.len(); + if removed > 0 { + info!( + removed = removed, + remaining = store.len(), + "cleaned up expired session cache entries" + ); + } + } +} diff --git a/crates/brightstaff/src/session_cache/mod.rs b/crates/brightstaff/src/session_cache/mod.rs new file mode 100644 index 00000000..e6d0d75e --- /dev/null +++ b/crates/brightstaff/src/session_cache/mod.rs @@ -0,0 +1,26 @@ +use async_trait::async_trait; +use std::time::Duration; + +pub mod memory; +pub mod redis; + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct CachedRoute { + pub model_name: String, + pub route_name: Option, +} + +#[async_trait] +pub trait SessionCache: Send + Sync { + /// Look up a cached routing decision by session ID. + async fn get(&self, session_id: &str) -> Option; + + /// Store a routing decision in the session cache with the given TTL. + async fn put(&self, session_id: &str, route: CachedRoute, ttl: Duration); + + /// Remove a cached routing decision by session ID. + async fn remove(&self, session_id: &str); + + /// Remove all expired entries. No-op for backends that handle expiry natively (e.g. Redis). + async fn cleanup_expired(&self); +} diff --git a/crates/brightstaff/src/session_cache/redis.rs b/crates/brightstaff/src/session_cache/redis.rs new file mode 100644 index 00000000..2b31bdf6 --- /dev/null +++ b/crates/brightstaff/src/session_cache/redis.rs @@ -0,0 +1,46 @@ +use std::time::Duration; + +use async_trait::async_trait; +use redis::aio::MultiplexedConnection; +use redis::AsyncCommands; + +use super::{CachedRoute, SessionCache}; + +pub struct RedisSessionCache { + conn: MultiplexedConnection, +} + +impl RedisSessionCache { + pub async fn new(url: &str) -> Result { + let client = redis::Client::open(url)?; + let conn = client.get_multiplexed_async_connection().await?; + Ok(Self { conn }) + } +} + +#[async_trait] +impl SessionCache for RedisSessionCache { + async fn get(&self, session_id: &str) -> Option { + let mut conn = self.conn.clone(); + let value: Option = conn.get(session_id).await.ok()?; + value.and_then(|v| serde_json::from_str(&v).ok()) + } + + async fn put(&self, session_id: &str, route: CachedRoute, ttl: Duration) { + let mut conn = self.conn.clone(); + let Ok(json) = serde_json::to_string(&route) else { + return; + }; + let ttl_secs = ttl.as_secs().max(1); + let _: Result<(), _> = conn.set_ex(session_id, json, ttl_secs).await; + } + + async fn remove(&self, session_id: &str) { + let mut conn = self.conn.clone(); + let _: Result<(), _> = conn.del(session_id).await; + } + + async fn cleanup_expired(&self) { + // Redis handles TTL expiry natively via EX — nothing to do here. + } +} diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index c4c5924a..d4065d0d 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -7,12 +7,29 @@ use crate::api::open_ai::{ ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType, }; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "lowercase")] +pub enum SessionCacheType { + #[default] + Memory, + Redis, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionCacheConfig { + #[serde(rename = "type", default)] + pub cache_type: SessionCacheType, + /// Redis URL, e.g. `redis://localhost:6379`. Required when `type` is `redis`. + pub url: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Routing { pub llm_provider: Option, pub model: Option, pub session_ttl_seconds: Option, pub session_max_entries: Option, + pub session_cache: Option, } #[derive(Debug, Clone, Serialize, Deserialize)]