support filter chains on model listeners

This commit is contained in:
Adil Hafeez 2026-02-19 04:34:26 +00:00
parent 21b043e6d0
commit 8136d7d6ab
8 changed files with 121 additions and 1 deletions

View file

@ -194,6 +194,7 @@ mod tests {
Listener {
name: name.to_string(),
agents: Some(agents),
filter_chain: None,
port: 8080,
router: None,
}

View file

@ -73,6 +73,7 @@ mod tests {
let listener = Listener {
name: "test-listener".to_string(),
agents: Some(vec![agent_pipeline.clone()]),
filter_chain: None,
port: 8080,
router: None,
};

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::configuration::ModelAlias;
use common::configuration::{Agent, AgentFilterChain, Listener, ModelAlias};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
@ -19,6 +19,8 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
use super::pipeline_processor::PipelineProcessor;
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
@ -36,6 +38,7 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
.boxed()
}
#[allow(clippy::too_many_arguments)]
pub async fn llm_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
@ -43,6 +46,8 @@ pub async fn llm_chat(
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
state_storage: Option<Arc<dyn StateStorage>>,
listeners: Arc<RwLock<Vec<Listener>>>,
agents_list: Arc<RwLock<Option<Vec<Agent>>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
@ -79,6 +84,8 @@ pub async fn llm_chat(
request_id,
request_path,
request_headers,
listeners,
agents_list,
)
.instrument(request_span)
.await
@ -95,6 +102,8 @@ async fn llm_chat_inner(
request_id: String,
request_path: String,
mut request_headers: hyper::HeaderMap,
listeners: Arc<RwLock<Vec<Listener>>>,
agents_list: Arc<RwLock<Option<Vec<Agent>>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
@ -235,6 +244,96 @@ async fn llm_chat_inner(
debug!("removed plano_preference_config from metadata");
}
// === Filter chain processing for model listeners ===
// Check if any model listener (no agents) has a filter_chain configured
{
let listeners_guard = listeners.read().await;
let filter_chain: Option<Vec<String>> = listeners_guard
.iter()
.find(|l| l.agents.is_none() && l.filter_chain.is_some())
.and_then(|l| l.filter_chain.clone());
if let Some(ref fc) = filter_chain {
if !fc.is_empty() {
debug!(filter_chain = ?fc, "processing model listener filter chain");
// Build agent map from agents_list
let agents_guard = agents_list.read().await;
let agent_map: HashMap<String, Agent> = agents_guard
.as_ref()
.map(|agents| {
agents
.iter()
.map(|a| (a.id.clone(), a.clone()))
.collect()
})
.unwrap_or_default();
// Create a temporary AgentFilterChain to reuse PipelineProcessor
let temp_filter_chain = AgentFilterChain {
id: "model_listener".to_string(),
default: None,
description: None,
filter_chain: Some(fc.clone()),
};
let mut pipeline_processor = PipelineProcessor::default();
let messages = client_request.get_messages();
match pipeline_processor
.process_filter_chain(
&messages,
&temp_filter_chain,
&agent_map,
&request_headers,
)
.await
{
Ok(filtered_messages) => {
client_request.set_messages(&filtered_messages);
info!(
original_count = messages.len(),
filtered_count = filtered_messages.len(),
"filter chain processed successfully"
);
}
Err(super::pipeline_processor::PipelineError::ClientError {
agent,
status,
body,
}) => {
warn!(
agent = %agent,
status = %status,
body = %body,
"client error from filter chain"
);
let error_json = serde_json::json!({
"error": "FilterChainError",
"agent": agent,
"status": status,
"agent_response": body
});
let mut error_response = Response::new(full(error_json.to_string()));
*error_response.status_mut() =
StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST);
error_response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
return Ok(error_response);
}
Err(err) => {
warn!(error = %err, "filter chain processing failed");
let err_msg = format!("Filter chain processing failed: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
}
}
}
}
// === v1/responses state management: Determine upstream API and combine input if needed ===
// Do this BEFORE routing since routing consumes the request
// Only process state if state_storage is configured

View file

@ -221,6 +221,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
model_aliases,
llm_providers,
state_storage,
listeners,
agents_list,
)
.with_context(parent_cx)
.await