Validate model listener filter references before serving traffic (#947)

* Validate output filter references

* ci: trigger workflows for org member
This commit is contained in:
mukeshbaphna 2026-05-19 13:53:41 -07:00 committed by GitHub
parent 5a4487fc6e
commit 241a181d3a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 160 additions and 26 deletions

View file

@ -142,25 +142,19 @@ async fn init_app_state(
.listeners
.iter()
.find(|l| l.listener_type == ListenerType::Model);
let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
filter_ids.map(|ids| {
let agents = ids
.iter()
.filter_map(|id| {
global_agent_map
.get(id)
.map(|a: &Agent| (id.clone(), a.clone()))
})
.collect();
ResolvedFilterChain {
filter_ids: ids,
agents,
}
})
};
let filter_pipeline = Arc::new(FilterPipeline {
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
input: resolve_filter_chain(
"input_filters",
model_listener.and_then(|l| l.input_filters.clone()),
&global_agent_map,
)
.map_err(|e| format!("failed to resolve model listener input filters: {e}"))?,
output: resolve_filter_chain(
"output_filters",
model_listener.and_then(|l| l.output_filters.clone()),
&global_agent_map,
)
.map_err(|e| format!("failed to resolve model listener output filters: {e}"))?,
});
let overrides = config.overrides.clone().unwrap_or_default();
@ -350,6 +344,29 @@ async fn init_app_state(
})
}
fn resolve_filter_chain(
field_name: &str,
filter_ids: Option<Vec<String>>,
global_agent_map: &HashMap<String, Agent>,
) -> Result<Option<ResolvedFilterChain>, String> {
let Some(ids) = filter_ids else {
return Ok(None);
};
let mut agents = HashMap::new();
for id in &ids {
let agent = global_agent_map
.get(id)
.ok_or_else(|| format!("{field_name} id '{id}' is not defined in agents or filters"))?;
agents.insert(id.clone(), agent.clone());
}
Ok(Some(ResolvedFilterChain {
filter_ids: ids,
agents,
}))
}
/// Initialize the conversation state storage backend (if configured).
async fn init_state_storage(
config: &Configuration,
@ -588,3 +605,63 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let state = Arc::new(init_app_state(&config).await?);
run_server(state).await
}
#[cfg(test)]
mod tests {
use super::*;
fn test_agent(id: &str) -> Agent {
Agent {
id: id.to_string(),
transport: None,
tool: None,
url: "http://localhost:10500".to_string(),
agent_type: Some("http".to_string()),
}
}
#[test]
fn resolve_filter_chain_keeps_valid_filter_references() {
let agent = test_agent("output_guard");
let global_agent_map = HashMap::from([(agent.id.clone(), agent)]);
let resolved = resolve_filter_chain(
"output_filters",
Some(vec!["output_guard".to_string()]),
&global_agent_map,
)
.expect("filter chain should resolve")
.expect("filter chain should be present");
assert_eq!(resolved.filter_ids, vec!["output_guard".to_string()]);
assert!(resolved.agents.contains_key("output_guard"));
}
#[test]
fn resolve_filter_chain_errors_on_missing_output_filter_reference() {
let global_agent_map = HashMap::new();
let err = resolve_filter_chain(
"output_filters",
Some(vec!["missing_output_guard".to_string()]),
&global_agent_map,
)
.expect_err("missing output filter should fail closed");
assert!(err.contains("output_filters id 'missing_output_guard'"));
}
#[test]
fn resolve_filter_chain_errors_on_missing_input_filter_reference() {
let global_agent_map = HashMap::new();
let err = resolve_filter_chain(
"input_filters",
Some(vec!["missing_input_guard".to_string()]),
&global_agent_map,
)
.expect_err("missing input filter should fail closed");
assert!(err.contains("input_filters id 'missing_input_guard'"));
}
}