diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index 5d0ee07f..02e4fcf3 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -1,5 +1,5 @@ use bytes::Bytes; -use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes}; +use common::configuration::{FilterPipeline, ModelAlias, SpanAttributes}; use common::consts::{ ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, }; @@ -22,9 +22,9 @@ use tracing::{debug, info, info_span, warn, Instrument}; use super::pipeline_processor::PipelineProcessor; use crate::handlers::router_chat::router_chat_get_upstream_model; -use crate::handlers::utils::{ +use crate::handlers::streaming::{ create_streaming_response, create_streaming_response_with_output_filter, truncate_message, - ObservableStreamProcessor, + ObservableStreamProcessor, StreamProcessor, }; use crate::router::llm_router::RouterService; use crate::state::response_state_processor::ResponsesStateProcessor; @@ -46,10 +46,7 @@ pub async fn llm_chat( llm_providers: Arc>, span_attributes: Arc>, state_storage: Option>, - input_filters: Arc>>, - input_filter_agents: Arc>, - output_filters: Arc>>, - output_filter_agents: Arc>, + filter_pipeline: Arc, ) -> Result>, hyper::Error> { let request_path = request.uri().path().to_string(); let request_headers = request.headers().clone(); @@ -89,10 +86,7 @@ pub async fn llm_chat( request_id, request_path, request_headers, - input_filters, - input_filter_agents, - output_filters, - output_filter_agents, + filter_pipeline, ) .instrument(request_span) .await @@ -110,10 +104,7 @@ async fn llm_chat_inner( request_id: String, request_path: String, mut request_headers: hyper::HeaderMap, - input_filters: Arc>>, - input_filter_agents: Arc>, - output_filters: Arc>>, - output_filter_agents: Arc>, + filter_pipeline: Arc, ) -> Result>, hyper::Error> { // Set service name for LLM operations set_service_name(operation_component::LLM); @@ -271,23 +262,18 @@ async fn llm_chat_inner( // Filters receive the raw request bytes and return (possibly modified) raw bytes. // The returned bytes are re-parsed into a ProviderRequestType to continue the request. { - if let Some(ref fc) = *input_filters { - if !fc.is_empty() { - debug!(input_filters = ?fc, "processing model listener input filters"); + if let Some(ref input_chain) = filter_pipeline.input { + if !input_chain.is_empty() { + debug!(input_filters = ?input_chain.filter_ids, "processing model listener input filters"); - let temp_filter_chain = AgentFilterChain { - id: "model_listener".to_string(), - default: None, - description: None, - input_filters: Some(fc.clone()), - }; + let chain = input_chain.to_agent_filter_chain("model_listener"); let mut pipeline_processor = PipelineProcessor::default(); match pipeline_processor .process_raw_filter_chain( &chat_request_bytes, - &temp_filter_chain, - &input_filter_agents, + &chain, + &input_chain.agents, &request_headers, &request_path, ) @@ -523,12 +509,7 @@ async fn llm_chat_inner( // Output filters run for any API shape that reaches this handler (e.g. /v1/chat/completions, // /v1/messages, /v1/responses). Brightstaff does inbound translation and llm_gateway does // outbound translation; filters receive raw response bytes and request path. - let output_filters_configured = output_filters - .as_ref() - .as_ref() - .map(|fc| !fc.is_empty()) - .unwrap_or(false); - let has_output_filter = output_filters_configured; + let has_output_filter = filter_pipeline.has_output_filters(); // Save request headers for output filters (before they're consumed by upstream request) let output_filter_request_headers = if has_output_filter { @@ -579,20 +560,18 @@ async fn llm_chat_inner( ); // === v1/responses state management: Wrap with ResponsesStateProcessor === - // Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured) - let streaming_response = if let (true, false, Some(state_store)) = ( + // Pick the right processor: state-aware if needed, otherwise base metrics-only. + let processor: Box = if let (true, false, Some(state_store)) = ( should_manage_state, original_input_items.is_empty(), state_storage, ) { - // Extract Content-Encoding header to handle decompression for state parsing let content_encoding = response_headers .get("content-encoding") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); - // Wrap with state management processor to store state after response completes - let state_processor = ResponsesStateProcessor::new( + Box::new(ResponsesStateProcessor::new( base_processor, state_store, original_input_items, @@ -602,37 +581,23 @@ async fn llm_chat_inner( false, // Not OpenAI upstream since should_manage_state is true content_encoding, request_id, - ); - if has_output_filter { - let ofc = output_filters.as_ref().as_ref().unwrap().clone(); - let ofa = (*output_filter_agents).clone(); - create_streaming_response_with_output_filter( - byte_stream, - state_processor, - 16, - ofc, - ofa, - output_filter_request_headers.unwrap(), - request_path.clone(), - ) - } else { - create_streaming_response(byte_stream, state_processor, 16) - } - } else if has_output_filter { - let ofc = output_filters.as_ref().as_ref().unwrap().clone(); - let ofa = (*output_filter_agents).clone(); + )) + } else { + Box::new(base_processor) + }; + + // Apply output filters if configured, then build the streaming response. + let streaming_response = if has_output_filter { + let output_chain = filter_pipeline.output.as_ref().unwrap().clone(); create_streaming_response_with_output_filter( byte_stream, - base_processor, - 16, - ofc, - ofa, + processor, + output_chain, output_filter_request_headers.unwrap(), request_path.clone(), ) } else { - // Use base processor without state management - create_streaming_response(byte_stream, base_processor, 16) + create_streaming_response(byte_stream, processor) }; match response.body(streaming_response.body) { diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index 9c602e93..b2161e43 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -8,7 +8,7 @@ pub mod pipeline_processor; pub mod response_handler; pub mod router_chat; pub mod routing_service; -pub mod utils; +pub mod streaming; #[cfg(test)] mod integration_tests; diff --git a/crates/brightstaff/src/handlers/utils.rs b/crates/brightstaff/src/handlers/streaming.rs similarity index 93% rename from crates/brightstaff/src/handlers/utils.rs rename to crates/brightstaff/src/handlers/streaming.rs index 00707356..0cea182e 100644 --- a/crates/brightstaff/src/handlers/utils.rs +++ b/crates/brightstaff/src/handlers/streaming.rs @@ -1,12 +1,11 @@ use bytes::Bytes; -use common::configuration::{Agent, AgentFilterChain}; +use common::configuration::ResolvedFilterChain; use http_body_util::combinators::BoxBody; use http_body_util::StreamBody; use hyper::body::Frame; use hyper::header::HeaderMap; use opentelemetry::trace::TraceContextExt; use opentelemetry::KeyValue; -use std::collections::HashMap; use std::time::Instant; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; @@ -15,6 +14,8 @@ use tracing::{debug, info, warn, Instrument}; use tracing_opentelemetry::OpenTelemetrySpanExt; use super::pipeline_processor::{PipelineError, PipelineProcessor}; + +const STREAM_BUFFER_SIZE: usize = 16; use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER}; use crate::tracing::{llm, set_service_name, signals as signal_constants}; use hermesllm::apis::openai::Message; @@ -35,6 +36,21 @@ pub trait StreamProcessor: Send + 'static { fn on_error(&mut self, _error: &str) {} } +impl StreamProcessor for Box { + fn process_chunk(&mut self, chunk: Bytes) -> Result, String> { + (**self).process_chunk(chunk) + } + fn on_first_bytes(&mut self) { + (**self).on_first_bytes() + } + fn on_complete(&mut self) { + (**self).on_complete() + } + fn on_error(&mut self, error: &str) { + (**self).on_error(error) + } +} + /// A processor that tracks streaming metrics pub struct ObservableStreamProcessor { service_name: String, @@ -210,16 +226,12 @@ pub struct StreamingResponse { pub processor_handle: tokio::task::JoinHandle<()>, } -pub fn create_streaming_response( - mut byte_stream: S, - mut processor: P, - buffer_size: usize, -) -> StreamingResponse +pub fn create_streaming_response(mut byte_stream: S, mut processor: P) -> StreamingResponse where S: StreamExt> + Send + Unpin + 'static, P: StreamProcessor, { - let (tx, rx) = mpsc::channel::(buffer_size); + let (tx, rx) = mpsc::channel::(STREAM_BUFFER_SIZE); // Capture the current span so the spawned task inherits the request context let current_span = tracing::Span::current(); @@ -287,29 +299,22 @@ where pub fn create_streaming_response_with_output_filter( mut byte_stream: S, mut inner_processor: P, - buffer_size: usize, - output_filters: Vec, - output_filter_agents: HashMap, + output_chain: ResolvedFilterChain, request_headers: HeaderMap, - upstream_path: String, + request_path: String, ) -> StreamingResponse where S: StreamExt> + Send + Unpin + 'static, P: StreamProcessor, { - let (tx, rx) = mpsc::channel::(buffer_size); + let (tx, rx) = mpsc::channel::(STREAM_BUFFER_SIZE); let current_span = tracing::Span::current(); let processor_handle = tokio::spawn( async move { let mut is_first_chunk = true; let mut pipeline_processor = PipelineProcessor::default(); - let temp_filter_chain = AgentFilterChain { - id: "output_filter".to_string(), - default: None, - description: None, - input_filters: Some(output_filters), - }; + let chain = output_chain.to_agent_filter_chain("output_filter"); while let Some(item) = byte_stream.next().await { let chunk = match item { @@ -331,10 +336,10 @@ where let processed_chunk = match pipeline_processor .process_raw_filter_chain( &chunk, - &temp_filter_chain, - &output_filter_agents, + &chain, + &output_chain.agents, &request_headers, - &upstream_path, + &request_path, ) .await { diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 96a66c60..391aed03 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -10,7 +10,9 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage; use brightstaff::state::StateStorage; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; -use common::configuration::{Agent, Configuration, ListenerType}; +use common::configuration::{ + Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain, +}; use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH}; use common::llm_providers::LlmProviders; use http_body_util::{combinators::BoxBody, BodyExt, Empty}; @@ -108,32 +110,22 @@ async fn main() -> Result<(), Box> { .listeners .iter() .find(|l| l.listener_type == ListenerType::Model); - let model_input_filters: Arc>> = - Arc::new(model_listener.and_then(|l| l.input_filters.clone())); - let model_input_filter_agents: Arc> = Arc::new( - model_input_filters - .as_ref() - .as_ref() - .map(|fc| { - fc.iter() - .filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone()))) - .collect() - }) - .unwrap_or_default(), - ); - let model_output_filters: Arc>> = - Arc::new(model_listener.and_then(|l| l.output_filters.clone())); - let model_output_filter_agents: Arc> = Arc::new( - model_output_filters - .as_ref() - .as_ref() - .map(|fc| { - fc.iter() - .filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone()))) - .collect() - }) - .unwrap_or_default(), - ); + let resolve_chain = |filter_ids: Option>| -> Option { + filter_ids.map(|ids| { + let agents = ids + .iter() + .filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone()))) + .collect(); + ResolvedFilterChain { + filter_ids: ids, + agents, + } + }) + }; + let filter_pipeline = Arc::new(FilterPipeline { + input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())), + output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())), + }); let listeners = Arc::new(RwLock::new(plano_config.listeners.clone())); let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string()); @@ -249,10 +241,7 @@ async fn main() -> Result<(), Box> { let llm_providers = llm_providers.clone(); let agents_list = combined_agents_filters_list.clone(); - let model_input_filters = model_input_filters.clone(); - let model_input_filter_agents = model_input_filter_agents.clone(); - let model_output_filters = model_output_filters.clone(); - let model_output_filter_agents = model_output_filter_agents.clone(); + let filter_pipeline = filter_pipeline.clone(); let listeners = listeners.clone(); let span_attributes = span_attributes.clone(); let state_storage = state_storage.clone(); @@ -264,10 +253,7 @@ async fn main() -> Result<(), Box> { let llm_providers = llm_providers.clone(); let model_aliases = Arc::clone(&model_aliases); let agents_list = agents_list.clone(); - let model_input_filters = model_input_filters.clone(); - let model_input_filter_agents = model_input_filter_agents.clone(); - let model_output_filters = model_output_filters.clone(); - let model_output_filter_agents = model_output_filter_agents.clone(); + let filter_pipeline = filter_pipeline.clone(); let listeners = listeners.clone(); let span_attributes = span_attributes.clone(); let state_storage = state_storage.clone(); @@ -326,10 +312,7 @@ async fn main() -> Result<(), Box> { llm_providers, span_attributes, state_storage, - model_input_filters, - model_input_filter_agents, - model_output_filters, - model_output_filter_agents, + filter_pipeline, ) .with_context(parent_cx) .await diff --git a/crates/brightstaff/src/state/response_state_processor.rs b/crates/brightstaff/src/state/response_state_processor.rs index 0920324c..6f6c7b62 100644 --- a/crates/brightstaff/src/state/response_state_processor.rs +++ b/crates/brightstaff/src/state/response_state_processor.rs @@ -7,7 +7,7 @@ use std::io::Read; use std::sync::Arc; use tracing::{debug, info, warn}; -use crate::handlers::utils::StreamProcessor; +use crate::handlers::streaming::StreamProcessor; use crate::state::{OpenAIConversationState, StateStorage}; /// Processor that wraps another processor and handles v1/responses state management diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 6bdaa01e..df179059 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -30,6 +30,46 @@ pub struct AgentFilterChain { pub input_filters: Option>, } +/// A filter chain with its agent references resolved to concrete Agent objects. +/// Bundles the ordered filter IDs with the agent lookup map so they stay in sync. +#[derive(Debug, Clone, Default)] +pub struct ResolvedFilterChain { + pub filter_ids: Vec, + pub agents: HashMap, +} + +impl ResolvedFilterChain { + pub fn is_empty(&self) -> bool { + self.filter_ids.is_empty() + } + + pub fn to_agent_filter_chain(&self, id: &str) -> AgentFilterChain { + AgentFilterChain { + id: id.to_string(), + default: None, + description: None, + input_filters: Some(self.filter_ids.clone()), + } + } +} + +/// Holds resolved input and output filter chains for a model listener. +#[derive(Debug, Clone, Default)] +pub struct FilterPipeline { + pub input: Option, + pub output: Option, +} + +impl FilterPipeline { + pub fn has_input_filters(&self) -> bool { + self.input.as_ref().is_some_and(|c| !c.is_empty()) + } + + pub fn has_output_filters(&self) -> bool { + self.output.as_ref().is_some_and(|c| !c.is_empty()) + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[serde(rename_all = "lowercase")] pub enum ListenerType { diff --git a/demos/filter_chains/pii_anonymizer/pii_anonymizer.py b/demos/filter_chains/pii_anonymizer/pii_anonymizer.py index f8e2e613..9c1450cc 100644 --- a/demos/filter_chains/pii_anonymizer/pii_anonymizer.py +++ b/demos/filter_chains/pii_anonymizer/pii_anonymizer.py @@ -118,7 +118,7 @@ async def deanonymize(path: str, request: Request) -> Response: body_str = raw_body.decode("utf-8", errors="replace") - if "data: " in body_str: + if "data: " in body_str or "event: " in body_str: return deanonymize_sse(request_id, body_str, mapping, is_anthropic) return deanonymize_json(request_id, raw_body, body_str, mapping, is_anthropic)