mirror of
https://github.com/katanemo/plano.git
synced 2026-05-11 00:32:42 +02:00
Redis-backed session cache for cross-replica model affinity (#879)
Some checks failed
CI / pre-commit (push) Has been cancelled
CI / plano-tools-tests (push) Has been cancelled
CI / native-smoke-test (push) Has been cancelled
CI / docker-build (push) Has been cancelled
CI / validate-config (push) Has been cancelled
Publish docker image (latest) / build-arm64 (push) Has been cancelled
Publish docker image (latest) / build-amd64 (push) Has been cancelled
Build and Deploy Documentation / build (push) Has been cancelled
CI / security-scan (push) Has been cancelled
CI / test-prompt-gateway (push) Has been cancelled
CI / test-model-alias-routing (push) Has been cancelled
CI / test-responses-api-with-state (push) Has been cancelled
CI / e2e-plano-tests (3.10) (push) Has been cancelled
CI / e2e-plano-tests (3.11) (push) Has been cancelled
CI / e2e-plano-tests (3.12) (push) Has been cancelled
CI / e2e-plano-tests (3.13) (push) Has been cancelled
CI / e2e-plano-tests (3.14) (push) Has been cancelled
CI / e2e-demo-preference (push) Has been cancelled
CI / e2e-demo-currency (push) Has been cancelled
Publish docker image (latest) / create-manifest (push) Has been cancelled
Some checks failed
CI / pre-commit (push) Has been cancelled
CI / plano-tools-tests (push) Has been cancelled
CI / native-smoke-test (push) Has been cancelled
CI / docker-build (push) Has been cancelled
CI / validate-config (push) Has been cancelled
Publish docker image (latest) / build-arm64 (push) Has been cancelled
Publish docker image (latest) / build-amd64 (push) Has been cancelled
Build and Deploy Documentation / build (push) Has been cancelled
CI / security-scan (push) Has been cancelled
CI / test-prompt-gateway (push) Has been cancelled
CI / test-model-alias-routing (push) Has been cancelled
CI / test-responses-api-with-state (push) Has been cancelled
CI / e2e-plano-tests (3.10) (push) Has been cancelled
CI / e2e-plano-tests (3.11) (push) Has been cancelled
CI / e2e-plano-tests (3.12) (push) Has been cancelled
CI / e2e-plano-tests (3.13) (push) Has been cancelled
CI / e2e-plano-tests (3.14) (push) Has been cancelled
CI / e2e-demo-preference (push) Has been cancelled
CI / e2e-demo-currency (push) Has been cancelled
Publish docker image (latest) / create-manifest (push) Has been cancelled
* add pluggable session cache with Redis backend
* add Redis session affinity demos (Docker Compose and Kubernetes)
* address PR review feedback on session cache
* document Redis session cache backend for model affinity
* sync rendered config reference with session_cache addition
* add tenant-scoped Redis session cache keys and remove dead log_affinity_hit
- Add tenant_header to SessionCacheConfig; when set, cache keys are scoped
as plano:affinity:{tenant_id}:{session_id} for multi-tenant isolation
- Thread tenant_id through RouterService, routing_service, and llm handlers
- Use Cow<'_, str> in session_key to avoid allocation when no tenant is set
- Remove unused log_affinity_hit (logging was already inlined at call sites)
* remove session_affinity_redis and session_affinity_redis_k8s demos
This commit is contained in:
parent
128059e7c1
commit
980faef6be
15 changed files with 1538 additions and 729 deletions
|
|
@ -99,10 +99,16 @@ async fn llm_chat_inner(
|
|||
.get(MODEL_AFFINITY_HEADER)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let tenant_id: Option<String> = state
|
||||
.router_service
|
||||
.tenant_header()
|
||||
.and_then(|hdr| request_headers.get(hdr))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
let pinned_model: Option<String> = if let Some(ref sid) = session_id {
|
||||
state
|
||||
.router_service
|
||||
.get_cached_route(sid)
|
||||
.get_cached_route(sid, tenant_id.as_deref())
|
||||
.await
|
||||
.map(|c| c.model_name)
|
||||
} else {
|
||||
|
|
@ -313,7 +319,7 @@ async fn llm_chat_inner(
|
|||
if let Some(ref sid) = session_id {
|
||||
state
|
||||
.router_service
|
||||
.cache_route(sid.clone(), model.clone(), route_name)
|
||||
.cache_route(sid.clone(), tenant_id.as_deref(), model.clone(), route_name)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -76,6 +76,12 @@ pub async fn routing_decision(
|
|||
.and_then(|h| h.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let tenant_id: Option<String> = router_service
|
||||
.tenant_header()
|
||||
.and_then(|hdr| request_headers.get(hdr))
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref());
|
||||
|
||||
let request_span = info_span!(
|
||||
|
|
@ -94,11 +100,13 @@ pub async fn routing_decision(
|
|||
request_headers,
|
||||
custom_attrs,
|
||||
session_id,
|
||||
tenant_id,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn routing_decision_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
|
|
@ -107,6 +115,7 @@ async fn routing_decision_inner(
|
|||
request_headers: hyper::HeaderMap,
|
||||
custom_attrs: std::collections::HashMap<String, String>,
|
||||
session_id: Option<String>,
|
||||
tenant_id: Option<String>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
set_service_name(operation_component::ROUTING);
|
||||
opentelemetry::trace::get_active_span(|span| {
|
||||
|
|
@ -126,7 +135,10 @@ async fn routing_decision_inner(
|
|||
|
||||
// Session pinning: check cache before doing any routing work
|
||||
if let Some(ref sid) = session_id {
|
||||
if let Some(cached) = router_service.get_cached_route(sid).await {
|
||||
if let Some(cached) = router_service
|
||||
.get_cached_route(sid, tenant_id.as_deref())
|
||||
.await
|
||||
{
|
||||
info!(
|
||||
session_id = %sid,
|
||||
model = %cached.model_name,
|
||||
|
|
@ -206,6 +218,7 @@ async fn routing_decision_inner(
|
|||
router_service
|
||||
.cache_route(
|
||||
sid.clone(),
|
||||
tenant_id.as_deref(),
|
||||
result.model_name.clone(),
|
||||
result.route_name.clone(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
pub mod app_state;
|
||||
pub mod handlers;
|
||||
pub mod router;
|
||||
pub mod session_cache;
|
||||
pub mod signals;
|
||||
pub mod state;
|
||||
pub mod streaming;
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ use brightstaff::handlers::routing_service::routing_decision;
|
|||
use brightstaff::router::llm::RouterService;
|
||||
use brightstaff::router::model_metrics::ModelMetricsService;
|
||||
use brightstaff::router::orchestrator::OrchestratorService;
|
||||
use brightstaff::session_cache::init_session_cache;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::StateStorage;
|
||||
|
|
@ -175,7 +176,7 @@ async fn init_app_state(
|
|||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||
|
||||
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
|
||||
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
|
||||
let session_cache = init_session_cache(config).await?;
|
||||
|
||||
// Validate that top-level routing_preferences requires v0.4.0+.
|
||||
let config_version = parse_semver(&config.version);
|
||||
|
|
@ -297,6 +298,12 @@ async fn init_app_state(
|
|||
}
|
||||
}
|
||||
|
||||
let session_tenant_header = config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.session_cache.as_ref())
|
||||
.and_then(|c| c.tenant_header.clone());
|
||||
|
||||
let router_service = Arc::new(RouterService::new(
|
||||
config.routing_preferences.clone(),
|
||||
metrics_service,
|
||||
|
|
@ -304,21 +311,10 @@ async fn init_app_state(
|
|||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
session_ttl_seconds,
|
||||
session_max_entries,
|
||||
session_cache,
|
||||
session_tenant_header,
|
||||
));
|
||||
|
||||
// Spawn background task to clean up expired session cache entries every 5 minutes
|
||||
{
|
||||
let router_service = Arc::clone(&router_service);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
router_service.cleanup_expired_sessions().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let orchestrator_model_name: String = overrides
|
||||
.agent_orchestration_model
|
||||
.as_deref()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant};
|
||||
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use common::{
|
||||
configuration::TopLevelRoutingPreference,
|
||||
|
|
@ -9,7 +9,6 @@ use super::router_model::{ModelUsagePreference, RoutingPreference};
|
|||
use hermesllm::apis::openai::Message;
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::http::{self, post_and_extract_content};
|
||||
|
|
@ -17,17 +16,11 @@ use super::model_metrics::ModelMetricsService;
|
|||
use super::router_model::RouterModel;
|
||||
|
||||
use crate::router::router_model_v1;
|
||||
use crate::session_cache::SessionCache;
|
||||
|
||||
pub use crate::session_cache::CachedRoute;
|
||||
|
||||
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
|
||||
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedRoute {
|
||||
pub model_name: String,
|
||||
pub route_name: Option<String>,
|
||||
pub cached_at: Instant,
|
||||
}
|
||||
|
||||
pub struct RouterService {
|
||||
router_url: String,
|
||||
|
|
@ -36,9 +29,9 @@ pub struct RouterService {
|
|||
routing_provider_name: String,
|
||||
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
session_cache: RwLock<HashMap<String, CachedRoute>>,
|
||||
session_cache: Arc<dyn SessionCache>,
|
||||
session_ttl: Duration,
|
||||
session_max_entries: usize,
|
||||
tenant_header: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
@ -53,6 +46,7 @@ pub enum RoutingError {
|
|||
pub type Result<T> = std::result::Result<T, RoutingError>;
|
||||
|
||||
impl RouterService {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
|
|
@ -60,7 +54,8 @@ impl RouterService {
|
|||
routing_model_name: String,
|
||||
routing_provider_name: String,
|
||||
session_ttl_seconds: Option<u64>,
|
||||
session_max_entries: Option<usize>,
|
||||
session_cache: Arc<dyn SessionCache>,
|
||||
tenant_header: Option<String>,
|
||||
) -> Self {
|
||||
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
|
||||
.map_or_else(HashMap::new, |prefs| {
|
||||
|
|
@ -93,9 +88,6 @@ impl RouterService {
|
|||
|
||||
let session_ttl =
|
||||
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
|
||||
let session_max_entries = session_max_entries
|
||||
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
|
||||
.min(MAX_SESSION_MAX_ENTRIES);
|
||||
|
||||
RouterService {
|
||||
router_url,
|
||||
|
|
@ -104,65 +96,57 @@ impl RouterService {
|
|||
routing_provider_name,
|
||||
top_level_preferences,
|
||||
metrics_service,
|
||||
session_cache: RwLock::new(HashMap::new()),
|
||||
session_cache,
|
||||
session_ttl,
|
||||
session_max_entries,
|
||||
tenant_header,
|
||||
}
|
||||
}
|
||||
|
||||
/// Name of the HTTP header used to scope cache keys by tenant, if configured.
|
||||
#[must_use]
|
||||
pub fn tenant_header(&self) -> Option<&str> {
|
||||
self.tenant_header.as_deref()
|
||||
}
|
||||
|
||||
/// Build the cache key, optionally scoped by tenant: `{tenant_id}:{session_id}` or `{session_id}`.
|
||||
/// Returns a borrowed key when no tenant prefix is needed, avoiding an allocation.
|
||||
fn session_key<'a>(tenant_id: Option<&str>, session_id: &'a str) -> Cow<'a, str> {
|
||||
match tenant_id {
|
||||
Some(t) => Cow::Owned(format!("{t}:{session_id}")),
|
||||
None => Cow::Borrowed(session_id),
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up a cached routing decision by session ID.
|
||||
/// Returns None if not found or expired.
|
||||
pub async fn get_cached_route(&self, session_id: &str) -> Option<CachedRoute> {
|
||||
let cache = self.session_cache.read().await;
|
||||
if let Some(entry) = cache.get(session_id) {
|
||||
if entry.cached_at.elapsed() < self.session_ttl {
|
||||
return Some(entry.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
pub async fn get_cached_route(
|
||||
&self,
|
||||
session_id: &str,
|
||||
tenant_id: Option<&str>,
|
||||
) -> Option<CachedRoute> {
|
||||
self.session_cache
|
||||
.get(&Self::session_key(tenant_id, session_id))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Store a routing decision in the session cache.
|
||||
/// If at max capacity, evicts the oldest entry.
|
||||
pub async fn cache_route(
|
||||
&self,
|
||||
session_id: String,
|
||||
tenant_id: Option<&str>,
|
||||
model_name: String,
|
||||
route_name: Option<String>,
|
||||
) {
|
||||
let mut cache = self.session_cache.write().await;
|
||||
if cache.len() >= self.session_max_entries && !cache.contains_key(&session_id) {
|
||||
if let Some(oldest_key) = cache
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.cached_at)
|
||||
.map(|(k, _)| k.clone())
|
||||
{
|
||||
cache.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
cache.insert(
|
||||
session_id,
|
||||
CachedRoute {
|
||||
model_name,
|
||||
route_name,
|
||||
cached_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Remove all expired entries from the session cache.
|
||||
pub async fn cleanup_expired_sessions(&self) {
|
||||
let mut cache = self.session_cache.write().await;
|
||||
let before = cache.len();
|
||||
cache.retain(|_, entry| entry.cached_at.elapsed() < self.session_ttl);
|
||||
let removed = before - cache.len();
|
||||
if removed > 0 {
|
||||
info!(
|
||||
removed = removed,
|
||||
remaining = cache.len(),
|
||||
"cleaned up expired session cache entries"
|
||||
);
|
||||
}
|
||||
self.session_cache
|
||||
.put(
|
||||
&Self::session_key(tenant_id, &session_id),
|
||||
CachedRoute {
|
||||
model_name,
|
||||
route_name,
|
||||
},
|
||||
self.session_ttl,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub async fn determine_route(
|
||||
|
|
@ -283,8 +267,10 @@ impl RouterService {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::session_cache::memory::MemorySessionCache;
|
||||
|
||||
fn make_router_service(ttl_seconds: u64, max_entries: usize) -> RouterService {
|
||||
let session_cache = Arc::new(MemorySessionCache::new(max_entries));
|
||||
RouterService::new(
|
||||
None,
|
||||
None,
|
||||
|
|
@ -292,14 +278,18 @@ mod tests {
|
|||
"Arch-Router".to_string(),
|
||||
"arch-router".to_string(),
|
||||
Some(ttl_seconds),
|
||||
Some(max_entries),
|
||||
session_cache,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_miss_returns_none() {
|
||||
let svc = make_router_service(600, 100);
|
||||
assert!(svc.get_cached_route("unknown-session").await.is_none());
|
||||
assert!(svc
|
||||
.get_cached_route("unknown-session", None)
|
||||
.await
|
||||
.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -307,12 +297,13 @@ mod tests {
|
|||
let svc = make_router_service(600, 100);
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
None,
|
||||
"gpt-4o".to_string(),
|
||||
Some("code".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let cached = svc.get_cached_route("s1").await.unwrap();
|
||||
let cached = svc.get_cached_route("s1", None).await.unwrap();
|
||||
assert_eq!(cached.model_name, "gpt-4o");
|
||||
assert_eq!(cached.route_name, Some("code".to_string()));
|
||||
}
|
||||
|
|
@ -320,61 +311,61 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_cache_expired_entry_returns_none() {
|
||||
let svc = make_router_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None)
|
||||
svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
assert!(svc.get_cached_route("s1").await.is_none());
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cleanup_removes_expired() {
|
||||
async fn test_expired_entries_not_returned() {
|
||||
let svc = make_router_service(0, 100);
|
||||
svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None)
|
||||
svc.cache_route("s1".to_string(), None, "gpt-4o".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), "claude".to_string(), None)
|
||||
svc.cache_route("s2".to_string(), None, "claude".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cleanup_expired_sessions().await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert!(cache.is_empty());
|
||||
// Entries with TTL=0 should be expired immediately
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
assert!(svc.get_cached_route("s2", None).await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_evicts_oldest_when_full() {
|
||||
let svc = make_router_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), "model-a".to_string(), None)
|
||||
svc.cache_route("s1".to_string(), None, "model-a".to_string(), None)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
svc.cache_route("s2".to_string(), "model-b".to_string(), None)
|
||||
svc.cache_route("s2".to_string(), None, "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route("s3".to_string(), "model-c".to_string(), None)
|
||||
svc.cache_route("s3".to_string(), None, "model-c".to_string(), None)
|
||||
.await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert_eq!(cache.len(), 2);
|
||||
assert!(!cache.contains_key("s1"));
|
||||
assert!(cache.contains_key("s2"));
|
||||
assert!(cache.contains_key("s3"));
|
||||
// s1 should be evicted (oldest); s2 and s3 should remain
|
||||
assert!(svc.get_cached_route("s1", None).await.is_none());
|
||||
assert!(svc.get_cached_route("s2", None).await.is_some());
|
||||
assert!(svc.get_cached_route("s3", None).await.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_update_existing_session_does_not_evict() {
|
||||
let svc = make_router_service(600, 2);
|
||||
svc.cache_route("s1".to_string(), "model-a".to_string(), None)
|
||||
svc.cache_route("s1".to_string(), None, "model-a".to_string(), None)
|
||||
.await;
|
||||
svc.cache_route("s2".to_string(), "model-b".to_string(), None)
|
||||
svc.cache_route("s2".to_string(), None, "model-b".to_string(), None)
|
||||
.await;
|
||||
|
||||
svc.cache_route(
|
||||
"s1".to_string(),
|
||||
None,
|
||||
"model-a-updated".to_string(),
|
||||
Some("route".to_string()),
|
||||
)
|
||||
.await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert_eq!(cache.len(), 2);
|
||||
assert_eq!(cache.get("s1").unwrap().model_name, "model-a-updated");
|
||||
// Both sessions should still be present
|
||||
let s1 = svc.get_cached_route("s1", None).await.unwrap();
|
||||
assert_eq!(s1.model_name, "model-a-updated");
|
||||
assert!(svc.get_cached_route("s2", None).await.is_some());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
82
crates/brightstaff/src/session_cache/memory.rs
Normal file
82
crates/brightstaff/src/session_cache/memory.rs
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
use std::{
|
||||
num::NonZeroUsize,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use lru::LruCache;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::info;
|
||||
|
||||
use super::{CachedRoute, SessionCache};
|
||||
|
||||
type CacheStore = Mutex<LruCache<String, (CachedRoute, Instant, Duration)>>;
|
||||
|
||||
pub struct MemorySessionCache {
|
||||
store: Arc<CacheStore>,
|
||||
}
|
||||
|
||||
impl MemorySessionCache {
|
||||
pub fn new(max_entries: usize) -> Self {
|
||||
let capacity = NonZeroUsize::new(max_entries)
|
||||
.unwrap_or_else(|| NonZeroUsize::new(10_000).expect("10_000 is non-zero"));
|
||||
let store = Arc::new(Mutex::new(LruCache::new(capacity)));
|
||||
|
||||
// Spawn a background task to evict TTL-expired entries every 5 minutes.
|
||||
let store_clone = Arc::clone(&store);
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(300));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
Self::evict_expired(&store_clone).await;
|
||||
}
|
||||
});
|
||||
|
||||
Self { store }
|
||||
}
|
||||
|
||||
async fn evict_expired(store: &CacheStore) {
|
||||
let mut cache = store.lock().await;
|
||||
let expired: Vec<String> = cache
|
||||
.iter()
|
||||
.filter(|(_, (_, inserted_at, ttl))| inserted_at.elapsed() >= *ttl)
|
||||
.map(|(k, _)| k.clone())
|
||||
.collect();
|
||||
let removed = expired.len();
|
||||
for key in &expired {
|
||||
cache.pop(key.as_str());
|
||||
}
|
||||
if removed > 0 {
|
||||
info!(
|
||||
removed = removed,
|
||||
remaining = cache.len(),
|
||||
"cleaned up expired session cache entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionCache for MemorySessionCache {
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute> {
|
||||
let mut cache = self.store.lock().await;
|
||||
if let Some((route, inserted_at, ttl)) = cache.get(key) {
|
||||
if inserted_at.elapsed() < *ttl {
|
||||
return Some(route.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration) {
|
||||
self.store
|
||||
.lock()
|
||||
.await
|
||||
.put(key.to_string(), (route, Instant::now(), ttl));
|
||||
}
|
||||
|
||||
async fn remove(&self, key: &str) {
|
||||
self.store.lock().await.pop(key);
|
||||
}
|
||||
}
|
||||
70
crates/brightstaff/src/session_cache/mod.rs
Normal file
70
crates/brightstaff/src/session_cache/mod.rs
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common::configuration::Configuration;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, info};
|
||||
|
||||
pub mod memory;
|
||||
pub mod redis;
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct CachedRoute {
|
||||
pub model_name: String,
|
||||
pub route_name: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait SessionCache: Send + Sync {
|
||||
/// Look up a cached routing decision by key.
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute>;
|
||||
|
||||
/// Store a routing decision in the session cache with the given TTL.
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration);
|
||||
|
||||
/// Remove a cached routing decision by key.
|
||||
async fn remove(&self, key: &str);
|
||||
}
|
||||
|
||||
/// Initialize the session cache backend from config.
|
||||
/// Defaults to the in-memory backend when no `session_cache` block is configured.
|
||||
pub async fn init_session_cache(
|
||||
config: &Configuration,
|
||||
) -> Result<Arc<dyn SessionCache>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use common::configuration::SessionCacheType;
|
||||
|
||||
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
|
||||
|
||||
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
|
||||
let max_entries = session_max_entries
|
||||
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
|
||||
.min(MAX_SESSION_MAX_ENTRIES);
|
||||
|
||||
let cache_config = config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.session_cache.as_ref());
|
||||
|
||||
let cache_type = cache_config
|
||||
.map(|c| &c.cache_type)
|
||||
.unwrap_or(&SessionCacheType::Memory);
|
||||
|
||||
match cache_type {
|
||||
SessionCacheType::Memory => {
|
||||
info!(storage_type = "memory", "initialized session cache");
|
||||
Ok(Arc::new(memory::MemorySessionCache::new(max_entries)))
|
||||
}
|
||||
SessionCacheType::Redis => {
|
||||
let url = cache_config
|
||||
.and_then(|c| c.url.as_ref())
|
||||
.ok_or("session_cache.url is required when type is redis")?;
|
||||
debug!(storage_type = "redis", url = %url, "initializing session cache");
|
||||
let cache = redis::RedisSessionCache::new(url)
|
||||
.await
|
||||
.map_err(|e| format!("failed to connect to Redis session cache: {e}"))?;
|
||||
Ok(Arc::new(cache))
|
||||
}
|
||||
}
|
||||
}
|
||||
48
crates/brightstaff/src/session_cache/redis.rs
Normal file
48
crates/brightstaff/src/session_cache/redis.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use redis::aio::MultiplexedConnection;
|
||||
use redis::AsyncCommands;
|
||||
|
||||
use super::{CachedRoute, SessionCache};
|
||||
|
||||
const KEY_PREFIX: &str = "plano:affinity:";
|
||||
|
||||
pub struct RedisSessionCache {
|
||||
conn: MultiplexedConnection,
|
||||
}
|
||||
|
||||
impl RedisSessionCache {
|
||||
pub async fn new(url: &str) -> Result<Self, redis::RedisError> {
|
||||
let client = redis::Client::open(url)?;
|
||||
let conn = client.get_multiplexed_async_connection().await?;
|
||||
Ok(Self { conn })
|
||||
}
|
||||
|
||||
fn make_key(key: &str) -> String {
|
||||
format!("{KEY_PREFIX}{key}")
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionCache for RedisSessionCache {
|
||||
async fn get(&self, key: &str) -> Option<CachedRoute> {
|
||||
let mut conn = self.conn.clone();
|
||||
let value: Option<String> = conn.get(Self::make_key(key)).await.ok()?;
|
||||
value.and_then(|v| serde_json::from_str(&v).ok())
|
||||
}
|
||||
|
||||
async fn put(&self, key: &str, route: CachedRoute, ttl: Duration) {
|
||||
let mut conn = self.conn.clone();
|
||||
let Ok(json) = serde_json::to_string(&route) else {
|
||||
return;
|
||||
};
|
||||
let ttl_secs = ttl.as_secs().max(1);
|
||||
let _: Result<(), _> = conn.set_ex(Self::make_key(key), json, ttl_secs).await;
|
||||
}
|
||||
|
||||
async fn remove(&self, key: &str) {
|
||||
let mut conn = self.conn.clone();
|
||||
let _: Result<(), _> = conn.del(Self::make_key(key)).await;
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue