add output filter chain (#822)

This commit is contained in:
Adil Hafeez 2026-03-18 17:58:20 -07:00 committed by GitHub
parent de2d8847f3
commit 1f23c573bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
59 changed files with 2961 additions and 2621 deletions

View file

@ -332,15 +332,38 @@ async fn handle_agent_chat_inner(
"processing agent"
);
// Process the filter chain
let chat_history = pipeline_processor
.process_filter_chain(
&current_messages,
selected_agent,
&agent_map,
&request_headers,
)
.await?;
// Process input filters — serialize current request as OpenAI chat completions body,
// pass raw bytes through each filter, then extract updated messages from the result.
let chat_history = if selected_agent
.input_filters
.as_ref()
.map(|f| !f.is_empty())
.unwrap_or(false)
{
let filter_body = serde_json::json!({
"model": client_request.model(),
"messages": current_messages,
});
let filter_bytes =
serde_json::to_vec(&filter_body).map_err(PipelineError::ParseError)?;
let filtered_bytes = pipeline_processor
.process_raw_filter_chain(
&filter_bytes,
selected_agent,
&agent_map,
&request_headers,
"/v1/chat/completions",
)
.await?;
let filtered_body: serde_json::Value =
serde_json::from_slice(&filtered_bytes).map_err(PipelineError::ParseError)?;
serde_json::from_value(filtered_body["messages"].clone())
.map_err(PipelineError::ParseError)?
} else {
current_messages.clone()
};
// Get agent details and invoke
let agent = agent_map.get(&agent_name).unwrap();

View file

@ -172,7 +172,7 @@ impl AgentSelector {
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{AgentFilterChain, Listener};
use common::configuration::{AgentFilterChain, Listener, ListenerType};
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
@ -187,14 +187,17 @@ mod tests {
id: name.to_string(),
description: Some(description.to_string()),
default: Some(is_default),
filter_chain: Some(vec![name.to_string()]),
input_filters: Some(vec![name.to_string()]),
}
}
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
Listener {
listener_type: ListenerType::Agent,
name: name.to_string(),
agents: Some(agents),
input_filters: None,
output_filters: None,
port: 8080,
router: None,
}

View file

@ -17,7 +17,7 @@ use hyper::StatusCode;
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};
use common::configuration::{Agent, AgentFilterChain, Listener, ListenerType};
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
@ -64,7 +64,7 @@ mod tests {
let agent_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: Some(vec![
input_filters: Some(vec![
"filter-agent".to_string(),
"terminal-agent".to_string(),
]),
@ -73,8 +73,11 @@ mod tests {
};
let listener = Listener {
listener_type: ListenerType::Agent,
name: "test-listener".to_string(),
agents: Some(vec![agent_pipeline.clone()]),
input_filters: None,
output_filters: None,
port: 8080,
router: None,
};
@ -107,23 +110,32 @@ mod tests {
// Create a pipeline with empty filter chain to avoid network calls
let test_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: Some(vec![]), // Empty filter chain - no network calls needed
input_filters: Some(vec![]), // Empty filter chain - no network calls needed
description: None,
default: None,
};
let headers = HeaderMap::new();
let request_bytes = serde_json::to_vec(&request).expect("failed to serialize request");
let result = pipeline_processor
.process_filter_chain(&request.messages, &test_pipeline, &agent_map, &headers)
.process_raw_filter_chain(
&request_bytes,
&test_pipeline,
&agent_map,
&headers,
"/v1/chat/completions",
)
.await;
println!("Pipeline processing result: {:?}", result);
assert!(result.is_ok());
let processed_messages = result.unwrap();
// With empty filter chain, should return the original messages unchanged
assert_eq!(processed_messages.len(), 1);
if let Some(MessageContent::Text(content)) = &processed_messages[0].content {
let processed_bytes = result.unwrap();
// With empty filter chain, should return the original bytes unchanged
let processed_request: ChatCompletionsRequest =
serde_json::from_slice(&processed_bytes).expect("failed to deserialize response");
assert_eq!(processed_request.messages.len(), 1);
if let Some(MessageContent::Text(content)) = &processed_request.messages[0].content {
assert_eq!(content, "Hello world!");
} else {
panic!("Expected text content");

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::configuration::{ModelAlias, SpanAttributes};
use common::configuration::{FilterPipeline, ModelAlias, SpanAttributes};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
@ -8,9 +8,9 @@ use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::{BodyExt, Full};
use hyper::header::{self};
use hyper::{Request, Response};
use hyper::{Request, Response, StatusCode};
use opentelemetry::global;
use opentelemetry::trace::get_active_span;
use opentelemetry_http::HeaderInjector;
@ -19,9 +19,12 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
use super::pipeline_processor::PipelineProcessor;
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
use crate::handlers::streaming::{
create_streaming_response, create_streaming_response_with_output_filter, truncate_message,
ObservableStreamProcessor, StreamProcessor,
};
use crate::router::llm_router::RouterService;
use crate::state::response_state_processor::ResponsesStateProcessor;
@ -34,6 +37,7 @@ use crate::tracing::{
use common::errors::BrightStaffError;
#[allow(clippy::too_many_arguments)]
pub async fn llm_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
@ -42,6 +46,7 @@ pub async fn llm_chat(
llm_providers: Arc<RwLock<LlmProviders>>,
span_attributes: Arc<Option<SpanAttributes>>,
state_storage: Option<Arc<dyn StateStorage>>,
filter_pipeline: Arc<FilterPipeline>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
@ -81,6 +86,7 @@ pub async fn llm_chat(
request_id,
request_path,
request_headers,
filter_pipeline,
)
.instrument(request_span)
.await
@ -98,6 +104,7 @@ async fn llm_chat_inner(
request_id: String,
request_path: String,
mut request_headers: hyper::HeaderMap,
filter_pipeline: Arc<FilterPipeline>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
@ -250,6 +257,85 @@ async fn llm_chat_inner(
if client_request.remove_metadata_key("plano_preference_config") {
debug!("removed plano_preference_config from metadata");
}
// === Input filters processing for model listener ===
// Filters receive the raw request bytes and return (possibly modified) raw bytes.
// The returned bytes are re-parsed into a ProviderRequestType to continue the request.
{
if let Some(ref input_chain) = filter_pipeline.input {
if !input_chain.is_empty() {
debug!(input_filters = ?input_chain.filter_ids, "processing model listener input filters");
let chain = input_chain.to_agent_filter_chain("model_listener");
let mut pipeline_processor = PipelineProcessor::default();
match pipeline_processor
.process_raw_filter_chain(
&chat_request_bytes,
&chain,
&input_chain.agents,
&request_headers,
&request_path,
)
.await
{
Ok(filtered_bytes) => {
match ProviderRequestType::try_from((
&filtered_bytes[..],
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(updated_request) => {
client_request = updated_request;
info!("input filter chain processed successfully");
}
Err(parse_err) => {
warn!(error = %parse_err, "input filter returned invalid request JSON");
return Ok(BrightStaffError::InvalidRequest(format!(
"Input filter returned invalid request: {}",
parse_err
))
.into_response());
}
}
}
Err(super::pipeline_processor::PipelineError::ClientError {
agent,
status,
body,
}) => {
warn!(
agent = %agent,
status = %status,
body = %body,
"client error from filter chain"
);
let error_json = serde_json::json!({
"error": "FilterChainError",
"agent": agent,
"status": status,
"agent_response": body
});
let mut error_response = Response::new(full(error_json.to_string()));
*error_response.status_mut() =
StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST);
error_response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
return Ok(error_response);
}
Err(err) => {
warn!(error = %err, "filter chain processing failed");
let err_msg = format!("Filter chain processing failed: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
}
}
}
}
if let Some(ref client_api_kind) = client_api {
let upstream_api =
provider_id.compatible_api_for_client(client_api_kind, is_streaming_request);
@ -420,6 +506,18 @@ async fn llm_chat_inner(
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
});
// Output filters run for any API shape that reaches this handler (e.g. /v1/chat/completions,
// /v1/messages, /v1/responses). Brightstaff does inbound translation and llm_gateway does
// outbound translation; filters receive raw response bytes and request path.
let has_output_filter = filter_pipeline.has_output_filters();
// Save request headers for output filters (before they're consumed by upstream request)
let output_filter_request_headers = if has_output_filter {
Some(request_headers.clone())
} else {
None
};
// Capture start time right before sending request to upstream
let request_start_time = std::time::Instant::now();
let _request_start_system_time = std::time::SystemTime::now();
@ -462,20 +560,18 @@ async fn llm_chat_inner(
);
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
let streaming_response = if let (true, false, Some(state_store)) = (
// Pick the right processor: state-aware if needed, otherwise base metrics-only.
let processor: Box<dyn StreamProcessor> = if let (true, false, Some(state_store)) = (
should_manage_state,
original_input_items.is_empty(),
state_storage,
) {
// Extract Content-Encoding header to handle decompression for state parsing
let content_encoding = response_headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
// Wrap with state management processor to store state after response completes
let state_processor = ResponsesStateProcessor::new(
Box::new(ResponsesStateProcessor::new(
base_processor,
state_store,
original_input_items,
@ -485,11 +581,23 @@ async fn llm_chat_inner(
false, // Not OpenAI upstream since should_manage_state is true
content_encoding,
request_id,
);
create_streaming_response(byte_stream, state_processor, 16)
))
} else {
// Use base processor without state management
create_streaming_response(byte_stream, base_processor, 16)
Box::new(base_processor)
};
// Apply output filters if configured, then build the streaming response.
let streaming_response = if has_output_filter {
let output_chain = filter_pipeline.output.as_ref().unwrap().clone();
create_streaming_response_with_output_filter(
byte_stream,
processor,
output_chain,
output_filter_request_headers.unwrap(),
request_path.clone(),
)
} else {
create_streaming_response(byte_stream, processor)
};
match response.body(streaming_response.body) {
@ -570,3 +678,9 @@ async fn get_provider_info(
(hermesllm::ProviderId::OpenAI, None)
}
}
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}

View file

@ -8,7 +8,7 @@ pub mod pipeline_processor;
pub mod response_handler;
pub mod router_chat;
pub mod routing_service;
pub mod utils;
pub mod streaming;
#[cfg(test)]
mod integration_tests;

View file

@ -1,5 +1,6 @@
use std::collections::HashMap;
use bytes::Bytes;
use common::configuration::{Agent, AgentFilterChain};
use common::consts::{
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
@ -35,8 +36,6 @@ pub enum PipelineError {
NoResultInResponse(String),
#[error("No structured content in response from agent '{0}'")]
NoStructuredContentInResponse(String),
#[error("No messages in response from agent '{0}'")]
NoMessagesInResponse(String),
#[error("Client error from agent '{agent}' (HTTP {status}): {body}")]
ClientError {
agent: String,
@ -79,68 +78,6 @@ impl PipelineProcessor {
}
}
// /// Process the filter chain of agents (all except the terminal agent)
// #[instrument(
// skip(self, chat_history, agent_filter_chain, agent_map, request_headers),
// fields(
// filter_count = agent_filter_chain.filter_chain.as_ref().map(|fc| fc.len()).unwrap_or(0),
// message_count = chat_history.len()
// )
// )]
#[allow(clippy::too_many_arguments)]
pub async fn process_filter_chain(
&mut self,
chat_history: &[Message],
agent_filter_chain: &AgentFilterChain,
agent_map: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
let mut chat_history_updated = chat_history.to_vec();
// If filter_chain is None or empty, proceed without filtering
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
Some(fc) if !fc.is_empty() => fc,
_ => return Ok(chat_history_updated),
};
for agent_name in filter_chain {
debug!(agent = %agent_name, "processing filter agent");
let agent = agent_map
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
info!(
agent = %agent_name,
tool = %tool_name,
url = %agent.url,
agent_type = %agent.agent_type.as_deref().unwrap_or("mcp"),
conversation_len = chat_history.len(),
"executing filter"
);
if agent.agent_type.as_deref().unwrap_or("mcp") == "mcp" {
chat_history_updated = self
.execute_mcp_filter(&chat_history_updated, agent, request_headers)
.await?;
} else {
chat_history_updated = self
.execute_http_filter(&chat_history_updated, agent, request_headers)
.await?;
}
info!(
agent = %agent_name,
updated_len = chat_history_updated.len(),
"filter completed"
);
}
Ok(chat_history_updated)
}
/// Build common MCP headers for requests
fn build_mcp_headers(
&self,
@ -261,14 +198,17 @@ impl PipelineProcessor {
Ok(response)
}
/// Build a tools/call JSON-RPC request
fn build_tool_call_request(
/// Build a tools/call JSON-RPC request with a full body dict and path hint.
/// Used by execute_mcp_filter_raw so MCP tools receive the same contract as HTTP filters.
fn build_tool_call_request_with_body(
&self,
tool_name: &str,
messages: &[Message],
body: &serde_json::Value,
path: &str,
) -> Result<JsonRpcRequest, PipelineError> {
let mut arguments = HashMap::new();
arguments.insert("messages".to_string(), serde_json::to_value(messages)?);
arguments.insert("body".to_string(), serde_json::to_value(body)?);
arguments.insert("path".to_string(), serde_json::to_value(path)?);
let mut params = HashMap::new();
params.insert("name".to_string(), serde_json::to_value(tool_name)?);
@ -282,31 +222,24 @@ impl PipelineProcessor {
})
}
/// Send request to a specific agent and return the response content
#[instrument(
skip(self, messages, agent, request_headers),
fields(
agent_id = %agent.id,
filter_name = %agent.id,
message_count = messages.len()
)
)]
async fn execute_mcp_filter(
/// Like execute_mcp_filter_raw but passes the full raw body dict + path hint as MCP tool arguments.
/// The MCP tool receives (body: dict, path: str) and returns the modified body dict.
async fn execute_mcp_filter_raw(
&mut self,
messages: &[Message],
raw_bytes: &[u8],
agent: &Agent,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
// Set service name for this filter span
request_path: &str,
) -> Result<Bytes, PipelineError> {
set_service_name(operation_component::AGENT_FILTER);
// Update current span name to include filter name
use opentelemetry::trace::get_active_span;
get_active_span(|span| {
span.update_name(format!("execute_mcp_filter ({})", agent.id));
span.update_name(format!("execute_mcp_filter_raw ({})", agent.id));
});
// Get or create MCP session
let body: serde_json::Value =
serde_json::from_slice(raw_bytes).map_err(PipelineError::ParseError)?;
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
session_id.clone()
} else {
@ -321,11 +254,10 @@ impl PipelineProcessor {
mcp_session_id, agent.id
);
// Build JSON-RPC request
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
let json_rpc_request = self.build_tool_call_request(tool_name, messages)?;
let json_rpc_request =
self.build_tool_call_request_with_body(tool_name, &body, request_path)?;
// Build headers
let agent_headers =
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id))?;
@ -335,7 +267,6 @@ impl PipelineProcessor {
let http_status = response.status();
let response_bytes = response.bytes().await?;
// Handle HTTP errors
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
@ -353,20 +284,12 @@ impl PipelineProcessor {
});
}
info!(
"Response from agent {}: {}",
agent.id,
String::from_utf8_lossy(&response_bytes)
);
// Parse SSE response
let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?;
let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?;
let response_result = response
.result
.ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?;
// Check if error field is set in response result
if response_result
.get("isError")
.and_then(|v| v.as_bool())
@ -388,21 +311,28 @@ impl PipelineProcessor {
});
}
// Extract structured content and parse messages
let response_json = response_result
// FastMCP puts structured Pydantic return values in structuredContent.result,
// but plain dicts land in content[0].text as a JSON string. Try both.
let result = if let Some(structured) = response_result
.get("structuredContent")
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
.and_then(|v| v.get("result"))
.cloned()
{
structured
} else {
let text = response_result
.get("content")
.and_then(|v| v.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.get("text"))
.and_then(|v| v.as_str())
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
serde_json::from_str(text).map_err(PipelineError::ParseError)?
};
let messages: Vec<Message> = response_json
.get("result")
.and_then(|v| v.as_array())
.ok_or_else(|| PipelineError::NoMessagesInResponse(agent.id.clone()))?
.iter()
.map(|msg_value| serde_json::from_value(msg_value.clone()))
.collect::<Result<Vec<Message>, _>>()
.map_err(PipelineError::ParseError)?;
Ok(messages)
Ok(Bytes::from(
serde_json::to_vec(&result).map_err(PipelineError::ParseError)?,
))
}
/// Build an initialize JSON-RPC request
@ -499,36 +429,34 @@ impl PipelineProcessor {
session_id
}
/// Execute a HTTP-based filter agent
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
/// Used for input and output filters where the full raw request/response is passed through.
/// No MCP protocol wrapping; agent_type is ignored.
#[instrument(
skip(self, messages, agent, request_headers),
skip(self, raw_bytes, agent, request_headers),
fields(
agent_id = %agent.id,
agent_url = %agent.url,
filter_name = %agent.id,
message_count = messages.len()
bytes_len = raw_bytes.len()
)
)]
async fn execute_http_filter(
async fn execute_raw_filter(
&mut self,
messages: &[Message],
raw_bytes: &[u8],
agent: &Agent,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
// Set service name for this filter span
request_path: &str,
) -> Result<Bytes, PipelineError> {
set_service_name(operation_component::AGENT_FILTER);
// Update current span name to include filter name
use opentelemetry::trace::get_active_span;
get_active_span(|span| {
span.update_name(format!("execute_http_filter ({})", agent.id));
span.update_name(format!("execute_raw_filter ({})", agent.id));
});
// Build headers
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
// Inject OpenTelemetry trace context automatically
agent_headers.remove(TRACE_PARENT_HEADER);
global::get_text_map_propagator(|propagator| {
let cx =
@ -541,40 +469,36 @@ impl PipelineProcessor {
hyper::header::HeaderValue::from_str(&agent.id)
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
agent_headers.insert(
"Accept",
hyper::header::HeaderValue::from_static("application/json"),
);
agent_headers.insert(
"Content-Type",
hyper::header::HeaderValue::from_static("application/json"),
);
debug!(
"Sending HTTP request to agent {} at URL: {}",
agent.id, agent.url
);
// Append the original request path so the filter endpoint encodes the API format.
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
// -> POST http://host/anonymize/v1/chat/completions
let url = format!("{}{}", agent.url, request_path);
debug!(agent = %agent.id, url = %url, "sending raw filter request");
// Send messages array directly as request body
let response = self
.client
.post(&agent.url)
.post(&url)
.headers(agent_headers)
.json(&messages)
.body(raw_bytes.to_vec())
.send()
.await?;
let http_status = response.status();
let response_bytes = response.bytes().await?;
// Handle HTTP errors
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
@ -592,17 +516,56 @@ impl PipelineProcessor {
});
}
debug!(
"Response from HTTP agent {}: {}",
agent.id,
String::from_utf8_lossy(&response_bytes)
);
debug!(agent = %agent.id, bytes_len = response_bytes.len(), "raw filter response received");
Ok(response_bytes)
}
// Parse response - expecting array of messages directly
let messages: Vec<Message> =
serde_json::from_slice(&response_bytes).map_err(PipelineError::ParseError)?;
/// Process a chain of raw-bytes filters sequentially.
/// Input: raw request or response bytes. Output: filtered bytes.
/// Each agent receives the output of the previous one.
pub async fn process_raw_filter_chain(
&mut self,
raw_bytes: &[u8],
agent_filter_chain: &AgentFilterChain,
agent_map: &HashMap<String, Agent>,
request_headers: &HeaderMap,
request_path: &str,
) -> Result<Bytes, PipelineError> {
let filter_chain = match agent_filter_chain.input_filters.as_ref() {
Some(fc) if !fc.is_empty() => fc,
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
};
Ok(messages)
let mut current_bytes = Bytes::copy_from_slice(raw_bytes);
for agent_name in filter_chain {
debug!(agent = %agent_name, "processing raw filter agent");
let agent = agent_map
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
let agent_type = agent.agent_type.as_deref().unwrap_or("mcp");
info!(
agent = %agent_name,
url = %agent.url,
agent_type = %agent_type,
bytes_len = current_bytes.len(),
"executing raw filter"
);
current_bytes = if agent_type == "mcp" {
self.execute_mcp_filter_raw(&current_bytes, agent, request_headers, request_path)
.await?
} else {
self.execute_raw_filter(&current_bytes, agent, request_headers, request_path)
.await?
};
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
}
Ok(current_bytes)
}
/// Send request to terminal agent and return the raw response for streaming
@ -661,24 +624,13 @@ impl PipelineProcessor {
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai::{Message, MessageContent, Role};
use mockito::Server;
use std::collections::HashMap;
fn create_test_message(role: Role, content: &str) -> Message {
Message {
role,
content: Some(MessageContent::Text(content.to_string())),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
AgentFilterChain {
id: "test-agent".to_string(),
filter_chain: Some(agents.iter().map(|s| s.to_string()).collect()),
input_filters: Some(agents.iter().map(|s| s.to_string()).collect()),
description: None,
default: None,
}
@ -690,12 +642,19 @@ mod tests {
let agent_map = HashMap::new();
let request_headers = HeaderMap::new();
let messages = vec![create_test_message(Role::User, "Hello")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
let result = processor
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers)
.process_raw_filter_chain(
&raw_bytes,
&pipeline,
&agent_map,
&request_headers,
"/v1/chat/completions",
)
.await;
assert!(result.is_err());
@ -725,11 +684,12 @@ mod tests {
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Hello")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let request_headers = HeaderMap::new();
let result = processor
.execute_mcp_filter(&messages, &agent, &request_headers)
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
.await;
match result {
@ -764,11 +724,12 @@ mod tests {
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Ping")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Ping"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let request_headers = HeaderMap::new();
let result = processor
.execute_mcp_filter(&messages, &agent, &request_headers)
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
.await;
match result {
@ -816,11 +777,12 @@ mod tests {
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Hi")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hi"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let request_headers = HeaderMap::new();
let result = processor
.execute_mcp_filter(&messages, &agent, &request_headers)
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
.await;
match result {

View file

@ -1,16 +1,21 @@
use bytes::Bytes;
use common::configuration::ResolvedFilterChain;
use http_body_util::combinators::BoxBody;
use http_body_util::StreamBody;
use hyper::body::Frame;
use hyper::header::HeaderMap;
use opentelemetry::trace::TraceContextExt;
use opentelemetry::KeyValue;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{info, warn, Instrument};
use tracing::{debug, info, warn, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use super::pipeline_processor::{PipelineError, PipelineProcessor};
const STREAM_BUFFER_SIZE: usize = 16;
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
use crate::tracing::{llm, set_service_name, signals as signal_constants};
use hermesllm::apis::openai::Message;
@ -31,6 +36,21 @@ pub trait StreamProcessor: Send + 'static {
fn on_error(&mut self, _error: &str) {}
}
impl StreamProcessor for Box<dyn StreamProcessor> {
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
(**self).process_chunk(chunk)
}
fn on_first_bytes(&mut self) {
(**self).on_first_bytes()
}
fn on_complete(&mut self) {
(**self).on_complete()
}
fn on_error(&mut self, error: &str) {
(**self).on_error(error)
}
}
/// A processor that tracks streaming metrics
pub struct ObservableStreamProcessor {
service_name: String,
@ -206,16 +226,12 @@ pub struct StreamingResponse {
pub processor_handle: tokio::task::JoinHandle<()>,
}
pub fn create_streaming_response<S, P>(
mut byte_stream: S,
mut processor: P,
buffer_size: usize,
) -> StreamingResponse
pub fn create_streaming_response<S, P>(mut byte_stream: S, mut processor: P) -> StreamingResponse
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size);
let (tx, rx) = mpsc::channel::<Bytes>(STREAM_BUFFER_SIZE);
// Capture the current span so the spawned task inherits the request context
let current_span = tracing::Span::current();
@ -277,6 +293,108 @@ where
}
}
/// Creates a streaming response that processes each raw chunk through output filters.
/// Filters receive the raw LLM response bytes and request path (any API shape; not limited to
/// chat completions). On filter error mid-stream the original chunk is passed through (headers already sent).
pub fn create_streaming_response_with_output_filter<S, P>(
mut byte_stream: S,
mut inner_processor: P,
output_chain: ResolvedFilterChain,
request_headers: HeaderMap,
request_path: String,
) -> StreamingResponse
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
let (tx, rx) = mpsc::channel::<Bytes>(STREAM_BUFFER_SIZE);
let current_span = tracing::Span::current();
let processor_handle = tokio::spawn(
async move {
let mut is_first_chunk = true;
let mut pipeline_processor = PipelineProcessor::default();
let chain = output_chain.to_agent_filter_chain("output_filter");
while let Some(item) = byte_stream.next().await {
let chunk = match item {
Ok(chunk) => chunk,
Err(err) => {
let err_msg = format!("Error receiving chunk: {:?}", err);
warn!(error = %err_msg, "stream error");
inner_processor.on_error(&err_msg);
break;
}
};
if is_first_chunk {
inner_processor.on_first_bytes();
is_first_chunk = false;
}
// Pass raw chunk bytes through the output filter chain
let processed_chunk = match pipeline_processor
.process_raw_filter_chain(
&chunk,
&chain,
&output_chain.agents,
&request_headers,
&request_path,
)
.await
{
Ok(filtered) => filtered,
Err(PipelineError::ClientError {
agent,
status,
body,
}) => {
warn!(
agent = %agent,
status = %status,
body = %body,
"output filter client error, passing through original chunk"
);
chunk
}
Err(e) => {
warn!(error = %e, "output filter error, passing through original chunk");
chunk
}
};
// Pass through inner processor for metrics/observability
match inner_processor.process_chunk(processed_chunk) {
Ok(Some(final_chunk)) => {
if tx.send(final_chunk).await.is_err() {
warn!("receiver dropped");
break;
}
}
Ok(None) => continue,
Err(err) => {
warn!("processor error: {}", err);
inner_processor.on_error(&err);
break;
}
}
}
inner_processor.on_complete();
debug!("output filter streaming completed");
}
.instrument(current_span),
);
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
StreamingResponse {
body: stream_body,
processor_handle,
}
}
/// Truncates a message to the specified maximum length, adding "..." if truncated.
pub fn truncate_message(message: &str, max_length: usize) -> String {
if message.chars().count() > max_length {

View file

@ -10,7 +10,9 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::{Agent, Configuration};
use common::configuration::{
Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain,
};
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
use common::llm_providers::LlmProviders;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
@ -22,6 +24,7 @@ use hyper_util::rt::TokioIo;
use opentelemetry::trace::FutureExt;
use opentelemetry::{global, Context};
use opentelemetry_http::HeaderExtractor;
use std::collections::HashMap;
use std::sync::Arc;
use std::{env, fs};
use tokio::net::TcpListener;
@ -80,11 +83,49 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.cloned()
.collect();
// Build global agent map for resolving filter chain references
let global_agent_map: HashMap<String, Agent> = all_agents
.iter()
.map(|a| (a.id.clone(), a.clone()))
.collect();
// Create expanded provider list for /v1/models endpoint
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
.expect("Failed to create LlmProviders");
let llm_providers = Arc::new(RwLock::new(llm_providers));
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
// Resolve model listener filter chain and agents at startup
let model_listener_count = plano_config
.listeners
.iter()
.filter(|l| l.listener_type == ListenerType::Model)
.count();
assert!(
model_listener_count <= 1,
"only one model listener is allowed, found {}",
model_listener_count
);
let model_listener = plano_config
.listeners
.iter()
.find(|l| l.listener_type == ListenerType::Model);
let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
filter_ids.map(|ids| {
let agents = ids
.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect();
ResolvedFilterChain {
filter_ids: ids,
agents,
}
})
};
let filter_pipeline = Arc::new(FilterPipeline {
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
});
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
@ -200,6 +241,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let agents_list = combined_agents_filters_list.clone();
let filter_pipeline = filter_pipeline.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
@ -211,6 +253,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let filter_pipeline = filter_pipeline.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
@ -269,6 +312,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
llm_providers,
span_attributes,
state_storage,
filter_pipeline,
)
.with_context(parent_cx)
.await

View file

@ -7,7 +7,7 @@ use std::io::Read;
use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::handlers::utils::StreamProcessor;
use crate::handlers::streaming::StreamProcessor;
use crate::state::{OpenAIConversationState, StateStorage};
/// Processor that wraps another processor and handles v1/responses state management

View file

@ -27,14 +27,66 @@ pub struct AgentFilterChain {
pub id: String,
pub default: Option<bool>,
pub description: Option<String>,
pub filter_chain: Option<Vec<String>>,
pub input_filters: Option<Vec<String>>,
}
/// A filter chain with its agent references resolved to concrete Agent objects.
/// Bundles the ordered filter IDs with the agent lookup map so they stay in sync.
#[derive(Debug, Clone, Default)]
pub struct ResolvedFilterChain {
pub filter_ids: Vec<String>,
pub agents: HashMap<String, Agent>,
}
impl ResolvedFilterChain {
pub fn is_empty(&self) -> bool {
self.filter_ids.is_empty()
}
pub fn to_agent_filter_chain(&self, id: &str) -> AgentFilterChain {
AgentFilterChain {
id: id.to_string(),
default: None,
description: None,
input_filters: Some(self.filter_ids.clone()),
}
}
}
/// Holds resolved input and output filter chains for a model listener.
#[derive(Debug, Clone, Default)]
pub struct FilterPipeline {
pub input: Option<ResolvedFilterChain>,
pub output: Option<ResolvedFilterChain>,
}
impl FilterPipeline {
pub fn has_input_filters(&self) -> bool {
self.input.as_ref().is_some_and(|c| !c.is_empty())
}
pub fn has_output_filters(&self) -> bool {
self.output.as_ref().is_some_and(|c| !c.is_empty())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ListenerType {
Model,
Agent,
Prompt,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Listener {
#[serde(rename = "type")]
pub listener_type: ListenerType,
pub name: String,
pub router: Option<String>,
pub agents: Option<Vec<AgentFilterChain>>,
pub input_filters: Option<Vec<String>>,
pub output_filters: Option<Vec<String>>,
pub port: u16,
}