feat: add policy provider integration and caching mechanism

This commit is contained in:
Musa 2026-03-11 07:39:57 -07:00
parent 6610097659
commit 5aeb69e034
No known key found for this signature in database
9 changed files with 674 additions and 36 deletions

View file

@ -19,6 +19,7 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
use crate::handlers::policy_provider::PolicyProviderClient;
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
@ -34,9 +35,11 @@ use crate::tracing::{
use common::errors::BrightStaffError;
#[allow(clippy::too_many_arguments)]
pub async fn llm_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
policy_provider: Option<Arc<PolicyProviderClient>>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
@ -73,6 +76,7 @@ pub async fn llm_chat(
llm_chat_inner(
request,
router_service,
policy_provider,
full_qualified_llm_provider_url,
model_aliases,
llm_providers,
@ -90,6 +94,7 @@ pub async fn llm_chat(
async fn llm_chat_inner(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
policy_provider: Option<Arc<PolicyProviderClient>>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
@ -134,7 +139,7 @@ async fn llm_chat_inner(
);
// Extract routing_policy from request body if present
let (chat_request_bytes, inline_routing_policy) =
let (chat_request_bytes, inline_routing_policy, policy_id) =
match crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false) {
Ok(result) => result,
Err(err) => {
@ -355,6 +360,8 @@ async fn llm_chat_inner(
&request_path,
&request_id,
inline_routing_policy,
policy_id,
policy_provider,
)
.await
}

View file

@ -5,6 +5,7 @@ pub mod jsonrpc;
pub mod llm;
pub mod models;
pub mod pipeline_processor;
pub mod policy_provider;
pub mod response_handler;
pub mod router_chat;
pub mod routing_service;

View file

@ -0,0 +1,291 @@
use std::sync::Arc;
use std::time::Duration;
use common::configuration::{ModelUsagePreference, RoutingPolicyProvider};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use serde::Deserialize;
use tracing::warn;
use crate::state::policy_cache::PolicyCache;
const DEFAULT_POLICY_TTL_SECONDS: u64 = 300;
#[derive(Debug, Deserialize)]
struct ExternalPolicyResponse {
policy_id: String,
routing_preferences: Vec<ModelUsagePreference>,
}
#[derive(Debug)]
pub enum PolicyFetchError {
Transient(String),
Invalid(String),
}
impl PolicyFetchError {
pub fn is_transient(&self) -> bool {
matches!(self, PolicyFetchError::Transient(_))
}
pub fn message(&self) -> &str {
match self {
PolicyFetchError::Transient(msg) | PolicyFetchError::Invalid(msg) => msg,
}
}
}
pub struct PolicyProviderClient {
client: reqwest::Client,
config: RoutingPolicyProvider,
cache: Arc<PolicyCache>,
ttl: Duration,
}
impl PolicyProviderClient {
pub fn new(config: RoutingPolicyProvider, cache: Arc<PolicyCache>) -> Self {
let ttl = Duration::from_secs(config.ttl_seconds.unwrap_or(DEFAULT_POLICY_TTL_SECONDS));
Self {
client: reqwest::Client::new(),
config,
cache,
ttl,
}
}
pub async fn fetch_policy(
&self,
policy_id: &str,
) -> Result<Vec<ModelUsagePreference>, PolicyFetchError> {
if let Some(cached) = self.cache.get_valid(policy_id).await {
return Ok(cached);
}
let headers = self.build_headers()?;
let response = self
.client
.get(&self.config.url)
.query(&[("policy_id", policy_id)])
.headers(headers)
.send()
.await
.map_err(|err| PolicyFetchError::Transient(format!("policy fetch failed: {}", err)))?;
if !response.status().is_success() {
return if response.status().is_server_error() {
Err(PolicyFetchError::Transient(format!(
"policy provider returned {}",
response.status()
)))
} else {
Err(PolicyFetchError::Invalid(format!(
"policy provider returned non-success status {}",
response.status()
)))
};
}
let payload: ExternalPolicyResponse = response
.json()
.await
.map_err(|err| PolicyFetchError::Invalid(format!("invalid policy payload: {}", err)))?;
if payload.policy_id != policy_id {
return Err(PolicyFetchError::Invalid(format!(
"policy_id mismatch in provider response: expected '{}', got '{}'",
policy_id, payload.policy_id
)));
}
if payload.routing_preferences.is_empty() {
warn!(
policy_id,
"policy provider returned empty routing preferences"
);
}
self.cache
.insert(
policy_id.to_string(),
payload.routing_preferences.clone(),
self.ttl,
)
.await;
Ok(payload.routing_preferences)
}
fn build_headers(&self) -> Result<HeaderMap, PolicyFetchError> {
let mut headers = HeaderMap::new();
if let Some(configured_headers) = &self.config.headers {
for (name, value) in configured_headers {
let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| {
PolicyFetchError::Invalid(format!(
"invalid header name '{}' in routing.policy_provider.headers: {}",
name, err
))
})?;
let header_value = HeaderValue::from_str(value).map_err(|err| {
PolicyFetchError::Invalid(format!(
"invalid header value for '{}' in routing.policy_provider.headers: {}",
name, err
))
})?;
headers.insert(header_name, header_value);
}
}
Ok(headers)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use common::configuration::RoutingPolicyProvider;
use mockito::{Matcher, Server};
use crate::handlers::policy_provider::{PolicyFetchError, PolicyProviderClient};
use crate::state::policy_cache::PolicyCache;
fn provider_config(url: String, ttl_seconds: Option<u64>) -> RoutingPolicyProvider {
RoutingPolicyProvider {
url,
headers: None,
ttl_seconds,
}
}
#[tokio::test]
async fn fetches_policy_and_populates_cache() {
let mut server = Server::new_async().await;
let _mock = server
.mock("GET", "/v1/routing-policy")
.match_query(Matcher::UrlEncoded(
"policy_id".to_string(),
"customer-abc".to_string(),
))
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
r#"{
"policy_id":"customer-abc",
"routing_preferences":[
{
"model":"openai/gpt-4o",
"routing_preferences":[{"name":"quick response","description":"fast"}]
}
]
}"#,
)
.expect(1)
.create_async()
.await;
let cache = Arc::new(PolicyCache::new());
let client = PolicyProviderClient::new(
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
cache,
);
let first = client.fetch_policy("customer-abc").await.unwrap();
let second = client.fetch_policy("customer-abc").await.unwrap();
assert_eq!(first.len(), 1);
assert_eq!(second[0].model, "openai/gpt-4o");
}
#[tokio::test]
async fn returns_invalid_on_policy_id_mismatch() {
let mut server = Server::new_async().await;
let _mock = server
.mock("GET", "/v1/routing-policy")
.match_query(Matcher::Any)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
r#"{
"policy_id":"different-id",
"routing_preferences":[]
}"#,
)
.create_async()
.await;
let cache = Arc::new(PolicyCache::new());
let client = PolicyProviderClient::new(
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
cache,
);
let err = client.fetch_policy("customer-abc").await.unwrap_err();
assert!(matches!(err, PolicyFetchError::Invalid(_)));
}
#[tokio::test]
async fn returns_transient_on_server_error() {
let mut server = Server::new_async().await;
let _mock = server
.mock("GET", "/v1/routing-policy")
.match_query(Matcher::Any)
.with_status(500)
.create_async()
.await;
let cache = Arc::new(PolicyCache::new());
let client = PolicyProviderClient::new(
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
cache,
);
let err = client.fetch_policy("customer-abc").await.unwrap_err();
assert!(err.is_transient());
}
#[tokio::test]
async fn returns_invalid_on_client_error_status() {
let mut server = Server::new_async().await;
let _mock = server
.mock("GET", "/v1/routing-policy")
.match_query(Matcher::Any)
.with_status(404)
.create_async()
.await;
let cache = Arc::new(PolicyCache::new());
let client = PolicyProviderClient::new(
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
cache,
);
let err = client.fetch_policy("customer-abc").await.unwrap_err();
assert!(matches!(err, PolicyFetchError::Invalid(_)));
}
#[tokio::test]
async fn supports_headers() {
let mut server = Server::new_async().await;
let _mock = server
.mock("GET", "/v1/routing-policy")
.match_header("authorization", "Bearer token")
.match_query(Matcher::Any)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{"policy_id":"customer-abc","routing_preferences":[]}"#)
.create_async()
.await;
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), "Bearer token".to_string());
let cache = Arc::new(PolicyCache::new());
let client = PolicyProviderClient::new(
RoutingPolicyProvider {
url: format!("{}/v1/routing-policy", server.url()),
headers: Some(headers),
ttl_seconds: Some(Duration::from_secs(300).as_secs()),
},
cache,
);
let _ = client.fetch_policy("customer-abc").await.unwrap();
}
}

