support session pinning for consistent model selection in routing (#813)

This commit is contained in:
Adil Hafeez 2026-03-13 17:32:32 -07:00
parent 785bf7e021
commit 46a5bfd82d
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
8 changed files with 406 additions and 3 deletions

View file

@ -1,6 +1,6 @@
use bytes::Bytes;
use common::configuration::{ModelUsagePreference, SpanAttributes};
use common::consts::{REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
use common::consts::{REQUEST_ID_HEADER, SESSION_ID_HEADER, TRACE_PARENT_HEADER};
use common::errors::BrightStaffError;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::ProviderRequestType;
@ -66,6 +66,9 @@ struct RoutingDecisionResponse {
model: String,
route: Option<String>,
trace_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
pinned: bool,
}
pub async fn routing_decision(
@ -81,6 +84,11 @@ pub async fn routing_decision(
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let session_id: Option<String> = request_headers
.get(SESSION_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
let custom_attrs =
collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref());
@ -99,6 +107,7 @@ pub async fn routing_decision(
request_path,
request_headers,
custom_attrs,
session_id,
)
.instrument(request_span)
.await
@ -111,6 +120,7 @@ async fn routing_decision_inner(
request_path: String,
request_headers: hyper::HeaderMap,
custom_attrs: std::collections::HashMap<String, String>,
session_id: Option<String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
set_service_name(operation_component::ROUTING);
opentelemetry::trace::get_active_span(|span| {
@ -144,6 +154,34 @@ async fn routing_decision_inner(
.unwrap_or("unknown")
.to_string();
// Session pinning: check cache before doing any routing work
if let Some(ref sid) = session_id {
if let Some(cached) = router_service.get_cached_route(sid).await {
info!(
session_id = %sid,
model = %cached.model_name,
route = ?cached.route_name,
"returning pinned routing decision from cache"
);
let response = RoutingDecisionResponse {
model: cached.model_name,
route: cached.route_name,
trace_id,
session_id: Some(sid.clone()),
pinned: true,
};
let json = serde_json::to_string(&response).unwrap();
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap());
}
}
// Parse request body
let raw_bytes = request.collect().await?.to_bytes();
@ -182,7 +220,7 @@ async fn routing_decision_inner(
// Call the existing routing logic with inline preferences
let routing_result = router_chat_get_upstream_model(
router_service,
Arc::clone(&router_service),
client_request,
&traceparent,
&request_path,
@ -193,10 +231,23 @@ async fn routing_decision_inner(
match routing_result {
Ok(result) => {
// Cache the result if session_id is present
if let Some(ref sid) = session_id {
router_service
.cache_route(
sid.clone(),
result.model_name.clone(),
result.route_name.clone(),
)
.await;
}
let response = RoutingDecisionResponse {
model: result.model_name,
route: result.route_name,
trace_id,
session_id,
pinned: false,
};
info!(
@ -334,12 +385,16 @@ mod tests {
model: "openai/gpt-4o".to_string(),
route: Some("code_generation".to_string()),
trace_id: "abc123".to_string(),
session_id: Some("sess-abc".to_string()),
pinned: true,
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "openai/gpt-4o");
assert_eq!(parsed["route"], "code_generation");
assert_eq!(parsed["trace_id"], "abc123");
assert_eq!(parsed["session_id"], "sess-abc");
assert_eq!(parsed["pinned"], true);
}
#[test]
@ -348,10 +403,14 @@ mod tests {
model: "none".to_string(),
route: None,
trace_id: "abc123".to_string(),
session_id: None,
pinned: false,
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "none");
assert!(parsed["route"].is_null());
assert!(parsed.get("session_id").is_none());
assert_eq!(parsed["pinned"], false);
}
}

View file

@ -102,13 +102,37 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.and_then(|r| r.model_provider.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
let session_ttl_seconds = plano_config
.routing
.as_ref()
.and_then(|r| r.session_ttl_seconds);
let session_max_entries = plano_config
.routing
.as_ref()
.and_then(|r| r.session_max_entries);
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
plano_config.model_providers.clone(),
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
routing_model_name,
routing_llm_provider,
session_ttl_seconds,
session_max_entries,
));
// Spawn background task to clean up expired session cache entries every 5 minutes
{
let router_service = Arc::clone(&router_service);
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
interval.tick().await;
router_service.cleanup_expired_sessions().await;
}
});
}
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),

View file

@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant};
use common::{
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
@ -7,12 +7,23 @@ use common::{
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
use hyper::header;
use thiserror::Error;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::router::router_model_v1::{self};
use super::router_model::RouterModel;
const DEFAULT_SESSION_TTL_SECONDS: u64 = 600;
const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000;
#[derive(Clone, Debug)]
pub struct CachedRoute {
pub model_name: String,
pub route_name: Option<String>,
pub cached_at: Instant,
}
pub struct RouterService {
router_url: String,
client: reqwest::Client,
@ -20,6 +31,9 @@ pub struct RouterService {
#[allow(dead_code)]
routing_provider_name: String,
llm_usage_defined: bool,
session_cache: RwLock<HashMap<String, CachedRoute>>,
session_ttl: Duration,
session_max_entries: usize,
}
#[derive(Debug, Error)]
@ -42,6 +56,8 @@ impl RouterService {
router_url: String,
routing_model_name: String,
routing_provider_name: String,
session_ttl_seconds: Option<u64>,
session_max_entries: Option<usize>,
) -> Self {
let providers_with_usage = providers
.iter()
@ -65,12 +81,75 @@ impl RouterService {
router_model_v1::MAX_TOKEN_LEN,
));
let session_ttl =
Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS));
let session_max_entries = session_max_entries.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES);
RouterService {
router_url,
client: reqwest::Client::new(),
router_model,
routing_provider_name,
llm_usage_defined: !providers_with_usage.is_empty(),
session_cache: RwLock::new(HashMap::new()),
session_ttl,
session_max_entries,
}
}
/// Look up a cached routing decision by session ID.
/// Returns None if not found or expired.
pub async fn get_cached_route(&self, session_id: &str) -> Option<CachedRoute> {
let cache = self.session_cache.read().await;
if let Some(entry) = cache.get(session_id) {
if entry.cached_at.elapsed() < self.session_ttl {
return Some(entry.clone());
}
}
None
}
/// Store a routing decision in the session cache.
/// If at max capacity, evicts the oldest entry.
pub async fn cache_route(
&self,
session_id: String,
model_name: String,
route_name: Option<String>,
) {
let mut cache = self.session_cache.write().await;
if cache.len() >= self.session_max_entries && !cache.contains_key(&session_id) {
// Evict the oldest entry
if let Some(oldest_key) = cache
.iter()
.min_by_key(|(_, v)| v.cached_at)
.map(|(k, _)| k.clone())
{
cache.remove(&oldest_key);
}
}
cache.insert(
session_id,
CachedRoute {
model_name,
route_name,
cached_at: Instant::now(),
},
);
}
/// Remove all expired entries from the session cache.
pub async fn cleanup_expired_sessions(&self) {
let mut cache = self.session_cache.write().await;
let before = cache.len();
cache.retain(|_, entry| entry.cached_at.elapsed() < self.session_ttl);
let removed = before - cache.len();
if removed > 0 {
info!(
removed = removed,
remaining = cache.len(),
"cleaned up expired session cache entries"
);
}
}
@ -185,3 +264,105 @@ impl RouterService {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_router_service(ttl_seconds: u64, max_entries: usize) -> RouterService {
RouterService::new(
vec![],
"http://localhost:12001/v1/chat/completions".to_string(),
"Arch-Router".to_string(),
"arch-router".to_string(),
Some(ttl_seconds),
Some(max_entries),
)
}
#[tokio::test]
async fn test_cache_miss_returns_none() {
let svc = make_router_service(600, 100);
assert!(svc.get_cached_route("unknown-session").await.is_none());
}
#[tokio::test]
async fn test_cache_hit_returns_cached_route() {
let svc = make_router_service(600, 100);
svc.cache_route(
"s1".to_string(),
"gpt-4o".to_string(),
Some("code".to_string()),
)
.await;
let cached = svc.get_cached_route("s1").await.unwrap();
assert_eq!(cached.model_name, "gpt-4o");
assert_eq!(cached.route_name, Some("code".to_string()));
}
#[tokio::test]
async fn test_cache_expired_entry_returns_none() {
let svc = make_router_service(0, 100);
svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None)
.await;
// TTL is 0 seconds, so the entry should be expired immediately
assert!(svc.get_cached_route("s1").await.is_none());
}
#[tokio::test]
async fn test_cleanup_removes_expired() {
let svc = make_router_service(0, 100);
svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None)
.await;
svc.cache_route("s2".to_string(), "claude".to_string(), None)
.await;
svc.cleanup_expired_sessions().await;
let cache = svc.session_cache.read().await;
assert!(cache.is_empty());
}
#[tokio::test]
async fn test_cache_evicts_oldest_when_full() {
let svc = make_router_service(600, 2);
svc.cache_route("s1".to_string(), "model-a".to_string(), None)
.await;
// Small delay so s2 has a later cached_at
tokio::time::sleep(Duration::from_millis(10)).await;
svc.cache_route("s2".to_string(), "model-b".to_string(), None)
.await;
// Cache is full (2 entries). Adding s3 should evict s1 (oldest).
svc.cache_route("s3".to_string(), "model-c".to_string(), None)
.await;
let cache = svc.session_cache.read().await;
assert_eq!(cache.len(), 2);
assert!(!cache.contains_key("s1"));
assert!(cache.contains_key("s2"));
assert!(cache.contains_key("s3"));
}
#[tokio::test]
async fn test_cache_update_existing_session_does_not_evict() {
let svc = make_router_service(600, 2);
svc.cache_route("s1".to_string(), "model-a".to_string(), None)
.await;
svc.cache_route("s2".to_string(), "model-b".to_string(), None)
.await;
// Updating s1 should not trigger eviction since key already exists
svc.cache_route(
"s1".to_string(),
"model-a-updated".to_string(),
Some("route".to_string()),
)
.await;
let cache = svc.session_cache.read().await;
assert_eq!(cache.len(), 2);
assert_eq!(cache.get("s1").unwrap().model_name, "model-a-updated");
}
}