mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
Validate model listener filter references before serving traffic (#947)
* Validate output filter references * ci: trigger workflows for org member
This commit is contained in:
parent
5a4487fc6e
commit
241a181d3a
3 changed files with 160 additions and 26 deletions
|
|
@ -562,15 +562,15 @@ def validate_and_render_schema():
|
||||||
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
|
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate input_filters IDs on listeners reference valid agent/filter IDs
|
# Validate listener-level filter IDs reference valid agent/filter IDs.
|
||||||
for listener in listeners:
|
for listener in listeners:
|
||||||
listener_input_filters = listener.get("input_filters", [])
|
for filter_field in ("input_filters", "output_filters"):
|
||||||
for fc_id in listener_input_filters:
|
for fc_id in listener.get(filter_field, []):
|
||||||
if fc_id not in agent_id_keys:
|
if fc_id not in agent_id_keys:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Listener '{listener.get('name', 'unknown')}' references input_filters id '{fc_id}' "
|
f"Listener '{listener.get('name', 'unknown')}' references {filter_field} id '{fc_id}' "
|
||||||
f"which is not defined in agents or filters. Available ids: {', '.join(sorted(agent_id_keys))}"
|
f"which is not defined in agents or filters. Available ids: {', '.join(sorted(agent_id_keys))}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate model aliases if present
|
# Validate model aliases if present
|
||||||
if "model_aliases" in config_yaml:
|
if "model_aliases" in config_yaml:
|
||||||
|
|
|
||||||
|
|
@ -327,6 +327,63 @@ routing_preferences:
|
||||||
tracing:
|
tracing:
|
||||||
random_sampling: 100
|
random_sampling: 100
|
||||||
|
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "unknown_listener_output_filter",
|
||||||
|
"expected_error": "references output_filters id 'missing_output_guard'",
|
||||||
|
"plano_config": """
|
||||||
|
version: v0.4.0
|
||||||
|
|
||||||
|
filters:
|
||||||
|
- id: input_guard
|
||||||
|
url: http://localhost:10500
|
||||||
|
type: http
|
||||||
|
|
||||||
|
listeners:
|
||||||
|
- name: llm
|
||||||
|
type: model
|
||||||
|
port: 12000
|
||||||
|
input_filters:
|
||||||
|
- input_guard
|
||||||
|
output_filters:
|
||||||
|
- missing_output_guard
|
||||||
|
|
||||||
|
model_providers:
|
||||||
|
- model: openai/gpt-4o-mini
|
||||||
|
access_key: $OPENAI_API_KEY
|
||||||
|
default: true
|
||||||
|
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "valid_listener_output_filter",
|
||||||
|
"expected_error": None,
|
||||||
|
"plano_config": """
|
||||||
|
version: v0.4.0
|
||||||
|
|
||||||
|
filters:
|
||||||
|
- id: input_guard
|
||||||
|
url: http://localhost:10500
|
||||||
|
type: http
|
||||||
|
- id: output_guard
|
||||||
|
url: http://localhost:10501
|
||||||
|
type: http
|
||||||
|
|
||||||
|
listeners:
|
||||||
|
- name: llm
|
||||||
|
type: model
|
||||||
|
port: 12000
|
||||||
|
input_filters:
|
||||||
|
- input_guard
|
||||||
|
output_filters:
|
||||||
|
- output_guard
|
||||||
|
|
||||||
|
model_providers:
|
||||||
|
- model: openai/gpt-4o-mini
|
||||||
|
access_key: $OPENAI_API_KEY
|
||||||
|
default: true
|
||||||
|
|
||||||
""",
|
""",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -142,25 +142,19 @@ async fn init_app_state(
|
||||||
.listeners
|
.listeners
|
||||||
.iter()
|
.iter()
|
||||||
.find(|l| l.listener_type == ListenerType::Model);
|
.find(|l| l.listener_type == ListenerType::Model);
|
||||||
let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
|
|
||||||
filter_ids.map(|ids| {
|
|
||||||
let agents = ids
|
|
||||||
.iter()
|
|
||||||
.filter_map(|id| {
|
|
||||||
global_agent_map
|
|
||||||
.get(id)
|
|
||||||
.map(|a: &Agent| (id.clone(), a.clone()))
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
ResolvedFilterChain {
|
|
||||||
filter_ids: ids,
|
|
||||||
agents,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
};
|
|
||||||
let filter_pipeline = Arc::new(FilterPipeline {
|
let filter_pipeline = Arc::new(FilterPipeline {
|
||||||
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
|
input: resolve_filter_chain(
|
||||||
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
|
"input_filters",
|
||||||
|
model_listener.and_then(|l| l.input_filters.clone()),
|
||||||
|
&global_agent_map,
|
||||||
|
)
|
||||||
|
.map_err(|e| format!("failed to resolve model listener input filters: {e}"))?,
|
||||||
|
output: resolve_filter_chain(
|
||||||
|
"output_filters",
|
||||||
|
model_listener.and_then(|l| l.output_filters.clone()),
|
||||||
|
&global_agent_map,
|
||||||
|
)
|
||||||
|
.map_err(|e| format!("failed to resolve model listener output filters: {e}"))?,
|
||||||
});
|
});
|
||||||
|
|
||||||
let overrides = config.overrides.clone().unwrap_or_default();
|
let overrides = config.overrides.clone().unwrap_or_default();
|
||||||
|
|
@ -350,6 +344,29 @@ async fn init_app_state(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn resolve_filter_chain(
|
||||||
|
field_name: &str,
|
||||||
|
filter_ids: Option<Vec<String>>,
|
||||||
|
global_agent_map: &HashMap<String, Agent>,
|
||||||
|
) -> Result<Option<ResolvedFilterChain>, String> {
|
||||||
|
let Some(ids) = filter_ids else {
|
||||||
|
return Ok(None);
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut agents = HashMap::new();
|
||||||
|
for id in &ids {
|
||||||
|
let agent = global_agent_map
|
||||||
|
.get(id)
|
||||||
|
.ok_or_else(|| format!("{field_name} id '{id}' is not defined in agents or filters"))?;
|
||||||
|
agents.insert(id.clone(), agent.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(ResolvedFilterChain {
|
||||||
|
filter_ids: ids,
|
||||||
|
agents,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
/// Initialize the conversation state storage backend (if configured).
|
/// Initialize the conversation state storage backend (if configured).
|
||||||
async fn init_state_storage(
|
async fn init_state_storage(
|
||||||
config: &Configuration,
|
config: &Configuration,
|
||||||
|
|
@ -588,3 +605,63 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let state = Arc::new(init_app_state(&config).await?);
|
let state = Arc::new(init_app_state(&config).await?);
|
||||||
run_server(state).await
|
run_server(state).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn test_agent(id: &str) -> Agent {
|
||||||
|
Agent {
|
||||||
|
id: id.to_string(),
|
||||||
|
transport: None,
|
||||||
|
tool: None,
|
||||||
|
url: "http://localhost:10500".to_string(),
|
||||||
|
agent_type: Some("http".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_filter_chain_keeps_valid_filter_references() {
|
||||||
|
let agent = test_agent("output_guard");
|
||||||
|
let global_agent_map = HashMap::from([(agent.id.clone(), agent)]);
|
||||||
|
|
||||||
|
let resolved = resolve_filter_chain(
|
||||||
|
"output_filters",
|
||||||
|
Some(vec!["output_guard".to_string()]),
|
||||||
|
&global_agent_map,
|
||||||
|
)
|
||||||
|
.expect("filter chain should resolve")
|
||||||
|
.expect("filter chain should be present");
|
||||||
|
|
||||||
|
assert_eq!(resolved.filter_ids, vec!["output_guard".to_string()]);
|
||||||
|
assert!(resolved.agents.contains_key("output_guard"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_filter_chain_errors_on_missing_output_filter_reference() {
|
||||||
|
let global_agent_map = HashMap::new();
|
||||||
|
|
||||||
|
let err = resolve_filter_chain(
|
||||||
|
"output_filters",
|
||||||
|
Some(vec!["missing_output_guard".to_string()]),
|
||||||
|
&global_agent_map,
|
||||||
|
)
|
||||||
|
.expect_err("missing output filter should fail closed");
|
||||||
|
|
||||||
|
assert!(err.contains("output_filters id 'missing_output_guard'"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_filter_chain_errors_on_missing_input_filter_reference() {
|
||||||
|
let global_agent_map = HashMap::new();
|
||||||
|
|
||||||
|
let err = resolve_filter_chain(
|
||||||
|
"input_filters",
|
||||||
|
Some(vec!["missing_input_guard".to_string()]),
|
||||||
|
&global_agent_map,
|
||||||
|
)
|
||||||
|
.expect_err("missing input filter should fail closed");
|
||||||
|
|
||||||
|
assert!(err.contains("input_filters id 'missing_input_guard'"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue