mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
feat: add policy provider integration and caching mechanism
This commit is contained in:
parent
6610097659
commit
5aeb69e034
9 changed files with 674 additions and 36 deletions
|
|
@ -19,6 +19,7 @@ use std::sync::Arc;
|
|||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use crate::handlers::policy_provider::PolicyProviderClient;
|
||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||
use crate::handlers::utils::{
|
||||
create_streaming_response, truncate_message, ObservableStreamProcessor,
|
||||
|
|
@ -34,9 +35,11 @@ use crate::tracing::{
|
|||
|
||||
use common::errors::BrightStaffError;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn llm_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
policy_provider: Option<Arc<PolicyProviderClient>>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
|
|
@ -73,6 +76,7 @@ pub async fn llm_chat(
|
|||
llm_chat_inner(
|
||||
request,
|
||||
router_service,
|
||||
policy_provider,
|
||||
full_qualified_llm_provider_url,
|
||||
model_aliases,
|
||||
llm_providers,
|
||||
|
|
@ -90,6 +94,7 @@ pub async fn llm_chat(
|
|||
async fn llm_chat_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
policy_provider: Option<Arc<PolicyProviderClient>>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
|
|
@ -134,7 +139,7 @@ async fn llm_chat_inner(
|
|||
);
|
||||
|
||||
// Extract routing_policy from request body if present
|
||||
let (chat_request_bytes, inline_routing_policy) =
|
||||
let (chat_request_bytes, inline_routing_policy, policy_id) =
|
||||
match crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false) {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
|
|
@ -355,6 +360,8 @@ async fn llm_chat_inner(
|
|||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_policy,
|
||||
policy_id,
|
||||
policy_provider,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ pub mod jsonrpc;
|
|||
pub mod llm;
|
||||
pub mod models;
|
||||
pub mod pipeline_processor;
|
||||
pub mod policy_provider;
|
||||
pub mod response_handler;
|
||||
pub mod router_chat;
|
||||
pub mod routing_service;
|
||||
|
|
|
|||
291
crates/brightstaff/src/handlers/policy_provider.rs
Normal file
291
crates/brightstaff/src/handlers/policy_provider.rs
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use common::configuration::{ModelUsagePreference, RoutingPolicyProvider};
|
||||
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use serde::Deserialize;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::state::policy_cache::PolicyCache;
|
||||
|
||||
const DEFAULT_POLICY_TTL_SECONDS: u64 = 300;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ExternalPolicyResponse {
|
||||
policy_id: String,
|
||||
routing_preferences: Vec<ModelUsagePreference>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum PolicyFetchError {
|
||||
Transient(String),
|
||||
Invalid(String),
|
||||
}
|
||||
|
||||
impl PolicyFetchError {
|
||||
pub fn is_transient(&self) -> bool {
|
||||
matches!(self, PolicyFetchError::Transient(_))
|
||||
}
|
||||
|
||||
pub fn message(&self) -> &str {
|
||||
match self {
|
||||
PolicyFetchError::Transient(msg) | PolicyFetchError::Invalid(msg) => msg,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PolicyProviderClient {
|
||||
client: reqwest::Client,
|
||||
config: RoutingPolicyProvider,
|
||||
cache: Arc<PolicyCache>,
|
||||
ttl: Duration,
|
||||
}
|
||||
|
||||
impl PolicyProviderClient {
|
||||
pub fn new(config: RoutingPolicyProvider, cache: Arc<PolicyCache>) -> Self {
|
||||
let ttl = Duration::from_secs(config.ttl_seconds.unwrap_or(DEFAULT_POLICY_TTL_SECONDS));
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
config,
|
||||
cache,
|
||||
ttl,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn fetch_policy(
|
||||
&self,
|
||||
policy_id: &str,
|
||||
) -> Result<Vec<ModelUsagePreference>, PolicyFetchError> {
|
||||
if let Some(cached) = self.cache.get_valid(policy_id).await {
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
let headers = self.build_headers()?;
|
||||
let response = self
|
||||
.client
|
||||
.get(&self.config.url)
|
||||
.query(&[("policy_id", policy_id)])
|
||||
.headers(headers)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| PolicyFetchError::Transient(format!("policy fetch failed: {}", err)))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return if response.status().is_server_error() {
|
||||
Err(PolicyFetchError::Transient(format!(
|
||||
"policy provider returned {}",
|
||||
response.status()
|
||||
)))
|
||||
} else {
|
||||
Err(PolicyFetchError::Invalid(format!(
|
||||
"policy provider returned non-success status {}",
|
||||
response.status()
|
||||
)))
|
||||
};
|
||||
}
|
||||
|
||||
let payload: ExternalPolicyResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|err| PolicyFetchError::Invalid(format!("invalid policy payload: {}", err)))?;
|
||||
|
||||
if payload.policy_id != policy_id {
|
||||
return Err(PolicyFetchError::Invalid(format!(
|
||||
"policy_id mismatch in provider response: expected '{}', got '{}'",
|
||||
policy_id, payload.policy_id
|
||||
)));
|
||||
}
|
||||
|
||||
if payload.routing_preferences.is_empty() {
|
||||
warn!(
|
||||
policy_id,
|
||||
"policy provider returned empty routing preferences"
|
||||
);
|
||||
}
|
||||
|
||||
self.cache
|
||||
.insert(
|
||||
policy_id.to_string(),
|
||||
payload.routing_preferences.clone(),
|
||||
self.ttl,
|
||||
)
|
||||
.await;
|
||||
Ok(payload.routing_preferences)
|
||||
}
|
||||
|
||||
fn build_headers(&self) -> Result<HeaderMap, PolicyFetchError> {
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(configured_headers) = &self.config.headers {
|
||||
for (name, value) in configured_headers {
|
||||
let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| {
|
||||
PolicyFetchError::Invalid(format!(
|
||||
"invalid header name '{}' in routing.policy_provider.headers: {}",
|
||||
name, err
|
||||
))
|
||||
})?;
|
||||
let header_value = HeaderValue::from_str(value).map_err(|err| {
|
||||
PolicyFetchError::Invalid(format!(
|
||||
"invalid header value for '{}' in routing.policy_provider.headers: {}",
|
||||
name, err
|
||||
))
|
||||
})?;
|
||||
headers.insert(header_name, header_value);
|
||||
}
|
||||
}
|
||||
Ok(headers)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use common::configuration::RoutingPolicyProvider;
|
||||
use mockito::{Matcher, Server};
|
||||
|
||||
use crate::handlers::policy_provider::{PolicyFetchError, PolicyProviderClient};
|
||||
use crate::state::policy_cache::PolicyCache;
|
||||
|
||||
fn provider_config(url: String, ttl_seconds: Option<u64>) -> RoutingPolicyProvider {
|
||||
RoutingPolicyProvider {
|
||||
url,
|
||||
headers: None,
|
||||
ttl_seconds,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetches_policy_and_populates_cache() {
|
||||
let mut server = Server::new_async().await;
|
||||
let _mock = server
|
||||
.mock("GET", "/v1/routing-policy")
|
||||
.match_query(Matcher::UrlEncoded(
|
||||
"policy_id".to_string(),
|
||||
"customer-abc".to_string(),
|
||||
))
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(
|
||||
r#"{
|
||||
"policy_id":"customer-abc",
|
||||
"routing_preferences":[
|
||||
{
|
||||
"model":"openai/gpt-4o",
|
||||
"routing_preferences":[{"name":"quick response","description":"fast"}]
|
||||
}
|
||||
]
|
||||
}"#,
|
||||
)
|
||||
.expect(1)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let cache = Arc::new(PolicyCache::new());
|
||||
let client = PolicyProviderClient::new(
|
||||
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
|
||||
cache,
|
||||
);
|
||||
|
||||
let first = client.fetch_policy("customer-abc").await.unwrap();
|
||||
let second = client.fetch_policy("customer-abc").await.unwrap();
|
||||
assert_eq!(first.len(), 1);
|
||||
assert_eq!(second[0].model, "openai/gpt-4o");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_invalid_on_policy_id_mismatch() {
|
||||
let mut server = Server::new_async().await;
|
||||
let _mock = server
|
||||
.mock("GET", "/v1/routing-policy")
|
||||
.match_query(Matcher::Any)
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(
|
||||
r#"{
|
||||
"policy_id":"different-id",
|
||||
"routing_preferences":[]
|
||||
}"#,
|
||||
)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let cache = Arc::new(PolicyCache::new());
|
||||
let client = PolicyProviderClient::new(
|
||||
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
|
||||
cache,
|
||||
);
|
||||
|
||||
let err = client.fetch_policy("customer-abc").await.unwrap_err();
|
||||
assert!(matches!(err, PolicyFetchError::Invalid(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_transient_on_server_error() {
|
||||
let mut server = Server::new_async().await;
|
||||
let _mock = server
|
||||
.mock("GET", "/v1/routing-policy")
|
||||
.match_query(Matcher::Any)
|
||||
.with_status(500)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let cache = Arc::new(PolicyCache::new());
|
||||
let client = PolicyProviderClient::new(
|
||||
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
|
||||
cache,
|
||||
);
|
||||
|
||||
let err = client.fetch_policy("customer-abc").await.unwrap_err();
|
||||
assert!(err.is_transient());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_invalid_on_client_error_status() {
|
||||
let mut server = Server::new_async().await;
|
||||
let _mock = server
|
||||
.mock("GET", "/v1/routing-policy")
|
||||
.match_query(Matcher::Any)
|
||||
.with_status(404)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let cache = Arc::new(PolicyCache::new());
|
||||
let client = PolicyProviderClient::new(
|
||||
provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)),
|
||||
cache,
|
||||
);
|
||||
|
||||
let err = client.fetch_policy("customer-abc").await.unwrap_err();
|
||||
assert!(matches!(err, PolicyFetchError::Invalid(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn supports_headers() {
|
||||
let mut server = Server::new_async().await;
|
||||
let _mock = server
|
||||
.mock("GET", "/v1/routing-policy")
|
||||
.match_header("authorization", "Bearer token")
|
||||
.match_query(Matcher::Any)
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(r#"{"policy_id":"customer-abc","routing_preferences":[]}"#)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let mut headers = HashMap::new();
|
||||
headers.insert("Authorization".to_string(), "Bearer token".to_string());
|
||||
let cache = Arc::new(PolicyCache::new());
|
||||
let client = PolicyProviderClient::new(
|
||||
RoutingPolicyProvider {
|
||||
url: format!("{}/v1/routing-policy", server.url()),
|
||||
headers: Some(headers),
|
||||
ttl_seconds: Some(Duration::from_secs(300).as_secs()),
|
||||
},
|
||||
cache,
|
||||
);
|
||||
|
||||
let _ = client.fetch_policy("customer-abc").await.unwrap();
|
||||
}
|
||||
}
|
||||
|
|
@ -2,9 +2,12 @@ use common::configuration::ModelUsagePreference;
|
|||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hyper::StatusCode;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::handlers::policy_provider::PolicyProviderClient;
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::tracing::routing;
|
||||
|
||||
|
|
@ -13,6 +16,7 @@ pub struct RoutingResult {
|
|||
pub route_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RoutingError {
|
||||
pub message: String,
|
||||
pub status_code: StatusCode,
|
||||
|
|
@ -25,6 +29,60 @@ impl RoutingError {
|
|||
status_code: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bad_request(message: String) -> Self {
|
||||
Self {
|
||||
message,
|
||||
status_code: StatusCode::BAD_REQUEST,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn resolve_usage_preferences(
|
||||
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||
policy_id: Option<&str>,
|
||||
policy_provider: Option<&PolicyProviderClient>,
|
||||
routing_metadata: Option<&HashMap<String, Value>>,
|
||||
) -> Result<Option<Vec<ModelUsagePreference>>, RoutingError> {
|
||||
if let Some(inline_preferences) = inline_usage_preferences {
|
||||
info!("using inline routing_policy from request body");
|
||||
return Ok(Some(inline_preferences));
|
||||
}
|
||||
|
||||
if let (Some(policy_id), Some(policy_provider_client)) = (policy_id, policy_provider) {
|
||||
match policy_provider_client.fetch_policy(policy_id).await {
|
||||
Ok(preferences) => {
|
||||
info!(
|
||||
policy_id,
|
||||
"using routing policy from external policy provider"
|
||||
);
|
||||
return Ok(Some(preferences));
|
||||
}
|
||||
Err(err) if err.is_transient() => {
|
||||
warn!(
|
||||
policy_id,
|
||||
error = %err.message(),
|
||||
"policy provider fetch failed, falling back to metadata/config routing preferences"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(RoutingError::bad_request(format!(
|
||||
"Failed to load routing policy for policy_id '{}': {}",
|
||||
policy_id,
|
||||
err.message()
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let usage_preferences_str: Option<String> = routing_metadata.and_then(|metadata| {
|
||||
metadata
|
||||
.get("plano_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
Ok(usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok()))
|
||||
}
|
||||
|
||||
/// Determines the routing decision if
|
||||
|
|
@ -32,6 +90,7 @@ impl RoutingError {
|
|||
/// # Returns
|
||||
/// * `Ok(RoutingResult)` - Contains the selected model name and span ID
|
||||
/// * `Err(RoutingError)` - Contains error details and optional span ID
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn router_chat_get_upstream_model(
|
||||
router_service: Arc<RouterService>,
|
||||
client_request: ProviderRequestType,
|
||||
|
|
@ -39,6 +98,8 @@ pub async fn router_chat_get_upstream_model(
|
|||
request_path: &str,
|
||||
request_id: &str,
|
||||
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||
policy_id: Option<String>,
|
||||
policy_provider: Option<Arc<PolicyProviderClient>>,
|
||||
) -> Result<RoutingResult, RoutingError> {
|
||||
// Clone metadata for routing before converting (which consumes client_request)
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
|
@ -77,21 +138,13 @@ pub async fn router_chat_get_upstream_model(
|
|||
"router request"
|
||||
);
|
||||
|
||||
// Use inline preferences if provided, otherwise fall back to metadata extraction
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = if inline_usage_preferences.is_some()
|
||||
{
|
||||
inline_usage_preferences
|
||||
} else {
|
||||
let usage_preferences_str: Option<String> =
|
||||
routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("plano_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok())
|
||||
};
|
||||
let usage_preferences = resolve_usage_preferences(
|
||||
inline_usage_preferences,
|
||||
policy_id.as_deref(),
|
||||
policy_provider.as_deref(),
|
||||
routing_metadata.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Prepare log message with latest message from chat request
|
||||
let latest_message_for_log = chat_request
|
||||
|
|
@ -168,3 +221,109 @@ pub async fn router_chat_get_upstream_model(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::resolve_usage_preferences;
|
||||
use crate::handlers::policy_provider::PolicyProviderClient;
|
||||
use crate::state::policy_cache::PolicyCache;
|
||||
use common::configuration::{ModelUsagePreference, RoutingPolicyProvider, RoutingPreference};
|
||||
use mockito::{Matcher, Server};
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn inline_policy(name: &str) -> Vec<ModelUsagePreference> {
|
||||
vec![ModelUsagePreference {
|
||||
model: "openai/gpt-4o".to_string(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: name.to_string(),
|
||||
description: "desc".to_string(),
|
||||
}],
|
||||
}]
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resolve_usage_preferences_prioritizes_inline_policy() {
|
||||
let inline = inline_policy("inline");
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"plano_preference_config".to_string(),
|
||||
json!(
|
||||
[{"model":"openai/gpt-4o-mini","routing_preferences":[{"name":"metadata","description":"desc"}]}]
|
||||
),
|
||||
);
|
||||
|
||||
let result = resolve_usage_preferences(
|
||||
Some(inline.clone()),
|
||||
Some("policy-a"),
|
||||
None,
|
||||
Some(&metadata),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result.unwrap()[0].routing_preferences[0].name, "inline");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resolve_usage_preferences_falls_back_to_metadata_on_transient_policy_error() {
|
||||
let mut server = Server::new_async().await;
|
||||
let _mock = server
|
||||
.mock("GET", "/policy")
|
||||
.match_query(Matcher::Any)
|
||||
.with_status(500)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let provider = PolicyProviderClient::new(
|
||||
RoutingPolicyProvider {
|
||||
url: format!("{}/policy", server.url()),
|
||||
headers: None,
|
||||
ttl_seconds: Some(60),
|
||||
},
|
||||
Arc::new(PolicyCache::new()),
|
||||
);
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"plano_preference_config".to_string(),
|
||||
json!(
|
||||
[{"model":"openai/gpt-4o-mini","routing_preferences":[{"name":"metadata","description":"desc"}]}]
|
||||
),
|
||||
);
|
||||
|
||||
let result =
|
||||
resolve_usage_preferences(None, Some("customer-a"), Some(&provider), Some(&metadata))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result[0].routing_preferences[0].name, "metadata");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn resolve_usage_preferences_returns_bad_request_on_policy_mismatch() {
|
||||
let mut server = Server::new_async().await;
|
||||
let _mock = server
|
||||
.mock("GET", "/policy")
|
||||
.match_query(Matcher::Any)
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(r#"{"policy_id":"different","routing_preferences":[]}"#)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
let provider = PolicyProviderClient::new(
|
||||
RoutingPolicyProvider {
|
||||
url: format!("{}/policy", server.url()),
|
||||
headers: None,
|
||||
ttl_seconds: Some(60),
|
||||
},
|
||||
Arc::new(PolicyCache::new()),
|
||||
);
|
||||
|
||||
let err = resolve_usage_preferences(None, Some("expected"), Some(&provider), None)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_eq!(err.status_code, hyper::StatusCode::BAD_REQUEST);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,11 +10,13 @@ use hyper::{Request, Response, StatusCode};
|
|||
use std::sync::Arc;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use crate::handlers::policy_provider::PolicyProviderClient;
|
||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||
|
||||
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
|
||||
type ExtractRoutingPolicyResult = (Bytes, Option<Vec<ModelUsagePreference>>, Option<String>);
|
||||
|
||||
/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes
|
||||
/// and parsed preferences. The `routing_policy` field is removed from the JSON
|
||||
|
|
@ -24,10 +26,20 @@ const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
|
|||
pub fn extract_routing_policy(
|
||||
raw_bytes: &[u8],
|
||||
warn_on_size: bool,
|
||||
) -> Result<(Bytes, Option<Vec<ModelUsagePreference>>), String> {
|
||||
) -> Result<ExtractRoutingPolicyResult, String> {
|
||||
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
|
||||
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
|
||||
|
||||
let policy_id = json_body
|
||||
.as_object_mut()
|
||||
.and_then(|obj| obj.remove("policy_id"))
|
||||
.map(|policy_id_value| match policy_id_value {
|
||||
serde_json::Value::String(policy_id) if !policy_id.trim().is_empty() => Ok(policy_id),
|
||||
serde_json::Value::String(_) => Err("policy_id cannot be empty".to_string()),
|
||||
_ => Err("policy_id must be a string".to_string()),
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let preferences = json_body
|
||||
.as_object_mut()
|
||||
.and_then(|obj| obj.remove("routing_policy"))
|
||||
|
|
@ -58,7 +70,7 @@ pub fn extract_routing_policy(
|
|||
});
|
||||
|
||||
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
|
||||
Ok((bytes, preferences))
|
||||
Ok((bytes, preferences, policy_id))
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
|
|
@ -71,6 +83,7 @@ struct RoutingDecisionResponse {
|
|||
pub async fn routing_decision(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
policy_provider: Option<Arc<PolicyProviderClient>>,
|
||||
request_path: String,
|
||||
span_attributes: Arc<Option<SpanAttributes>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
|
@ -95,6 +108,7 @@ pub async fn routing_decision(
|
|||
routing_decision_inner(
|
||||
request,
|
||||
router_service,
|
||||
policy_provider,
|
||||
request_id,
|
||||
request_path,
|
||||
request_headers,
|
||||
|
|
@ -107,6 +121,7 @@ pub async fn routing_decision(
|
|||
async fn routing_decision_inner(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
policy_provider: Option<Arc<PolicyProviderClient>>,
|
||||
request_id: String,
|
||||
request_path: String,
|
||||
request_headers: hyper::HeaderMap,
|
||||
|
|
@ -153,17 +168,18 @@ async fn routing_decision_inner(
|
|||
);
|
||||
|
||||
// Extract routing_policy from request body before parsing as ProviderRequestType
|
||||
let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request JSON: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
let (chat_request_bytes, inline_preferences, policy_id) =
|
||||
match extract_routing_policy(&raw_bytes, true) {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request JSON: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
let client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
|
|
@ -188,6 +204,8 @@ async fn routing_decision_inner(
|
|||
&request_path,
|
||||
&request_id,
|
||||
inline_preferences,
|
||||
policy_id,
|
||||
policy_provider,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -218,7 +236,11 @@ async fn routing_decision_inner(
|
|||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err.message, "routing decision failed");
|
||||
Ok(BrightStaffError::InternalServerError(err.message).into_response())
|
||||
Ok(BrightStaffError::ForwardedError {
|
||||
status_code: err.status_code,
|
||||
message: err.message,
|
||||
}
|
||||
.into_response())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -243,9 +265,10 @@ mod tests {
|
|||
#[test]
|
||||
fn extract_routing_policy_no_policy() {
|
||||
let body = make_chat_body("");
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
assert!(prefs.is_none());
|
||||
assert!(policy_id.is_none());
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
|
|
@ -268,7 +291,7 @@ mod tests {
|
|||
}
|
||||
]"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
let prefs = prefs.expect("should have parsed preferences");
|
||||
assert_eq!(prefs.len(), 2);
|
||||
|
|
@ -280,6 +303,7 @@ mod tests {
|
|||
// routing_policy should be stripped from cleaned body
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
assert!(policy_id.is_none());
|
||||
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
||||
}
|
||||
|
||||
|
|
@ -288,13 +312,14 @@ mod tests {
|
|||
// routing_policy is present but has wrong shape
|
||||
let policy = r#""routing_policy": "not-an-array""#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
// Invalid policy should be ignored (returns None), not error
|
||||
assert!(prefs.is_none());
|
||||
// routing_policy should still be stripped from cleaned body
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
assert!(policy_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -309,23 +334,44 @@ mod tests {
|
|||
fn extract_routing_policy_empty_array() {
|
||||
let policy = r#""routing_policy": []"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (_, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
let (_, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
let prefs = prefs.expect("empty array is valid");
|
||||
assert_eq!(prefs.len(), 0);
|
||||
assert!(policy_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_preserves_other_fields() {
|
||||
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
assert!(prefs.is_some());
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert_eq!(cleaned_json["temperature"], 0.5);
|
||||
assert_eq!(cleaned_json["max_tokens"], 100);
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
assert!(policy_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_extracts_and_strips_policy_id() {
|
||||
let body = make_chat_body(r#""policy_id": "customer-abc-123""#);
|
||||
let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
assert!(prefs.is_none());
|
||||
assert_eq!(policy_id, Some("customer-abc-123".to_string()));
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("policy_id").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_rejects_non_string_policy_id() {
|
||||
let body = make_chat_body(r#""policy_id": 123"#);
|
||||
let result = extract_routing_policy(&body, false);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("policy_id must be a string"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -2,10 +2,12 @@ use brightstaff::handlers::agent_chat_completions::agent_chat;
|
|||
use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
||||
use brightstaff::handlers::llm::llm_chat;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::handlers::policy_provider::PolicyProviderClient;
|
||||
use brightstaff::handlers::routing_service::routing_decision;
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::router::plano_orchestrator::OrchestratorService;
|
||||
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||
use brightstaff::state::policy_cache::PolicyCache;
|
||||
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
|
|
@ -108,6 +110,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
routing_model_name,
|
||||
routing_llm_provider,
|
||||
));
|
||||
let policy_provider: Option<Arc<PolicyProviderClient>> = plano_config
|
||||
.routing
|
||||
.as_ref()
|
||||
.and_then(|routing| routing.policy_provider.clone())
|
||||
.map(|policy_provider_config| {
|
||||
Arc::new(PolicyProviderClient::new(
|
||||
policy_provider_config,
|
||||
Arc::new(PolicyCache::new()),
|
||||
))
|
||||
});
|
||||
|
||||
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
|
|
@ -172,6 +184,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let orchestrator_service: Arc<OrchestratorService> = Arc::clone(&orchestrator_service);
|
||||
let policy_provider = policy_provider.clone();
|
||||
let model_aliases: Arc<
|
||||
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
|
||||
> = Arc::clone(&model_aliases);
|
||||
|
|
@ -185,6 +198,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let service = service_fn(move |req| {
|
||||
let router_service = Arc::clone(&router_service);
|
||||
let orchestrator_service = Arc::clone(&orchestrator_service);
|
||||
let policy_provider = policy_provider.clone();
|
||||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
|
|
@ -227,6 +241,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
return routing_decision(
|
||||
req,
|
||||
router_service,
|
||||
policy_provider,
|
||||
stripped_path,
|
||||
span_attributes,
|
||||
)
|
||||
|
|
@ -243,6 +258,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
llm_chat(
|
||||
req,
|
||||
router_service,
|
||||
policy_provider,
|
||||
fully_qualified_url,
|
||||
model_aliases,
|
||||
llm_providers,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ use std::sync::Arc;
|
|||
use tracing::debug;
|
||||
|
||||
pub mod memory;
|
||||
pub mod policy_cache;
|
||||
pub mod postgresql;
|
||||
pub mod response_state_processor;
|
||||
|
||||
|
|
|
|||
108
crates/brightstaff/src/state/policy_cache.rs
Normal file
108
crates/brightstaff/src/state/policy_cache.rs
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use std::collections::HashMap;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CachedPolicy {
|
||||
preferences: Vec<ModelUsagePreference>,
|
||||
expires_at: Instant,
|
||||
}
|
||||
|
||||
pub struct PolicyCache {
|
||||
entries: RwLock<HashMap<String, CachedPolicy>>,
|
||||
}
|
||||
|
||||
impl Default for PolicyCache {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PolicyCache {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_valid(&self, policy_id: &str) -> Option<Vec<ModelUsagePreference>> {
|
||||
let now = Instant::now();
|
||||
let cached = {
|
||||
let entries = self.entries.read().await;
|
||||
entries.get(policy_id).cloned()
|
||||
};
|
||||
|
||||
let cached = cached?;
|
||||
if cached.expires_at > now {
|
||||
return Some(cached.preferences);
|
||||
}
|
||||
|
||||
self.entries.write().await.remove(policy_id);
|
||||
None
|
||||
}
|
||||
|
||||
pub async fn insert(
|
||||
&self,
|
||||
policy_id: String,
|
||||
preferences: Vec<ModelUsagePreference>,
|
||||
ttl: Duration,
|
||||
) {
|
||||
let expires_at = Instant::now() + ttl;
|
||||
self.entries.write().await.insert(
|
||||
policy_id,
|
||||
CachedPolicy {
|
||||
preferences,
|
||||
expires_at,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::PolicyCache;
|
||||
use common::configuration::{ModelUsagePreference, RoutingPreference};
|
||||
use std::time::Duration;
|
||||
|
||||
fn sample_preferences() -> Vec<ModelUsagePreference> {
|
||||
vec![ModelUsagePreference {
|
||||
model: "openai/gpt-4o".to_string(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: "quick response".to_string(),
|
||||
description: "fast lightweight responses".to_string(),
|
||||
}],
|
||||
}]
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_cached_policy_before_expiry() {
|
||||
let cache = PolicyCache::new();
|
||||
cache
|
||||
.insert(
|
||||
"customer-a".to_string(),
|
||||
sample_preferences(),
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
|
||||
let cached = cache.get_valid("customer-a").await;
|
||||
assert!(cached.is_some());
|
||||
assert_eq!(cached.unwrap()[0].model, "openai/gpt-4o");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn expires_cached_policy_after_ttl() {
|
||||
let cache = PolicyCache::new();
|
||||
cache
|
||||
.insert(
|
||||
"customer-a".to_string(),
|
||||
sample_preferences(),
|
||||
Duration::from_millis(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
assert!(cache.get_valid("customer-a").await.is_none());
|
||||
}
|
||||
}
|
||||
|
|
@ -9,8 +9,17 @@ use crate::api::open_ai::{
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Routing {
|
||||
#[serde(alias = "llm_provider")]
|
||||
pub model_provider: Option<String>,
|
||||
pub model: Option<String>,
|
||||
pub policy_provider: Option<RoutingPolicyProvider>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoutingPolicyProvider {
|
||||
pub url: String,
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
pub ttl_seconds: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -270,7 +279,7 @@ impl LlmProviderType {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelUsagePreference {
|
||||
pub model: String,
|
||||
pub routing_preferences: Vec<RoutingPreference>,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue