pass raw bytes through input/output filter chains

This commit is contained in:
Adil Hafeez 2026-03-17 05:10:07 -07:00
parent 80dfb41cad
commit d26abbfb9c
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
11 changed files with 610 additions and 448 deletions

View file

@ -268,12 +268,13 @@ async fn llm_chat_inner(
}
// === Input filters processing for model listener ===
// Filters receive the raw request bytes and return (possibly modified) raw bytes.
// The returned bytes are re-parsed into a ProviderRequestType to continue the request.
{
if let Some(ref fc) = *input_filters {
if !fc.is_empty() {
debug!(input_filters = ?fc, "processing model listener input filters");
// Create a temporary AgentFilterChain to reuse PipelineProcessor
let temp_filter_chain = AgentFilterChain {
id: "model_listener".to_string(),
default: None,
@ -282,23 +283,34 @@ async fn llm_chat_inner(
};
let mut pipeline_processor = PipelineProcessor::default();
let messages = client_request.get_messages();
match pipeline_processor
.process_filter_chain(
&messages,
.process_raw_filter_chain(
&chat_request_bytes,
&temp_filter_chain,
&input_filter_agents,
&request_headers,
&request_path,
)
.await
{
Ok(filtered_messages) => {
client_request.set_messages(&filtered_messages);
info!(
original_count = messages.len(),
filtered_count = filtered_messages.len(),
"filter chain processed successfully"
);
Ok(filtered_bytes) => {
match ProviderRequestType::try_from((
&filtered_bytes[..],
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(updated_request) => {
client_request = updated_request;
info!("input filter chain processed successfully");
}
Err(parse_err) => {
warn!(error = %parse_err, "input filter returned invalid request JSON");
return Ok(BrightStaffError::InvalidRequest(format!(
"Input filter returned invalid request: {}",
parse_err
))
.into_response());
}
}
}
Err(super::pipeline_processor::PipelineError::ClientError {
agent,
@ -508,21 +520,25 @@ async fn llm_chat_inner(
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
});
// Output filters are only supported for /v1/chat/completions — the SSE content
// extraction logic is specific to that API shape (choices[].delta.content).
let output_filters_configured = output_filters
.as_ref()
.as_ref()
.map(|fc| !fc.is_empty())
.unwrap_or(false);
let has_output_filter = output_filters_configured
&& request_path == common::consts::CHAT_COMPLETIONS_PATH;
if output_filters_configured && !has_output_filter {
warn!(
path = %request_path,
"output filters are configured but only supported for /v1/chat/completions, skipping"
);
}
let has_output_filter = output_filters_configured;
// Extract the upstream API path (e.g. "/v1/messages" from "https://api.anthropic.com/v1/messages").
// Output filters are called at <agent.url><upstream_api_path> so they know the exact byte format.
let upstream_api_path = {
let after_scheme = full_qualified_llm_provider_url
.find("://")
.map(|i| &full_qualified_llm_provider_url[i + 3..])
.unwrap_or(&full_qualified_llm_provider_url);
after_scheme
.find('/')
.map(|i| after_scheme[i..].to_string())
.unwrap_or_else(|| "/".to_string())
};
// Save request headers for output filters (before they're consumed by upstream request)
let output_filter_request_headers = if has_output_filter {
@ -607,6 +623,7 @@ async fn llm_chat_inner(
ofc,
ofa,
output_filter_request_headers.unwrap(),
upstream_api_path.clone(),
)
} else {
create_streaming_response(byte_stream, state_processor, 16)
@ -621,6 +638,7 @@ async fn llm_chat_inner(
ofc,
ofa,
output_filter_request_headers.unwrap(),
upstream_api_path,
)
} else {
// Use base processor without state management

View file

@ -1,5 +1,6 @@
use std::collections::HashMap;
use bytes::Bytes;
use common::configuration::{Agent, AgentFilterChain};
use common::consts::{
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
@ -605,6 +606,139 @@ impl PipelineProcessor {
Ok(messages)
}
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
/// Used for input and output filters where the full raw request/response is passed through.
/// No MCP protocol wrapping; agent_type is ignored.
#[instrument(
skip(self, raw_bytes, agent, request_headers),
fields(
agent_id = %agent.id,
agent_url = %agent.url,
filter_name = %agent.id,
bytes_len = raw_bytes.len()
)
)]
async fn execute_raw_filter(
&mut self,
raw_bytes: &[u8],
agent: &Agent,
request_headers: &HeaderMap,
request_path: &str,
) -> Result<Bytes, PipelineError> {
set_service_name(operation_component::AGENT_FILTER);
use opentelemetry::trace::get_active_span;
get_active_span(|span| {
span.update_name(format!("execute_raw_filter ({})", agent.id));
});
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
agent_headers.remove(TRACE_PARENT_HEADER);
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
});
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&agent.id)
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
agent_headers.insert(
"Accept",
hyper::header::HeaderValue::from_static("application/json"),
);
agent_headers.insert(
"Content-Type",
hyper::header::HeaderValue::from_static("application/json"),
);
// Append the original request path so the filter endpoint encodes the API format.
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
// -> POST http://host/anonymize/v1/chat/completions
let url = format!("{}{}", agent.url, request_path);
debug!(agent = %agent.id, url = %url, "sending raw filter request");
let response = self
.client
.post(&url)
.headers(agent_headers)
.body(raw_bytes.to_vec())
.send()
.await?;
let http_status = response.status();
let response_bytes = response.bytes().await?;
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
PipelineError::ClientError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
} else {
PipelineError::ServerError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
});
}
debug!(agent = %agent.id, bytes_len = response_bytes.len(), "raw filter response received");
Ok(response_bytes)
}
/// Process a chain of raw-bytes filters sequentially.
/// Input: raw request or response bytes. Output: filtered bytes.
/// Each agent receives the output of the previous one.
pub async fn process_raw_filter_chain(
&mut self,
raw_bytes: &[u8],
agent_filter_chain: &AgentFilterChain,
agent_map: &HashMap<String, Agent>,
request_headers: &HeaderMap,
request_path: &str,
) -> Result<Bytes, PipelineError> {
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
Some(fc) if !fc.is_empty() => fc,
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
};
let mut current_bytes = Bytes::copy_from_slice(raw_bytes);
for agent_name in filter_chain {
debug!(agent = %agent_name, "processing raw filter agent");
let agent = agent_map
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
info!(
agent = %agent_name,
url = %agent.url,
bytes_len = current_bytes.len(),
"executing raw filter"
);
current_bytes = self
.execute_raw_filter(&current_bytes, agent, request_headers, request_path)
.await?;
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
}
Ok(current_bytes)
}
/// Send request to terminal agent and return the raw response for streaming
/// Note: The caller is responsible for creating the plano(agent) span that wraps
/// both this call and the subsequent response consumption.

View file

@ -14,10 +14,10 @@ use tokio_stream::StreamExt;
use tracing::{debug, info, warn, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use super::pipeline_processor::PipelineProcessor;
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
use crate::tracing::{llm, set_service_name, signals as signal_constants};
use hermesllm::apis::openai::{Message, MessageContent, Role};
use hermesllm::apis::openai::Message;
/// Trait for processing streaming chunks
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
@ -281,184 +281,9 @@ where
}
}
/// Extract content text from an SSE chunk line (the JSON part after "data: ").
/// Returns the content string and whether it was found.
fn extract_sse_content(json_str: &str) -> Option<String> {
let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
value
.get("choices")?
.get(0)?
.get("delta")?
.get("content")?
.as_str()
.map(|s| s.to_string())
}
/// Replace content in an SSE JSON chunk with new content.
fn replace_sse_content(json_str: &str, new_content: &str) -> Option<String> {
let mut value: serde_json::Value = serde_json::from_str(json_str).ok()?;
value
.get_mut("choices")?
.get_mut(0)?
.get_mut("delta")?
.as_object_mut()?
.insert(
"content".to_string(),
serde_json::Value::String(new_content.to_string()),
);
serde_json::to_string(&value).ok()
}
/// Process an SSE chunk through the output filter chain.
/// Parses each `data: ` line, extracts content, sends through filters, and reconstructs.
async fn filter_sse_chunk(
chunk_str: &str,
pipeline_processor: &mut PipelineProcessor,
filter_chain: &AgentFilterChain,
filter_agents: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> String {
let mut result = String::new();
for line in chunk_str.split('\n') {
if let Some(json_str) = line.strip_prefix("data: ") {
if json_str.trim() == "[DONE]" {
result.push_str(line);
result.push('\n');
continue;
}
if let Some(content) = extract_sse_content(json_str) {
if content.is_empty() {
result.push_str(line);
result.push('\n');
continue;
}
// Send content through output filter chain
let messages = vec![Message {
role: Role::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
tool_calls: None,
tool_call_id: None,
}];
match pipeline_processor
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
.await
{
Ok(filtered_messages) => {
if let Some(msg) = filtered_messages.first() {
let filtered_content = match &msg.content {
Some(MessageContent::Text(t)) => Some(t.clone()),
_ => None,
};
if let Some(filtered_content) = filtered_content {
if let Some(new_json) =
replace_sse_content(json_str, &filtered_content)
{
result.push_str("data: ");
result.push_str(&new_json);
result.push('\n');
continue;
}
}
}
// Fallback: pass through original
result.push_str(line);
result.push('\n');
}
Err(e) => {
warn!(error = %e, "output filter chain error, passing through original chunk");
result.push_str(line);
result.push('\n');
}
}
} else {
// No content in this SSE line, pass through
result.push_str(line);
result.push('\n');
}
} else {
result.push_str(line);
result.push('\n');
}
}
// Remove trailing extra newline if the original didn't end with one
if !chunk_str.ends_with('\n') && result.ends_with('\n') {
result.pop();
}
result
}
/// Process a non-streaming JSON response through the output filter chain.
/// Extracts assistant message content, filters it, and reconstructs the response.
pub async fn filter_non_streaming_response(
response_bytes: &[u8],
pipeline_processor: &mut PipelineProcessor,
filter_chain: &AgentFilterChain,
filter_agents: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> Bytes {
let response_str = match std::str::from_utf8(response_bytes) {
Ok(s) => s,
Err(_) => return Bytes::from(response_bytes.to_vec()),
};
let mut value: serde_json::Value = match serde_json::from_str(response_str) {
Ok(v) => v,
Err(_) => return Bytes::from(response_bytes.to_vec()),
};
// Extract content from choices[0].message.content
let content = value
.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("message"))
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
.map(|s| s.to_string());
if let Some(content) = content {
let messages = vec![Message {
role: Role::Assistant,
content: Some(MessageContent::Text(content)),
name: None,
tool_calls: None,
tool_call_id: None,
}];
match pipeline_processor
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
.await
{
Ok(filtered_messages) => {
if let Some(msg) = filtered_messages.first() {
let filtered_content = match &msg.content {
Some(MessageContent::Text(t)) => Some(t.clone()),
_ => None,
};
if let Some(filtered_content) = filtered_content {
if let Some(choices) = value.get_mut("choices") {
if let Some(choice) = choices.get_mut(0) {
if let Some(message) = choice.get_mut("message") {
message.as_object_mut().unwrap().insert(
"content".to_string(),
serde_json::Value::String(filtered_content),
);
}
}
}
}
}
}
Err(e) => {
warn!(error = %e, "output filter chain error on non-streaming response");
}
}
}
Bytes::from(serde_json::to_string(&value).unwrap_or_else(|_| response_str.to_string()))
}
/// Creates a streaming response that processes each chunk through output filters.
/// The output filter is called asynchronously for each SSE chunk's content.
/// Creates a streaming response that processes each raw chunk through output filters.
/// Filters receive the raw LLM response bytes and return (possibly modified) bytes.
/// On filter error mid-stream the original chunk is passed through (headers already sent).
pub fn create_streaming_response_with_output_filter<S, P>(
mut byte_stream: S,
mut inner_processor: P,
@ -466,6 +291,7 @@ pub fn create_streaming_response_with_output_filter<S, P>(
output_filters: Vec<String>,
output_filter_agents: HashMap<String, Agent>,
request_headers: HeaderMap,
upstream_path: String,
) -> StreamingResponse
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
@ -501,32 +327,35 @@ where
is_first_chunk = false;
}
// Try to process through output filter chain
let processed_chunk = if let Ok(chunk_str) = std::str::from_utf8(&chunk) {
if chunk_str.contains("data: ") {
let filtered = filter_sse_chunk(
chunk_str,
&mut pipeline_processor,
&temp_filter_chain,
&output_filter_agents,
&request_headers,
)
.await;
Bytes::from(filtered)
} else {
// Non-SSE chunk (could be non-streaming JSON response)
let filtered = filter_non_streaming_response(
&chunk,
&mut pipeline_processor,
&temp_filter_chain,
&output_filter_agents,
&request_headers,
)
.await;
filtered
// Pass raw chunk bytes through the output filter chain
let processed_chunk = match pipeline_processor
.process_raw_filter_chain(
&chunk,
&temp_filter_chain,
&output_filter_agents,
&request_headers,
&upstream_path,
)
.await
{
Ok(filtered) => filtered,
Err(PipelineError::ClientError {
agent,
status,
body,
}) => {
warn!(
agent = %agent,
status = %status,
body = %body,
"output filter client error, passing through original chunk"
);
chunk
}
Err(e) => {
warn!(error = %e, "output filter error, passing through original chunk");
chunk
}
} else {
chunk
};
// Pass through inner processor for metrics/observability