diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index e8bb516b..3fa794a8 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -426,6 +426,13 @@ def validate_and_render_schema(): "Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers" ) + # Validate at most one model listener + model_listeners = [l for l in listeners if l.get("type") == "model"] + if len(model_listeners) > 1: + raise Exception( + f"Only one model listener is allowed, found {len(model_listeners)}" + ) + # Validate filter_chain IDs on listeners reference valid agent/filter IDs for listener in listeners: listener_filter_chain = listener.get("filter_chain", []) diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index 59e35606..e6f6f89f 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -1,5 +1,5 @@ use bytes::Bytes; -use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes}; +use common::configuration::{AgentFilterChain, ModelAlias, ModelFilterChain, SpanAttributes}; use common::consts::{ ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, }; @@ -45,8 +45,7 @@ pub async fn llm_chat( llm_providers: Arc>, span_attributes: Arc>, state_storage: Option>, - filter_chain: Arc>>, - filter_agents: Arc>, + model_filter_chain: Arc>, ) -> Result>, hyper::Error> { let request_path = request.uri().path().to_string(); let request_headers = request.headers().clone(); @@ -86,8 +85,7 @@ pub async fn llm_chat( request_id, request_path, request_headers, - filter_chain, - filter_agents, + model_filter_chain, ) .instrument(request_span) .await @@ -105,8 +103,7 @@ async fn llm_chat_inner( request_id: String, request_path: String, mut request_headers: hyper::HeaderMap, - filter_chain: Arc>>, - filter_agents: Arc>, + model_filter_chain: Arc>, ) -> Result>, hyper::Error> { // Set service name for LLM operations set_service_name(operation_component::LLM); @@ -261,29 +258,27 @@ async fn llm_chat_inner( } // === Filter chain processing for model listener === - { - if let Some(ref fc) = *filter_chain { - if !fc.is_empty() { - debug!(filter_chain = ?fc, "processing model listener filter chain"); + if let Some(ref mfc) = *model_filter_chain { + { + debug!(filter_ids = ?mfc.filter_ids, "processing model listener filter chain"); - // Create a temporary AgentFilterChain to reuse PipelineProcessor - let temp_filter_chain = AgentFilterChain { - id: "model_listener".to_string(), - default: None, - description: None, - filter_chain: Some(fc.clone()), - }; + let temp_filter_chain = AgentFilterChain { + id: "model_listener".to_string(), + default: None, + description: None, + filter_chain: Some(mfc.filter_ids.clone()), + }; - let mut pipeline_processor = PipelineProcessor::default(); - let messages = client_request.get_messages(); - match pipeline_processor - .process_filter_chain( - &messages, - &temp_filter_chain, - &filter_agents, - &request_headers, - ) - .await + let mut pipeline_processor = PipelineProcessor::default(); + let messages = client_request.get_messages(); + match pipeline_processor + .process_filter_chain( + &messages, + &temp_filter_chain, + &mfc.agents, + &request_headers, + ) + .await { Ok(filtered_messages) => { client_request.set_messages(&filtered_messages); @@ -326,7 +321,6 @@ async fn llm_chat_inner( *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; return Ok(internal_error); } - } } } } diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index b1715624..6df163de 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -10,7 +10,7 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage; use brightstaff::state::StateStorage; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; -use common::configuration::{Agent, Configuration, ListenerType}; +use common::configuration::{Agent, Configuration, ListenerType, ModelFilterChain}; use common::consts::{ CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME, }; @@ -94,32 +94,26 @@ async fn main() -> Result<(), Box> { let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents))); // Resolve model listener filter chain and agents at startup - let model_listener_count = plano_config - .listeners - .iter() - .filter(|l| l.listener_type == ListenerType::Model) - .count(); - assert!( - model_listener_count <= 1, - "only one model listener is allowed, found {}", - model_listener_count - ); let model_listener = plano_config .listeners .iter() .find(|l| l.listener_type == ListenerType::Model); - let model_filter_chain: Arc>> = - Arc::new(model_listener.and_then(|l| l.filter_chain.clone())); - let model_filter_agents: Arc> = Arc::new( - model_filter_chain - .as_ref() - .as_ref() + let model_filter_chain: Arc> = Arc::new( + model_listener + .and_then(|l| l.filter_chain.clone()) + .filter(|fc| !fc.is_empty()) .map(|fc| { - fc.iter() - .filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone()))) - .collect() - }) - .unwrap_or_default(), + let agents = fc + .iter() + .filter_map(|id| { + global_agent_map.get(id).map(|a| (id.clone(), a.clone())) + }) + .collect(); + ModelFilterChain { + filter_ids: fc, + agents, + } + }), ); let listeners = Arc::new(RwLock::new(plano_config.listeners.clone())); let llm_provider_url = @@ -216,7 +210,6 @@ async fn main() -> Result<(), Box> { let llm_providers = llm_providers.clone(); let agents_list = combined_agents_filters_list.clone(); let model_filter_chain = model_filter_chain.clone(); - let model_filter_agents = model_filter_agents.clone(); let listeners = listeners.clone(); let span_attributes = span_attributes.clone(); let state_storage = state_storage.clone(); @@ -229,7 +222,6 @@ async fn main() -> Result<(), Box> { let model_aliases = Arc::clone(&model_aliases); let agents_list = agents_list.clone(); let model_filter_chain = model_filter_chain.clone(); - let model_filter_agents = model_filter_agents.clone(); let listeners = listeners.clone(); let span_attributes = span_attributes.clone(); let state_storage = state_storage.clone(); @@ -289,7 +281,6 @@ async fn main() -> Result<(), Box> { span_attributes, state_storage, model_filter_chain, - model_filter_agents, ) .with_context(parent_cx) .await diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 18945a28..e89459f4 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -36,6 +36,14 @@ pub struct AgentFilterChain { pub filter_chain: Option>, } +/// Resolved filter chain for a model listener: the ordered filter IDs +/// together with the agent definitions they reference. +#[derive(Debug, Clone)] +pub struct ModelFilterChain { + pub filter_ids: Vec, + pub agents: HashMap, +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "lowercase")] pub enum ListenerType {