feat: add SessionCache trait with memory and redis backends for model affinity

Agent-Logs-Url: https://github.com/katanemo/plano/sessions/6a75240b-4578-409d-b8c7-eff47dba8a03

Co-authored-by: adilhafeez <13196462+adilhafeez@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2026-04-09 03:44:57 +00:00 committed by GitHub
parent b822e27957
commit 66f8230dd5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1168 additions and 658 deletions

View file

@ -441,6 +441,20 @@ properties:
minimum: 1
maximum: 10000
description: Maximum number of session-pinned routing cache entries. Default 10000.
session_cache:
type: object
description: Session cache backend configuration. Defaults to in-memory if omitted.
properties:
type:
type: string
enum: [memory, redis]
description: "Cache backend type. 'memory' (default) is in-process; 'redis' is shared across replicas."
url:
type: string
description: "Redis connection URL (e.g. redis://localhost:6379). Required when type is 'redis'."
required:
- type
additionalProperties: false
additionalProperties: false
state_storage:
type: object

1308
crates/Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -26,6 +26,7 @@ opentelemetry-stdout = "0.31"
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
pretty_assertions = "1.4.1"
rand = "0.9.2"
redis = { version = "0.27.6", features = ["tokio-comp"] }
reqwest = { version = "0.12.15", features = ["stream"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"

View file

@ -8,6 +8,7 @@ use brightstaff::handlers::routing_service::routing_decision;
use brightstaff::router::llm::RouterService;
use brightstaff::router::model_metrics::ModelMetricsService;
use brightstaff::router::orchestrator::OrchestratorService;
use brightstaff::router::session_cache::{MemorySessionCache, RedisSessionCache, SessionCache};
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
@ -176,6 +177,10 @@ async fn init_app_state(
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
let session_cache_config = config
.routing
.as_ref()
.and_then(|r| r.session_cache.clone());
// Validate that top-level routing_preferences requires v0.4.0+.
let config_version = parse_semver(&config.version);
@ -297,17 +302,24 @@ async fn init_app_state(
}
}
let router_service = Arc::new(RouterService::new(
let session_cache = init_session_cache(
session_cache_config,
session_ttl_seconds,
session_max_entries,
)
.await?;
let router_service = Arc::new(RouterService::with_cache(
config.routing_preferences.clone(),
metrics_service,
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
routing_model_name,
routing_llm_provider,
session_ttl_seconds,
session_max_entries,
session_cache,
));
// Spawn background task to clean up expired session cache entries every 5 minutes
// Spawn background task to clean up expired session cache entries every 5 minutes.
// For the Redis backend this is a no-op because Redis handles TTL natively.
{
let router_service = Arc::clone(&router_service);
tokio::spawn(async move {
@ -361,6 +373,59 @@ async fn init_app_state(
})
}
/// Initialize the session cache backend from the routing config.
///
/// - No config → memory backend with defaults
/// - `type: memory` → memory backend (explicit)
/// - `type: redis` → Redis backend; `url` is required
async fn init_session_cache(
cache_config: Option<common::configuration::SessionCacheConfig>,
session_ttl_seconds: Option<u64>,
session_max_entries: Option<usize>,
) -> Result<Arc<dyn SessionCache>, Box<dyn std::error::Error + Send + Sync>> {
use brightstaff::router::llm::{
DEFAULT_SESSION_MAX_ENTRIES, DEFAULT_SESSION_TTL_SECONDS, MAX_SESSION_MAX_ENTRIES,
};
use common::configuration::SessionCacheType;
let ttl_secs = session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS);
let max_entries = session_max_entries
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
.min(MAX_SESSION_MAX_ENTRIES);
let cache_type = cache_config
.as_ref()
.map(|c| &c.cache_type)
.unwrap_or(&SessionCacheType::Memory);
let cache: Arc<dyn SessionCache> = match cache_type {
SessionCacheType::Memory => {
info!(cache_type = "memory", "initialized session cache");
Arc::new(MemorySessionCache::new(
std::time::Duration::from_secs(ttl_secs),
max_entries,
))
}
SessionCacheType::Redis => {
let url = cache_config
.as_ref()
.and_then(|c| c.url.as_deref())
.ok_or("session_cache.url is required for redis session cache")?;
debug!(url = %url, "redis session cache connection");
info!(cache_type = "redis", url = %url, "initializing session cache");
Arc::new(
RedisSessionCache::new(url, ttl_secs)
.await
.map_err(|e| format!("failed to initialize Redis session cache: {e}"))?,
)
}
};
Ok(cache)
}
/// Initialize the conversation state storage backend (if configured).
async fn init_state_storage(
config: &Configuration,

View file

@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant};
use std::{collections::HashMap, sync::Arc, time::Duration};
use common::{
configuration::TopLevelRoutingPreference,
@ -6,10 +6,10 @@ use common::{
};
use super::router_model::{ModelUsagePreference, RoutingPreference};
use super::session_cache::{CachedRoute, MemorySessionCache, SessionCache};
use hermesllm::apis::openai::Message;
use hyper::header;
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::{debug, info};
use super::http::{self, post_and_extract_content};
@ -18,16 +18,9 @@ use super::router_model::RouterModel;
use crate::router::router_model_v1;
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
#[derive(Clone, Debug)]
pub struct CachedRoute {
pub model_name: String,
pub route_name: Option<String>,
pub cached_at: Instant,
}
pub const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
pub const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
pub const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
pub struct RouterService {
router_url: String,
@ -36,9 +29,7 @@ pub struct RouterService {
routing_provider_name: String,
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
metrics_service: Option<Arc<ModelMetricsService>>,
session_cache: RwLock<HashMap<String, CachedRoute>>,
session_ttl: Duration,
session_max_entries: usize,
session_cache: Arc<dyn SessionCache>,
}
#[derive(Debug, Error)]
@ -61,6 +52,34 @@ impl RouterService {
routing_provider_name: String,
session_ttl_seconds: Option<u64>,
session_max_entries: Option<usize>,
) -> Self {
let session_ttl =
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
let session_max_entries = session_max_entries
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
.min(MAX_SESSION_MAX_ENTRIES);
let session_cache: Arc<dyn SessionCache> =
Arc::new(MemorySessionCache::new(session_ttl, session_max_entries));
RouterService::with_cache(
top_level_prefs,
metrics_service,
router_url,
routing_model_name,
routing_provider_name,
session_cache,
)
}
/// Create a `RouterService` with an explicit `SessionCache` backend.
pub fn with_cache(
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
metrics_service: Option<Arc<ModelMetricsService>>,
router_url: String,
routing_model_name: String,
routing_provider_name: String,
session_cache: Arc<dyn SessionCache>,
) -> Self {
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
.map_or_else(HashMap::new, |prefs| {
@ -91,12 +110,6 @@ impl RouterService {
router_model_v1::MAX_TOKEN_LEN,
));
let session_ttl =
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
let session_max_entries = session_max_entries
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
.min(MAX_SESSION_MAX_ENTRIES);
RouterService {
router_url,
client: reqwest::Client::new(),
@ -104,65 +117,37 @@ impl RouterService {
routing_provider_name,
top_level_preferences,
metrics_service,
session_cache: RwLock::new(HashMap::new()),
session_ttl,
session_max_entries,
session_cache,
}
}
/// Look up a cached routing decision by session ID.
/// Returns None if not found or expired.
pub async fn get_cached_route(&self, session_id: &str) -> Option<CachedRoute> {
let cache = self.session_cache.read().await;
if let Some(entry) = cache.get(session_id) {
if entry.cached_at.elapsed() < self.session_ttl {
return Some(entry.clone());
}
}
None
self.session_cache.get(session_id).await
}
/// Store a routing decision in the session cache.
/// If at max capacity, evicts the oldest entry.
pub async fn cache_route(
&self,
session_id: String,
model_name: String,
route_name: Option<String>,
) {
let mut cache = self.session_cache.write().await;
if cache.len() >= self.session_max_entries && !cache.contains_key(&session_id) {
if let Some(oldest_key) = cache
.iter()
.min_by_key(|(_, v)| v.cached_at)
.map(|(k, _)| k.clone())
{
cache.remove(&oldest_key);
}
}
cache.insert(
session_id,
CachedRoute {
model_name,
route_name,
cached_at: Instant::now(),
},
);
let route = CachedRoute {
model_name,
route_name,
cached_at_ms: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis() as u64,
};
self.session_cache.put(&session_id, route).await;
}
/// Remove all expired entries from the session cache.
pub async fn cleanup_expired_sessions(&self) {
let mut cache = self.session_cache.write().await;
let before = cache.len();
cache.retain(|_, entry| entry.cached_at.elapsed() < self.session_ttl);
let removed = before - cache.len();
if removed > 0 {
info!(
removed = removed,
remaining = cache.len(),
"cleaned up expired session cache entries"
);
}
self.session_cache.cleanup_expired().await;
}
pub async fn determine_route(
@ -335,8 +320,8 @@ mod tests {
svc.cleanup_expired_sessions().await;
let cache = svc.session_cache.read().await;
assert!(cache.is_empty());
assert!(svc.get_cached_route("s1").await.is_none());
assert!(svc.get_cached_route("s2").await.is_none());
}
#[tokio::test]
@ -351,11 +336,12 @@ mod tests {
svc.cache_route("s3".to_string(), "model-c".to_string(), None)
.await;
let cache = svc.session_cache.read().await;
assert_eq!(cache.len(), 2);
assert!(!cache.contains_key("s1"));
assert!(cache.contains_key("s2"));
assert!(cache.contains_key("s3"));
assert!(
svc.get_cached_route("s1").await.is_none(),
"s1 should have been evicted"
);
assert!(svc.get_cached_route("s2").await.is_some());
assert!(svc.get_cached_route("s3").await.is_some());
}
#[tokio::test]
@ -373,8 +359,11 @@ mod tests {
)
.await;
let cache = svc.session_cache.read().await;
assert_eq!(cache.len(), 2);
assert_eq!(cache.get("s1").unwrap().model_name, "model-a-updated");
assert!(svc.get_cached_route("s1").await.is_some());
assert!(svc.get_cached_route("s2").await.is_some());
assert_eq!(
svc.get_cached_route("s1").await.unwrap().model_name,
"model-a-updated"
);
}
}

View file

@ -6,3 +6,4 @@ pub mod orchestrator_model;
pub mod orchestrator_model_v1;
pub mod router_model;
pub mod router_model_v1;
pub mod session_cache;

View file

@ -0,0 +1,283 @@
use async_trait::async_trait;
use redis::aio::MultiplexedConnection;
use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::info;
/// A cached routing decision stored by session ID.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CachedRoute {
pub model_name: String,
pub route_name: Option<String>,
/// Milliseconds since the UNIX epoch when this entry was created.
/// Used only by the memory backend for TTL checks and eviction ordering;
/// Redis uses native key expiry and ignores this field.
pub cached_at_ms: u64,
}
/// Abstracts the session-affinity cache so it can be backed by in-memory
/// storage (default, single-replica) or Redis (multi-replica).
#[async_trait]
pub trait SessionCache: Send + Sync {
/// Return a cached route for `session_id`, or `None` if absent/expired.
async fn get(&self, session_id: &str) -> Option<CachedRoute>;
/// Store a routing decision for `session_id`.
async fn put(&self, session_id: &str, route: CachedRoute);
/// Remove a session entry explicitly.
async fn remove(&self, session_id: &str);
/// Evict all expired entries (no-op for backends with native TTL such as Redis).
async fn cleanup_expired(&self);
}
// ---------------------------------------------------------------------------
// In-memory backend
// ---------------------------------------------------------------------------
/// In-process session cache backed by a `RwLock<HashMap>`.
///
/// This is the default backend and replicates the previous behaviour of
/// `RouterService`. All state is local to the process, so it is only suitable
/// for single-replica deployments.
pub struct MemorySessionCache {
inner: RwLock<HashMap<String, CachedRoute>>,
ttl: Duration,
max_entries: usize,
}
impl MemorySessionCache {
pub fn new(ttl: Duration, max_entries: usize) -> Self {
Self {
inner: RwLock::new(HashMap::new()),
ttl,
max_entries,
}
}
}
/// Returns milliseconds since the UNIX epoch (for TTL bookkeeping).
fn unix_now_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_millis() as u64
}
#[async_trait]
impl SessionCache for MemorySessionCache {
async fn get(&self, session_id: &str) -> Option<CachedRoute> {
let cache = self.inner.read().await;
if let Some(entry) = cache.get(session_id) {
let age_ms = unix_now_ms().saturating_sub(entry.cached_at_ms);
if Duration::from_millis(age_ms) < self.ttl {
return Some(entry.clone());
}
}
None
}
async fn put(&self, session_id: &str, route: CachedRoute) {
let mut cache = self.inner.write().await;
if cache.len() >= self.max_entries && !cache.contains_key(session_id) {
// Evict the oldest entry by `cached_at_ms`.
if let Some(oldest_key) = cache
.iter()
.min_by_key(|(_, v)| v.cached_at_ms)
.map(|(k, _)| k.clone())
{
cache.remove(&oldest_key);
}
}
cache.insert(session_id.to_string(), route);
}
async fn remove(&self, session_id: &str) {
self.inner.write().await.remove(session_id);
}
async fn cleanup_expired(&self) {
let mut cache = self.inner.write().await;
let before = cache.len();
let ttl_ms = self.ttl.as_millis() as u64;
let now = unix_now_ms();
cache.retain(|_, entry| now.saturating_sub(entry.cached_at_ms) < ttl_ms);
let removed = before - cache.len();
if removed > 0 {
info!(
removed = removed,
remaining = cache.len(),
"cleaned up expired session cache entries"
);
}
}
}
// ---------------------------------------------------------------------------
// Redis backend
// ---------------------------------------------------------------------------
/// Shared-state session cache backed by Redis.
///
/// Uses `SET … EX` for automatic TTL-based expiry so that expired sessions are
/// cleaned up by Redis itself — no background cleanup task is needed.
pub struct RedisSessionCache {
conn: Arc<RwLock<MultiplexedConnection>>,
ttl_secs: u64,
}
impl RedisSessionCache {
/// Connect to Redis at `url` and return a new cache instance.
pub async fn new(
url: &str,
ttl_secs: u64,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let client = redis::Client::open(url)?;
let conn = client.get_multiplexed_tokio_connection().await?;
Ok(Self {
conn: Arc::new(RwLock::new(conn)),
ttl_secs,
})
}
}
#[async_trait]
impl SessionCache for RedisSessionCache {
async fn get(&self, session_id: &str) -> Option<CachedRoute> {
let mut conn = self.conn.write().await;
let raw: Option<String> = conn.get(session_id).await.ok()?;
let entry: CachedRoute = serde_json::from_str(&raw?).ok()?;
Some(entry)
}
async fn put(&self, session_id: &str, route: CachedRoute) {
let Ok(json) = serde_json::to_string(&route) else {
return;
};
let mut conn = self.conn.write().await;
let _: redis::RedisResult<()> = conn.set_ex(session_id, json, self.ttl_secs).await;
}
async fn remove(&self, session_id: &str) {
let mut conn = self.conn.write().await;
let _: redis::RedisResult<()> = conn.del(session_id).await;
}
/// Redis handles expiry natively — this is a no-op.
async fn cleanup_expired(&self) {}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
fn make_route(model: &str) -> CachedRoute {
CachedRoute {
model_name: model.to_string(),
route_name: None,
cached_at_ms: unix_now_ms(),
}
}
#[tokio::test]
async fn memory_cache_miss_returns_none() {
let cache = MemorySessionCache::new(Duration::from_secs(600), 100);
assert!(cache.get("unknown").await.is_none());
}
#[tokio::test]
async fn memory_cache_hit_returns_entry() {
let cache = MemorySessionCache::new(Duration::from_secs(600), 100);
cache.put("s1", make_route("gpt-4o")).await;
let hit = cache.get("s1").await.unwrap();
assert_eq!(hit.model_name, "gpt-4o");
}
#[tokio::test]
async fn memory_cache_expired_returns_none() {
let cache = MemorySessionCache::new(Duration::ZERO, 100);
cache.put("s1", make_route("gpt-4o")).await;
assert!(cache.get("s1").await.is_none());
}
#[tokio::test]
async fn memory_cache_cleanup_removes_expired() {
let cache = MemorySessionCache::new(Duration::ZERO, 100);
cache.put("s1", make_route("gpt-4o")).await;
cache.put("s2", make_route("claude")).await;
cache.cleanup_expired().await;
assert!(cache.inner.read().await.is_empty());
}
#[tokio::test]
async fn memory_cache_evicts_oldest_when_full() {
let cache = MemorySessionCache::new(Duration::from_secs(600), 2);
cache
.put(
"s1",
CachedRoute {
model_name: "model-a".to_string(),
route_name: None,
cached_at_ms: unix_now_ms(),
},
)
.await;
tokio::time::sleep(Duration::from_millis(10)).await;
cache
.put(
"s2",
CachedRoute {
model_name: "model-b".to_string(),
route_name: None,
cached_at_ms: unix_now_ms(),
},
)
.await;
cache
.put(
"s3",
CachedRoute {
model_name: "model-c".to_string(),
route_name: None,
cached_at_ms: unix_now_ms(),
},
)
.await;
let inner = cache.inner.read().await;
assert_eq!(inner.len(), 2);
assert!(!inner.contains_key("s1"), "s1 should have been evicted");
assert!(inner.contains_key("s2"));
assert!(inner.contains_key("s3"));
}
#[tokio::test]
async fn memory_cache_remove_deletes_entry() {
let cache = MemorySessionCache::new(Duration::from_secs(600), 100);
cache.put("s1", make_route("gpt-4o")).await;
cache.remove("s1").await;
assert!(cache.get("s1").await.is_none());
}
#[tokio::test]
async fn cached_route_serializes_round_trip() {
let original = CachedRoute {
model_name: "claude-3".to_string(),
route_name: Some("code".to_string()),
cached_at_ms: 1_700_000_000_000,
};
let json = serde_json::to_string(&original).unwrap();
let decoded: CachedRoute = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.model_name, original.model_name);
assert_eq!(decoded.route_name, original.route_name);
assert_eq!(decoded.cached_at_ms, original.cached_at_ms);
}
}

View file

@ -7,12 +7,27 @@ use crate::api::open_ai::{
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum SessionCacheType {
Memory,
Redis,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionCacheConfig {
#[serde(rename = "type")]
pub cache_type: SessionCacheType,
pub url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Routing {
pub llm_provider: Option<String>,
pub model: Option<String>,
pub session_ttl_seconds: Option<u64>,
pub session_max_entries: Option<usize>,
pub session_cache: Option<SessionCacheConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]