View file

@ -2,9 +2,12 @@ use common::configuration::ModelUsagePreference;
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hyper::StatusCode;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::handlers::policy_provider::PolicyProviderClient;
use crate::router::llm_router::RouterService;
use crate::tracing::routing;
@ -13,6 +16,7 @@ pub struct RoutingResult {
pub route_name: Option<String>,
}
#[derive(Debug)]
pub struct RoutingError {
pub message: String,
pub status_code: StatusCode,
@ -25,6 +29,60 @@ impl RoutingError {
status_code: StatusCode::INTERNAL_SERVER_ERROR,
}
}
pub fn bad_request(message: String) -> Self {
Self {
message,
status_code: StatusCode::BAD_REQUEST,
}
}
}
async fn resolve_usage_preferences(
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
policy_id: Option<&str>,
policy_provider: Option<&PolicyProviderClient>,
routing_metadata: Option<&HashMap<String, Value>>,
) -> Result<Option<Vec<ModelUsagePreference>>, RoutingError> {
if let Some(inline_preferences) = inline_usage_preferences {
info!("using inline routing_policy from request body");
return Ok(Some(inline_preferences));
}
if let (Some(policy_id), Some(policy_provider_client)) = (policy_id, policy_provider) {
match policy_provider_client.fetch_policy(policy_id).await {
Ok(preferences) => {
info!(
policy_id,
"using routing policy from external policy provider"
);
return Ok(Some(preferences));
}
Err(err) if err.is_transient() => {
warn!(
policy_id,
error = %err.message(),
"policy provider fetch failed, falling back to metadata/config routing preferences"
);
}
Err(err) => {
return Err(RoutingError::bad_request(format!(
"Failed to load routing policy for policy_id '{}': {}",
policy_id,
err.message()
)));
}
}
}
let usage_preferences_str: Option<String> = routing_metadata.and_then(|metadata| {
metadata
.get("plano_preference_config")
.map(|value| value.to_string())
});
Ok(usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok()))
}
/// Determines the routing decision if
@ -32,6 +90,7 @@ impl RoutingError {
/// # Returns
/// * `Ok(RoutingResult)` - Contains the selected model name and span ID
/// * `Err(RoutingError)` - Contains error details and optional span ID
#[allow(clippy::too_many_arguments)]
pub async fn router_chat_get_upstream_model(
router_service: Arc<RouterService>,
client_request: ProviderRequestType,
@ -39,6 +98,8 @@ pub async fn router_chat_get_upstream_model(
request_path: &str,
request_id: &str,
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
policy_id: Option<String>,
policy_provider: Option<Arc<PolicyProviderClient>>,
) -> Result<RoutingResult, RoutingError> {
// Clone metadata for routing before converting (which consumes client_request)
let routing_metadata = client_request.metadata().clone();
@ -77,21 +138,13 @@ pub async fn router_chat_get_upstream_model(
"router request"
);
// Use inline preferences if provided, otherwise fall back to metadata extraction
let usage_preferences: Option<Vec<ModelUsagePreference>> = if inline_usage_preferences.is_some()
{
inline_usage_preferences
} else {
let usage_preferences_str: Option<String> =
routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("plano_preference_config")
.map(|value| value.to_string())
});
usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok())
};
let usage_preferences = resolve_usage_preferences(
inline_usage_preferences,
policy_id.as_deref(),
policy_provider.as_deref(),
routing_metadata.as_ref(),
)
.await?;
// Prepare log message with latest message from chat request
let latest_message_for_log = chat_request
@ -168,3 +221,109 @@ pub async fn router_chat_get_upstream_model(
}
}
}
#[cfg(test)]
mod tests {
use super::resolve_usage_preferences;
use crate::handlers::policy_provider::PolicyProviderClient;
use crate::state::policy_cache::PolicyCache;
use common::configuration::{ModelUsagePreference, RoutingPolicyProvider, RoutingPreference};
use mockito::{Matcher, Server};
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
fn inline_policy(name: &str) -> Vec<ModelUsagePreference> {
vec![ModelUsagePreference {
model: "openai/gpt-4o".to_string(),
routing_preferences: vec![RoutingPreference {
name: name.to_string(),
description: "desc".to_string(),
}],
}]
}
#[tokio::test]
async fn resolve_usage_preferences_prioritizes_inline_policy() {
let inline = inline_policy("inline");
let mut metadata = HashMap::new();
metadata.insert(
"plano_preference_config".to_string(),
json!(
[{"model":"openai/gpt-4o-mini","routing_preferences":[{"name":"metadata","description":"desc"}]}]
),
);
let result = resolve_usage_preferences(
Some(inline.clone()),
Some("policy-a"),
None,
Some(&metadata),
)
.await
.unwrap();
assert_eq!(result.unwrap()[0].routing_preferences[0].name, "inline");
}
#[tokio::test]
async fn resolve_usage_preferences_falls_back_to_metadata_on_transient_policy_error() {
let mut server = Server::new_async().await;
let _mock = server
.mock("GET", "/policy")
.match_query(Matcher::Any)
.with_status(500)
.create_async()
.await;
let provider = PolicyProviderClient::new(
RoutingPolicyProvider {
url: format!("{}/policy", server.url()),
headers: None,
ttl_seconds: Some(60),
},
Arc::new(PolicyCache::new()),
);
let mut metadata = HashMap::new();
metadata.insert(
"plano_preference_config".to_string(),
json!(
[{"model":"openai/gpt-4o-mini","routing_preferences":[{"name":"metadata","description":"desc"}]}]
),
);
let result =
resolve_usage_preferences(None, Some("customer-a"), Some(&provider), Some(&metadata))
.await
.unwrap()
.unwrap();
assert_eq!(result[0].routing_preferences[0].name, "metadata");
}
#[tokio::test]
async fn resolve_usage_preferences_returns_bad_request_on_policy_mismatch() {
let mut server = Server::new_async().await;
let _mock = server
.mock("GET", "/policy")
.match_query(Matcher::Any)
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{"policy_id":"different","routing_preferences":[]}"#)
.create_async()
.await;
let provider = PolicyProviderClient::new(
RoutingPolicyProvider {
url: format!("{}/policy", server.url()),
headers: None,
ttl_seconds: Some(60),
},
Arc::new(PolicyCache::new()),
);
let err = resolve_usage_preferences(None, Some("expected"), Some(&provider), None)
.await
.unwrap_err();
assert_eq!(err.status_code, hyper::StatusCode::BAD_REQUEST);
}
}

