Add support for output filter chain on listener

- add demo for pii redaction
This commit is contained in:
Adil Hafeez 2026-03-12 15:49:40 -07:00
parent 8ae4901735
commit 73f15feea3
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
16 changed files with 1218 additions and 4 deletions

View file

@ -196,6 +196,7 @@ mod tests {
name: name.to_string(),
agents: Some(agents),
filter_chain: None,
output_filter_chain: None,
port: 8080,
router: None,
}

View file

@ -76,6 +76,7 @@ mod tests {
name: "test-listener".to_string(),
agents: Some(vec![agent_pipeline.clone()]),
filter_chain: None,
output_filter_chain: None,
port: 8080,
router: None,
};

View file

@ -23,7 +23,8 @@ use super::pipeline_processor::PipelineProcessor;
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
create_streaming_response, create_streaming_response_with_output_filter, truncate_message,
ObservableStreamProcessor,
};
use crate::router::llm_router::RouterService;
use crate::state::response_state_processor::ResponsesStateProcessor;
@ -47,6 +48,8 @@ pub async fn llm_chat(
state_storage: Option<Arc<dyn StateStorage>>,
filter_chain: Arc<Option<Vec<String>>>,
filter_agents: Arc<HashMap<String, Agent>>,
output_filter_chain: Arc<Option<Vec<String>>>,
output_filter_agents: Arc<HashMap<String, Agent>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
@ -88,6 +91,8 @@ pub async fn llm_chat(
request_headers,
filter_chain,
filter_agents,
output_filter_chain,
output_filter_agents,
)
.instrument(request_span)
.await
@ -107,6 +112,8 @@ async fn llm_chat_inner(
mut request_headers: hyper::HeaderMap,
filter_chain: Arc<Option<Vec<String>>>,
filter_agents: Arc<HashMap<String, Agent>>,
output_filter_chain: Arc<Option<Vec<String>>>,
output_filter_agents: Arc<HashMap<String, Agent>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
@ -501,6 +508,20 @@ async fn llm_chat_inner(
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
});
// Determine if output filter chain is configured
let has_output_filter = output_filter_chain
.as_ref()
.as_ref()
.map(|fc| !fc.is_empty())
.unwrap_or(false);
// Save request headers for output filter chain (before they're consumed by upstream request)
let output_filter_request_headers = if has_output_filter {
Some(request_headers.clone())
} else {
None
};
// Capture start time right before sending request to upstream
let request_start_time = std::time::Instant::now();
let _request_start_system_time = std::time::SystemTime::now();
@ -567,7 +588,31 @@ async fn llm_chat_inner(
content_encoding,
request_id,
);
create_streaming_response(byte_stream, state_processor, 16)
if has_output_filter {
let ofc = output_filter_chain.as_ref().as_ref().unwrap().clone();
let ofa = (*output_filter_agents).clone();
create_streaming_response_with_output_filter(
byte_stream,
state_processor,
16,
ofc,
ofa,
output_filter_request_headers.unwrap(),
)
} else {
create_streaming_response(byte_stream, state_processor, 16)
}
} else if has_output_filter {
let ofc = output_filter_chain.as_ref().as_ref().unwrap().clone();
let ofa = (*output_filter_agents).clone();
create_streaming_response_with_output_filter(
byte_stream,
base_processor,
16,
ofc,
ofa,
output_filter_request_headers.unwrap(),
)
} else {
// Use base processor without state management
create_streaming_response(byte_stream, base_processor, 16)

View file

@ -1,19 +1,23 @@
use bytes::Bytes;
use common::configuration::{Agent, AgentFilterChain};
use http_body_util::combinators::BoxBody;
use http_body_util::StreamBody;
use hyper::body::Frame;
use hyper::header::HeaderMap;
use opentelemetry::trace::TraceContextExt;
use opentelemetry::KeyValue;
use std::collections::HashMap;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{info, warn, Instrument};
use tracing::{debug, info, warn, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use super::pipeline_processor::PipelineProcessor;
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
use crate::tracing::{llm, set_service_name, signals as signal_constants};
use hermesllm::apis::openai::Message;
use hermesllm::apis::openai::{Message, MessageContent, Role};
/// Trait for processing streaming chunks
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
@ -277,6 +281,286 @@ where
}
}
/// Extract content text from an SSE chunk line (the JSON part after "data: ").
/// Returns the content string and whether it was found.
fn extract_sse_content(json_str: &str) -> Option<String> {
let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
value
.get("choices")?
.get(0)?
.get("delta")?
.get("content")?
.as_str()
.map(|s| s.to_string())
}
/// Replace content in an SSE JSON chunk with new content.
fn replace_sse_content(json_str: &str, new_content: &str) -> Option<String> {
let mut value: serde_json::Value = serde_json::from_str(json_str).ok()?;
value
.get_mut("choices")?
.get_mut(0)?
.get_mut("delta")?
.as_object_mut()?
.insert(
"content".to_string(),
serde_json::Value::String(new_content.to_string()),
);
serde_json::to_string(&value).ok()
}
/// Process an SSE chunk through the output filter chain.
/// Parses each `data: ` line, extracts content, sends through filters, and reconstructs.
async fn filter_sse_chunk(
chunk_str: &str,
pipeline_processor: &mut PipelineProcessor,
filter_chain: &AgentFilterChain,
filter_agents: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> String {
let mut result = String::new();
for line in chunk_str.split('\n') {
if let Some(json_str) = line.strip_prefix("data: ") {
if json_str.trim() == "[DONE]" {
result.push_str(line);
result.push('\n');
continue;
}
if let Some(content) = extract_sse_content(json_str) {
if content.is_empty() {
result.push_str(line);
result.push('\n');
continue;
}
// Send content through output filter chain
let messages = vec![Message {
role: Role::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
tool_calls: None,
tool_call_id: None,
}];
match pipeline_processor
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
.await
{
Ok(filtered_messages) => {
if let Some(msg) = filtered_messages.first() {
let filtered_content = match &msg.content {
Some(MessageContent::Text(t)) => Some(t.clone()),
_ => None,
};
if let Some(filtered_content) = filtered_content {
if let Some(new_json) =
replace_sse_content(json_str, &filtered_content)
{
result.push_str("data: ");
result.push_str(&new_json);
result.push('\n');
continue;
}
}
}
// Fallback: pass through original
result.push_str(line);
result.push('\n');
}
Err(e) => {
warn!(error = %e, "output filter chain error, passing through original chunk");
result.push_str(line);
result.push('\n');
}
}
} else {
// No content in this SSE line, pass through
result.push_str(line);
result.push('\n');
}
} else {
result.push_str(line);
result.push('\n');
}
}
// Remove trailing extra newline if the original didn't end with one
if !chunk_str.ends_with('\n') && result.ends_with('\n') {
result.pop();
}
result
}
/// Process a non-streaming JSON response through the output filter chain.
/// Extracts assistant message content, filters it, and reconstructs the response.
pub async fn filter_non_streaming_response(
response_bytes: &[u8],
pipeline_processor: &mut PipelineProcessor,
filter_chain: &AgentFilterChain,
filter_agents: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> Bytes {
let response_str = match std::str::from_utf8(response_bytes) {
Ok(s) => s,
Err(_) => return Bytes::from(response_bytes.to_vec()),
};
let mut value: serde_json::Value = match serde_json::from_str(response_str) {
Ok(v) => v,
Err(_) => return Bytes::from(response_bytes.to_vec()),
};
// Extract content from choices[0].message.content
let content = value
.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("message"))
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
.map(|s| s.to_string());
if let Some(content) = content {
let messages = vec![Message {
role: Role::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
tool_calls: None,
tool_call_id: None,
}];
match pipeline_processor
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
.await
{
Ok(filtered_messages) => {
if let Some(msg) = filtered_messages.first() {
let filtered_content = match &msg.content {
Some(MessageContent::Text(t)) => Some(t.clone()),
_ => None,
};
if let Some(filtered_content) = filtered_content {
if let Some(choices) = value.get_mut("choices") {
if let Some(choice) = choices.get_mut(0) {
if let Some(message) = choice.get_mut("message") {
message.as_object_mut().unwrap().insert(
"content".to_string(),
serde_json::Value::String(filtered_content),
);
}
}
}
}
}
}
Err(e) => {
warn!(error = %e, "output filter chain error on non-streaming response");
}
}
}
Bytes::from(serde_json::to_string(&value).unwrap_or_else(|_| response_str.to_string()))
}
/// Creates a streaming response that processes each chunk through an output filter chain.
/// The output filter is called asynchronously for each SSE chunk's content.
pub fn create_streaming_response_with_output_filter<S, P>(
mut byte_stream: S,
mut inner_processor: P,
buffer_size: usize,
output_filter_chain: Vec<String>,
output_filter_agents: HashMap<String, Agent>,
request_headers: HeaderMap,
) -> StreamingResponse
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size);
let current_span = tracing::Span::current();
let processor_handle = tokio::spawn(
async move {
let mut is_first_chunk = true;
let mut pipeline_processor = PipelineProcessor::default();
let temp_filter_chain = AgentFilterChain {
id: "output_filter".to_string(),
default: None,
description: None,
filter_chain: Some(output_filter_chain),
};
while let Some(item) = byte_stream.next().await {
let chunk = match item {
Ok(chunk) => chunk,
Err(err) => {
let err_msg = format!("Error receiving chunk: {:?}", err);
warn!(error = %err_msg, "stream error");
inner_processor.on_error(&err_msg);
break;
}
};
if is_first_chunk {
inner_processor.on_first_bytes();
is_first_chunk = false;
}
// Try to process through output filter chain
let processed_chunk = if let Ok(chunk_str) = std::str::from_utf8(&chunk) {
if chunk_str.contains("data: ") {
let filtered = filter_sse_chunk(
chunk_str,
&mut pipeline_processor,
&temp_filter_chain,
&output_filter_agents,
&request_headers,
)
.await;
Bytes::from(filtered)
} else {
// Non-SSE chunk (could be non-streaming JSON response)
let filtered = filter_non_streaming_response(
&chunk,
&mut pipeline_processor,
&temp_filter_chain,
&output_filter_agents,
&request_headers,
)
.await;
filtered
}
} else {
chunk
};
// Pass through inner processor for metrics/observability
match inner_processor.process_chunk(processed_chunk) {
Ok(Some(final_chunk)) => {
if tx.send(final_chunk).await.is_err() {
warn!("receiver dropped");
break;
}
}
Ok(None) => continue,
Err(err) => {
warn!("processor error: {}", err);
inner_processor.on_error(&err);
break;
}
}
}
inner_processor.on_complete();
debug!("output filter streaming completed");
}
.instrument(current_span),
);
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
StreamingResponse {
body: stream_body,
processor_handle,
}
}
/// Truncates a message to the specified maximum length, adding "..." if truncated.
pub fn truncate_message(message: &str, max_length: usize) -> String {
if message.chars().count() > max_length {

View file

@ -121,6 +121,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
})
.unwrap_or_default(),
);
let model_output_filter_chain: Arc<Option<Vec<String>>> =
Arc::new(model_listener.and_then(|l| l.output_filter_chain.clone()));
let model_output_filter_agents: Arc<HashMap<String, Agent>> = Arc::new(
model_output_filter_chain
.as_ref()
.as_ref()
.map(|fc| {
fc.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect()
})
.unwrap_or_default(),
);
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
@ -217,6 +230,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let agents_list = combined_agents_filters_list.clone();
let model_filter_chain = model_filter_chain.clone();
let model_filter_agents = model_filter_agents.clone();
let model_output_filter_chain = model_output_filter_chain.clone();
let model_output_filter_agents = model_output_filter_agents.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
@ -230,6 +245,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let agents_list = agents_list.clone();
let model_filter_chain = model_filter_chain.clone();
let model_filter_agents = model_filter_agents.clone();
let model_output_filter_chain = model_output_filter_chain.clone();
let model_output_filter_agents = model_output_filter_agents.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
@ -290,6 +307,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
state_storage,
model_filter_chain,
model_filter_agents,
model_output_filter_chain,
model_output_filter_agents,
)
.with_context(parent_cx)
.await