Model affinity for consistent model selection in agentic loops (#827)
Some checks are pending
CI / pre-commit (push) Waiting to run
CI / plano-tools-tests (push) Waiting to run
CI / native-smoke-test (push) Waiting to run
CI / docker-build (push) Waiting to run
CI / validate-config (push) Waiting to run
CI / security-scan (push) Blocked by required conditions
CI / test-prompt-gateway (push) Blocked by required conditions
CI / test-model-alias-routing (push) Blocked by required conditions
CI / test-responses-api-with-state (push) Blocked by required conditions
CI / e2e-plano-tests (3.10) (push) Blocked by required conditions
CI / e2e-plano-tests (3.11) (push) Blocked by required conditions
CI / e2e-plano-tests (3.12) (push) Blocked by required conditions
CI / e2e-plano-tests (3.13) (push) Blocked by required conditions
CI / e2e-plano-tests (3.14) (push) Blocked by required conditions
CI / e2e-demo-preference (push) Blocked by required conditions
CI / e2e-demo-currency (push) Blocked by required conditions
Publish docker image (latest) / build-arm64 (push) Waiting to run
Publish docker image (latest) / build-amd64 (push) Waiting to run
Publish docker image (latest) / create-manifest (push) Blocked by required conditions
Build and Deploy Documentation / build (push) Waiting to run

This commit is contained in:
Adil Hafeez 2026-04-08 17:32:02 -07:00 committed by GitHub
parent 978b1ea722
commit 8dedf0bec1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 614 additions and 43 deletions

View file

@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant};
use common::{
configuration::TopLevelRoutingPreference,
@ -9,6 +9,7 @@ use super::router_model::{ModelUsagePreference, RoutingPreference};
use hermesllm::apis::openai::Message;
use hyper::header;
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::{debug, info};
use super::http::{self, post_and_extract_content};
@ -17,6 +18,17 @@ use super::router_model::RouterModel;
use crate::router::router_model_v1;
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
const MAX_SESSION_MAX_ENTRIES: usize = 10_000;
#[derive(Clone, Debug)]
pub struct CachedRoute {
pub model_name: String,
pub route_name: Option<String>,
pub cached_at: Instant,
}
pub struct RouterService {
router_url: String,
client: reqwest::Client,
@ -24,6 +36,9 @@ pub struct RouterService {
routing_provider_name: String,
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
metrics_service: Option<Arc<ModelMetricsService>>,
session_cache: RwLock<HashMap<String, CachedRoute>>,
session_ttl: Duration,
session_max_entries: usize,
}
#[derive(Debug, Error)]
@ -44,6 +59,8 @@ impl RouterService {
router_url: String,
routing_model_name: String,
routing_provider_name: String,
session_ttl_seconds: Option<u64>,
session_max_entries: Option<usize>,
) -> Self {
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
.map_or_else(HashMap::new, |prefs| {
@ -74,6 +91,12 @@ impl RouterService {
router_model_v1::MAX_TOKEN_LEN,
));
let session_ttl =
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
let session_max_entries = session_max_entries
.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES)
.min(MAX_SESSION_MAX_ENTRIES);
RouterService {
router_url,
client: reqwest::Client::new(),
@ -81,6 +104,64 @@ impl RouterService {
routing_provider_name,
top_level_preferences,
metrics_service,
session_cache: RwLock::new(HashMap::new()),
session_ttl,
session_max_entries,
}
}
/// Look up a cached routing decision by session ID.
/// Returns None if not found or expired.
pub async fn get_cached_route(&self, session_id: &str) -> Option<CachedRoute> {
let cache = self.session_cache.read().await;
if let Some(entry) = cache.get(session_id) {
if entry.cached_at.elapsed() < self.session_ttl {
return Some(entry.clone());
}
}
None
}
/// Store a routing decision in the session cache.
/// If at max capacity, evicts the oldest entry.
pub async fn cache_route(
&self,
session_id: String,
model_name: String,
route_name: Option<String>,
) {
let mut cache = self.session_cache.write().await;
if cache.len() >= self.session_max_entries && !cache.contains_key(&session_id) {
if let Some(oldest_key) = cache
.iter()
.min_by_key(|(_, v)| v.cached_at)
.map(|(k, _)| k.clone())
{
cache.remove(&oldest_key);
}
}
cache.insert(
session_id,
CachedRoute {
model_name,
route_name,
cached_at: Instant::now(),
},
);
}
/// Remove all expired entries from the session cache.
pub async fn cleanup_expired_sessions(&self) {
let mut cache = self.session_cache.write().await;
let before = cache.len();
cache.retain(|_, entry| entry.cached_at.elapsed() < self.session_ttl);
let removed = before - cache.len();
if removed > 0 {
info!(
removed = removed,
remaining = cache.len(),
"cleaned up expired session cache entries"
);
}
}
@ -198,3 +279,102 @@ impl RouterService {
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_router_service(ttl_seconds: u64, max_entries: usize) -> RouterService {
RouterService::new(
None,
None,
"http://localhost:12001/v1/chat/completions".to_string(),
"Arch-Router".to_string(),
"arch-router".to_string(),
Some(ttl_seconds),
Some(max_entries),
)
}
#[tokio::test]
async fn test_cache_miss_returns_none() {
let svc = make_router_service(600, 100);
assert!(svc.get_cached_route("unknown-session").await.is_none());
}
#[tokio::test]
async fn test_cache_hit_returns_cached_route() {
let svc = make_router_service(600, 100);
svc.cache_route(
"s1".to_string(),
"gpt-4o".to_string(),
Some("code".to_string()),
)
.await;
let cached = svc.get_cached_route("s1").await.unwrap();
assert_eq!(cached.model_name, "gpt-4o");
assert_eq!(cached.route_name, Some("code".to_string()));
}
#[tokio::test]
async fn test_cache_expired_entry_returns_none() {
let svc = make_router_service(0, 100);
svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None)
.await;
assert!(svc.get_cached_route("s1").await.is_none());
}
#[tokio::test]
async fn test_cleanup_removes_expired() {
let svc = make_router_service(0, 100);
svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None)
.await;
svc.cache_route("s2".to_string(), "claude".to_string(), None)
.await;
svc.cleanup_expired_sessions().await;
let cache = svc.session_cache.read().await;
assert!(cache.is_empty());
}
#[tokio::test]
async fn test_cache_evicts_oldest_when_full() {
let svc = make_router_service(600, 2);
svc.cache_route("s1".to_string(), "model-a".to_string(), None)
.await;
tokio::time::sleep(Duration::from_millis(10)).await;
svc.cache_route("s2".to_string(), "model-b".to_string(), None)
.await;
svc.cache_route("s3".to_string(), "model-c".to_string(), None)
.await;
let cache = svc.session_cache.read().await;
assert_eq!(cache.len(), 2);
assert!(!cache.contains_key("s1"));
assert!(cache.contains_key("s2"));
assert!(cache.contains_key("s3"));
}
#[tokio::test]
async fn test_cache_update_existing_session_does_not_evict() {
let svc = make_router_service(600, 2);
svc.cache_route("s1".to_string(), "model-a".to_string(), None)
.await;
svc.cache_route("s2".to_string(), "model-b".to_string(), None)
.await;
svc.cache_route(
"s1".to_string(),
"model-a-updated".to_string(),
Some("route".to_string()),
)
.await;
let cache = svc.session_cache.read().await;
assert_eq!(cache.len(), 2);
assert_eq!(cache.get("s1").unwrap().model_name, "model-a-updated");
}
}