consolidate model filter chain into single struct, move validation to config generator

This commit is contained in:
Adil Hafeez 2026-03-13 13:38:40 -07:00
parent e41e9e1cf4
commit edf782f07a
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
4 changed files with 54 additions and 54 deletions

View file

@ -426,6 +426,13 @@ def validate_and_render_schema():
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers" "Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
) )
# Validate at most one model listener
model_listeners = [l for l in listeners if l.get("type") == "model"]
if len(model_listeners) > 1:
raise Exception(
f"Only one model listener is allowed, found {len(model_listeners)}"
)
# Validate filter_chain IDs on listeners reference valid agent/filter IDs # Validate filter_chain IDs on listeners reference valid agent/filter IDs
for listener in listeners: for listener in listeners:
listener_filter_chain = listener.get("filter_chain", []) listener_filter_chain = listener.get("filter_chain", [])

View file

@ -1,5 +1,5 @@
use bytes::Bytes; use bytes::Bytes;
use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes}; use common::configuration::{AgentFilterChain, ModelAlias, ModelFilterChain, SpanAttributes};
use common::consts::{ use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
}; };
@ -45,8 +45,7 @@ pub async fn llm_chat(
llm_providers: Arc<RwLock<LlmProviders>>, llm_providers: Arc<RwLock<LlmProviders>>,
span_attributes: Arc<Option<SpanAttributes>>, span_attributes: Arc<Option<SpanAttributes>>,
state_storage: Option<Arc<dyn StateStorage>>, state_storage: Option<Arc<dyn StateStorage>>,
filter_chain: Arc<Option<Vec<String>>>, model_filter_chain: Arc<Option<ModelFilterChain>>,
filter_agents: Arc<HashMap<String, Agent>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string(); let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone(); let request_headers = request.headers().clone();
@ -86,8 +85,7 @@ pub async fn llm_chat(
request_id, request_id,
request_path, request_path,
request_headers, request_headers,
filter_chain, model_filter_chain,
filter_agents,
) )
.instrument(request_span) .instrument(request_span)
.await .await
@ -105,8 +103,7 @@ async fn llm_chat_inner(
request_id: String, request_id: String,
request_path: String, request_path: String,
mut request_headers: hyper::HeaderMap, mut request_headers: hyper::HeaderMap,
filter_chain: Arc<Option<Vec<String>>>, model_filter_chain: Arc<Option<ModelFilterChain>>,
filter_agents: Arc<HashMap<String, Agent>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations // Set service name for LLM operations
set_service_name(operation_component::LLM); set_service_name(operation_component::LLM);
@ -261,29 +258,27 @@ async fn llm_chat_inner(
} }
// === Filter chain processing for model listener === // === Filter chain processing for model listener ===
{ if let Some(ref mfc) = *model_filter_chain {
if let Some(ref fc) = *filter_chain { {
if !fc.is_empty() { debug!(filter_ids = ?mfc.filter_ids, "processing model listener filter chain");
debug!(filter_chain = ?fc, "processing model listener filter chain");
// Create a temporary AgentFilterChain to reuse PipelineProcessor let temp_filter_chain = AgentFilterChain {
let temp_filter_chain = AgentFilterChain { id: "model_listener".to_string(),
id: "model_listener".to_string(), default: None,
default: None, description: None,
description: None, filter_chain: Some(mfc.filter_ids.clone()),
filter_chain: Some(fc.clone()), };
};
let mut pipeline_processor = PipelineProcessor::default(); let mut pipeline_processor = PipelineProcessor::default();
let messages = client_request.get_messages(); let messages = client_request.get_messages();
match pipeline_processor match pipeline_processor
.process_filter_chain( .process_filter_chain(
&messages, &messages,
&temp_filter_chain, &temp_filter_chain,
&filter_agents, &mfc.agents,
&request_headers, &request_headers,
) )
.await .await
{ {
Ok(filtered_messages) => { Ok(filtered_messages) => {
client_request.set_messages(&filtered_messages); client_request.set_messages(&filtered_messages);
@ -326,7 +321,6 @@ async fn llm_chat_inner(
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error); return Ok(internal_error);
} }
}
} }
} }
} }

View file

@ -10,7 +10,7 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage; use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer; use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes; use bytes::Bytes;
use common::configuration::{Agent, Configuration, ListenerType}; use common::configuration::{Agent, Configuration, ListenerType, ModelFilterChain};
use common::consts::{ use common::consts::{
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME, CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
}; };
@ -94,32 +94,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents))); let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
// Resolve model listener filter chain and agents at startup // Resolve model listener filter chain and agents at startup
let model_listener_count = plano_config
.listeners
.iter()
.filter(|l| l.listener_type == ListenerType::Model)
.count();
assert!(
model_listener_count <= 1,
"only one model listener is allowed, found {}",
model_listener_count
);
let model_listener = plano_config let model_listener = plano_config
.listeners .listeners
.iter() .iter()
.find(|l| l.listener_type == ListenerType::Model); .find(|l| l.listener_type == ListenerType::Model);
let model_filter_chain: Arc<Option<Vec<String>>> = let model_filter_chain: Arc<Option<ModelFilterChain>> = Arc::new(
Arc::new(model_listener.and_then(|l| l.filter_chain.clone())); model_listener
let model_filter_agents: Arc<HashMap<String, Agent>> = Arc::new( .and_then(|l| l.filter_chain.clone())
model_filter_chain .filter(|fc| !fc.is_empty())
.as_ref()
.as_ref()
.map(|fc| { .map(|fc| {
fc.iter() let agents = fc
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone()))) .iter()
.collect() .filter_map(|id| {
}) global_agent_map.get(id).map(|a| (id.clone(), a.clone()))
.unwrap_or_default(), })
.collect();
ModelFilterChain {
filter_ids: fc,
agents,
}
}),
); );
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone())); let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
let llm_provider_url = let llm_provider_url =
@ -216,7 +210,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone(); let llm_providers = llm_providers.clone();
let agents_list = combined_agents_filters_list.clone(); let agents_list = combined_agents_filters_list.clone();
let model_filter_chain = model_filter_chain.clone(); let model_filter_chain = model_filter_chain.clone();
let model_filter_agents = model_filter_agents.clone();
let listeners = listeners.clone(); let listeners = listeners.clone();
let span_attributes = span_attributes.clone(); let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone(); let state_storage = state_storage.clone();
@ -229,7 +222,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let model_aliases = Arc::clone(&model_aliases); let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone(); let agents_list = agents_list.clone();
let model_filter_chain = model_filter_chain.clone(); let model_filter_chain = model_filter_chain.clone();
let model_filter_agents = model_filter_agents.clone();
let listeners = listeners.clone(); let listeners = listeners.clone();
let span_attributes = span_attributes.clone(); let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone(); let state_storage = state_storage.clone();
@ -289,7 +281,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
span_attributes, span_attributes,
state_storage, state_storage,
model_filter_chain, model_filter_chain,
model_filter_agents,
) )
.with_context(parent_cx) .with_context(parent_cx)
.await .await

View file

@ -36,6 +36,14 @@ pub struct AgentFilterChain {
pub filter_chain: Option<Vec<String>>, pub filter_chain: Option<Vec<String>>,
} }
/// Resolved filter chain for a model listener: the ordered filter IDs
/// together with the agent definitions they reference.
#[derive(Debug, Clone)]
pub struct ModelFilterChain {
pub filter_ids: Vec<String>,
pub agents: HashMap<String, Agent>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum ListenerType { pub enum ListenerType {