consolidate model filter chain into single struct, move validation to config generator

This commit is contained in:
Adil Hafeez 2026-03-13 13:38:40 -07:00
parent e41e9e1cf4
commit edf782f07a
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
4 changed files with 54 additions and 54 deletions

View file

@ -426,6 +426,13 @@ def validate_and_render_schema():
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
)
# Validate at most one model listener
model_listeners = [l for l in listeners if l.get("type") == "model"]
if len(model_listeners) > 1:
raise Exception(
f"Only one model listener is allowed, found {len(model_listeners)}"
)
# Validate filter_chain IDs on listeners reference valid agent/filter IDs
for listener in listeners:
listener_filter_chain = listener.get("filter_chain", [])

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes};
use common::configuration::{AgentFilterChain, ModelAlias, ModelFilterChain, SpanAttributes};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
@ -45,8 +45,7 @@ pub async fn llm_chat(
llm_providers: Arc<RwLock<LlmProviders>>,
span_attributes: Arc<Option<SpanAttributes>>,
state_storage: Option<Arc<dyn StateStorage>>,
filter_chain: Arc<Option<Vec<String>>>,
filter_agents: Arc<HashMap<String, Agent>>,
model_filter_chain: Arc<Option<ModelFilterChain>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
@ -86,8 +85,7 @@ pub async fn llm_chat(
request_id,
request_path,
request_headers,
filter_chain,
filter_agents,
model_filter_chain,
)
.instrument(request_span)
.await
@ -105,8 +103,7 @@ async fn llm_chat_inner(
request_id: String,
request_path: String,
mut request_headers: hyper::HeaderMap,
filter_chain: Arc<Option<Vec<String>>>,
filter_agents: Arc<HashMap<String, Agent>>,
model_filter_chain: Arc<Option<ModelFilterChain>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
@ -261,29 +258,27 @@ async fn llm_chat_inner(
}
// === Filter chain processing for model listener ===
{
if let Some(ref fc) = *filter_chain {
if !fc.is_empty() {
debug!(filter_chain = ?fc, "processing model listener filter chain");
if let Some(ref mfc) = *model_filter_chain {
{
debug!(filter_ids = ?mfc.filter_ids, "processing model listener filter chain");
// Create a temporary AgentFilterChain to reuse PipelineProcessor
let temp_filter_chain = AgentFilterChain {
id: "model_listener".to_string(),
default: None,
description: None,
filter_chain: Some(fc.clone()),
};
let temp_filter_chain = AgentFilterChain {
id: "model_listener".to_string(),
default: None,
description: None,
filter_chain: Some(mfc.filter_ids.clone()),
};
let mut pipeline_processor = PipelineProcessor::default();
let messages = client_request.get_messages();
match pipeline_processor
.process_filter_chain(
&messages,
&temp_filter_chain,
&filter_agents,
&request_headers,
)
.await
let mut pipeline_processor = PipelineProcessor::default();
let messages = client_request.get_messages();
match pipeline_processor
.process_filter_chain(
&messages,
&temp_filter_chain,
&mfc.agents,
&request_headers,
)
.await
{
Ok(filtered_messages) => {
client_request.set_messages(&filtered_messages);
@ -326,7 +321,6 @@ async fn llm_chat_inner(
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
}
}
}
}

View file

@ -10,7 +10,7 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::{Agent, Configuration, ListenerType};
use common::configuration::{Agent, Configuration, ListenerType, ModelFilterChain};
use common::consts::{
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
};
@ -94,32 +94,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
// Resolve model listener filter chain and agents at startup
let model_listener_count = plano_config
.listeners
.iter()
.filter(|l| l.listener_type == ListenerType::Model)
.count();
assert!(
model_listener_count <= 1,
"only one model listener is allowed, found {}",
model_listener_count
);
let model_listener = plano_config
.listeners
.iter()
.find(|l| l.listener_type == ListenerType::Model);
let model_filter_chain: Arc<Option<Vec<String>>> =
Arc::new(model_listener.and_then(|l| l.filter_chain.clone()));
let model_filter_agents: Arc<HashMap<String, Agent>> = Arc::new(
model_filter_chain
.as_ref()
.as_ref()
let model_filter_chain: Arc<Option<ModelFilterChain>> = Arc::new(
model_listener
.and_then(|l| l.filter_chain.clone())
.filter(|fc| !fc.is_empty())
.map(|fc| {
fc.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect()
})
.unwrap_or_default(),
let agents = fc
.iter()
.filter_map(|id| {
global_agent_map.get(id).map(|a| (id.clone(), a.clone()))
})
.collect();
ModelFilterChain {
filter_ids: fc,
agents,
}
}),
);
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
let llm_provider_url =
@ -216,7 +210,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let agents_list = combined_agents_filters_list.clone();
let model_filter_chain = model_filter_chain.clone();
let model_filter_agents = model_filter_agents.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
@ -229,7 +222,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let model_filter_chain = model_filter_chain.clone();
let model_filter_agents = model_filter_agents.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
@ -289,7 +281,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
span_attributes,
state_storage,
model_filter_chain,
model_filter_agents,
)
.with_context(parent_cx)
.await

View file

@ -36,6 +36,14 @@ pub struct AgentFilterChain {
pub filter_chain: Option<Vec<String>>,
}
/// Resolved filter chain for a model listener: the ordered filter IDs
/// together with the agent definitions they reference.
#[derive(Debug, Clone)]
pub struct ModelFilterChain {
pub filter_ids: Vec<String>,
pub agents: HashMap<String, Agent>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ListenerType {