diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index 522968c9..e8bb516b 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -426,6 +426,16 @@ def validate_and_render_schema(): "Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers" ) + # Validate filter_chain IDs on listeners reference valid agent/filter IDs + for listener in listeners: + listener_filter_chain = listener.get("filter_chain", []) + for fc_id in listener_filter_chain: + if fc_id not in agent_id_keys: + raise Exception( + f"Listener '{listener.get('name', 'unknown')}' references filter_chain id '{fc_id}' " + f"which is not defined in agents or filters. Available ids: {', '.join(sorted(agent_id_keys))}" + ) + # Validate model aliases if present if "model_aliases" in config_yaml: model_aliases = config_yaml["model_aliases"] diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index cd736eb6..1f763577 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -93,6 +93,10 @@ properties: required: - id - description + filter_chain: + type: array + items: + type: string additionalProperties: false required: - type diff --git a/crates/brightstaff/src/handlers/agent_selector.rs b/crates/brightstaff/src/handlers/agent_selector.rs index faa734ee..a555646d 100644 --- a/crates/brightstaff/src/handlers/agent_selector.rs +++ b/crates/brightstaff/src/handlers/agent_selector.rs @@ -194,6 +194,7 @@ mod tests { Listener { name: name.to_string(), agents: Some(agents), + filter_chain: None, port: 8080, router: None, } diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index 70eaacd7..65ff9546 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -73,6 +73,7 @@ mod tests { let listener = Listener { name: "test-listener".to_string(), agents: Some(vec![agent_pipeline.clone()]), + filter_chain: None, port: 8080, router: None, }; diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index 435fb6f5..572ec368 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -1,5 +1,5 @@ use bytes::Bytes; -use common::configuration::ModelAlias; +use common::configuration::{Agent, AgentFilterChain, Listener, ModelAlias}; use common::consts::{ ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, }; @@ -19,6 +19,8 @@ use std::sync::Arc; use tokio::sync::RwLock; use tracing::{debug, info, info_span, warn, Instrument}; +use super::pipeline_processor::PipelineProcessor; + use crate::handlers::router_chat::router_chat_get_upstream_model; use crate::handlers::utils::{ create_streaming_response, truncate_message, ObservableStreamProcessor, @@ -36,6 +38,7 @@ fn full>(chunk: T) -> BoxBody { .boxed() } +#[allow(clippy::too_many_arguments)] pub async fn llm_chat( request: Request, router_service: Arc, @@ -43,6 +46,8 @@ pub async fn llm_chat( model_aliases: Arc>>, llm_providers: Arc>, state_storage: Option>, + listeners: Arc>>, + agents_list: Arc>>>, ) -> Result>, hyper::Error> { let request_path = request.uri().path().to_string(); let request_headers = request.headers().clone(); @@ -79,6 +84,8 @@ pub async fn llm_chat( request_id, request_path, request_headers, + listeners, + agents_list, ) .instrument(request_span) .await @@ -95,6 +102,8 @@ async fn llm_chat_inner( request_id: String, request_path: String, mut request_headers: hyper::HeaderMap, + listeners: Arc>>, + agents_list: Arc>>>, ) -> Result>, hyper::Error> { // Set service name for LLM operations set_service_name(operation_component::LLM); @@ -235,6 +244,96 @@ async fn llm_chat_inner( debug!("removed plano_preference_config from metadata"); } + // === Filter chain processing for model listeners === + // Check if any model listener (no agents) has a filter_chain configured + { + let listeners_guard = listeners.read().await; + let filter_chain: Option> = listeners_guard + .iter() + .find(|l| l.agents.is_none() && l.filter_chain.is_some()) + .and_then(|l| l.filter_chain.clone()); + + if let Some(ref fc) = filter_chain { + if !fc.is_empty() { + debug!(filter_chain = ?fc, "processing model listener filter chain"); + + // Build agent map from agents_list + let agents_guard = agents_list.read().await; + let agent_map: HashMap = agents_guard + .as_ref() + .map(|agents| { + agents + .iter() + .map(|a| (a.id.clone(), a.clone())) + .collect() + }) + .unwrap_or_default(); + + // Create a temporary AgentFilterChain to reuse PipelineProcessor + let temp_filter_chain = AgentFilterChain { + id: "model_listener".to_string(), + default: None, + description: None, + filter_chain: Some(fc.clone()), + }; + + let mut pipeline_processor = PipelineProcessor::default(); + let messages = client_request.get_messages(); + match pipeline_processor + .process_filter_chain( + &messages, + &temp_filter_chain, + &agent_map, + &request_headers, + ) + .await + { + Ok(filtered_messages) => { + client_request.set_messages(&filtered_messages); + info!( + original_count = messages.len(), + filtered_count = filtered_messages.len(), + "filter chain processed successfully" + ); + } + Err(super::pipeline_processor::PipelineError::ClientError { + agent, + status, + body, + }) => { + warn!( + agent = %agent, + status = %status, + body = %body, + "client error from filter chain" + ); + let error_json = serde_json::json!({ + "error": "FilterChainError", + "agent": agent, + "status": status, + "agent_response": body + }); + let mut error_response = Response::new(full(error_json.to_string())); + *error_response.status_mut() = + StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST); + error_response.headers_mut().insert( + hyper::header::CONTENT_TYPE, + "application/json".parse().unwrap(), + ); + return Ok(error_response); + } + Err(err) => { + warn!(error = %err, "filter chain processing failed"); + let err_msg = format!("Filter chain processing failed: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Ok(internal_error); + } + } + } + } + } + // === v1/responses state management: Determine upstream API and combine input if needed === // Do this BEFORE routing since routing consumes the request // Only process state if state_storage is configured diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 87deda6a..6123a91c 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -221,6 +221,8 @@ async fn main() -> Result<(), Box> { model_aliases, llm_providers, state_storage, + listeners, + agents_list, ) .with_context(parent_cx) .await diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 0a683b8b..dfa40d71 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -41,6 +41,7 @@ pub struct Listener { pub name: String, pub router: Option, pub agents: Option>, + pub filter_chain: Option>, pub port: u16, } diff --git a/docs/source/resources/includes/plano_config_full_reference_rendered.yaml b/docs/source/resources/includes/plano_config_full_reference_rendered.yaml index abd909a0..68505e83 100644 --- a/docs/source/resources/includes/plano_config_full_reference_rendered.yaml +++ b/docs/source/resources/includes/plano_config_full_reference_rendered.yaml @@ -65,6 +65,8 @@ listeners: port: 443 protocol: https provider_interface: openai + filter_chain: + - input_guards name: model_1 port: 12000 type: model