add pluggable session cache with Redis backend

This commit is contained in:
Spherrrical 2026-04-09 16:32:31 -07:00
parent 8dedf0bec1
commit 50670f843d
10 changed files with 353 additions and 75 deletions

View file

@ -0,0 +1,73 @@
use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use async_trait::async_trait;
use tokio::sync::RwLock;
use tracing::info;
use super::{CachedRoute, SessionCache};
pub struct MemorySessionCache {
store: Arc<RwLock<HashMap<String, (CachedRoute, Instant)>>>,
ttl: Duration,
max_entries: usize,
}
impl MemorySessionCache {
pub fn new(ttl: Duration, max_entries: usize) -> Self {
Self {
store: Arc::new(RwLock::new(HashMap::new())),
ttl,
max_entries,
}
}
}
#[async_trait]
impl SessionCache for MemorySessionCache {
async fn get(&self, session_id: &str) -> Option<CachedRoute> {
let store = self.store.read().await;
if let Some((route, inserted_at)) = store.get(session_id) {
if inserted_at.elapsed() < self.ttl {
return Some(route.clone());
}
}
None
}
async fn put(&self, session_id: &str, route: CachedRoute, _ttl: Duration) {
let mut store = self.store.write().await;
if store.len() >= self.max_entries && !store.contains_key(session_id) {
if let Some(oldest_key) = store
.iter()
.min_by_key(|(_, (_, inserted_at))| *inserted_at)
.map(|(k, _)| k.clone())
{
store.remove(&oldest_key);
}
}
store.insert(session_id.to_string(), (route, Instant::now()));
}
async fn remove(&self, session_id: &str) {
self.store.write().await.remove(session_id);
}
async fn cleanup_expired(&self) {
let ttl = self.ttl;
let mut store = self.store.write().await;
let before = store.len();
store.retain(|_, (_, inserted_at)| inserted_at.elapsed() < ttl);
let removed = before - store.len();
if removed > 0 {
info!(
removed = removed,
remaining = store.len(),
"cleaned up expired session cache entries"
);
}
}
}

View file

@ -0,0 +1,26 @@
use async_trait::async_trait;
use std::time::Duration;
pub mod memory;
pub mod redis;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct CachedRoute {
pub model_name: String,
pub route_name: Option<String>,
}
#[async_trait]
pub trait SessionCache: Send + Sync {
/// Look up a cached routing decision by session ID.
async fn get(&self, session_id: &str) -> Option<CachedRoute>;
/// Store a routing decision in the session cache with the given TTL.
async fn put(&self, session_id: &str, route: CachedRoute, ttl: Duration);
/// Remove a cached routing decision by session ID.
async fn remove(&self, session_id: &str);
/// Remove all expired entries. No-op for backends that handle expiry natively (e.g. Redis).
async fn cleanup_expired(&self);
}

View file

@ -0,0 +1,46 @@
use std::time::Duration;
use async_trait::async_trait;
use redis::aio::MultiplexedConnection;
use redis::AsyncCommands;
use super::{CachedRoute, SessionCache};
pub struct RedisSessionCache {
conn: MultiplexedConnection,
}
impl RedisSessionCache {
pub async fn new(url: &str) -> Result<Self, redis::RedisError> {
let client = redis::Client::open(url)?;
let conn = client.get_multiplexed_async_connection().await?;
Ok(Self { conn })
}
}
#[async_trait]
impl SessionCache for RedisSessionCache {
async fn get(&self, session_id: &str) -> Option<CachedRoute> {
let mut conn = self.conn.clone();
let value: Option<String> = conn.get(session_id).await.ok()?;
value.and_then(|v| serde_json::from_str(&v).ok())
}
async fn put(&self, session_id: &str, route: CachedRoute, ttl: Duration) {
let mut conn = self.conn.clone();
let Ok(json) = serde_json::to_string(&route) else {
return;
};
let ttl_secs = ttl.as_secs().max(1);
let _: Result<(), _> = conn.set_ex(session_id, json, ttl_secs).await;
}
async fn remove(&self, session_id: &str) {
let mut conn = self.conn.clone();
let _: Result<(), _> = conn.del(session_id).await;
}
async fn cleanup_expired(&self) {
// Redis handles TTL expiry natively via EX — nothing to do here.
}
}