diff --git a/crates/brightstaff/src/handlers/agent_selector.rs b/crates/brightstaff/src/handlers/agent_selector.rs index 45fc74bd..a6b22964 100644 --- a/crates/brightstaff/src/handlers/agent_selector.rs +++ b/crates/brightstaff/src/handlers/agent_selector.rs @@ -172,7 +172,7 @@ impl AgentSelector { #[cfg(test)] mod tests { use super::*; - use common::configuration::{AgentFilterChain, Listener}; + use common::configuration::{AgentFilterChain, Listener, ListenerType}; fn create_test_orchestrator_service() -> Arc { Arc::new(OrchestratorService::new( @@ -192,12 +192,12 @@ mod tests { fn create_test_listener(name: &str, agents: Vec) -> Listener { Listener { + listener_type: ListenerType::Agent, name: name.to_string(), agents: Some(agents), filter_chain: None, port: 8080, router: None, - filter_agents: None, } } diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index bf88ee86..6d8b99d0 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -17,7 +17,7 @@ use hyper::StatusCode; #[cfg(test)] mod tests { use super::*; - use common::configuration::{Agent, AgentFilterChain, Listener}; + use common::configuration::{Agent, AgentFilterChain, Listener, ListenerType}; fn create_test_orchestrator_service() -> Arc { Arc::new(OrchestratorService::new( @@ -72,12 +72,12 @@ mod tests { }; let listener = Listener { + listener_type: ListenerType::Agent, name: "test-listener".to_string(), agents: Some(vec![agent_pipeline.clone()]), filter_chain: None, port: 8080, router: None, - filter_agents: None, }; let listeners = vec![listener]; diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index 0852b0ed..59e35606 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -1,5 +1,5 @@ use bytes::Bytes; -use common::configuration::{AgentFilterChain, Listener, ModelAlias, SpanAttributes}; +use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes}; use common::consts::{ ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, }; @@ -45,7 +45,8 @@ pub async fn llm_chat( llm_providers: Arc>, span_attributes: Arc>, state_storage: Option>, - listeners: Arc>>, + filter_chain: Arc>>, + filter_agents: Arc>, ) -> Result>, hyper::Error> { let request_path = request.uri().path().to_string(); let request_headers = request.headers().clone(); @@ -85,7 +86,8 @@ pub async fn llm_chat( request_id, request_path, request_headers, - listeners, + filter_chain, + filter_agents, ) .instrument(request_span) .await @@ -103,7 +105,8 @@ async fn llm_chat_inner( request_id: String, request_path: String, mut request_headers: hyper::HeaderMap, - listeners: Arc>>, + filter_chain: Arc>>, + filter_agents: Arc>, ) -> Result>, hyper::Error> { // Set service name for LLM operations set_service_name(operation_component::LLM); @@ -257,20 +260,9 @@ async fn llm_chat_inner( debug!("removed plano_preference_config from metadata"); } - // === Filter chain processing for model listeners === - // Check if any model listener (no agents) has a filter_chain configured + // === Filter chain processing for model listener === { - let listeners_guard = listeners.read().await; - let model_listener = listeners_guard - .iter() - .find(|l| l.agents.is_none() && l.filter_chain.is_some()); - - let filter_chain = model_listener.and_then(|l| l.filter_chain.clone()); - let agent_map = model_listener - .and_then(|l| l.filter_agents.clone()) - .unwrap_or_default(); - - if let Some(ref fc) = filter_chain { + if let Some(ref fc) = *filter_chain { if !fc.is_empty() { debug!(filter_chain = ?fc, "processing model listener filter chain"); @@ -288,7 +280,7 @@ async fn llm_chat_inner( .process_filter_chain( &messages, &temp_filter_chain, - &agent_map, + &filter_agents, &request_headers, ) .await diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 872951c8..b1715624 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -10,7 +10,7 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage; use brightstaff::state::StateStorage; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; -use common::configuration::{Agent, Configuration}; +use common::configuration::{Agent, Configuration, ListenerType}; use common::consts::{ CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME, }; @@ -87,26 +87,41 @@ async fn main() -> Result<(), Box> { .map(|a| (a.id.clone(), a.clone())) .collect(); - // Resolve filter_agents on each listener at startup - let mut listeners_resolved = plano_config.listeners.clone(); - for listener in &mut listeners_resolved { - if let Some(ref fc) = listener.filter_chain { - let filter_agents: HashMap = fc - .iter() - .filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone()))) - .collect(); - if !filter_agents.is_empty() { - listener.filter_agents = Some(filter_agents); - } - } - } - // Create expanded provider list for /v1/models endpoint let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone()) .expect("Failed to create LlmProviders"); let llm_providers = Arc::new(RwLock::new(llm_providers)); let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents))); - let listeners = Arc::new(RwLock::new(listeners_resolved)); + + // Resolve model listener filter chain and agents at startup + let model_listener_count = plano_config + .listeners + .iter() + .filter(|l| l.listener_type == ListenerType::Model) + .count(); + assert!( + model_listener_count <= 1, + "only one model listener is allowed, found {}", + model_listener_count + ); + let model_listener = plano_config + .listeners + .iter() + .find(|l| l.listener_type == ListenerType::Model); + let model_filter_chain: Arc>> = + Arc::new(model_listener.and_then(|l| l.filter_chain.clone())); + let model_filter_agents: Arc> = Arc::new( + model_filter_chain + .as_ref() + .as_ref() + .map(|fc| { + fc.iter() + .filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone()))) + .collect() + }) + .unwrap_or_default(), + ); + let listeners = Arc::new(RwLock::new(plano_config.listeners.clone())); let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string()); @@ -200,6 +215,8 @@ async fn main() -> Result<(), Box> { let llm_providers = llm_providers.clone(); let agents_list = combined_agents_filters_list.clone(); + let model_filter_chain = model_filter_chain.clone(); + let model_filter_agents = model_filter_agents.clone(); let listeners = listeners.clone(); let span_attributes = span_attributes.clone(); let state_storage = state_storage.clone(); @@ -211,6 +228,8 @@ async fn main() -> Result<(), Box> { let llm_providers = llm_providers.clone(); let model_aliases = Arc::clone(&model_aliases); let agents_list = agents_list.clone(); + let model_filter_chain = model_filter_chain.clone(); + let model_filter_agents = model_filter_agents.clone(); let listeners = listeners.clone(); let span_attributes = span_attributes.clone(); let state_storage = state_storage.clone(); @@ -269,7 +288,8 @@ async fn main() -> Result<(), Box> { llm_providers, span_attributes, state_storage, - listeners, + model_filter_chain, + model_filter_agents, ) .with_context(parent_cx) .await diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index b54453f3..18945a28 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -36,15 +36,23 @@ pub struct AgentFilterChain { pub filter_chain: Option>, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum ListenerType { + Model, + Agent, + Prompt, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Listener { + #[serde(rename = "type")] + pub listener_type: ListenerType, pub name: String, pub router: Option, pub agents: Option>, pub filter_chain: Option>, pub port: u16, - #[serde(skip)] - pub filter_agents: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)]