add pluggable session cache with Redis backend

This commit is contained in:
Spherrrical 2026-04-09 16:32:31 -07:00
parent 8dedf0bec1
commit 50670f843d
10 changed files with 353 additions and 75 deletions

85
crates/Cargo.lock generated
View file

@ -68,6 +68,15 @@ version = "1.0.98"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
[[package]]
name = "arc-swap"
version = "1.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a3a1fd6f75306b68087b831f025c712524bcb19aad54e557b1129cfa0a2b207"
dependencies = [
"rustversion",
]
[[package]]
name = "assert-json-diff"
version = "2.0.2"
@ -330,6 +339,7 @@ dependencies = [
"opentelemetry_sdk",
"pretty_assertions",
"rand 0.9.2",
"redis",
"reqwest",
"serde",
"serde_json",
@ -428,7 +438,21 @@ version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e"
dependencies = [
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
name = "combine"
version = "4.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
dependencies = [
"bytes",
"futures-core",
"memchr",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]
@ -710,7 +734,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@ -1455,6 +1479,15 @@ dependencies = [
"serde",
]
[[package]]
name = "itertools"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.14.0"
@ -1721,6 +1754,16 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-conv"
version = "0.2.0"
@ -2101,7 +2144,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
dependencies = [
"anyhow",
"itertools",
"itertools 0.14.0",
"proc-macro2",
"quote",
"syn 2.0.101",
@ -2169,7 +2212,7 @@ dependencies = [
"once_cell",
"socket2",
"tracing",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@ -2246,6 +2289,30 @@ dependencies = [
"getrandom 0.3.3",
]
[[package]]
name = "redis"
version = "0.27.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09d8f99a4090c89cc489a94833c901ead69bfbf3877b4867d5482e321ee875bc"
dependencies = [
"arc-swap",
"async-trait",
"bytes",
"combine",
"futures-util",
"itertools 0.13.0",
"itoa",
"num-bigint",
"percent-encoding",
"pin-project-lite",
"ryu",
"sha1_smol",
"socket2",
"tokio",
"tokio-util",
"url",
]
[[package]]
name = "redox_syscall"
version = "0.5.18"
@ -2413,7 +2480,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@ -2751,6 +2818,12 @@ dependencies = [
"syn 2.0.101",
]
[[package]]
name = "sha1_smol"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d"
[[package]]
name = "sha2"
version = "0.10.9"
@ -2940,7 +3013,7 @@ dependencies = [
"getrandom 0.3.3",
"once_cell",
"rustix",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]

View file

@ -26,6 +26,7 @@ opentelemetry-stdout = "0.31"
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
pretty_assertions = "1.4.1"
rand = "0.9.2"
redis = { version = "0.27", features = ["tokio-comp"] }
reqwest = { version = "0.12.15", features = ["stream"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"

View file

@ -1,6 +1,7 @@
pub mod app_state;
pub mod handlers;
pub mod router;
pub mod session_cache;
pub mod signals;
pub mod state;
pub mod streaming;

View file

@ -8,6 +8,8 @@ use brightstaff::handlers::routing_service::routing_decision;
use brightstaff::router::llm::RouterService;
use brightstaff::router::model_metrics::ModelMetricsService;
use brightstaff::router::orchestrator::OrchestratorService;
use brightstaff::session_cache::memory::MemorySessionCache;
use brightstaff::session_cache::SessionCache;
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
@ -175,7 +177,7 @@ async fn init_app_state(
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
let session_cache = init_session_cache(config).await?;
// Validate that top-level routing_preferences requires v0.4.0+.
let config_version = parse_semver(&config.version);
@ -304,7 +306,7 @@ async fn init_app_state(
routing_model_name,
routing_llm_provider,
session_ttl_seconds,
session_max_entries,
session_cache,
));
// Spawn background task to clean up expired session cache entries every 5 minutes
@ -361,6 +363,54 @@ async fn init_app_state(
})
}
/// Initialize the session cache backend based on config.
/// Defaults to the in-memory backend when no `session_cache` is configured.
async fn init_session_cache(
config: &Configuration,
) -> Result<Arc<dyn SessionCache>, Box<dyn std::error::Error + Send + Sync>> {
use common::configuration::SessionCacheType;
use std::time::Duration;
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
let ttl = Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
let max_entries = session_max_entries
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
.min(MAX_SESSION_MAX_ENTRIES);
let cache_config = config
.routing
.as_ref()
.and_then(|r| r.session_cache.as_ref());
let cache_type = cache_config
.map(|c| &c.cache_type)
.unwrap_or(&SessionCacheType::Memory);
match cache_type {
SessionCacheType::Memory => {
info!(storage_type = "memory", "initialized session cache");
Ok(Arc::new(MemorySessionCache::new(ttl, max_entries)))
}
SessionCacheType::Redis => {
use brightstaff::session_cache::redis::RedisSessionCache;
let url = cache_config
.and_then(|c| c.url.as_ref())
.ok_or("session_cache.url is required when type is redis")?;
info!(storage_type = "redis", url = %url, "initializing session cache");
let cache = RedisSessionCache::new(url)
.await
.map_err(|e| format!("failed to connect to Redis session cache: {e}"))?;
Ok(Arc::new(cache))
}
}
}
/// Initialize the conversation state storage backend (if configured).
async fn init_state_storage(
config: &Configuration,

View file

@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant};
use std::{collections::HashMap, sync::Arc, time::Duration};
use common::{
configuration::TopLevelRoutingPreference,
@ -9,7 +9,6 @@ use super::router_model::{ModelUsagePreference, RoutingPreference};
use hermesllm::apis::openai::Message;
use hyper::header;
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::{debug, info};
use super::http::{self, post_and_extract_content};
@ -17,17 +16,11 @@ use super::model_metrics::ModelMetricsService;
use super::router_model::RouterModel;
use crate::router::router_model_v1;
use crate::session_cache::SessionCache;
pub use crate::session_cache::CachedRoute;
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
#[derive(Clone, Debug)]
pub struct CachedRoute {
pub model_name: String,
pub route_name: Option<String>,
pub cached_at: Instant,
}
pub struct RouterService {
router_url: String,
@ -36,9 +29,8 @@ pub struct RouterService {
routing_provider_name: String,
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
metrics_service: Option<Arc<ModelMetricsService>>,
session_cache: RwLock<HashMap<String, CachedRoute>>,
session_cache: Arc<dyn SessionCache>,
session_ttl: Duration,
session_max_entries: usize,
}
#[derive(Debug, Error)]
@ -60,7 +52,7 @@ impl RouterService {
routing_model_name: String,
routing_provider_name: String,
session_ttl_seconds: Option<u64>,
session_max_entries: Option<usize>,
session_cache: Arc<dyn SessionCache>,
) -> Self {
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
.map_or_else(HashMap::new, |prefs| {
@ -93,9 +85,6 @@ impl RouterService {
let session_ttl =
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
let session_max_entries = session_max_entries
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
.min(MAX_SESSION_MAX_ENTRIES);
RouterService {
router_url,
@ -104,65 +93,49 @@ impl RouterService {
routing_provider_name,
top_level_preferences,
metrics_service,
session_cache: RwLock::new(HashMap::new()),
session_cache,
session_ttl,
session_max_entries,
}
}
/// Look up a cached routing decision by session ID.
/// Returns None if not found or expired.
pub async fn get_cached_route(&self, session_id: &str) -> Option<CachedRoute> {
let cache = self.session_cache.read().await;
if let Some(entry) = cache.get(session_id) {
if entry.cached_at.elapsed() < self.session_ttl {
return Some(entry.clone());
}
}
None
self.session_cache.get(session_id).await
}
/// Store a routing decision in the session cache.
/// If at max capacity, evicts the oldest entry.
pub async fn cache_route(
&self,
session_id: String,
model_name: String,
route_name: Option<String>,
) {
let mut cache = self.session_cache.write().await;
if cache.len() >= self.session_max_entries && !cache.contains_key(&session_id) {
if let Some(oldest_key) = cache
.iter()
.min_by_key(|(_, v)| v.cached_at)
.map(|(k, _)| k.clone())
{
cache.remove(&oldest_key);
}
}
cache.insert(
session_id,
CachedRoute {
model_name,
route_name,
cached_at: Instant::now(),
},
);
self.session_cache
.put(
&session_id,
CachedRoute {
model_name,
route_name,
},
self.session_ttl,
)
.await;
}
/// Remove all expired entries from the session cache.
pub async fn cleanup_expired_sessions(&self) {
let mut cache = self.session_cache.write().await;
let before = cache.len();
cache.retain(|_, entry| entry.cached_at.elapsed() < self.session_ttl);
let removed = before - cache.len();
if removed > 0 {
info!(
removed = removed,
remaining = cache.len(),
"cleaned up expired session cache entries"
);
}
self.session_cache.cleanup_expired().await;
}
/// Log a routing decision, used to surface affinity hits in structured logs.
pub fn log_affinity_hit(session_id: &str, model_name: &str, route_name: &Option<String>) {
info!(
session_id = %session_id,
model = %model_name,
route = ?route_name,
"returning pinned routing decision from cache"
);
}
pub async fn determine_route(
@ -283,8 +256,11 @@ impl RouterService {
#[cfg(test)]
mod tests {
use super::*;
use crate::session_cache::memory::MemorySessionCache;
fn make_router_service(ttl_seconds: u64, max_entries: usize) -> RouterService {
let ttl = Duration::from_secs(ttl_seconds);
let session_cache = Arc::new(MemorySessionCache::new(ttl, max_entries));
RouterService::new(
None,
None,
@ -292,7 +268,7 @@ mod tests {
"Arch-Router".to_string(),
"arch-router".to_string(),
Some(ttl_seconds),
Some(max_entries),
session_cache,
)
}
@ -335,8 +311,9 @@ mod tests {
svc.cleanup_expired_sessions().await;
let cache = svc.session_cache.read().await;
assert!(cache.is_empty());
// After cleanup, both expired entries should be gone
assert!(svc.get_cached_route("s1").await.is_none());
assert!(svc.get_cached_route("s2").await.is_none());
}
#[tokio::test]
@ -351,11 +328,10 @@ mod tests {
svc.cache_route("s3".to_string(), "model-c".to_string(), None)
.await;
let cache = svc.session_cache.read().await;
assert_eq!(cache.len(), 2);
assert!(!cache.contains_key("s1"));
assert!(cache.contains_key("s2"));
assert!(cache.contains_key("s3"));
// s1 should be evicted (oldest); s2 and s3 should remain
assert!(svc.get_cached_route("s1").await.is_none());
assert!(svc.get_cached_route("s2").await.is_some());
assert!(svc.get_cached_route("s3").await.is_some());
}
#[tokio::test]
@ -373,8 +349,9 @@ mod tests {
)
.await;
let cache = svc.session_cache.read().await;
assert_eq!(cache.len(), 2);
assert_eq!(cache.get("s1").unwrap().model_name, "model-a-updated");
// Both sessions should still be present
let s1 = svc.get_cached_route("s1").await.unwrap();
assert_eq!(s1.model_name, "model-a-updated");
assert!(svc.get_cached_route("s2").await.is_some());
}
}

View file

@ -0,0 +1,73 @@
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use async_trait::async_trait;
use tokio::sync::RwLock;
use tracing::info;
use super::{CachedRoute, SessionCache};
pub struct MemorySessionCache {
store: Arc<RwLock<HashMap<String, (CachedRoute, Instant)>>>,
ttl: Duration,
max_entries: usize,
}
impl MemorySessionCache {
pub fn new(ttl: Duration, max_entries: usize) -> Self {
Self {
store: Arc::new(RwLock::new(HashMap::new())),
ttl,
max_entries,
}
}
}
#[async_trait]
impl SessionCache for MemorySessionCache {
async fn get(&self, session_id: &str) -> Option<CachedRoute> {
let store = self.store.read().await;
if let Some((route, inserted_at)) = store.get(session_id) {
if inserted_at.elapsed() < self.ttl {
return Some(route.clone());
}
}
None
}
async fn put(&self, session_id: &str, route: CachedRoute, _ttl: Duration) {
let mut store = self.store.write().await;
if store.len() >= self.max_entries && !store.contains_key(session_id) {
if let Some(oldest_key) = store
.iter()
.min_by_key(|(_, (_, inserted_at))| *inserted_at)
.map(|(k, _)| k.clone())
{
store.remove(&oldest_key);
}
}
store.insert(session_id.to_string(), (route, Instant::now()));
}
async fn remove(&self, session_id: &str) {
self.store.write().await.remove(session_id);
}
async fn cleanup_expired(&self) {
let ttl = self.ttl;
let mut store = self.store.write().await;
let before = store.len();
store.retain(|_, (_, inserted_at)| inserted_at.elapsed() < ttl);
let removed = before - store.len();
if removed > 0 {
info!(
removed = removed,
remaining = store.len(),
"cleaned up expired session cache entries"
);
}
}
}

View file

@ -0,0 +1,26 @@
use async_trait::async_trait;
use std::time::Duration;
pub mod memory;
pub mod redis;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct CachedRoute {
pub model_name: String,
pub route_name: Option<String>,
}
#[async_trait]
pub trait SessionCache: Send + Sync {
/// Look up a cached routing decision by session ID.
async fn get(&self, session_id: &str) -> Option<CachedRoute>;
/// Store a routing decision in the session cache with the given TTL.
async fn put(&self, session_id: &str, route: CachedRoute, ttl: Duration);
/// Remove a cached routing decision by session ID.
async fn remove(&self, session_id: &str);
/// Remove all expired entries. No-op for backends that handle expiry natively (e.g. Redis).
async fn cleanup_expired(&self);
}

View file

@ -0,0 +1,46 @@
use std::time::Duration;
use async_trait::async_trait;
use redis::aio::MultiplexedConnection;
use redis::AsyncCommands;
use super::{CachedRoute, SessionCache};
pub struct RedisSessionCache {
conn: MultiplexedConnection,
}
impl RedisSessionCache {
pub async fn new(url: &str) -> Result<Self, redis::RedisError> {
let client = redis::Client::open(url)?;
let conn = client.get_multiplexed_async_connection().await?;
Ok(Self { conn })
}
}
#[async_trait]
impl SessionCache for RedisSessionCache {
async fn get(&self, session_id: &str) -> Option<CachedRoute> {
let mut conn = self.conn.clone();
let value: Option<String> = conn.get(session_id).await.ok()?;
value.and_then(|v| serde_json::from_str(&v).ok())
}
async fn put(&self, session_id: &str, route: CachedRoute, ttl: Duration) {
let mut conn = self.conn.clone();
let Ok(json) = serde_json::to_string(&route) else {
return;
};
let ttl_secs = ttl.as_secs().max(1);
let _: Result<(), _> = conn.set_ex(session_id, json, ttl_secs).await;
}
async fn remove(&self, session_id: &str) {
let mut conn = self.conn.clone();
let _: Result<(), _> = conn.del(session_id).await;
}
async fn cleanup_expired(&self) {
// Redis handles TTL expiry natively via EX — nothing to do here.
}
}

View file

@ -7,12 +7,29 @@ use crate::api::open_ai::{
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum SessionCacheType {
#[default]
Memory,
Redis,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionCacheConfig {
#[serde(rename = "type", default)]
pub cache_type: SessionCacheType,
/// Redis URL, e.g. `redis://localhost:6379`. Required when `type` is `redis`.
pub url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Routing {
pub llm_provider: Option<String>,
pub model: Option<String>,
pub session_ttl_seconds: Option<u64>,
pub session_max_entries: Option<usize>,
pub session_cache: Option<SessionCacheConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]