pass raw bytes through input/output filter chains

This commit is contained in:
Adil Hafeez 2026-03-17 05:10:07 -07:00
parent 80dfb41cad
commit d26abbfb9c
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
11 changed files with 610 additions and 448 deletions

View file

@ -14,10 +14,10 @@ use tokio_stream::StreamExt;
use tracing::{debug, info, warn, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use super::pipeline_processor::PipelineProcessor;
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
use crate::tracing::{llm, set_service_name, signals as signal_constants};
use hermesllm::apis::openai::{Message, MessageContent, Role};
use hermesllm::apis::openai::Message;
/// Trait for processing streaming chunks
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
@ -281,184 +281,9 @@ where
}
}
/// Extract content text from an SSE chunk line (the JSON part after "data: ").
/// Returns the content string and whether it was found.
fn extract_sse_content(json_str: &str) -> Option<String> {
let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
value
.get("choices")?
.get(0)?
.get("delta")?
.get("content")?
.as_str()
.map(|s| s.to_string())
}
/// Replace content in an SSE JSON chunk with new content.
fn replace_sse_content(json_str: &str, new_content: &str) -> Option<String> {
let mut value: serde_json::Value = serde_json::from_str(json_str).ok()?;
value
.get_mut("choices")?
.get_mut(0)?
.get_mut("delta")?
.as_object_mut()?
.insert(
"content".to_string(),
serde_json::Value::String(new_content.to_string()),
);
serde_json::to_string(&value).ok()
}
/// Process an SSE chunk through the output filter chain.
/// Parses each `data: ` line, extracts content, sends through filters, and reconstructs.
async fn filter_sse_chunk(
chunk_str: &str,
pipeline_processor: &mut PipelineProcessor,
filter_chain: &AgentFilterChain,
filter_agents: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> String {
let mut result = String::new();
for line in chunk_str.split('\n') {
if let Some(json_str) = line.strip_prefix("data: ") {
if json_str.trim() == "[DONE]" {
result.push_str(line);
result.push('\n');
continue;
}
if let Some(content) = extract_sse_content(json_str) {
if content.is_empty() {
result.push_str(line);
result.push('\n');
continue;
}
// Send content through output filter chain
let messages = vec![Message {
role: Role::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
tool_calls: None,
tool_call_id: None,
}];
match pipeline_processor
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
.await
{
Ok(filtered_messages) => {
if let Some(msg) = filtered_messages.first() {
let filtered_content = match &msg.content {
Some(MessageContent::Text(t)) => Some(t.clone()),
_ => None,
};
if let Some(filtered_content) = filtered_content {
if let Some(new_json) =
replace_sse_content(json_str, &filtered_content)
{
result.push_str("data: ");
result.push_str(&new_json);
result.push('\n');
continue;
}
}
}
// Fallback: pass through original
result.push_str(line);
result.push('\n');
}
Err(e) => {
warn!(error = %e, "output filter chain error, passing through original chunk");
result.push_str(line);
result.push('\n');
}
}
} else {
// No content in this SSE line, pass through
result.push_str(line);
result.push('\n');
}
} else {
result.push_str(line);
result.push('\n');
}
}
// Remove trailing extra newline if the original didn't end with one
if !chunk_str.ends_with('\n') && result.ends_with('\n') {
result.pop();
}
result
}
/// Process a non-streaming JSON response through the output filter chain.
/// Extracts assistant message content, filters it, and reconstructs the response.
pub async fn filter_non_streaming_response(
response_bytes: &[u8],
pipeline_processor: &mut PipelineProcessor,
filter_chain: &AgentFilterChain,
filter_agents: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> Bytes {
let response_str = match std::str::from_utf8(response_bytes) {
Ok(s) => s,
Err(_) => return Bytes::from(response_bytes.to_vec()),
};
let mut value: serde_json::Value = match serde_json::from_str(response_str) {
Ok(v) => v,
Err(_) => return Bytes::from(response_bytes.to_vec()),
};
// Extract content from choices[0].message.content
let content = value
.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("message"))
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
.map(|s| s.to_string());
if let Some(content) = content {
let messages = vec![Message {
role: Role::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
tool_calls: None,
tool_call_id: None,
}];
match pipeline_processor
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
.await
{
Ok(filtered_messages) => {
if let Some(msg) = filtered_messages.first() {
let filtered_content = match &msg.content {
Some(MessageContent::Text(t)) => Some(t.clone()),
_ => None,
};
if let Some(filtered_content) = filtered_content {
if let Some(choices) = value.get_mut("choices") {
if let Some(choice) = choices.get_mut(0) {
if let Some(message) = choice.get_mut("message") {
message.as_object_mut().unwrap().insert(
"content".to_string(),
serde_json::Value::String(filtered_content),
);
}
}
}
}
}
}
Err(e) => {
warn!(error = %e, "output filter chain error on non-streaming response");
}
}
}
Bytes::from(serde_json::to_string(&value).unwrap_or_else(|_| response_str.to_string()))
}
/// Creates a streaming response that processes each chunk through output filters.
/// The output filter is called asynchronously for each SSE chunk's content.
/// Creates a streaming response that processes each raw chunk through output filters.
/// Filters receive the raw LLM response bytes and return (possibly modified) bytes.
/// On filter error mid-stream the original chunk is passed through (headers already sent).
pub fn create_streaming_response_with_output_filter<S, P>(
mut byte_stream: S,
mut inner_processor: P,
@ -466,6 +291,7 @@ pub fn create_streaming_response_with_output_filter<S, P>(
output_filters: Vec<String>,
output_filter_agents: HashMap<String, Agent>,
request_headers: HeaderMap,
upstream_path: String,
) -> StreamingResponse
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
@ -501,32 +327,35 @@ where
is_first_chunk = false;
}
// Try to process through output filter chain
let processed_chunk = if let Ok(chunk_str) = std::str::from_utf8(&chunk) {
if chunk_str.contains("data: ") {
let filtered = filter_sse_chunk(
chunk_str,
&mut pipeline_processor,
&temp_filter_chain,
&output_filter_agents,
&request_headers,
)
.await;
Bytes::from(filtered)
} else {
// Non-SSE chunk (could be non-streaming JSON response)
let filtered = filter_non_streaming_response(
&chunk,
&mut pipeline_processor,
&temp_filter_chain,
&output_filter_agents,
&request_headers,
)
.await;
filtered
// Pass raw chunk bytes through the output filter chain
let processed_chunk = match pipeline_processor
.process_raw_filter_chain(
&chunk,
&temp_filter_chain,
&output_filter_agents,
&request_headers,
&upstream_path,
)
.await
{
Ok(filtered) => filtered,
Err(PipelineError::ClientError {
agent,
status,
body,
}) => {
warn!(
agent = %agent,
status = %status,
body = %body,
"output filter client error, passing through original chunk"
);
chunk
}
Err(e) => {
warn!(error = %e, "output filter error, passing through original chunk");
chunk
}
} else {
chunk
};
// Pass through inner processor for metrics/observability