mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
feat: add policy provider integration and caching mechanism
This commit is contained in:
parent
6610097659
commit
5aeb69e034
9 changed files with 674 additions and 36 deletions
|
|
@ -19,6 +19,7 @@ use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tracing::{debug, info, info_span, warn, Instrument};
|
use tracing::{debug, info, info_span, warn, Instrument};
|
||||||
|
|
||||||
|
use crate::handlers::policy_provider::PolicyProviderClient;
|
||||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||||
use crate::handlers::utils::{
|
use crate::handlers::utils::{
|
||||||
create_streaming_response, truncate_message, ObservableStreamProcessor,
|
create_streaming_response, truncate_message, ObservableStreamProcessor,
|
||||||
|
|
@ -34,9 +35,11 @@ use crate::tracing::{
|
||||||
|
|
||||||
use common::errors::BrightStaffError;
|
use common::errors::BrightStaffError;
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn llm_chat(
|
pub async fn llm_chat(
|
||||||
request: Request<hyper::body::Incoming>,
|
request: Request<hyper::body::Incoming>,
|
||||||
router_service: Arc<RouterService>,
|
router_service: Arc<RouterService>,
|
||||||
|
policy_provider: Option<Arc<PolicyProviderClient>>,
|
||||||
full_qualified_llm_provider_url: String,
|
full_qualified_llm_provider_url: String,
|
||||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||||
|
|
@ -73,6 +76,7 @@ pub async fn llm_chat(
|
||||||
llm_chat_inner(
|
llm_chat_inner(
|
||||||
request,
|
request,
|
||||||
router_service,
|
router_service,
|
||||||
|
policy_provider,
|
||||||
full_qualified_llm_provider_url,
|
full_qualified_llm_provider_url,
|
||||||
model_aliases,
|
model_aliases,
|
||||||
llm_providers,
|
llm_providers,
|
||||||
|
|
@ -90,6 +94,7 @@ pub async fn llm_chat(
|
||||||
async fn llm_chat_inner(
|
async fn llm_chat_inner(
|
||||||
request: Request<hyper::body::Incoming>,
|
request: Request<hyper::body::Incoming>,
|
||||||
router_service: Arc<RouterService>,
|
router_service: Arc<RouterService>,
|
||||||
|
policy_provider: Option<Arc<PolicyProviderClient>>,
|
||||||
full_qualified_llm_provider_url: String,
|
full_qualified_llm_provider_url: String,
|
||||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||||
|
|
@ -134,7 +139,7 @@ async fn llm_chat_inner(
|
||||||
);
|
);
|
||||||
|
|
||||||
// Extract routing_policy from request body if present
|
// Extract routing_policy from request body if present
|
||||||
let (chat_request_bytes, inline_routing_policy) =
|
let (chat_request_bytes, inline_routing_policy, policy_id) =
|
||||||
match crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false) {
|
match crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false) {
|
||||||
Ok(result) => result,
|
Ok(result) => result,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
|
@ -355,6 +360,8 @@ async fn llm_chat_inner(
|
||||||
&request_path,
|
&request_path,
|
||||||
&request_id,
|
&request_id,
|
||||||
inline_routing_policy,
|
inline_routing_policy,
|
||||||
|
policy_id,
|
||||||
|
policy_provider,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ pub mod jsonrpc;
|
||||||
pub mod llm;
|
pub mod llm;
|
||||||
pub mod models;
|
pub mod models;
|
||||||
pub mod pipeline_processor;
|
pub mod pipeline_processor;
|
||||||
|
pub mod policy_provider;
|
||||||
pub mod response_handler;
|
pub mod response_handler;
|
||||||
pub mod router_chat;
|
pub mod router_chat;
|
||||||
pub mod routing_service;
|
pub mod routing_service;
|
||||||
|
|
|
||||||
291
crates/brightstaff/src/handlers/policy_provider.rs
Normal file
291
crates/brightstaff/src/handlers/policy_provider.rs
Normal file
|
|
@ -0,0 +1,291 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use common::configuration::{ModelUsagePreference, RoutingPolicyProvider};
|
||||||
|
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use crate::state::policy_cache::PolicyCache;
|
||||||
|
|
||||||
|
const DEFAULT_POLICY_TTL_SECONDS: u64 = 300;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ExternalPolicyResponse {
|
||||||
|
policy_id: String,
|
||||||
|
routing_preferences: Vec<ModelUsagePreference>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum PolicyFetchError {
|
||||||
|
Transient(String),
|
||||||
|
Invalid(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PolicyFetchError {
|
||||||
|
pub fn is_transient(&self) -> bool {
|
||||||
|
matches!(self, PolicyFetchError::Transient(_))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn message(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
PolicyFetchError::Transient(msg) | PolicyFetchError::Invalid(msg) => msg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PolicyProviderClient {
|
||||||
|
client: reqwest::Client,
|
||||||
|
config: RoutingPolicyProvider,
|
||||||
|
cache: Arc<PolicyCache>,
|
||||||
|
ttl: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PolicyProviderClient {
|
||||||
|
pub fn new(config: RoutingPolicyProvider, cache: Arc<PolicyCache>) -> Self {
|
||||||
|
let ttl = Duration::from_secs(config.ttl_seconds.unwrap_or(DEFAULT_POLICY_TTL_SECONDS));
|
||||||
|
Self {
|
||||||
|
client: reqwest::Client::new(),
|
||||||
|
config,
|
||||||
|
cache,
|
||||||
|
ttl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn fetch_policy(
|
||||||
|
&self,
|
||||||
|
policy_id: &str,
|
||||||
|
) -> Result<Vec<ModelUsagePreference>, PolicyFetchError> {
|
||||||
|
if let Some(cached) = self.cache.get_valid(policy_id).await {
|
||||||
|
return Ok(cached);
|
||||||
|
}
|
||||||
|
|
||||||
|
let headers = self.build_headers()?;
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.get(&self.config.url)
|
||||||
|
.query(&[("policy_id", policy_id)])
|
||||||
|
.headers(headers)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|err| PolicyFetchError::Transient(format!("policy fetch failed: {}", err)))?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return if response.status().is_server_error() {
|
||||||
|
Err(PolicyFetchError::Transient(format!(
|
||||||
|
"policy provider returned {}",
|
||||||
|
response.status()
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
Err(PolicyFetchError::Invalid(format!(
|
||||||
|
"policy provider returned non-success status {}",
|
||||||
|
response.status()
|
||||||
|
)))
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload: ExternalPolicyResponse = response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|err| PolicyFetchError::Invalid(format!("invalid policy payload: {}", err)))?;
|
||||||
|
|
||||||
|
if payload.policy_id != policy_id {
|
||||||
|
return Err(PolicyFetchError::Invalid(format!(
|
||||||
|
"policy_id mismatch in provider response: expected '{}', got '{}'",
|
||||||
|
policy_id, payload.policy_id
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload.routing_preferences.is_empty() {
|
||||||
|
warn!(
|
||||||
|
policy_id,
|
||||||
|
"policy provider returned empty routing preferences"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.cache
|
||||||
|
.insert(
|
||||||
|
policy_id.to_string(),
|
||||||
|
payload.routing_preferences.clone(),
|
||||||
|
self.ttl,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
Ok(payload.routing_preferences)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_headers(&self) -> Result<HeaderMap, PolicyFetchError> {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
if let Some(configured_headers) = &self.config.headers {
|
||||||
|
for (name, value) in configured_headers {
|
||||||
|
let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| {
|
||||||
|
PolicyFetchError::Invalid(format!(
|
||||||
|
"invalid header name '{}' in routing.policy_provider.headers: {}",
|
||||||
|
name, err
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
let header_value = HeaderValue::from_str(value).map_err(|err| {
|
||||||
|
PolicyFetchError::Invalid(format!(
|
||||||
|
"invalid header value for '{}' in routing.policy_provider.headers: {}",
|
||||||
|
name, err
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
headers.insert(header_name, header_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(headers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use common::configuration::RoutingPolicyProvider;
|
||||||
|
use mockito::{Matcher, Server};
|
||||||
|
|
||||||
|
use crate::handlers::policy_provider::{PolicyFetchError, PolicyProviderClient};
|
||||||
|
use crate::state::policy_cache::PolicyCache;
|
||||||
|
|
||||||
|
fn provider_config(url: String, ttl_seconds: Option<u64>) -> RoutingPolicyProvider {
|
||||||
|
RoutingPolicyProvider {
|
||||||
|
url,
|
||||||
|
headers: None,
|
||||||
|
ttl_seconds,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn fetches_policy_and_populates_cache() {
|
||||||
|
let mut server = Server::new_async().await;
|
||||||
|
let _mock = server
|
||||||
|
.mock("GET", "/v1/routing-policy")
|
||||||
|
.match_query(Matcher::UrlEncoded(
|
||||||
|
"policy_id".to_string(),
|
||||||
|
"customer-abc".to_string(),
|
||||||
|
))
|
||||||
|
.with_status(200)
|
||||||
|
.with_header("content-type", "application/json")
|
||||||
|
.with_body(
|
||||||
|
r#"{
|
||||||
|
"policy_id":"customer-abc",
|
||||||
|
"routing_preferences":[
|
||||||
|
{
|
||||||
|
"model":"openai/gpt-4o",
|
||||||
|
"routing_preferences":[{"name":"quick response","description":"fast"}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.expect(1)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let cache = Arc::new(PolicyCache::new());
|
||||||
|
let client = PolicyProviderClient::new(
|
||||||
|
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
|
||||||
|
cache,
|
||||||
|
);
|
||||||
|
|
||||||
|
let first = client.fetch_policy("customer-abc").await.unwrap();
|
||||||
|
let second = client.fetch_policy("customer-abc").await.unwrap();
|
||||||
|
assert_eq!(first.len(), 1);
|
||||||
|
assert_eq!(second[0].model, "openai/gpt-4o");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn returns_invalid_on_policy_id_mismatch() {
|
||||||
|
let mut server = Server::new_async().await;
|
||||||
|
let _mock = server
|
||||||
|
.mock("GET", "/v1/routing-policy")
|
||||||
|
.match_query(Matcher::Any)
|
||||||
|
.with_status(200)
|
||||||
|
.with_header("content-type", "application/json")
|
||||||
|
.with_body(
|
||||||
|
r#"{
|
||||||
|
"policy_id":"different-id",
|
||||||
|
"routing_preferences":[]
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let cache = Arc::new(PolicyCache::new());
|
||||||
|
let client = PolicyProviderClient::new(
|
||||||
|
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
|
||||||
|
cache,
|
||||||
|
);
|
||||||
|
|
||||||
|
let err = client.fetch_policy("customer-abc").await.unwrap_err();
|
||||||
|
assert!(matches!(err, PolicyFetchError::Invalid(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn returns_transient_on_server_error() {
|
||||||
|
let mut server = Server::new_async().await;
|
||||||
|
let _mock = server
|
||||||
|
.mock("GET", "/v1/routing-policy")
|
||||||
|
.match_query(Matcher::Any)
|
||||||
|
.with_status(500)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let cache = Arc::new(PolicyCache::new());
|
||||||
|
let client = PolicyProviderClient::new(
|
||||||
|
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
|
||||||
|
cache,
|
||||||
|
);
|
||||||
|
|
||||||
|
let err = client.fetch_policy("customer-abc").await.unwrap_err();
|
||||||
|
assert!(err.is_transient());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn returns_invalid_on_client_error_status() {
|
||||||
|
let mut server = Server::new_async().await;
|
||||||
|
let _mock = server
|
||||||
|
.mock("GET", "/v1/routing-policy")
|
||||||
|
.match_query(Matcher::Any)
|
||||||
|
.with_status(404)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let cache = Arc::new(PolicyCache::new());
|
||||||
|
let client = PolicyProviderClient::new(
|
||||||
|
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
|
||||||
|
cache,
|
||||||
|
);
|
||||||
|
|
||||||
|
let err = client.fetch_policy("customer-abc").await.unwrap_err();
|
||||||
|
assert!(matches!(err, PolicyFetchError::Invalid(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn supports_headers() {
|
||||||
|
let mut server = Server::new_async().await;
|
||||||
|
let _mock = server
|
||||||
|
.mock("GET", "/v1/routing-policy")
|
||||||
|
.match_header("authorization", "Bearer token")
|
||||||
|
.match_query(Matcher::Any)
|
||||||
|
.with_status(200)
|
||||||
|
.with_header("content-type", "application/json")
|
||||||
|
.with_body(r#"{"policy_id":"customer-abc","routing_preferences":[]}"#)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut headers = HashMap::new();
|
||||||
|
headers.insert("Authorization".to_string(), "Bearer token".to_string());
|
||||||
|
let cache = Arc::new(PolicyCache::new());
|
||||||
|
let client = PolicyProviderClient::new(
|
||||||
|
RoutingPolicyProvider {
|
||||||
|
url: format!("{}/v1/routing-policy", server.url()),
|
||||||
|
headers: Some(headers),
|
||||||
|
ttl_seconds: Some(Duration::from_secs(300).as_secs()),
|
||||||
|
},
|
||||||
|
cache,
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = client.fetch_policy("customer-abc").await.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -2,9 +2,12 @@ use common::configuration::ModelUsagePreference;
|
||||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||||
use hyper::StatusCode;
|
use hyper::StatusCode;
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::{debug, info, warn};
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
use crate::handlers::policy_provider::PolicyProviderClient;
|
||||||
use crate::router::llm_router::RouterService;
|
use crate::router::llm_router::RouterService;
|
||||||
use crate::tracing::routing;
|
use crate::tracing::routing;
|
||||||
|
|
||||||
|
|
@ -13,6 +16,7 @@ pub struct RoutingResult {
|
||||||
pub route_name: Option<String>,
|
pub route_name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct RoutingError {
|
pub struct RoutingError {
|
||||||
pub message: String,
|
pub message: String,
|
||||||
pub status_code: StatusCode,
|
pub status_code: StatusCode,
|
||||||
|
|
@ -25,6 +29,60 @@ impl RoutingError {
|
||||||
status_code: StatusCode::INTERNAL_SERVER_ERROR,
|
status_code: StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn bad_request(message: String) -> Self {
|
||||||
|
Self {
|
||||||
|
message,
|
||||||
|
status_code: StatusCode::BAD_REQUEST,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn resolve_usage_preferences(
|
||||||
|
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||||
|
policy_id: Option<&str>,
|
||||||
|
policy_provider: Option<&PolicyProviderClient>,
|
||||||
|
routing_metadata: Option<&HashMap<String, Value>>,
|
||||||
|
) -> Result<Option<Vec<ModelUsagePreference>>, RoutingError> {
|
||||||
|
if let Some(inline_preferences) = inline_usage_preferences {
|
||||||
|
info!("using inline routing_policy from request body");
|
||||||
|
return Ok(Some(inline_preferences));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let (Some(policy_id), Some(policy_provider_client)) = (policy_id, policy_provider) {
|
||||||
|
match policy_provider_client.fetch_policy(policy_id).await {
|
||||||
|
Ok(preferences) => {
|
||||||
|
info!(
|
||||||
|
policy_id,
|
||||||
|
"using routing policy from external policy provider"
|
||||||
|
);
|
||||||
|
return Ok(Some(preferences));
|
||||||
|
}
|
||||||
|
Err(err) if err.is_transient() => {
|
||||||
|
warn!(
|
||||||
|
policy_id,
|
||||||
|
error = %err.message(),
|
||||||
|
"policy provider fetch failed, falling back to metadata/config routing preferences"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
return Err(RoutingError::bad_request(format!(
|
||||||
|
"Failed to load routing policy for policy_id '{}': {}",
|
||||||
|
policy_id,
|
||||||
|
err.message()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let usage_preferences_str: Option<String> = routing_metadata.and_then(|metadata| {
|
||||||
|
metadata
|
||||||
|
.get("plano_preference_config")
|
||||||
|
.map(|value| value.to_string())
|
||||||
|
});
|
||||||
|
Ok(usage_preferences_str
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|s| serde_yaml::from_str(s).ok()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Determines the routing decision if
|
/// Determines the routing decision if
|
||||||
|
|
@ -32,6 +90,7 @@ impl RoutingError {
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Ok(RoutingResult)` - Contains the selected model name and span ID
|
/// * `Ok(RoutingResult)` - Contains the selected model name and span ID
|
||||||
/// * `Err(RoutingError)` - Contains error details and optional span ID
|
/// * `Err(RoutingError)` - Contains error details and optional span ID
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn router_chat_get_upstream_model(
|
pub async fn router_chat_get_upstream_model(
|
||||||
router_service: Arc<RouterService>,
|
router_service: Arc<RouterService>,
|
||||||
client_request: ProviderRequestType,
|
client_request: ProviderRequestType,
|
||||||
|
|
@ -39,6 +98,8 @@ pub async fn router_chat_get_upstream_model(
|
||||||
request_path: &str,
|
request_path: &str,
|
||||||
request_id: &str,
|
request_id: &str,
|
||||||
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
|
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||||
|
policy_id: Option<String>,
|
||||||
|
policy_provider: Option<Arc<PolicyProviderClient>>,
|
||||||
) -> Result<RoutingResult, RoutingError> {
|
) -> Result<RoutingResult, RoutingError> {
|
||||||
// Clone metadata for routing before converting (which consumes client_request)
|
// Clone metadata for routing before converting (which consumes client_request)
|
||||||
let routing_metadata = client_request.metadata().clone();
|
let routing_metadata = client_request.metadata().clone();
|
||||||
|
|
@ -77,21 +138,13 @@ pub async fn router_chat_get_upstream_model(
|
||||||
"router request"
|
"router request"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Use inline preferences if provided, otherwise fall back to metadata extraction
|
let usage_preferences = resolve_usage_preferences(
|
||||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = if inline_usage_preferences.is_some()
|
inline_usage_preferences,
|
||||||
{
|
policy_id.as_deref(),
|
||||||
inline_usage_preferences
|
policy_provider.as_deref(),
|
||||||
} else {
|
routing_metadata.as_ref(),
|
||||||
let usage_preferences_str: Option<String> =
|
)
|
||||||
routing_metadata.as_ref().and_then(|metadata| {
|
.await?;
|
||||||
metadata
|
|
||||||
.get("plano_preference_config")
|
|
||||||
.map(|value| value.to_string())
|
|
||||||
});
|
|
||||||
usage_preferences_str
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|s| serde_yaml::from_str(s).ok())
|
|
||||||
};
|
|
||||||
|
|
||||||
// Prepare log message with latest message from chat request
|
// Prepare log message with latest message from chat request
|
||||||
let latest_message_for_log = chat_request
|
let latest_message_for_log = chat_request
|
||||||
|
|
@ -168,3 +221,109 @@ pub async fn router_chat_get_upstream_model(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::resolve_usage_preferences;
|
||||||
|
use crate::handlers::policy_provider::PolicyProviderClient;
|
||||||
|
use crate::state::policy_cache::PolicyCache;
|
||||||
|
use common::configuration::{ModelUsagePreference, RoutingPolicyProvider, RoutingPreference};
|
||||||
|
use mockito::{Matcher, Server};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
fn inline_policy(name: &str) -> Vec<ModelUsagePreference> {
|
||||||
|
vec![ModelUsagePreference {
|
||||||
|
model: "openai/gpt-4o".to_string(),
|
||||||
|
routing_preferences: vec![RoutingPreference {
|
||||||
|
name: name.to_string(),
|
||||||
|
description: "desc".to_string(),
|
||||||
|
}],
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn resolve_usage_preferences_prioritizes_inline_policy() {
|
||||||
|
let inline = inline_policy("inline");
|
||||||
|
let mut metadata = HashMap::new();
|
||||||
|
metadata.insert(
|
||||||
|
"plano_preference_config".to_string(),
|
||||||
|
json!(
|
||||||
|
[{"model":"openai/gpt-4o-mini","routing_preferences":[{"name":"metadata","description":"desc"}]}]
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = resolve_usage_preferences(
|
||||||
|
Some(inline.clone()),
|
||||||
|
Some("policy-a"),
|
||||||
|
None,
|
||||||
|
Some(&metadata),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result.unwrap()[0].routing_preferences[0].name, "inline");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn resolve_usage_preferences_falls_back_to_metadata_on_transient_policy_error() {
|
||||||
|
let mut server = Server::new_async().await;
|
||||||
|
let _mock = server
|
||||||
|
.mock("GET", "/policy")
|
||||||
|
.match_query(Matcher::Any)
|
||||||
|
.with_status(500)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let provider = PolicyProviderClient::new(
|
||||||
|
RoutingPolicyProvider {
|
||||||
|
url: format!("{}/policy", server.url()),
|
||||||
|
headers: None,
|
||||||
|
ttl_seconds: Some(60),
|
||||||
|
},
|
||||||
|
Arc::new(PolicyCache::new()),
|
||||||
|
);
|
||||||
|
let mut metadata = HashMap::new();
|
||||||
|
metadata.insert(
|
||||||
|
"plano_preference_config".to_string(),
|
||||||
|
json!(
|
||||||
|
[{"model":"openai/gpt-4o-mini","routing_preferences":[{"name":"metadata","description":"desc"}]}]
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
let result =
|
||||||
|
resolve_usage_preferences(None, Some("customer-a"), Some(&provider), Some(&metadata))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(result[0].routing_preferences[0].name, "metadata");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn resolve_usage_preferences_returns_bad_request_on_policy_mismatch() {
|
||||||
|
let mut server = Server::new_async().await;
|
||||||
|
let _mock = server
|
||||||
|
.mock("GET", "/policy")
|
||||||
|
.match_query(Matcher::Any)
|
||||||
|
.with_status(200)
|
||||||
|
.with_header("content-type", "application/json")
|
||||||
|
.with_body(r#"{"policy_id":"different","routing_preferences":[]}"#)
|
||||||
|
.create_async()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let provider = PolicyProviderClient::new(
|
||||||
|
RoutingPolicyProvider {
|
||||||
|
url: format!("{}/policy", server.url()),
|
||||||
|
headers: None,
|
||||||
|
ttl_seconds: Some(60),
|
||||||
|
},
|
||||||
|
Arc::new(PolicyCache::new()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let err = resolve_usage_preferences(None, Some("expected"), Some(&provider), None)
|
||||||
|
.await
|
||||||
|
.unwrap_err();
|
||||||
|
assert_eq!(err.status_code, hyper::StatusCode::BAD_REQUEST);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,13 @@ use hyper::{Request, Response, StatusCode};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::{debug, info, info_span, warn, Instrument};
|
use tracing::{debug, info, info_span, warn, Instrument};
|
||||||
|
|
||||||
|
use crate::handlers::policy_provider::PolicyProviderClient;
|
||||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||||
use crate::router::llm_router::RouterService;
|
use crate::router::llm_router::RouterService;
|
||||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||||
|
|
||||||
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
|
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
|
||||||
|
type ExtractRoutingPolicyResult = (Bytes, Option<Vec<ModelUsagePreference>>, Option<String>);
|
||||||
|
|
||||||
/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes
|
/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes
|
||||||
/// and parsed preferences. The `routing_policy` field is removed from the JSON
|
/// and parsed preferences. The `routing_policy` field is removed from the JSON
|
||||||
|
|
@ -24,10 +26,20 @@ const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
|
||||||
pub fn extract_routing_policy(
|
pub fn extract_routing_policy(
|
||||||
raw_bytes: &[u8],
|
raw_bytes: &[u8],
|
||||||
warn_on_size: bool,
|
warn_on_size: bool,
|
||||||
) -> Result<(Bytes, Option<Vec<ModelUsagePreference>>), String> {
|
) -> Result<ExtractRoutingPolicyResult, String> {
|
||||||
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
|
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
|
||||||
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
|
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
|
||||||
|
|
||||||
|
let policy_id = json_body
|
||||||
|
.as_object_mut()
|
||||||
|
.and_then(|obj| obj.remove("policy_id"))
|
||||||
|
.map(|policy_id_value| match policy_id_value {
|
||||||
|
serde_json::Value::String(policy_id) if !policy_id.trim().is_empty() => Ok(policy_id),
|
||||||
|
serde_json::Value::String(_) => Err("policy_id cannot be empty".to_string()),
|
||||||
|
_ => Err("policy_id must be a string".to_string()),
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
let preferences = json_body
|
let preferences = json_body
|
||||||
.as_object_mut()
|
.as_object_mut()
|
||||||
.and_then(|obj| obj.remove("routing_policy"))
|
.and_then(|obj| obj.remove("routing_policy"))
|
||||||
|
|
@ -58,7 +70,7 @@ pub fn extract_routing_policy(
|
||||||
});
|
});
|
||||||
|
|
||||||
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
|
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
|
||||||
Ok((bytes, preferences))
|
Ok((bytes, preferences, policy_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Serialize)]
|
#[derive(serde::Serialize)]
|
||||||
|
|
@ -71,6 +83,7 @@ struct RoutingDecisionResponse {
|
||||||
pub async fn routing_decision(
|
pub async fn routing_decision(
|
||||||
request: Request<hyper::body::Incoming>,
|
request: Request<hyper::body::Incoming>,
|
||||||
router_service: Arc<RouterService>,
|
router_service: Arc<RouterService>,
|
||||||
|
policy_provider: Option<Arc<PolicyProviderClient>>,
|
||||||
request_path: String,
|
request_path: String,
|
||||||
span_attributes: Arc<Option<SpanAttributes>>,
|
span_attributes: Arc<Option<SpanAttributes>>,
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
|
|
@ -95,6 +108,7 @@ pub async fn routing_decision(
|
||||||
routing_decision_inner(
|
routing_decision_inner(
|
||||||
request,
|
request,
|
||||||
router_service,
|
router_service,
|
||||||
|
policy_provider,
|
||||||
request_id,
|
request_id,
|
||||||
request_path,
|
request_path,
|
||||||
request_headers,
|
request_headers,
|
||||||
|
|
@ -107,6 +121,7 @@ pub async fn routing_decision(
|
||||||
async fn routing_decision_inner(
|
async fn routing_decision_inner(
|
||||||
request: Request<hyper::body::Incoming>,
|
request: Request<hyper::body::Incoming>,
|
||||||
router_service: Arc<RouterService>,
|
router_service: Arc<RouterService>,
|
||||||
|
policy_provider: Option<Arc<PolicyProviderClient>>,
|
||||||
request_id: String,
|
request_id: String,
|
||||||
request_path: String,
|
request_path: String,
|
||||||
request_headers: hyper::HeaderMap,
|
request_headers: hyper::HeaderMap,
|
||||||
|
|
@ -153,17 +168,18 @@ async fn routing_decision_inner(
|
||||||
);
|
);
|
||||||
|
|
||||||
// Extract routing_policy from request body before parsing as ProviderRequestType
|
// Extract routing_policy from request body before parsing as ProviderRequestType
|
||||||
let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) {
|
let (chat_request_bytes, inline_preferences, policy_id) =
|
||||||
Ok(result) => result,
|
match extract_routing_policy(&raw_bytes, true) {
|
||||||
Err(err) => {
|
Ok(result) => result,
|
||||||
warn!(error = %err, "failed to parse request JSON");
|
Err(err) => {
|
||||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
warn!(error = %err, "failed to parse request JSON");
|
||||||
"Failed to parse request JSON: {}",
|
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||||
err
|
"Failed to parse request JSON: {}",
|
||||||
))
|
err
|
||||||
.into_response());
|
))
|
||||||
}
|
.into_response());
|
||||||
};
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let client_request = match ProviderRequestType::try_from((
|
let client_request = match ProviderRequestType::try_from((
|
||||||
&chat_request_bytes[..],
|
&chat_request_bytes[..],
|
||||||
|
|
@ -188,6 +204,8 @@ async fn routing_decision_inner(
|
||||||
&request_path,
|
&request_path,
|
||||||
&request_id,
|
&request_id,
|
||||||
inline_preferences,
|
inline_preferences,
|
||||||
|
policy_id,
|
||||||
|
policy_provider,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
|
@ -218,7 +236,11 @@ async fn routing_decision_inner(
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!(error = %err.message, "routing decision failed");
|
warn!(error = %err.message, "routing decision failed");
|
||||||
Ok(BrightStaffError::InternalServerError(err.message).into_response())
|
Ok(BrightStaffError::ForwardedError {
|
||||||
|
status_code: err.status_code,
|
||||||
|
message: err.message,
|
||||||
|
}
|
||||||
|
.into_response())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -243,9 +265,10 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn extract_routing_policy_no_policy() {
|
fn extract_routing_policy_no_policy() {
|
||||||
let body = make_chat_body("");
|
let body = make_chat_body("");
|
||||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||||
|
|
||||||
assert!(prefs.is_none());
|
assert!(prefs.is_none());
|
||||||
|
assert!(policy_id.is_none());
|
||||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||||
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
||||||
assert!(cleaned_json.get("routing_policy").is_none());
|
assert!(cleaned_json.get("routing_policy").is_none());
|
||||||
|
|
@ -268,7 +291,7 @@ mod tests {
|
||||||
}
|
}
|
||||||
]"#;
|
]"#;
|
||||||
let body = make_chat_body(policy);
|
let body = make_chat_body(policy);
|
||||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||||
|
|
||||||
let prefs = prefs.expect("should have parsed preferences");
|
let prefs = prefs.expect("should have parsed preferences");
|
||||||
assert_eq!(prefs.len(), 2);
|
assert_eq!(prefs.len(), 2);
|
||||||
|
|
@ -280,6 +303,7 @@ mod tests {
|
||||||
// routing_policy should be stripped from cleaned body
|
// routing_policy should be stripped from cleaned body
|
||||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||||
assert!(cleaned_json.get("routing_policy").is_none());
|
assert!(cleaned_json.get("routing_policy").is_none());
|
||||||
|
assert!(policy_id.is_none());
|
||||||
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -288,13 +312,14 @@ mod tests {
|
||||||
// routing_policy is present but has wrong shape
|
// routing_policy is present but has wrong shape
|
||||||
let policy = r#""routing_policy": "not-an-array""#;
|
let policy = r#""routing_policy": "not-an-array""#;
|
||||||
let body = make_chat_body(policy);
|
let body = make_chat_body(policy);
|
||||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||||
|
|
||||||
// Invalid policy should be ignored (returns None), not error
|
// Invalid policy should be ignored (returns None), not error
|
||||||
assert!(prefs.is_none());
|
assert!(prefs.is_none());
|
||||||
// routing_policy should still be stripped from cleaned body
|
// routing_policy should still be stripped from cleaned body
|
||||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||||
assert!(cleaned_json.get("routing_policy").is_none());
|
assert!(cleaned_json.get("routing_policy").is_none());
|
||||||
|
assert!(policy_id.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -309,23 +334,44 @@ mod tests {
|
||||||
fn extract_routing_policy_empty_array() {
|
fn extract_routing_policy_empty_array() {
|
||||||
let policy = r#""routing_policy": []"#;
|
let policy = r#""routing_policy": []"#;
|
||||||
let body = make_chat_body(policy);
|
let body = make_chat_body(policy);
|
||||||
let (_, prefs) = extract_routing_policy(&body, false).unwrap();
|
let (_, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||||
|
|
||||||
let prefs = prefs.expect("empty array is valid");
|
let prefs = prefs.expect("empty array is valid");
|
||||||
assert_eq!(prefs.len(), 0);
|
assert_eq!(prefs.len(), 0);
|
||||||
|
assert!(policy_id.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn extract_routing_policy_preserves_other_fields() {
|
fn extract_routing_policy_preserves_other_fields() {
|
||||||
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
|
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
|
||||||
let body = make_chat_body(policy);
|
let body = make_chat_body(policy);
|
||||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||||
|
|
||||||
assert!(prefs.is_some());
|
assert!(prefs.is_some());
|
||||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||||
assert_eq!(cleaned_json["temperature"], 0.5);
|
assert_eq!(cleaned_json["temperature"], 0.5);
|
||||||
assert_eq!(cleaned_json["max_tokens"], 100);
|
assert_eq!(cleaned_json["max_tokens"], 100);
|
||||||
assert!(cleaned_json.get("routing_policy").is_none());
|
assert!(cleaned_json.get("routing_policy").is_none());
|
||||||
|
assert!(policy_id.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_routing_policy_extracts_and_strips_policy_id() {
|
||||||
|
let body = make_chat_body(r#""policy_id": "customer-abc-123""#);
|
||||||
|
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||||
|
|
||||||
|
assert!(prefs.is_none());
|
||||||
|
assert_eq!(policy_id, Some("customer-abc-123".to_string()));
|
||||||
|
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||||
|
assert!(cleaned_json.get("policy_id").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_routing_policy_rejects_non_string_policy_id() {
|
||||||
|
let body = make_chat_body(r#""policy_id": 123"#);
|
||||||
|
let result = extract_routing_policy(&body, false);
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("policy_id must be a string"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,12 @@ use brightstaff::handlers::agent_chat_completions::agent_chat;
|
||||||
use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
||||||
use brightstaff::handlers::llm::llm_chat;
|
use brightstaff::handlers::llm::llm_chat;
|
||||||
use brightstaff::handlers::models::list_models;
|
use brightstaff::handlers::models::list_models;
|
||||||
|
use brightstaff::handlers::policy_provider::PolicyProviderClient;
|
||||||
use brightstaff::handlers::routing_service::routing_decision;
|
use brightstaff::handlers::routing_service::routing_decision;
|
||||||
use brightstaff::router::llm_router::RouterService;
|
use brightstaff::router::llm_router::RouterService;
|
||||||
use brightstaff::router::plano_orchestrator::OrchestratorService;
|
use brightstaff::router::plano_orchestrator::OrchestratorService;
|
||||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||||
|
use brightstaff::state::policy_cache::PolicyCache;
|
||||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||||
use brightstaff::state::StateStorage;
|
use brightstaff::state::StateStorage;
|
||||||
use brightstaff::utils::tracing::init_tracer;
|
use brightstaff::utils::tracing::init_tracer;
|
||||||
|
|
@ -108,6 +110,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
routing_model_name,
|
routing_model_name,
|
||||||
routing_llm_provider,
|
routing_llm_provider,
|
||||||
));
|
));
|
||||||
|
let policy_provider: Option<Arc<PolicyProviderClient>> = plano_config
|
||||||
|
.routing
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|routing| routing.policy_provider.clone())
|
||||||
|
.map(|policy_provider_config| {
|
||||||
|
Arc::new(PolicyProviderClient::new(
|
||||||
|
policy_provider_config,
|
||||||
|
Arc::new(PolicyCache::new()),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
|
||||||
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
|
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
|
||||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||||
|
|
@ -172,6 +184,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
|
|
||||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||||
let orchestrator_service: Arc<OrchestratorService> = Arc::clone(&orchestrator_service);
|
let orchestrator_service: Arc<OrchestratorService> = Arc::clone(&orchestrator_service);
|
||||||
|
let policy_provider = policy_provider.clone();
|
||||||
let model_aliases: Arc<
|
let model_aliases: Arc<
|
||||||
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
|
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
|
||||||
> = Arc::clone(&model_aliases);
|
> = Arc::clone(&model_aliases);
|
||||||
|
|
@ -185,6 +198,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let service = service_fn(move |req| {
|
let service = service_fn(move |req| {
|
||||||
let router_service = Arc::clone(&router_service);
|
let router_service = Arc::clone(&router_service);
|
||||||
let orchestrator_service = Arc::clone(&orchestrator_service);
|
let orchestrator_service = Arc::clone(&orchestrator_service);
|
||||||
|
let policy_provider = policy_provider.clone();
|
||||||
let parent_cx = extract_context_from_request(&req);
|
let parent_cx = extract_context_from_request(&req);
|
||||||
let llm_provider_url = llm_provider_url.clone();
|
let llm_provider_url = llm_provider_url.clone();
|
||||||
let llm_providers = llm_providers.clone();
|
let llm_providers = llm_providers.clone();
|
||||||
|
|
@ -227,6 +241,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
return routing_decision(
|
return routing_decision(
|
||||||
req,
|
req,
|
||||||
router_service,
|
router_service,
|
||||||
|
policy_provider,
|
||||||
stripped_path,
|
stripped_path,
|
||||||
span_attributes,
|
span_attributes,
|
||||||
)
|
)
|
||||||
|
|
@ -243,6 +258,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
llm_chat(
|
llm_chat(
|
||||||
req,
|
req,
|
||||||
router_service,
|
router_service,
|
||||||
|
policy_provider,
|
||||||
fully_qualified_url,
|
fully_qualified_url,
|
||||||
model_aliases,
|
model_aliases,
|
||||||
llm_providers,
|
llm_providers,
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ use std::sync::Arc;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
pub mod memory;
|
pub mod memory;
|
||||||
|
pub mod policy_cache;
|
||||||
pub mod postgresql;
|
pub mod postgresql;
|
||||||
pub mod response_state_processor;
|
pub mod response_state_processor;
|
||||||
|
|
||||||
|
|
|
||||||
108
crates/brightstaff/src/state/policy_cache.rs
Normal file
108
crates/brightstaff/src/state/policy_cache.rs
Normal file
|
|
@ -0,0 +1,108 @@
|
||||||
|
use common::configuration::ModelUsagePreference;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct CachedPolicy {
|
||||||
|
preferences: Vec<ModelUsagePreference>,
|
||||||
|
expires_at: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PolicyCache {
|
||||||
|
entries: RwLock<HashMap<String, CachedPolicy>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PolicyCache {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PolicyCache {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
entries: RwLock::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_valid(&self, policy_id: &str) -> Option<Vec<ModelUsagePreference>> {
|
||||||
|
let now = Instant::now();
|
||||||
|
let cached = {
|
||||||
|
let entries = self.entries.read().await;
|
||||||
|
entries.get(policy_id).cloned()
|
||||||
|
};
|
||||||
|
|
||||||
|
let cached = cached?;
|
||||||
|
if cached.expires_at > now {
|
||||||
|
return Some(cached.preferences);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.entries.write().await.remove(policy_id);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert(
|
||||||
|
&self,
|
||||||
|
policy_id: String,
|
||||||
|
preferences: Vec<ModelUsagePreference>,
|
||||||
|
ttl: Duration,
|
||||||
|
) {
|
||||||
|
let expires_at = Instant::now() + ttl;
|
||||||
|
self.entries.write().await.insert(
|
||||||
|
policy_id,
|
||||||
|
CachedPolicy {
|
||||||
|
preferences,
|
||||||
|
expires_at,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::PolicyCache;
|
||||||
|
use common::configuration::{ModelUsagePreference, RoutingPreference};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
fn sample_preferences() -> Vec<ModelUsagePreference> {
|
||||||
|
vec![ModelUsagePreference {
|
||||||
|
model: "openai/gpt-4o".to_string(),
|
||||||
|
routing_preferences: vec![RoutingPreference {
|
||||||
|
name: "quick response".to_string(),
|
||||||
|
description: "fast lightweight responses".to_string(),
|
||||||
|
}],
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn returns_cached_policy_before_expiry() {
|
||||||
|
let cache = PolicyCache::new();
|
||||||
|
cache
|
||||||
|
.insert(
|
||||||
|
"customer-a".to_string(),
|
||||||
|
sample_preferences(),
|
||||||
|
Duration::from_secs(10),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let cached = cache.get_valid("customer-a").await;
|
||||||
|
assert!(cached.is_some());
|
||||||
|
assert_eq!(cached.unwrap()[0].model, "openai/gpt-4o");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn expires_cached_policy_after_ttl() {
|
||||||
|
let cache = PolicyCache::new();
|
||||||
|
cache
|
||||||
|
.insert(
|
||||||
|
"customer-a".to_string(),
|
||||||
|
sample_preferences(),
|
||||||
|
Duration::from_millis(5),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||||
|
assert!(cache.get_valid("customer-a").await.is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -9,8 +9,17 @@ use crate::api::open_ai::{
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Routing {
|
pub struct Routing {
|
||||||
|
#[serde(alias = "llm_provider")]
|
||||||
pub model_provider: Option<String>,
|
pub model_provider: Option<String>,
|
||||||
pub model: Option<String>,
|
pub model: Option<String>,
|
||||||
|
pub policy_provider: Option<RoutingPolicyProvider>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct RoutingPolicyProvider {
|
||||||
|
pub url: String,
|
||||||
|
pub headers: Option<HashMap<String, String>>,
|
||||||
|
pub ttl_seconds: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -270,7 +279,7 @@ impl LlmProviderType {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ModelUsagePreference {
|
pub struct ModelUsagePreference {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub routing_preferences: Vec<RoutingPreference>,
|
pub routing_preferences: Vec<RoutingPreference>,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue