refactor filter pipeline: introduce ResolvedFilterChain and FilterPipeline types

- Replace 4 separate filter params with single Arc<FilterPipeline> in llm_chat
- Add ResolvedFilterChain (filter_ids + agents) and FilterPipeline (input + output)
- Rename utils.rs to streaming.rs, extract STREAM_BUFFER_SIZE constant
- Deduplicate output filter logic via Box<dyn StreamProcessor>
- Rename upstream_path param to request_path for consistency

Made-with: Cursor
This commit is contained in:
Adil Hafeez 2026-03-18 16:47:16 -07:00
parent 15c6ce7d64
commit 1605d2caa6
7 changed files with 119 additions and 126 deletions

View file

@ -1,5 +1,5 @@
use bytes::Bytes; use bytes::Bytes;
use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes}; use common::configuration::{FilterPipeline, ModelAlias, SpanAttributes};
use common::consts::{ use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
}; };
@ -22,9 +22,9 @@ use tracing::{debug, info, info_span, warn, Instrument};
use super::pipeline_processor::PipelineProcessor; use super::pipeline_processor::PipelineProcessor;
use crate::handlers::router_chat::router_chat_get_upstream_model; use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{ use crate::handlers::streaming::{
create_streaming_response, create_streaming_response_with_output_filter, truncate_message, create_streaming_response, create_streaming_response_with_output_filter, truncate_message,
ObservableStreamProcessor, ObservableStreamProcessor, StreamProcessor,
}; };
use crate::router::llm_router::RouterService; use crate::router::llm_router::RouterService;
use crate::state::response_state_processor::ResponsesStateProcessor; use crate::state::response_state_processor::ResponsesStateProcessor;
@ -46,10 +46,7 @@ pub async fn llm_chat(
llm_providers: Arc<RwLock<LlmProviders>>, llm_providers: Arc<RwLock<LlmProviders>>,
span_attributes: Arc<Option<SpanAttributes>>, span_attributes: Arc<Option<SpanAttributes>>,
state_storage: Option<Arc<dyn StateStorage>>, state_storage: Option<Arc<dyn StateStorage>>,
input_filters: Arc<Option<Vec<String>>>, filter_pipeline: Arc<FilterPipeline>,
input_filter_agents: Arc<HashMap<String, Agent>>,
output_filters: Arc<Option<Vec<String>>>,
output_filter_agents: Arc<HashMap<String, Agent>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string(); let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone(); let request_headers = request.headers().clone();
@ -89,10 +86,7 @@ pub async fn llm_chat(
request_id, request_id,
request_path, request_path,
request_headers, request_headers,
input_filters, filter_pipeline,
input_filter_agents,
output_filters,
output_filter_agents,
) )
.instrument(request_span) .instrument(request_span)
.await .await
@ -110,10 +104,7 @@ async fn llm_chat_inner(
request_id: String, request_id: String,
request_path: String, request_path: String,
mut request_headers: hyper::HeaderMap, mut request_headers: hyper::HeaderMap,
input_filters: Arc<Option<Vec<String>>>, filter_pipeline: Arc<FilterPipeline>,
input_filter_agents: Arc<HashMap<String, Agent>>,
output_filters: Arc<Option<Vec<String>>>,
output_filter_agents: Arc<HashMap<String, Agent>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations // Set service name for LLM operations
set_service_name(operation_component::LLM); set_service_name(operation_component::LLM);
@ -271,23 +262,18 @@ async fn llm_chat_inner(
// Filters receive the raw request bytes and return (possibly modified) raw bytes. // Filters receive the raw request bytes and return (possibly modified) raw bytes.
// The returned bytes are re-parsed into a ProviderRequestType to continue the request. // The returned bytes are re-parsed into a ProviderRequestType to continue the request.
{ {
if let Some(ref fc) = *input_filters { if let Some(ref input_chain) = filter_pipeline.input {
if !fc.is_empty() { if !input_chain.is_empty() {
debug!(input_filters = ?fc, "processing model listener input filters"); debug!(input_filters = ?input_chain.filter_ids, "processing model listener input filters");
let temp_filter_chain = AgentFilterChain { let chain = input_chain.to_agent_filter_chain("model_listener");
id: "model_listener".to_string(),
default: None,
description: None,
input_filters: Some(fc.clone()),
};
let mut pipeline_processor = PipelineProcessor::default(); let mut pipeline_processor = PipelineProcessor::default();
match pipeline_processor match pipeline_processor
.process_raw_filter_chain( .process_raw_filter_chain(
&chat_request_bytes, &chat_request_bytes,
&temp_filter_chain, &chain,
&input_filter_agents, &input_chain.agents,
&request_headers, &request_headers,
&request_path, &request_path,
) )
@ -523,12 +509,7 @@ async fn llm_chat_inner(
// Output filters run for any API shape that reaches this handler (e.g. /v1/chat/completions, // Output filters run for any API shape that reaches this handler (e.g. /v1/chat/completions,
// /v1/messages, /v1/responses). Brightstaff does inbound translation and llm_gateway does // /v1/messages, /v1/responses). Brightstaff does inbound translation and llm_gateway does
// outbound translation; filters receive raw response bytes and request path. // outbound translation; filters receive raw response bytes and request path.
let output_filters_configured = output_filters let has_output_filter = filter_pipeline.has_output_filters();
.as_ref()
.as_ref()
.map(|fc| !fc.is_empty())
.unwrap_or(false);
let has_output_filter = output_filters_configured;
// Save request headers for output filters (before they're consumed by upstream request) // Save request headers for output filters (before they're consumed by upstream request)
let output_filter_request_headers = if has_output_filter { let output_filter_request_headers = if has_output_filter {
@ -579,20 +560,18 @@ async fn llm_chat_inner(
); );
// === v1/responses state management: Wrap with ResponsesStateProcessor === // === v1/responses state management: Wrap with ResponsesStateProcessor ===
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured) // Pick the right processor: state-aware if needed, otherwise base metrics-only.
let streaming_response = if let (true, false, Some(state_store)) = ( let processor: Box<dyn StreamProcessor> = if let (true, false, Some(state_store)) = (
should_manage_state, should_manage_state,
original_input_items.is_empty(), original_input_items.is_empty(),
state_storage, state_storage,
) { ) {
// Extract Content-Encoding header to handle decompression for state parsing
let content_encoding = response_headers let content_encoding = response_headers
.get("content-encoding") .get("content-encoding")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.map(|s| s.to_string()); .map(|s| s.to_string());
// Wrap with state management processor to store state after response completes Box::new(ResponsesStateProcessor::new(
let state_processor = ResponsesStateProcessor::new(
base_processor, base_processor,
state_store, state_store,
original_input_items, original_input_items,
@ -602,37 +581,23 @@ async fn llm_chat_inner(
false, // Not OpenAI upstream since should_manage_state is true false, // Not OpenAI upstream since should_manage_state is true
content_encoding, content_encoding,
request_id, request_id,
); ))
if has_output_filter { } else {
let ofc = output_filters.as_ref().as_ref().unwrap().clone(); Box::new(base_processor)
let ofa = (*output_filter_agents).clone(); };
// Apply output filters if configured, then build the streaming response.
let streaming_response = if has_output_filter {
let output_chain = filter_pipeline.output.as_ref().unwrap().clone();
create_streaming_response_with_output_filter( create_streaming_response_with_output_filter(
byte_stream, byte_stream,
state_processor, processor,
16, output_chain,
ofc,
ofa,
output_filter_request_headers.unwrap(), output_filter_request_headers.unwrap(),
request_path.clone(), request_path.clone(),
) )
} else { } else {
create_streaming_response(byte_stream, state_processor, 16) create_streaming_response(byte_stream, processor)
}
} else if has_output_filter {
let ofc = output_filters.as_ref().as_ref().unwrap().clone();
let ofa = (*output_filter_agents).clone();
create_streaming_response_with_output_filter(
byte_stream,
base_processor,
16,
ofc,
ofa,
output_filter_request_headers.unwrap(),
request_path.clone(),
)
} else {
// Use base processor without state management
create_streaming_response(byte_stream, base_processor, 16)
}; };
match response.body(streaming_response.body) { match response.body(streaming_response.body) {

View file

@ -8,7 +8,7 @@ pub mod pipeline_processor;
pub mod response_handler; pub mod response_handler;
pub mod router_chat; pub mod router_chat;
pub mod routing_service; pub mod routing_service;
pub mod utils; pub mod streaming;
#[cfg(test)] #[cfg(test)]
mod integration_tests; mod integration_tests;

View file

@ -1,12 +1,11 @@
use bytes::Bytes; use bytes::Bytes;
use common::configuration::{Agent, AgentFilterChain}; use common::configuration::ResolvedFilterChain;
use http_body_util::combinators::BoxBody; use http_body_util::combinators::BoxBody;
use http_body_util::StreamBody; use http_body_util::StreamBody;
use hyper::body::Frame; use hyper::body::Frame;
use hyper::header::HeaderMap; use hyper::header::HeaderMap;
use opentelemetry::trace::TraceContextExt; use opentelemetry::trace::TraceContextExt;
use opentelemetry::KeyValue; use opentelemetry::KeyValue;
use std::collections::HashMap;
use std::time::Instant; use std::time::Instant;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
@ -15,6 +14,8 @@ use tracing::{debug, info, warn, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_opentelemetry::OpenTelemetrySpanExt;
use super::pipeline_processor::{PipelineError, PipelineProcessor}; use super::pipeline_processor::{PipelineError, PipelineProcessor};
const STREAM_BUFFER_SIZE: usize = 16;
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER}; use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
use crate::tracing::{llm, set_service_name, signals as signal_constants}; use crate::tracing::{llm, set_service_name, signals as signal_constants};
use hermesllm::apis::openai::Message; use hermesllm::apis::openai::Message;
@ -35,6 +36,21 @@ pub trait StreamProcessor: Send + 'static {
fn on_error(&mut self, _error: &str) {} fn on_error(&mut self, _error: &str) {}
} }
impl StreamProcessor for Box<dyn StreamProcessor> {
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
(**self).process_chunk(chunk)
}
fn on_first_bytes(&mut self) {
(**self).on_first_bytes()
}
fn on_complete(&mut self) {
(**self).on_complete()
}
fn on_error(&mut self, error: &str) {
(**self).on_error(error)
}
}
/// A processor that tracks streaming metrics /// A processor that tracks streaming metrics
pub struct ObservableStreamProcessor { pub struct ObservableStreamProcessor {
service_name: String, service_name: String,
@ -210,16 +226,12 @@ pub struct StreamingResponse {
pub processor_handle: tokio::task::JoinHandle<()>, pub processor_handle: tokio::task::JoinHandle<()>,
} }
pub fn create_streaming_response<S, P>( pub fn create_streaming_response<S, P>(mut byte_stream: S, mut processor: P) -> StreamingResponse
mut byte_stream: S,
mut processor: P,
buffer_size: usize,
) -> StreamingResponse
where where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static, S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor, P: StreamProcessor,
{ {
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size); let (tx, rx) = mpsc::channel::<Bytes>(STREAM_BUFFER_SIZE);
// Capture the current span so the spawned task inherits the request context // Capture the current span so the spawned task inherits the request context
let current_span = tracing::Span::current(); let current_span = tracing::Span::current();
@ -287,29 +299,22 @@ where
pub fn create_streaming_response_with_output_filter<S, P>( pub fn create_streaming_response_with_output_filter<S, P>(
mut byte_stream: S, mut byte_stream: S,
mut inner_processor: P, mut inner_processor: P,
buffer_size: usize, output_chain: ResolvedFilterChain,
output_filters: Vec<String>,
output_filter_agents: HashMap<String, Agent>,
request_headers: HeaderMap, request_headers: HeaderMap,
upstream_path: String, request_path: String,
) -> StreamingResponse ) -> StreamingResponse
where where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static, S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor, P: StreamProcessor,
{ {
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size); let (tx, rx) = mpsc::channel::<Bytes>(STREAM_BUFFER_SIZE);
let current_span = tracing::Span::current(); let current_span = tracing::Span::current();
let processor_handle = tokio::spawn( let processor_handle = tokio::spawn(
async move { async move {
let mut is_first_chunk = true; let mut is_first_chunk = true;
let mut pipeline_processor = PipelineProcessor::default(); let mut pipeline_processor = PipelineProcessor::default();
let temp_filter_chain = AgentFilterChain { let chain = output_chain.to_agent_filter_chain("output_filter");
id: "output_filter".to_string(),
default: None,
description: None,
input_filters: Some(output_filters),
};
while let Some(item) = byte_stream.next().await { while let Some(item) = byte_stream.next().await {
let chunk = match item { let chunk = match item {
@ -331,10 +336,10 @@ where
let processed_chunk = match pipeline_processor let processed_chunk = match pipeline_processor
.process_raw_filter_chain( .process_raw_filter_chain(
&chunk, &chunk,
&temp_filter_chain, &chain,
&output_filter_agents, &output_chain.agents,
&request_headers, &request_headers,
&upstream_path, &request_path,
) )
.await .await
{ {

View file

@ -10,7 +10,9 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage; use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer; use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes; use bytes::Bytes;
use common::configuration::{Agent, Configuration, ListenerType}; use common::configuration::{
Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain,
};
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH}; use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
use common::llm_providers::LlmProviders; use common::llm_providers::LlmProviders;
use http_body_util::{combinators::BoxBody, BodyExt, Empty}; use http_body_util::{combinators::BoxBody, BodyExt, Empty};
@ -108,32 +110,22 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.listeners .listeners
.iter() .iter()
.find(|l| l.listener_type == ListenerType::Model); .find(|l| l.listener_type == ListenerType::Model);
let model_input_filters: Arc<Option<Vec<String>>> = let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
Arc::new(model_listener.and_then(|l| l.input_filters.clone())); filter_ids.map(|ids| {
let model_input_filter_agents: Arc<HashMap<String, Agent>> = Arc::new( let agents = ids
model_input_filters .iter()
.as_ref()
.as_ref()
.map(|fc| {
fc.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone()))) .filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect() .collect();
ResolvedFilterChain {
filter_ids: ids,
agents,
}
}) })
.unwrap_or_default(), };
); let filter_pipeline = Arc::new(FilterPipeline {
let model_output_filters: Arc<Option<Vec<String>>> = input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
Arc::new(model_listener.and_then(|l| l.output_filters.clone())); output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
let model_output_filter_agents: Arc<HashMap<String, Agent>> = Arc::new( });
model_output_filters
.as_ref()
.as_ref()
.map(|fc| {
fc.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect()
})
.unwrap_or_default(),
);
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone())); let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
let llm_provider_url = let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string()); env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
@ -249,10 +241,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone(); let llm_providers = llm_providers.clone();
let agents_list = combined_agents_filters_list.clone(); let agents_list = combined_agents_filters_list.clone();
let model_input_filters = model_input_filters.clone(); let filter_pipeline = filter_pipeline.clone();
let model_input_filter_agents = model_input_filter_agents.clone();
let model_output_filters = model_output_filters.clone();
let model_output_filter_agents = model_output_filter_agents.clone();
let listeners = listeners.clone(); let listeners = listeners.clone();
let span_attributes = span_attributes.clone(); let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone(); let state_storage = state_storage.clone();
@ -264,10 +253,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone(); let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases); let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone(); let agents_list = agents_list.clone();
let model_input_filters = model_input_filters.clone(); let filter_pipeline = filter_pipeline.clone();
let model_input_filter_agents = model_input_filter_agents.clone();
let model_output_filters = model_output_filters.clone();
let model_output_filter_agents = model_output_filter_agents.clone();
let listeners = listeners.clone(); let listeners = listeners.clone();
let span_attributes = span_attributes.clone(); let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone(); let state_storage = state_storage.clone();
@ -326,10 +312,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
llm_providers, llm_providers,
span_attributes, span_attributes,
state_storage, state_storage,
model_input_filters, filter_pipeline,
model_input_filter_agents,
model_output_filters,
model_output_filter_agents,
) )
.with_context(parent_cx) .with_context(parent_cx)
.await .await

View file

@ -7,7 +7,7 @@ use std::io::Read;
use std::sync::Arc; use std::sync::Arc;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::handlers::utils::StreamProcessor; use crate::handlers::streaming::StreamProcessor;
use crate::state::{OpenAIConversationState, StateStorage}; use crate::state::{OpenAIConversationState, StateStorage};
/// Processor that wraps another processor and handles v1/responses state management /// Processor that wraps another processor and handles v1/responses state management

View file

@ -30,6 +30,46 @@ pub struct AgentFilterChain {
pub input_filters: Option<Vec<String>>, pub input_filters: Option<Vec<String>>,
} }
/// A filter chain with its agent references resolved to concrete Agent objects.
/// Bundles the ordered filter IDs with the agent lookup map so they stay in sync.
#[derive(Debug, Clone, Default)]
pub struct ResolvedFilterChain {
pub filter_ids: Vec<String>,
pub agents: HashMap<String, Agent>,
}
impl ResolvedFilterChain {
pub fn is_empty(&self) -> bool {
self.filter_ids.is_empty()
}
pub fn to_agent_filter_chain(&self, id: &str) -> AgentFilterChain {
AgentFilterChain {
id: id.to_string(),
default: None,
description: None,
input_filters: Some(self.filter_ids.clone()),
}
}
}
/// Holds resolved input and output filter chains for a model listener.
#[derive(Debug, Clone, Default)]
pub struct FilterPipeline {
pub input: Option<ResolvedFilterChain>,
pub output: Option<ResolvedFilterChain>,
}
impl FilterPipeline {
pub fn has_input_filters(&self) -> bool {
self.input.as_ref().is_some_and(|c| !c.is_empty())
}
pub fn has_output_filters(&self) -> bool {
self.output.as_ref().is_some_and(|c| !c.is_empty())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum ListenerType { pub enum ListenerType {

View file

@ -118,7 +118,7 @@ async def deanonymize(path: str, request: Request) -> Response:
body_str = raw_body.decode("utf-8", errors="replace") body_str = raw_body.decode("utf-8", errors="replace")
if "data: " in body_str: if "data: " in body_str or "event: " in body_str:
return deanonymize_sse(request_id, body_str, mapping, is_anthropic) return deanonymize_sse(request_id, body_str, mapping, is_anthropic)
return deanonymize_json(request_id, raw_body, body_str, mapping, is_anthropic) return deanonymize_json(request_id, raw_body, body_str, mapping, is_anthropic)