diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index cb07767e..c46b8917 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -562,15 +562,15 @@ def validate_and_render_schema(): "Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers" ) - # Validate input_filters IDs on listeners reference valid agent/filter IDs + # Validate listener-level filter IDs reference valid agent/filter IDs. for listener in listeners: - listener_input_filters = listener.get("input_filters", []) - for fc_id in listener_input_filters: - if fc_id not in agent_id_keys: - raise Exception( - f"Listener '{listener.get('name', 'unknown')}' references input_filters id '{fc_id}' " - f"which is not defined in agents or filters. Available ids: {', '.join(sorted(agent_id_keys))}" - ) + for filter_field in ("input_filters", "output_filters"): + for fc_id in listener.get(filter_field, []): + if fc_id not in agent_id_keys: + raise Exception( + f"Listener '{listener.get('name', 'unknown')}' references {filter_field} id '{fc_id}' " + f"which is not defined in agents or filters. Available ids: {', '.join(sorted(agent_id_keys))}" + ) # Validate model aliases if present if "model_aliases" in config_yaml: diff --git a/cli/test/test_config_generator.py b/cli/test/test_config_generator.py index 77b5b480..78c12e93 100644 --- a/cli/test/test_config_generator.py +++ b/cli/test/test_config_generator.py @@ -327,6 +327,63 @@ routing_preferences: tracing: random_sampling: 100 +""", + }, + { + "id": "unknown_listener_output_filter", + "expected_error": "references output_filters id 'missing_output_guard'", + "plano_config": """ +version: v0.4.0 + +filters: + - id: input_guard + url: http://localhost:10500 + type: http + +listeners: + - name: llm + type: model + port: 12000 + input_filters: + - input_guard + output_filters: + - missing_output_guard + +model_providers: + - model: openai/gpt-4o-mini + access_key: $OPENAI_API_KEY + default: true + +""", + }, + { + "id": "valid_listener_output_filter", + "expected_error": None, + "plano_config": """ +version: v0.4.0 + +filters: + - id: input_guard + url: http://localhost:10500 + type: http + - id: output_guard + url: http://localhost:10501 + type: http + +listeners: + - name: llm + type: model + port: 12000 + input_filters: + - input_guard + output_filters: + - output_guard + +model_providers: + - model: openai/gpt-4o-mini + access_key: $OPENAI_API_KEY + default: true + """, }, ] diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index b1e17e42..90ed84c3 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -142,25 +142,19 @@ async fn init_app_state( .listeners .iter() .find(|l| l.listener_type == ListenerType::Model); - let resolve_chain = |filter_ids: Option>| -> Option { - filter_ids.map(|ids| { - let agents = ids - .iter() - .filter_map(|id| { - global_agent_map - .get(id) - .map(|a: &Agent| (id.clone(), a.clone())) - }) - .collect(); - ResolvedFilterChain { - filter_ids: ids, - agents, - } - }) - }; let filter_pipeline = Arc::new(FilterPipeline { - input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())), - output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())), + input: resolve_filter_chain( + "input_filters", + model_listener.and_then(|l| l.input_filters.clone()), + &global_agent_map, + ) + .map_err(|e| format!("failed to resolve model listener input filters: {e}"))?, + output: resolve_filter_chain( + "output_filters", + model_listener.and_then(|l| l.output_filters.clone()), + &global_agent_map, + ) + .map_err(|e| format!("failed to resolve model listener output filters: {e}"))?, }); let overrides = config.overrides.clone().unwrap_or_default(); @@ -350,6 +344,29 @@ async fn init_app_state( }) } +fn resolve_filter_chain( + field_name: &str, + filter_ids: Option>, + global_agent_map: &HashMap, +) -> Result, String> { + let Some(ids) = filter_ids else { + return Ok(None); + }; + + let mut agents = HashMap::new(); + for id in &ids { + let agent = global_agent_map + .get(id) + .ok_or_else(|| format!("{field_name} id '{id}' is not defined in agents or filters"))?; + agents.insert(id.clone(), agent.clone()); + } + + Ok(Some(ResolvedFilterChain { + filter_ids: ids, + agents, + })) +} + /// Initialize the conversation state storage backend (if configured). async fn init_state_storage( config: &Configuration, @@ -588,3 +605,63 @@ async fn main() -> Result<(), Box> { let state = Arc::new(init_app_state(&config).await?); run_server(state).await } + +#[cfg(test)] +mod tests { + use super::*; + + fn test_agent(id: &str) -> Agent { + Agent { + id: id.to_string(), + transport: None, + tool: None, + url: "http://localhost:10500".to_string(), + agent_type: Some("http".to_string()), + } + } + + #[test] + fn resolve_filter_chain_keeps_valid_filter_references() { + let agent = test_agent("output_guard"); + let global_agent_map = HashMap::from([(agent.id.clone(), agent)]); + + let resolved = resolve_filter_chain( + "output_filters", + Some(vec!["output_guard".to_string()]), + &global_agent_map, + ) + .expect("filter chain should resolve") + .expect("filter chain should be present"); + + assert_eq!(resolved.filter_ids, vec!["output_guard".to_string()]); + assert!(resolved.agents.contains_key("output_guard")); + } + + #[test] + fn resolve_filter_chain_errors_on_missing_output_filter_reference() { + let global_agent_map = HashMap::new(); + + let err = resolve_filter_chain( + "output_filters", + Some(vec!["missing_output_guard".to_string()]), + &global_agent_map, + ) + .expect_err("missing output filter should fail closed"); + + assert!(err.contains("output_filters id 'missing_output_guard'")); + } + + #[test] + fn resolve_filter_chain_errors_on_missing_input_filter_reference() { + let global_agent_map = HashMap::new(); + + let err = resolve_filter_chain( + "input_filters", + Some(vec!["missing_input_guard".to_string()]), + &global_agent_map, + ) + .expect_err("missing input filter should fail closed"); + + assert!(err.contains("input_filters id 'missing_input_guard'")); + } +}