View file

@ -10,11 +10,13 @@ use hyper::{Request, Response, StatusCode};
use std::sync::Arc;
use tracing::{debug, info, info_span, warn, Instrument};
use crate::handlers::policy_provider::PolicyProviderClient;
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::router::llm_router::RouterService;
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
type ExtractRoutingPolicyResult = (Bytes, Option<Vec<ModelUsagePreference>>, Option<String>);
/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes
/// and parsed preferences. The `routing_policy` field is removed from the JSON
@ -24,10 +26,20 @@ const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
pub fn extract_routing_policy(
raw_bytes: &[u8],
warn_on_size: bool,
) -> Result<(Bytes, Option<Vec<ModelUsagePreference>>), String> {
) -> Result<ExtractRoutingPolicyResult, String> {
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
let policy_id = json_body
.as_object_mut()
.and_then(|obj| obj.remove("policy_id"))
.map(|policy_id_value| match policy_id_value {
serde_json::Value::String(policy_id) if !policy_id.trim().is_empty() => Ok(policy_id),
serde_json::Value::String(_) => Err("policy_id cannot be empty".to_string()),
_ => Err("policy_id must be a string".to_string()),
})
.transpose()?;
let preferences = json_body
.as_object_mut()
.and_then(|obj| obj.remove("routing_policy"))
@ -58,7 +70,7 @@ pub fn extract_routing_policy(
});
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
Ok((bytes, preferences))
Ok((bytes, preferences, policy_id))
}
#[derive(serde::Serialize)]
@ -71,6 +83,7 @@ struct RoutingDecisionResponse {
pub async fn routing_decision(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
policy_provider: Option<Arc<PolicyProviderClient>>,
request_path: String,
span_attributes: Arc<Option<SpanAttributes>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
@ -95,6 +108,7 @@ pub async fn routing_decision(
routing_decision_inner(
request,
router_service,
policy_provider,
request_id,
request_path,
request_headers,
@ -107,6 +121,7 @@ pub async fn routing_decision(
async fn routing_decision_inner(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
policy_provider: Option<Arc<PolicyProviderClient>>,
request_id: String,
request_path: String,
request_headers: hyper::HeaderMap,
@ -153,17 +168,18 @@ async fn routing_decision_inner(
);
// Extract routing_policy from request body before parsing as ProviderRequestType
let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) {
Ok(result) => result,
Err(err) => {
warn!(error = %err, "failed to parse request JSON");
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request JSON: {}",
err
))
.into_response());
}
};
let (chat_request_bytes, inline_preferences, policy_id) =
match extract_routing_policy(&raw_bytes, true) {
Ok(result) => result,
Err(err) => {
warn!(error = %err, "failed to parse request JSON");
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request JSON: {}",
err
))
.into_response());
}
};
let client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
@ -188,6 +204,8 @@ async fn routing_decision_inner(
&request_path,
&request_id,
inline_preferences,
policy_id,
policy_provider,
)
.await;
@ -218,7 +236,11 @@ async fn routing_decision_inner(
}
Err(err) => {
warn!(error = %err.message, "routing decision failed");
Ok(BrightStaffError::InternalServerError(err.message).into_response())
Ok(BrightStaffError::ForwardedError {
status_code: err.status_code,
message: err.message,
}
.into_response())
}
}
}
@ -243,9 +265,10 @@ mod tests {
#[test]
fn extract_routing_policy_no_policy() {
let body = make_chat_body("");
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
assert!(prefs.is_none());
assert!(policy_id.is_none());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
assert!(cleaned_json.get("routing_policy").is_none());
@ -268,7 +291,7 @@ mod tests {
}
]"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
let prefs = prefs.expect("should have parsed preferences");
assert_eq!(prefs.len(), 2);
@ -280,6 +303,7 @@ mod tests {
// routing_policy should be stripped from cleaned body
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_policy").is_none());
assert!(policy_id.is_none());
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
}
@ -288,13 +312,14 @@ mod tests {
// routing_policy is present but has wrong shape
let policy = r#""routing_policy": "not-an-array""#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
// Invalid policy should be ignored (returns None), not error
assert!(prefs.is_none());
// routing_policy should still be stripped from cleaned body
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_policy").is_none());
assert!(policy_id.is_none());
}
#[test]
@ -309,23 +334,44 @@ mod tests {
fn extract_routing_policy_empty_array() {
let policy = r#""routing_policy": []"#;
let body = make_chat_body(policy);
let (_, prefs) = extract_routing_policy(&body, false).unwrap();
let (_, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
let prefs = prefs.expect("empty array is valid");
assert_eq!(prefs.len(), 0);
assert!(policy_id.is_none());
}
#[test]
fn extract_routing_policy_preserves_other_fields() {
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
assert!(prefs.is_some());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["temperature"], 0.5);
assert_eq!(cleaned_json["max_tokens"], 100);
assert!(cleaned_json.get("routing_policy").is_none());
assert!(policy_id.is_none());
}
#[test]
fn extract_routing_policy_extracts_and_strips_policy_id() {
let body = make_chat_body(r#""policy_id": "customer-abc-123""#);
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
assert!(prefs.is_none());
assert_eq!(policy_id, Some("customer-abc-123".to_string()));
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("policy_id").is_none());
}
#[test]
fn extract_routing_policy_rejects_non_string_policy_id() {
let body = make_chat_body(r#""policy_id": 123"#);
let result = extract_routing_policy(&body, false);
assert!(result.is_err());
assert!(result.unwrap_err().contains("policy_id must be a string"));
}
#[test]

View file

@ -2,10 +2,12 @@ use brightstaff::handlers::agent_chat_completions::agent_chat;
use brightstaff::handlers::function_calling::function_calling_chat_handler;
use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::policy_provider::PolicyProviderClient;
use brightstaff::handlers::routing_service::routing_decision;
use brightstaff::router::llm_router::RouterService;
use brightstaff::router::plano_orchestrator::OrchestratorService;
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::policy_cache::PolicyCache;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
@ -108,6 +110,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
routing_model_name,
routing_llm_provider,
));
let policy_provider: Option<Arc<PolicyProviderClient>> = plano_config
.routing
.as_ref()
.and_then(|routing| routing.policy_provider.clone())
.map(|policy_provider_config| {
Arc::new(PolicyProviderClient::new(
policy_provider_config,
Arc::new(PolicyCache::new()),
))
});
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
@ -172,6 +184,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let router_service: Arc<RouterService> = Arc::clone(&router_service);
let orchestrator_service: Arc<OrchestratorService> = Arc::clone(&orchestrator_service);
let policy_provider = policy_provider.clone();
let model_aliases: Arc<
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
> = Arc::clone(&model_aliases);
@ -185,6 +198,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let orchestrator_service = Arc::clone(&orchestrator_service);
let policy_provider = policy_provider.clone();
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
@ -227,6 +241,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
return routing_decision(
req,
router_service,
policy_provider,
stripped_path,
span_attributes,
)
@ -243,6 +258,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
llm_chat(
req,
router_service,
policy_provider,
fully_qualified_url,
model_aliases,
llm_providers,

View file

@ -9,6 +9,7 @@ use std::sync::Arc;
use tracing::debug;
pub mod memory;
pub mod policy_cache;
pub mod postgresql;
pub mod response_state_processor;

View file

@ -0,0 +1,108 @@
use common::configuration::ModelUsagePreference;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Clone)]
struct CachedPolicy {
preferences: Vec<ModelUsagePreference>,
expires_at: Instant,
}
pub struct PolicyCache {
entries: RwLock<HashMap<String, CachedPolicy>>,
}
impl Default for PolicyCache {
fn default() -> Self {
Self::new()
}
}
impl PolicyCache {
pub fn new() -> Self {
Self {
entries: RwLock::new(HashMap::new()),
}
}
pub async fn get_valid(&self, policy_id: &str) -> Option<Vec<ModelUsagePreference>> {
let now = Instant::now();
let cached = {
let entries = self.entries.read().await;
entries.get(policy_id).cloned()
};
let cached = cached?;
if cached.expires_at > now {
return Some(cached.preferences);
}
self.entries.write().await.remove(policy_id);
None
}
pub async fn insert(
&self,
policy_id: String,
preferences: Vec<ModelUsagePreference>,
ttl: Duration,
) {
let expires_at = Instant::now() + ttl;
self.entries.write().await.insert(
policy_id,
CachedPolicy {
preferences,
expires_at,
},
);
}
}
#[cfg(test)]
mod tests {
use super::PolicyCache;
use common::configuration::{ModelUsagePreference, RoutingPreference};
use std::time::Duration;
fn sample_preferences() -> Vec<ModelUsagePreference> {
vec![ModelUsagePreference {
model: "openai/gpt-4o".to_string(),
routing_preferences: vec![RoutingPreference {
name: "quick response".to_string(),
description: "fast lightweight responses".to_string(),
}],
}]
}
#[tokio::test]
async fn returns_cached_policy_before_expiry() {
let cache = PolicyCache::new();
cache
.insert(
"customer-a".to_string(),
sample_preferences(),
Duration::from_secs(10),
)
.await;
let cached = cache.get_valid("customer-a").await;
assert!(cached.is_some());
assert_eq!(cached.unwrap()[0].model, "openai/gpt-4o");
}
#[tokio::test]
async fn expires_cached_policy_after_ttl() {
let cache = PolicyCache::new();
cache
.insert(
"customer-a".to_string(),
sample_preferences(),
Duration::from_millis(5),
)
.await;
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(cache.get_valid("customer-a").await.is_none());
}
}

View file

@ -9,8 +9,17 @@ use crate::api::open_ai::{
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Routing {
#[serde(alias = "llm_provider")]
pub model_provider: Option<String>,
pub model: Option<String>,
pub policy_provider: Option<RoutingPolicyProvider>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingPolicyProvider {
pub url: String,
pub headers: Option<HashMap<String, String>>,
pub ttl_seconds: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -270,7 +279,7 @@ impl LlmProviderType {
}
}
#[derive(Serialize, Deserialize, Debug)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelUsagePreference {
pub model: String,
pub routing_preferences: Vec<RoutingPreference>,