mirror of
https://github.com/katanemo/plano.git
synced 2026-06-11 15:05:14 +02:00
feat: add SessionCache trait with memory and redis backends for model affinity
Agent-Logs-Url: https://github.com/katanemo/plano/sessions/6a75240b-4578-409d-b8c7-eff47dba8a03 Co-authored-by: adilhafeez <13196462+adilhafeez@users.noreply.github.com>
This commit is contained in:
parent
b822e27957
commit
66f8230dd5
8 changed files with 1168 additions and 658 deletions
|
|
@ -441,6 +441,20 @@ properties:
|
|||
minimum: 1
|
||||
maximum: 10000
|
||||
description: Maximum number of session-pinned routing cache entries. Default 10000.
|
||||
session_cache:
|
||||
type: object
|
||||
description: Session cache backend configuration. Defaults to in-memory if omitted.
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
enum: [memory, redis]
|
||||
description: "Cache backend type. 'memory' (default) is in-process; 'redis' is shared across replicas."
|
||||
url:
|
||||
type: string
|
||||
description: "Redis connection URL (e.g. redis://localhost:6379). Required when type is 'redis'."
|
||||
required:
|
||||
- type
|
||||
additionalProperties: false
|
||||
additionalProperties: false
|
||||
state_storage:
|
||||
type: object
|
||||
|
|
|
|||
1308
crates/Cargo.lock
generated
1308
crates/Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -26,6 +26,7 @@ opentelemetry-stdout = "0.31"
|
|||
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
|
||||
pretty_assertions = "1.4.1"
|
||||
rand = "0.9.2"
|
||||
redis = { version = "0.27.6", features = ["tokio-comp"] }
|
||||
reqwest = { version = "0.12.15", features = ["stream"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ use brightstaff::handlers::routing_service::routing_decision;
|
|||
use brightstaff::router::llm::RouterService;
|
||||
use brightstaff::router::model_metrics::ModelMetricsService;
|
||||
use brightstaff::router::orchestrator::OrchestratorService;
|
||||
use brightstaff::router::session_cache::{MemorySessionCache, RedisSessionCache, SessionCache};
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::StateStorage;
|
||||
|
|
@ -176,6 +177,10 @@ async fn init_app_state(
|
|||
|
||||
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
|
||||
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
|
||||
let session_cache_config = config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|r| r.session_cache.clone());
|
||||
|
||||
// Validate that top-level routing_preferences requires v0.4.0+.
|
||||
let config_version = parse_semver(&config.version);
|
||||
|
|
@ -297,17 +302,24 @@ async fn init_app_state(
|
|||
}
|
||||
}
|
||||
|
||||
let router_service = Arc::new(RouterService::new(
|
||||
let session_cache = init_session_cache(
|
||||
session_cache_config,
|
||||
session_ttl_seconds,
|
||||
session_max_entries,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let router_service = Arc::new(RouterService::with_cache(
|
||||
config.routing_preferences.clone(),
|
||||
metrics_service,
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
session_ttl_seconds,
|
||||
session_max_entries,
|
||||
session_cache,
|
||||
));
|
||||
|
||||
// Spawn background task to clean up expired session cache entries every 5 minutes
|
||||
// Spawn background task to clean up expired session cache entries every 5 minutes.
|
||||
// For the Redis backend this is a no-op because Redis handles TTL natively.
|
||||
{
|
||||
let router_service = Arc::clone(&router_service);
|
||||
tokio::spawn(async move {
|
||||
|
|
@ -361,6 +373,59 @@ async fn init_app_state(
|
|||
})
|
||||
}
|
||||
|
||||
/// Initialize the session cache backend from the routing config.
|
||||
///
|
||||
/// - No config → memory backend with defaults
|
||||
/// - `type: memory` → memory backend (explicit)
|
||||
/// - `type: redis` → Redis backend; `url` is required
|
||||
async fn init_session_cache(
|
||||
cache_config: Option<common::configuration::SessionCacheConfig>,
|
||||
session_ttl_seconds: Option<u64>,
|
||||
session_max_entries: Option<usize>,
|
||||
) -> Result<Arc<dyn SessionCache>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use brightstaff::router::llm::{
|
||||
DEFAULT_SESSION_MAX_ENTRIES, DEFAULT_SESSION_TTL_SECONDS, MAX_SESSION_MAX_ENTRIES,
|
||||
};
|
||||
use common::configuration::SessionCacheType;
|
||||
|
||||
let ttl_secs = session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS);
|
||||
let max_entries = session_max_entries
|
||||
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
|
||||
.min(MAX_SESSION_MAX_ENTRIES);
|
||||
|
||||
let cache_type = cache_config
|
||||
.as_ref()
|
||||
.map(|c| &c.cache_type)
|
||||
.unwrap_or(&SessionCacheType::Memory);
|
||||
|
||||
let cache: Arc<dyn SessionCache> = match cache_type {
|
||||
SessionCacheType::Memory => {
|
||||
info!(cache_type = "memory", "initialized session cache");
|
||||
Arc::new(MemorySessionCache::new(
|
||||
std::time::Duration::from_secs(ttl_secs),
|
||||
max_entries,
|
||||
))
|
||||
}
|
||||
SessionCacheType::Redis => {
|
||||
let url = cache_config
|
||||
.as_ref()
|
||||
.and_then(|c| c.url.as_deref())
|
||||
.ok_or("session_cache.url is required for redis session cache")?;
|
||||
|
||||
debug!(url = %url, "redis session cache connection");
|
||||
info!(cache_type = "redis", url = %url, "initializing session cache");
|
||||
|
||||
Arc::new(
|
||||
RedisSessionCache::new(url, ttl_secs)
|
||||
.await
|
||||
.map_err(|e| format!("failed to initialize Redis session cache: {e}"))?,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(cache)
|
||||
}
|
||||
|
||||
/// Initialize the conversation state storage backend (if configured).
|
||||
async fn init_state_storage(
|
||||
config: &Configuration,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant};
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use common::{
|
||||
configuration::TopLevelRoutingPreference,
|
||||
|
|
@ -6,10 +6,10 @@ use common::{
|
|||
};
|
||||
|
||||
use super::router_model::{ModelUsagePreference, RoutingPreference};
|
||||
use super::session_cache::{CachedRoute, MemorySessionCache, SessionCache};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::http::{self, post_and_extract_content};
|
||||
|
|
@ -18,16 +18,9 @@ use super::router_model::RouterModel;
|
|||
|
||||
use crate::router::router_model_v1;
|
||||
|
||||
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
|
||||
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedRoute {
|
||||
pub model_name: String,
|
||||
pub route_name: Option<String>,
|
||||
pub cached_at: Instant,
|
||||
}
|
||||
pub const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
|
||||
pub const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
pub const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
|
||||
|
||||
pub struct RouterService {
|
||||
router_url: String,
|
||||
|
|
@ -36,9 +29,7 @@ pub struct RouterService {
|
|||
routing_provider_name: String,
|
||||
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
session_cache: RwLock<HashMap<String, CachedRoute>>,
|
||||
session_ttl: Duration,
|
||||
session_max_entries: usize,
|
||||
session_cache: Arc<dyn SessionCache>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
@ -61,6 +52,34 @@ impl RouterService {
|
|||
routing_provider_name: String,
|
||||
session_ttl_seconds: Option<u64>,
|
||||
session_max_entries: Option<usize>,
|
||||
) -> Self {
|
||||
let session_ttl =
|
||||
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
|
||||
let session_max_entries = session_max_entries
|
||||
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
|
||||
.min(MAX_SESSION_MAX_ENTRIES);
|
||||
|
||||
let session_cache: Arc<dyn SessionCache> =
|
||||
Arc::new(MemorySessionCache::new(session_ttl, session_max_entries));
|
||||
|
||||
RouterService::with_cache(
|
||||
top_level_prefs,
|
||||
metrics_service,
|
||||
router_url,
|
||||
routing_model_name,
|
||||
routing_provider_name,
|
||||
session_cache,
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a `RouterService` with an explicit `SessionCache` backend.
|
||||
pub fn with_cache(
|
||||
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
router_url: String,
|
||||
routing_model_name: String,
|
||||
routing_provider_name: String,
|
||||
session_cache: Arc<dyn SessionCache>,
|
||||
) -> Self {
|
||||
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
|
||||
.map_or_else(HashMap::new, |prefs| {
|
||||
|
|
@ -91,12 +110,6 @@ impl RouterService {
|
|||
router_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
let session_ttl =
|
||||
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
|
||||
let session_max_entries = session_max_entries
|
||||
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
|
||||
.min(MAX_SESSION_MAX_ENTRIES);
|
||||
|
||||
RouterService {
|
||||
router_url,
|
||||
client: reqwest::Client::new(),
|
||||
|
|
@ -104,65 +117,37 @@ impl RouterService {
|
|||
routing_provider_name,
|
||||
top_level_preferences,
|
||||
metrics_service,
|
||||
session_cache: RwLock::new(HashMap::new()),
|
||||
session_ttl,
|
||||
session_max_entries,
|
||||
session_cache,
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up a cached routing decision by session ID.
|
||||
/// Returns None if not found or expired.
|
||||
pub async fn get_cached_route(&self, session_id: &str) -> Option<CachedRoute> {
|
||||
let cache = self.session_cache.read().await;
|
||||
if let Some(entry) = cache.get(session_id) {
|
||||
if entry.cached_at.elapsed() < self.session_ttl {
|
||||
return Some(entry.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
self.session_cache.get(session_id).await
|
||||
}
|
||||
|
||||
/// Store a routing decision in the session cache.
|
||||
/// If at max capacity, evicts the oldest entry.
|
||||
pub async fn cache_route(
|
||||
&self,
|
||||
session_id: String,
|
||||
model_name: String,
|
||||
route_name: Option<String>,
|
||||
) {
|
||||
let mut cache = self.session_cache.write().await;
|
||||
if cache.len() >= self.session_max_entries && !cache.contains_key(&session_id) {
|
||||
if let Some(oldest_key) = cache
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.cached_at)
|
||||
.map(|(k, _)| k.clone())
|
||||
{
|
||||
cache.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
cache.insert(
|
||||
session_id,
|
||||
CachedRoute {
|
||||
model_name,
|
||||
route_name,
|
||||
cached_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
let route = CachedRoute {
|
||||
model_name,
|
||||
route_name,
|
||||
cached_at_ms: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or(Duration::ZERO)
|
||||
.as_millis() as u64,
|
||||
};
|
||||
self.session_cache.put(&session_id, route).await;
|
||||
}
|
||||
|
||||
/// Remove all expired entries from the session cache.
|
||||
pub async fn cleanup_expired_sessions(&self) {
|
||||
let mut cache = self.session_cache.write().await;
|
||||
let before = cache.len();
|
||||
cache.retain(|_, entry| entry.cached_at.elapsed() < self.session_ttl);
|
||||
let removed = before - cache.len();
|
||||
if removed > 0 {
|
||||
info!(
|
||||
removed = removed,
|
||||
remaining = cache.len(),
|
||||
"cleaned up expired session cache entries"
|
||||
);
|
||||
}
|
||||
self.session_cache.cleanup_expired().await;
|
||||
}
|
||||
|
||||
pub async fn determine_route(
|
||||
|
|
@ -335,8 +320,8 @@ mod tests {
|
|||
|
||||
svc.cleanup_expired_sessions().await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert!(cache.is_empty());
|
||||
assert!(svc.get_cached_route("s1").await.is_none());
|
||||
assert!(svc.get_cached_route("s2").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -351,11 +336,12 @@ mod tests {
|
|||
svc.cache_route("s3".to_string(), "model-c".to_string(), None)
|
||||
.await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert_eq!(cache.len(), 2);
|
||||
assert!(!cache.contains_key("s1"));
|
||||
assert!(cache.contains_key("s2"));
|
||||
assert!(cache.contains_key("s3"));
|
||||
assert!(
|
||||
svc.get_cached_route("s1").await.is_none(),
|
||||
"s1 should have been evicted"
|
||||
);
|
||||
assert!(svc.get_cached_route("s2").await.is_some());
|
||||
assert!(svc.get_cached_route("s3").await.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -373,8 +359,11 @@ mod tests {
|
|||
)
|
||||
.await;
|
||||
|
||||
let cache = svc.session_cache.read().await;
|
||||
assert_eq!(cache.len(), 2);
|
||||
assert_eq!(cache.get("s1").unwrap().model_name, "model-a-updated");
|
||||
assert!(svc.get_cached_route("s1").await.is_some());
|
||||
assert!(svc.get_cached_route("s2").await.is_some());
|
||||
assert_eq!(
|
||||
svc.get_cached_route("s1").await.unwrap().model_name,
|
||||
"model-a-updated"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,3 +6,4 @@ pub mod orchestrator_model;
|
|||
pub mod orchestrator_model_v1;
|
||||
pub mod router_model;
|
||||
pub mod router_model_v1;
|
||||
pub mod session_cache;
|
||||
|
|
|
|||
283
crates/brightstaff/src/router/session_cache.rs
Normal file
283
crates/brightstaff/src/router/session_cache.rs
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
use async_trait::async_trait;
|
||||
use redis::aio::MultiplexedConnection;
|
||||
use redis::AsyncCommands;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::info;
|
||||
|
||||
/// A cached routing decision stored by session ID.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct CachedRoute {
|
||||
pub model_name: String,
|
||||
pub route_name: Option<String>,
|
||||
/// Milliseconds since the UNIX epoch when this entry was created.
|
||||
/// Used only by the memory backend for TTL checks and eviction ordering;
|
||||
/// Redis uses native key expiry and ignores this field.
|
||||
pub cached_at_ms: u64,
|
||||
}
|
||||
|
||||
/// Abstracts the session-affinity cache so it can be backed by in-memory
|
||||
/// storage (default, single-replica) or Redis (multi-replica).
|
||||
#[async_trait]
|
||||
pub trait SessionCache: Send + Sync {
|
||||
/// Return a cached route for `session_id`, or `None` if absent/expired.
|
||||
async fn get(&self, session_id: &str) -> Option<CachedRoute>;
|
||||
|
||||
/// Store a routing decision for `session_id`.
|
||||
async fn put(&self, session_id: &str, route: CachedRoute);
|
||||
|
||||
/// Remove a session entry explicitly.
|
||||
async fn remove(&self, session_id: &str);
|
||||
|
||||
/// Evict all expired entries (no-op for backends with native TTL such as Redis).
|
||||
async fn cleanup_expired(&self);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// In-memory backend
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// In-process session cache backed by a `RwLock<HashMap>`.
|
||||
///
|
||||
/// This is the default backend and replicates the previous behaviour of
|
||||
/// `RouterService`. All state is local to the process, so it is only suitable
|
||||
/// for single-replica deployments.
|
||||
pub struct MemorySessionCache {
|
||||
inner: RwLock<HashMap<String, CachedRoute>>,
|
||||
ttl: Duration,
|
||||
max_entries: usize,
|
||||
}
|
||||
|
||||
impl MemorySessionCache {
|
||||
pub fn new(ttl: Duration, max_entries: usize) -> Self {
|
||||
Self {
|
||||
inner: RwLock::new(HashMap::new()),
|
||||
ttl,
|
||||
max_entries,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns milliseconds since the UNIX epoch (for TTL bookkeeping).
|
||||
fn unix_now_ms() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or(Duration::ZERO)
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionCache for MemorySessionCache {
|
||||
async fn get(&self, session_id: &str) -> Option<CachedRoute> {
|
||||
let cache = self.inner.read().await;
|
||||
if let Some(entry) = cache.get(session_id) {
|
||||
let age_ms = unix_now_ms().saturating_sub(entry.cached_at_ms);
|
||||
if Duration::from_millis(age_ms) < self.ttl {
|
||||
return Some(entry.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
async fn put(&self, session_id: &str, route: CachedRoute) {
|
||||
let mut cache = self.inner.write().await;
|
||||
if cache.len() >= self.max_entries && !cache.contains_key(session_id) {
|
||||
// Evict the oldest entry by `cached_at_ms`.
|
||||
if let Some(oldest_key) = cache
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.cached_at_ms)
|
||||
.map(|(k, _)| k.clone())
|
||||
{
|
||||
cache.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
cache.insert(session_id.to_string(), route);
|
||||
}
|
||||
|
||||
async fn remove(&self, session_id: &str) {
|
||||
self.inner.write().await.remove(session_id);
|
||||
}
|
||||
|
||||
async fn cleanup_expired(&self) {
|
||||
let mut cache = self.inner.write().await;
|
||||
let before = cache.len();
|
||||
let ttl_ms = self.ttl.as_millis() as u64;
|
||||
let now = unix_now_ms();
|
||||
cache.retain(|_, entry| now.saturating_sub(entry.cached_at_ms) < ttl_ms);
|
||||
let removed = before - cache.len();
|
||||
if removed > 0 {
|
||||
info!(
|
||||
removed = removed,
|
||||
remaining = cache.len(),
|
||||
"cleaned up expired session cache entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Redis backend
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Shared-state session cache backed by Redis.
|
||||
///
|
||||
/// Uses `SET … EX` for automatic TTL-based expiry so that expired sessions are
|
||||
/// cleaned up by Redis itself — no background cleanup task is needed.
|
||||
pub struct RedisSessionCache {
|
||||
conn: Arc<RwLock<MultiplexedConnection>>,
|
||||
ttl_secs: u64,
|
||||
}
|
||||
|
||||
impl RedisSessionCache {
|
||||
/// Connect to Redis at `url` and return a new cache instance.
|
||||
pub async fn new(
|
||||
url: &str,
|
||||
ttl_secs: u64,
|
||||
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let client = redis::Client::open(url)?;
|
||||
let conn = client.get_multiplexed_tokio_connection().await?;
|
||||
Ok(Self {
|
||||
conn: Arc::new(RwLock::new(conn)),
|
||||
ttl_secs,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionCache for RedisSessionCache {
|
||||
async fn get(&self, session_id: &str) -> Option<CachedRoute> {
|
||||
let mut conn = self.conn.write().await;
|
||||
let raw: Option<String> = conn.get(session_id).await.ok()?;
|
||||
let entry: CachedRoute = serde_json::from_str(&raw?).ok()?;
|
||||
Some(entry)
|
||||
}
|
||||
|
||||
async fn put(&self, session_id: &str, route: CachedRoute) {
|
||||
let Ok(json) = serde_json::to_string(&route) else {
|
||||
return;
|
||||
};
|
||||
let mut conn = self.conn.write().await;
|
||||
let _: redis::RedisResult<()> = conn.set_ex(session_id, json, self.ttl_secs).await;
|
||||
}
|
||||
|
||||
async fn remove(&self, session_id: &str) {
|
||||
let mut conn = self.conn.write().await;
|
||||
let _: redis::RedisResult<()> = conn.del(session_id).await;
|
||||
}
|
||||
|
||||
/// Redis handles expiry natively — this is a no-op.
|
||||
async fn cleanup_expired(&self) {}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_route(model: &str) -> CachedRoute {
|
||||
CachedRoute {
|
||||
model_name: model.to_string(),
|
||||
route_name: None,
|
||||
cached_at_ms: unix_now_ms(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn memory_cache_miss_returns_none() {
|
||||
let cache = MemorySessionCache::new(Duration::from_secs(600), 100);
|
||||
assert!(cache.get("unknown").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn memory_cache_hit_returns_entry() {
|
||||
let cache = MemorySessionCache::new(Duration::from_secs(600), 100);
|
||||
cache.put("s1", make_route("gpt-4o")).await;
|
||||
let hit = cache.get("s1").await.unwrap();
|
||||
assert_eq!(hit.model_name, "gpt-4o");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn memory_cache_expired_returns_none() {
|
||||
let cache = MemorySessionCache::new(Duration::ZERO, 100);
|
||||
cache.put("s1", make_route("gpt-4o")).await;
|
||||
assert!(cache.get("s1").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn memory_cache_cleanup_removes_expired() {
|
||||
let cache = MemorySessionCache::new(Duration::ZERO, 100);
|
||||
cache.put("s1", make_route("gpt-4o")).await;
|
||||
cache.put("s2", make_route("claude")).await;
|
||||
cache.cleanup_expired().await;
|
||||
assert!(cache.inner.read().await.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn memory_cache_evicts_oldest_when_full() {
|
||||
let cache = MemorySessionCache::new(Duration::from_secs(600), 2);
|
||||
cache
|
||||
.put(
|
||||
"s1",
|
||||
CachedRoute {
|
||||
model_name: "model-a".to_string(),
|
||||
route_name: None,
|
||||
cached_at_ms: unix_now_ms(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
cache
|
||||
.put(
|
||||
"s2",
|
||||
CachedRoute {
|
||||
model_name: "model-b".to_string(),
|
||||
route_name: None,
|
||||
cached_at_ms: unix_now_ms(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
cache
|
||||
.put(
|
||||
"s3",
|
||||
CachedRoute {
|
||||
model_name: "model-c".to_string(),
|
||||
route_name: None,
|
||||
cached_at_ms: unix_now_ms(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
let inner = cache.inner.read().await;
|
||||
assert_eq!(inner.len(), 2);
|
||||
assert!(!inner.contains_key("s1"), "s1 should have been evicted");
|
||||
assert!(inner.contains_key("s2"));
|
||||
assert!(inner.contains_key("s3"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn memory_cache_remove_deletes_entry() {
|
||||
let cache = MemorySessionCache::new(Duration::from_secs(600), 100);
|
||||
cache.put("s1", make_route("gpt-4o")).await;
|
||||
cache.remove("s1").await;
|
||||
assert!(cache.get("s1").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cached_route_serializes_round_trip() {
|
||||
let original = CachedRoute {
|
||||
model_name: "claude-3".to_string(),
|
||||
route_name: Some("code".to_string()),
|
||||
cached_at_ms: 1_700_000_000_000,
|
||||
};
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let decoded: CachedRoute = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(decoded.model_name, original.model_name);
|
||||
assert_eq!(decoded.route_name, original.route_name);
|
||||
assert_eq!(decoded.cached_at_ms, original.cached_at_ms);
|
||||
}
|
||||
}
|
||||
|
|
@ -7,12 +7,27 @@ use crate::api::open_ai::{
|
|||
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SessionCacheType {
|
||||
Memory,
|
||||
Redis,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionCacheConfig {
|
||||
#[serde(rename = "type")]
|
||||
pub cache_type: SessionCacheType,
|
||||
pub url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Routing {
|
||||
pub llm_provider: Option<String>,
|
||||
pub model: Option<String>,
|
||||
pub session_ttl_seconds: Option<u64>,
|
||||
pub session_max_entries: Option<usize>,
|
||||
pub session_cache: Option<SessionCacheConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue