mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
address PR review feedback on session cache
This commit is contained in:
parent
90810078da
commit
e9e6e1765a
7 changed files with 1179 additions and 760 deletions
1684
crates/Cargo.lock
generated
1684
crates/Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -26,6 +26,7 @@ opentelemetry-stdout = "0.31"
|
|||
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
|
||||
pretty_assertions = "1.4.1"
|
||||
rand = "0.9.2"
|
||||
lru = "0.12"
|
||||
redis = { version = "0.27", features = ["tokio-comp"] }
|
||||
reqwest = { version = "0.12.15", features = ["stream"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
|
|
|
|||
|
|
@ -8,8 +8,7 @@ use brightstaff::handlers::routing_service::routing_decision;
|
|||
use brightstaff::router::llm::RouterService;
|
||||
use brightstaff::router::model_metrics::ModelMetricsService;
|
||||
use brightstaff::router::orchestrator::OrchestratorService;
|
||||
use brightstaff::session_cache::memory::MemorySessionCache;
|
||||
use brightstaff::session_cache::SessionCache;
|
||||
use brightstaff::session_cache::init_session_cache;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::StateStorage;
|
||||
|
|
@ -309,18 +308,6 @@ async fn init_app_state(
|
|||
session_cache,
|
||||
));
|
||||
|
||||
// Spawn background task to clean up expired session cache entries every 5 minutes
|
||||
{
|
||||
let router_service = Arc::clone(&router_service);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
router_service.cleanup_expired_sessions().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let orchestrator_model_name: String = overrides
|
||||
.agent_orchestration_model
|
||||
.as_deref()
|
||||
|
|
@ -363,54 +350,6 @@ async fn init_app_state(
|
|||
})
|
||||
}
|
||||
|
||||
/// Initialize the session cache backend based on config.
|
||||
/// Defaults to the in-memory backend when no `session_cache` is configured.
|
||||
async fn init_session_cache(
|
||||
config: &Configuration,
|
||||
) -> Result<Arc<dyn SessionCache>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use common::configuration::SessionCacheType;
|
||||
use std::time::Duration;
|
||||
|
||||
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
|
||||
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
|
||||
|
||||
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
|
||||
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
|
||||
let ttl = Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
|
||||
let max_entries = session_max_entries
|
||||
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
|
||||
.min(MAX_SESSION_MAX_ENTRIES);
|
||||
|
||||
let cache_config = config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.session_cache.as_ref());
|
||||
|
||||
let cache_type = cache_config
|
||||
.map(|c| &c.cache_type)
|
||||
.unwrap_or(&SessionCacheType::Memory);
|
||||
|
||||
match cache_type {
|
||||
SessionCacheType::Memory => {
|
||||
info!(storage_type = "memory", "initialized session cache");
|
||||
Ok(Arc::new(MemorySessionCache::new(ttl, max_entries)))
|
||||
}
|
||||
SessionCacheType::Redis => {
|
||||
use brightstaff::session_cache::redis::RedisSessionCache;
|
||||
let url = cache_config
|
||||
.and_then(|c| c.url.as_ref())
|
||||
.ok_or("session_cache.url is required when type is redis")?;
|
||||
info!(storage_type = "redis", url = %url, "initializing session cache");
|
||||
let cache = RedisSessionCache::new(url)
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to Redis session cache: {e}"))?;
|
||||
Ok(Arc::new(cache))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the conversation state storage backend (if configured).
|
||||
async fn init_state_storage(
|
||||
config: &Configuration,
|
||||
|
|
|
|||
|
|
@ -123,11 +123,6 @@ impl RouterService {
|
|||
.await;
|
||||
}
|
||||
|
||||
/// Remove all expired entries from the session cache.
|
||||
pub async fn cleanup_expired_sessions(&self) {
|
||||
self.session_cache.cleanup_expired().await;
|
||||
}
|
||||
|
||||
/// Log a routing decision, used to surface affinity hits in structured logs.
|
||||
pub fn log_affinity_hit(session_id: &str, model_name: &str, route_name: &Option<String>) {
|
||||
info!(
|
||||
|
|
@ -259,8 +254,7 @@ mod tests {
|
|||
use crate::session_cache::memory::MemorySessionCache;
|
||||
|
||||
fn make_router_service(ttl_seconds: u64, max_entries: usize) -> RouterService {
|
||||
let ttl = Duration::from_secs(ttl_seconds);
|
||||
let session_cache = Arc::new(MemorySessionCache::new(ttl, max_entries));
|
||||
let session_cache = Arc::new(MemorySessionCache::new(max_entries));
|
||||
RouterService::new(
|
||||
None,
|
||||
None,
|
||||
|
|
@ -302,16 +296,14 @@ mod tests {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cleanup_removes_expired() {
|
||||
async fn test_expired_entries_not_returned() {
|
||||
let svc = make_router_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), "claude".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cleanup_expired_sessions().await;
|
||||
|
||||
// After cleanup, both expired entries should be gone
|
||||
// Entries with TTL=0 should be expired immediately
|
||||
assert!(svc.get_cached_route("s1").await.is_none());
|
||||
assert!(svc.get_cached_route("s2").await.is_none());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,73 +1,82 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
num::NonZeroUsize,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::sync::RwLock;
|
||||
use lru::LruCache;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::info;
|
||||
|
||||
use super::{CachedRoute, SessionCache};
|
||||
|
||||
type CacheStore = Mutex<LruCache<String, (CachedRoute, Instant, Duration)>>;
|
||||
|
||||
pub struct MemorySessionCache {
|
||||
store: Arc<RwLock<HashMap<String, (CachedRoute, Instant)>>>,
|
||||
ttl: Duration,
|
||||
max_entries: usize,
|
||||
store: Arc<CacheStore>,
|
||||
}
|
||||
|
||||
impl MemorySessionCache {
|
||||
pub fn new(ttl: Duration, max_entries: usize) -> Self {
|
||||
Self {
|
||||
store: Arc::new(RwLock::new(HashMap::new())),
|
||||
ttl,
|
||||
max_entries,
|
||||
pub fn new(max_entries: usize) -> Self {
|
||||
let capacity = NonZeroUsize::new(max_entries)
|
||||
.unwrap_or_else(|| NonZeroUsize::new(10_000).expect("10_000 is non-zero"));
|
||||
let store = Arc::new(Mutex::new(LruCache::new(capacity)));
|
||||
|
||||
// Spawn a background task to evict TTL-expired entries every 5 minutes.
|
||||
let store_clone = Arc::clone(&store);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(300));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
Self::evict_expired(&store_clone).await;
|
||||
}
|
||||
});
|
||||
|
||||
Self { store }
|
||||
}
|
||||
|
||||
async fn evict_expired(store: &CacheStore) {
|
||||
let mut cache = store.lock().await;
|
||||
let expired: Vec<String> = cache
|
||||
.iter()
|
||||
.filter(|(_, (_, inserted_at, ttl))| inserted_at.elapsed() >= *ttl)
|
||||
.map(|(k, _)| k.clone())
|
||||
.collect();
|
||||
let removed = expired.len();
|
||||
for key in &expired {
|
||||
cache.pop(key.as_str());
|
||||
}
|
||||
if removed > 0 {
|
||||
info!(
|
||||
removed = removed,
|
||||
remaining = cache.len(),
|
||||
"cleaned up expired session cache entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionCache for MemorySessionCache {
|
||||
async fn get(&self, session_id: &str) -> Option<CachedRoute> {
|
||||
let store = self.store.read().await;
|
||||
if let Some((route, inserted_at)) = store.get(session_id) {
|
||||
if inserted_at.elapsed() < self.ttl {
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute> {
|
||||
let mut cache = self.store.lock().await;
|
||||
if let Some((route, inserted_at, ttl)) = cache.get(key) {
|
||||
if inserted_at.elapsed() < *ttl {
|
||||
return Some(route.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
async fn put(&self, session_id: &str, route: CachedRoute, _ttl: Duration) {
|
||||
let mut store = self.store.write().await;
|
||||
if store.len() >= self.max_entries && !store.contains_key(session_id) {
|
||||
if let Some(oldest_key) = store
|
||||
.iter()
|
||||
.min_by_key(|(_, (_, inserted_at))| *inserted_at)
|
||||
.map(|(k, _)| k.clone())
|
||||
{
|
||||
store.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
store.insert(session_id.to_string(), (route, Instant::now()));
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration) {
|
||||
self.store
|
||||
.lock()
|
||||
.await
|
||||
.put(key.to_string(), (route, Instant::now(), ttl));
|
||||
}
|
||||
|
||||
async fn remove(&self, session_id: &str) {
|
||||
self.store.write().await.remove(session_id);
|
||||
}
|
||||
|
||||
async fn cleanup_expired(&self) {
|
||||
let ttl = self.ttl;
|
||||
let mut store = self.store.write().await;
|
||||
let before = store.len();
|
||||
store.retain(|_, (_, inserted_at)| inserted_at.elapsed() < ttl);
|
||||
let removed = before - store.len();
|
||||
if removed > 0 {
|
||||
info!(
|
||||
removed = removed,
|
||||
remaining = store.len(),
|
||||
"cleaned up expired session cache entries"
|
||||
);
|
||||
}
|
||||
async fn remove(&self, key: &str) {
|
||||
self.store.lock().await.pop(key);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common::configuration::Configuration;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, info};
|
||||
|
||||
pub mod memory;
|
||||
pub mod redis;
|
||||
|
|
@ -12,15 +16,55 @@ pub struct CachedRoute {
|
|||
|
||||
#[async_trait]
|
||||
pub trait SessionCache: Send + Sync {
|
||||
/// Look up a cached routing decision by session ID.
|
||||
async fn get(&self, session_id: &str) -> Option<CachedRoute>;
|
||||
/// Look up a cached routing decision by key.
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute>;
|
||||
|
||||
/// Store a routing decision in the session cache with the given TTL.
|
||||
async fn put(&self, session_id: &str, route: CachedRoute, ttl: Duration);
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration);
|
||||
|
||||
/// Remove a cached routing decision by session ID.
|
||||
async fn remove(&self, session_id: &str);
|
||||
|
||||
/// Remove all expired entries. No-op for backends that handle expiry natively (e.g. Redis).
|
||||
async fn cleanup_expired(&self);
|
||||
/// Remove a cached routing decision by key.
|
||||
async fn remove(&self, key: &str);
|
||||
}
|
||||
|
||||
/// Initialize the session cache backend from config.
|
||||
/// Defaults to the in-memory backend when no `session_cache` block is configured.
|
||||
pub async fn init_session_cache(
|
||||
config: &Configuration,
|
||||
) -> Result<Arc<dyn SessionCache>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use common::configuration::SessionCacheType;
|
||||
|
||||
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
|
||||
|
||||
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
|
||||
let max_entries = session_max_entries
|
||||
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
|
||||
.min(MAX_SESSION_MAX_ENTRIES);
|
||||
|
||||
let cache_config = config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.session_cache.as_ref());
|
||||
|
||||
let cache_type = cache_config
|
||||
.map(|c| &c.cache_type)
|
||||
.unwrap_or(&SessionCacheType::Memory);
|
||||
|
||||
match cache_type {
|
||||
SessionCacheType::Memory => {
|
||||
info!(storage_type = "memory", "initialized session cache");
|
||||
Ok(Arc::new(memory::MemorySessionCache::new(max_entries)))
|
||||
}
|
||||
SessionCacheType::Redis => {
|
||||
let url = cache_config
|
||||
.and_then(|c| c.url.as_ref())
|
||||
.ok_or("session_cache.url is required when type is redis")?;
|
||||
debug!(storage_type = "redis", url = %url, "initializing session cache");
|
||||
let cache = redis::RedisSessionCache::new(url)
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to Redis session cache: {e}"))?;
|
||||
Ok(Arc::new(cache))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ use redis::AsyncCommands;
|
|||
|
||||
use super::{CachedRoute, SessionCache};
|
||||
|
||||
const KEY_PREFIX: &str = "plano:affinity:";
|
||||
|
||||
pub struct RedisSessionCache {
|
||||
conn: MultiplexedConnection,
|
||||
}
|
||||
|
|
@ -16,31 +18,31 @@ impl RedisSessionCache {
|
|||
let conn = client.get_multiplexed_async_connection().await?;
|
||||
Ok(Self { conn })
|
||||
}
|
||||
|
||||
fn make_key(key: &str) -> String {
|
||||
format!("{KEY_PREFIX}{key}")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionCache for RedisSessionCache {
|
||||
async fn get(&self, session_id: &str) -> Option<CachedRoute> {
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute> {
|
||||
let mut conn = self.conn.clone();
|
||||
let value: Option<String> = conn.get(session_id).await.ok()?;
|
||||
let value: Option<String> = conn.get(Self::make_key(key)).await.ok()?;
|
||||
value.and_then(|v| serde_json::from_str(&v).ok())
|
||||
}
|
||||
|
||||
async fn put(&self, session_id: &str, route: CachedRoute, ttl: Duration) {
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration) {
|
||||
let mut conn = self.conn.clone();
|
||||
let Ok(json) = serde_json::to_string(&route) else {
|
||||
return;
|
||||
};
|
||||
let ttl_secs = ttl.as_secs().max(1);
|
||||
let _: Result<(), _> = conn.set_ex(session_id, json, ttl_secs).await;
|
||||
let _: Result<(), _> = conn.set_ex(Self::make_key(key), json, ttl_secs).await;
|
||||
}
|
||||
|
||||
async fn remove(&self, session_id: &str) {
|
||||
async fn remove(&self, key: &str) {
|
||||
let mut conn = self.conn.clone();
|
||||
let _: Result<(), _> = conn.del(session_id).await;
|
||||
}
|
||||
|
||||
async fn cleanup_expired(&self) {
|
||||
// Redis handles TTL expiry natively via EX — nothing to do here.
|
||||
let _: Result<(), _> = conn.del(Self::make_key(key)).await;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue