Merge remote-tracking branch 'origin/main' into adil/refactor_brightstaff

Made-with: Cursor

# Conflicts:
#	crates/brightstaff/src/handlers/agents/orchestrator.rs
#	crates/brightstaff/src/handlers/agents/pipeline.rs
#	crates/brightstaff/src/handlers/llm.rs
#	crates/brightstaff/src/main.rs
This commit is contained in:
Adil Hafeez 2026-03-18 18:14:38 -07:00
commit 8ed4b36087
60 changed files with 2948 additions and 2618 deletions

View file

@ -1,7 +1,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use common::configuration::{Agent, Listener, ModelAlias, SpanAttributes};
use common::configuration::{Agent, FilterPipeline, Listener, ModelAlias, SpanAttributes};
use common::llm_providers::LlmProviders;
use tokio::sync::RwLock;
@ -25,4 +25,5 @@ pub struct AppState {
pub span_attributes: Arc<Option<SpanAttributes>>,
/// Shared HTTP client for upstream LLM requests (connection pooling / keep-alive).
pub http_client: reqwest::Client,
pub filter_pipeline: Arc<FilterPipeline>,
}

View file

@ -289,14 +289,36 @@ async fn execute_agent_chain(
"processing agent"
);
let chat_history = pipeline_processor
.process_filter_chain(
&current_messages,
selected_agent,
agent_map,
request_headers,
)
.await?;
let chat_history = if selected_agent
.input_filters
.as_ref()
.map(|f| !f.is_empty())
.unwrap_or(false)
{
let filter_body = serde_json::json!({
"model": client_request.model(),
"messages": current_messages,
});
let filter_bytes =
serde_json::to_vec(&filter_body).map_err(PipelineError::ParseError)?;
let filtered_bytes = pipeline_processor
.process_raw_filter_chain(
&filter_bytes,
selected_agent,
agent_map,
request_headers,
"/v1/chat/completions",
)
.await?;
let filtered_body: serde_json::Value =
serde_json::from_slice(&filtered_bytes).map_err(PipelineError::ParseError)?;
serde_json::from_value(filtered_body["messages"].clone())
.map_err(PipelineError::ParseError)?
} else {
current_messages.clone()
};
let agent = agent_map.get(&agent_name).ok_or_else(|| {
AgentFilterChainError::RequestParsing(serde_json::Error::custom(format!(

View file

@ -1,5 +1,6 @@
use std::collections::HashMap;
use bytes::Bytes;
use common::configuration::{Agent, AgentFilterChain};
use common::consts::{
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
@ -35,8 +36,6 @@ pub enum PipelineError {
NoResultInResponse(String),
#[error("No structured content in response from agent '{0}'")]
NoStructuredContentInResponse(String),
#[error("No messages in response from agent '{0}'")]
NoMessagesInResponse(String),
#[error("Client error from agent '{agent}' (HTTP {status}): {body}")]
ClientError {
agent: String,
@ -79,68 +78,6 @@ impl PipelineProcessor {
}
}
// /// Process the filter chain of agents (all except the terminal agent)
// #[instrument(
// skip(self, chat_history, agent_filter_chain, agent_map, request_headers),
// fields(
// filter_count = agent_filter_chain.filter_chain.as_ref().map(|fc| fc.len()).unwrap_or(0),
// message_count = chat_history.len()
// )
// )]
#[allow(clippy::too_many_arguments)]
pub async fn process_filter_chain(
&mut self,
chat_history: &[Message],
agent_filter_chain: &AgentFilterChain,
agent_map: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
let mut chat_history_updated = chat_history.to_vec();
// If filter_chain is None or empty, proceed without filtering
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
Some(fc) if !fc.is_empty() => fc,
_ => return Ok(chat_history_updated),
};
for agent_name in filter_chain {
debug!(agent = %agent_name, "processing filter agent");
let agent = agent_map
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
info!(
agent = %agent_name,
tool = %tool_name,
url = %agent.url,
agent_type = %agent.agent_type.as_deref().unwrap_or("mcp"),
conversation_len = chat_history.len(),
"executing filter"
);
if agent.agent_type.as_deref().unwrap_or("mcp") == "mcp" {
chat_history_updated = self
.execute_mcp_filter(&chat_history_updated, agent, request_headers)
.await?;
} else {
chat_history_updated = self
.execute_http_filter(&chat_history_updated, agent, request_headers)
.await?;
}
info!(
agent = %agent_name,
updated_len = chat_history_updated.len(),
"filter completed"
);
}
Ok(chat_history_updated)
}
/// Prepare headers shared by all agent/filter requests: removes
/// content-length, injects trace context, sets upstream host and retry.
fn build_agent_headers(
@ -270,14 +207,17 @@ impl PipelineProcessor {
Ok(response)
}
/// Build a tools/call JSON-RPC request
fn build_tool_call_request(
/// Build a tools/call JSON-RPC request with a full body dict and path hint.
/// Used by execute_mcp_filter_raw so MCP tools receive the same contract as HTTP filters.
fn build_tool_call_request_with_body(
&self,
tool_name: &str,
messages: &[Message],
body: &serde_json::Value,
path: &str,
) -> Result<JsonRpcRequest, PipelineError> {
let mut arguments = HashMap::new();
arguments.insert("messages".to_string(), serde_json::to_value(messages)?);
arguments.insert("body".to_string(), serde_json::to_value(body)?);
arguments.insert("path".to_string(), serde_json::to_value(path)?);
let mut params = HashMap::new();
params.insert("name".to_string(), serde_json::to_value(tool_name)?);
@ -291,31 +231,24 @@ impl PipelineProcessor {
})
}
/// Send request to a specific agent and return the response content
#[instrument(
skip(self, messages, agent, request_headers),
fields(
agent_id = %agent.id,
filter_name = %agent.id,
message_count = messages.len()
)
)]
async fn execute_mcp_filter(
/// Like execute_mcp_filter_raw but passes the full raw body dict + path hint as MCP tool arguments.
/// The MCP tool receives (body: dict, path: str) and returns the modified body dict.
async fn execute_mcp_filter_raw(
&mut self,
messages: &[Message],
raw_bytes: &[u8],
agent: &Agent,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
// Set service name for this filter span
request_path: &str,
) -> Result<Bytes, PipelineError> {
set_service_name(operation_component::AGENT_FILTER);
// Update current span name to include filter name
use opentelemetry::trace::get_active_span;
get_active_span(|span| {
span.update_name(format!("execute_mcp_filter ({})", agent.id));
span.update_name(format!("execute_mcp_filter_raw ({})", agent.id));
});
// Get or create MCP session
let body: serde_json::Value =
serde_json::from_slice(raw_bytes).map_err(PipelineError::ParseError)?;
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
session_id.clone()
} else {
@ -330,11 +263,10 @@ impl PipelineProcessor {
mcp_session_id, agent.id
);
// Build JSON-RPC request
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
let json_rpc_request = self.build_tool_call_request(tool_name, messages)?;
let json_rpc_request =
self.build_tool_call_request_with_body(tool_name, &body, request_path)?;
// Build headers
let agent_headers =
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id))?;
@ -344,7 +276,6 @@ impl PipelineProcessor {
let http_status = response.status();
let response_bytes = response.bytes().await?;
// Handle HTTP errors
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
@ -362,20 +293,12 @@ impl PipelineProcessor {
});
}
info!(
"Response from agent {}: {}",
agent.id,
String::from_utf8_lossy(&response_bytes)
);
// Parse SSE response
let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?;
let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?;
let response_result = response
.result
.ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?;
// Check if error field is set in response result
if response_result
.get("isError")
.and_then(|v| v.as_bool())
@ -397,21 +320,28 @@ impl PipelineProcessor {
});
}
// Extract structured content and parse messages
let response_json = response_result
// FastMCP puts structured Pydantic return values in structuredContent.result,
// but plain dicts land in content[0].text as a JSON string. Try both.
let result = if let Some(structured) = response_result
.get("structuredContent")
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
.and_then(|v| v.get("result"))
.cloned()
{
structured
} else {
let text = response_result
.get("content")
.and_then(|v| v.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.get("text"))
.and_then(|v| v.as_str())
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
serde_json::from_str(text).map_err(PipelineError::ParseError)?
};
let messages: Vec<Message> = response_json
.get("result")
.and_then(|v| v.as_array())
.ok_or_else(|| PipelineError::NoMessagesInResponse(agent.id.clone()))?
.iter()
.map(|msg_value| serde_json::from_value(msg_value.clone()))
.collect::<Result<Vec<Message>, _>>()
.map_err(PipelineError::ParseError)?;
Ok(messages)
Ok(Bytes::from(
serde_json::to_vec(&result).map_err(PipelineError::ParseError)?,
))
}
/// Build an initialize JSON-RPC request
@ -513,32 +443,31 @@ impl PipelineProcessor {
Ok(session_id)
}
/// Execute a HTTP-based filter agent
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
/// Used for input and output filters where the full raw request/response is passed through.
/// No MCP protocol wrapping; agent_type is ignored.
#[instrument(
skip(self, messages, agent, request_headers),
skip(self, raw_bytes, agent, request_headers),
fields(
agent_id = %agent.id,
agent_url = %agent.url,
filter_name = %agent.id,
message_count = messages.len()
bytes_len = raw_bytes.len()
)
)]
async fn execute_http_filter(
async fn execute_raw_filter(
&mut self,
messages: &[Message],
raw_bytes: &[u8],
agent: &Agent,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
// Set service name for this filter span
request_path: &str,
) -> Result<Bytes, PipelineError> {
set_service_name(operation_component::AGENT_FILTER);
// Update current span name to include filter name
use opentelemetry::trace::get_active_span;
get_active_span(|span| {
span.update_name(format!("execute_http_filter ({})", agent.id));
span.update_name(format!("execute_raw_filter ({})", agent.id));
});
// Build headers
let mut agent_headers = Self::build_agent_headers(request_headers, &agent.id)?;
agent_headers.insert(
"Accept",
@ -549,24 +478,23 @@ impl PipelineProcessor {
hyper::header::HeaderValue::from_static("application/json"),
);
debug!(
"Sending HTTP request to agent {} at URL: {}",
agent.id, agent.url
);
// Append the original request path so the filter endpoint encodes the API format.
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
// -> POST http://host/anonymize/v1/chat/completions
let url = format!("{}{}", agent.url, request_path);
debug!(agent = %agent.id, url = %url, "sending raw filter request");
// Send messages array directly as request body
let response = self
.client
.post(&agent.url)
.post(&url)
.headers(agent_headers)
.json(&messages)
.body(raw_bytes.to_vec())
.send()
.await?;
let http_status = response.status();
let response_bytes = response.bytes().await?;
// Handle HTTP errors
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
@ -584,17 +512,56 @@ impl PipelineProcessor {
});
}
debug!(
"Response from HTTP agent {}: {}",
agent.id,
String::from_utf8_lossy(&response_bytes)
);
debug!(agent = %agent.id, bytes_len = response_bytes.len(), "raw filter response received");
Ok(response_bytes)
}
// Parse response - expecting array of messages directly
let messages: Vec<Message> =
serde_json::from_slice(&response_bytes).map_err(PipelineError::ParseError)?;
/// Process a chain of raw-bytes filters sequentially.
/// Input: raw request or response bytes. Output: filtered bytes.
/// Each agent receives the output of the previous one.
pub async fn process_raw_filter_chain(
&mut self,
raw_bytes: &[u8],
agent_filter_chain: &AgentFilterChain,
agent_map: &HashMap<String, Agent>,
request_headers: &HeaderMap,
request_path: &str,
) -> Result<Bytes, PipelineError> {
let filter_chain = match agent_filter_chain.input_filters.as_ref() {
Some(fc) if !fc.is_empty() => fc,
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
};
Ok(messages)
let mut current_bytes = Bytes::copy_from_slice(raw_bytes);
for agent_name in filter_chain {
debug!(agent = %agent_name, "processing raw filter agent");
let agent = agent_map
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
let agent_type = agent.agent_type.as_deref().unwrap_or("mcp");
info!(
agent = %agent_name,
url = %agent.url,
agent_type = %agent_type,
bytes_len = current_bytes.len(),
"executing raw filter"
);
current_bytes = if agent_type == "mcp" {
self.execute_mcp_filter_raw(&current_bytes, agent, request_headers, request_path)
.await?
} else {
self.execute_raw_filter(&current_bytes, agent, request_headers, request_path)
.await?
};
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
}
Ok(current_bytes)
}
/// Send request to terminal agent and return the raw response for streaming
@ -633,24 +600,13 @@ impl PipelineProcessor {
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai::{Message, MessageContent, Role};
use mockito::Server;
use std::collections::HashMap;
fn create_test_message(role: Role, content: &str) -> Message {
Message {
role,
content: Some(MessageContent::Text(content.to_string())),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
AgentFilterChain {
id: "test-agent".to_string(),
filter_chain: Some(agents.iter().map(|s| s.to_string()).collect()),
input_filters: Some(agents.iter().map(|s| s.to_string()).collect()),
description: None,
default: None,
}
@ -662,12 +618,19 @@ mod tests {
let agent_map = HashMap::new();
let request_headers = HeaderMap::new();
let messages = vec![create_test_message(Role::User, "Hello")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
let result = processor
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers)
.process_raw_filter_chain(
&raw_bytes,
&pipeline,
&agent_map,
&request_headers,
"/v1/chat/completions",
)
.await;
assert!(result.is_err());
@ -697,11 +660,12 @@ mod tests {
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Hello")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let request_headers = HeaderMap::new();
let result = processor
.execute_mcp_filter(&messages, &agent, &request_headers)
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
.await;
match result {
@ -736,11 +700,12 @@ mod tests {
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Ping")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Ping"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let request_headers = HeaderMap::new();
let result = processor
.execute_mcp_filter(&messages, &agent, &request_headers)
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
.await;
match result {
@ -788,11 +753,12 @@ mod tests {
agent_type: None,
};
let messages = vec![create_test_message(Role::User, "Hi")];
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hi"}]});
let raw_bytes = serde_json::to_vec(&body).unwrap();
let request_headers = HeaderMap::new();
let result = processor
.execute_mcp_filter(&messages, &agent, &request_headers)
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
.await;
match result {

View file

@ -172,7 +172,7 @@ impl AgentSelector {
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{AgentFilterChain, Listener};
use common::configuration::{AgentFilterChain, Listener, ListenerType};
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
@ -187,14 +187,17 @@ mod tests {
id: name.to_string(),
description: Some(description.to_string()),
default: Some(is_default),
filter_chain: Some(vec![name.to_string()]),
input_filters: Some(vec![name.to_string()]),
}
}
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
Listener {
listener_type: ListenerType::Agent,
name: name.to_string(),
agents: Some(agents),
input_filters: None,
output_filters: None,
port: 8080,
router: None,
}

View file

@ -16,7 +16,7 @@ use crate::router::orchestrator::OrchestratorService;
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};
use common::configuration::{Agent, AgentFilterChain, Listener, ListenerType};
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
@ -63,7 +63,7 @@ mod tests {
let agent_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: Some(vec![
input_filters: Some(vec![
"filter-agent".to_string(),
"terminal-agent".to_string(),
]),
@ -72,8 +72,11 @@ mod tests {
};
let listener = Listener {
listener_type: ListenerType::Agent,
name: "test-listener".to_string(),
agents: Some(vec![agent_pipeline.clone()]),
input_filters: None,
output_filters: None,
port: 8080,
router: None,
};
@ -106,23 +109,32 @@ mod tests {
// Create a pipeline with empty filter chain to avoid network calls
let test_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: Some(vec![]), // Empty filter chain - no network calls needed
input_filters: Some(vec![]), // Empty filter chain - no network calls needed
description: None,
default: None,
};
let headers = HeaderMap::new();
let request_bytes = serde_json::to_vec(&request).expect("failed to serialize request");
let result = pipeline_processor
.process_filter_chain(&request.messages, &test_pipeline, &agent_map, &headers)
.process_raw_filter_chain(
&request_bytes,
&test_pipeline,
&agent_map,
&headers,
"/v1/chat/completions",
)
.await;
println!("Pipeline processing result: {:?}", result);
assert!(result.is_ok());
let processed_messages = result.unwrap();
// With empty filter chain, should return the original messages unchanged
assert_eq!(processed_messages.len(), 1);
if let Some(MessageContent::Text(content)) = &processed_messages[0].content {
let processed_bytes = result.unwrap();
// With empty filter chain, should return the original bytes unchanged
let processed_request: ChatCompletionsRequest =
serde_json::from_slice(&processed_bytes).expect("failed to deserialize response");
assert_eq!(processed_request.messages.len(), 1);
if let Some(MessageContent::Text(content)) = &processed_request.messages[0].content {
assert_eq!(content, "Hello world!");
} else {
panic!("Expected text content");

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::configuration::ModelAlias;
use common::configuration::{FilterPipeline, ModelAlias};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, TRACE_PARENT_HEADER};
use common::llm_providers::LlmProviders;
use hermesllm::apis::openai::Message;
@ -21,9 +21,11 @@ use tracing::{debug, info, info_span, warn, Instrument};
pub(crate) mod router;
use crate::app_state::AppState;
use crate::handlers::agents::pipeline::PipelineProcessor;
use crate::handlers::request::extract_request_id;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
use crate::handlers::streaming::{
create_streaming_response, create_streaming_response_with_output_filter, truncate_message,
ObservableStreamProcessor, StreamProcessor,
};
use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{
@ -111,6 +113,7 @@ async fn llm_chat_inner(
let PreparedRequest {
mut client_request,
chat_request_bytes,
model_from_request,
alias_resolved_model,
model_name_only,
@ -121,6 +124,8 @@ async fn llm_chat_inner(
tool_names,
user_message_preview,
inline_routing_policy,
client_api,
provider_id,
} = parsed;
// Record LLM-specific span attributes
@ -140,6 +145,81 @@ async fn llm_chat_inner(
span.record(tracing_llm::USER_MESSAGE_PREVIEW, preview.as_str());
}
// --- Phase 1b: Input filter processing for model listener ---
if let Some(ref input_chain) = state.filter_pipeline.input {
if !input_chain.is_empty() {
debug!(input_filters = ?input_chain.filter_ids, "processing model listener input filters");
let chain = input_chain.to_agent_filter_chain("model_listener");
let mut pipeline_processor = PipelineProcessor::default();
match pipeline_processor
.process_raw_filter_chain(
&chat_request_bytes,
&chain,
&input_chain.agents,
&request_headers,
&request_path,
)
.await
{
Ok(filtered_bytes) => {
let api_type =
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap();
match ProviderRequestType::try_from((&filtered_bytes[..], &api_type)) {
Ok(updated_request) => {
client_request = updated_request;
info!("input filter chain processed successfully");
}
Err(parse_err) => {
warn!(error = %parse_err, "input filter returned invalid request JSON");
return Ok(
common::errors::BrightStaffError::InvalidRequest(format!(
"Input filter returned invalid request: {}",
parse_err
))
.into_response(),
);
}
}
}
Err(super::agents::pipeline::PipelineError::ClientError {
agent,
status,
body,
}) => {
warn!(agent = %agent, status = %status, body = %body, "client error from filter chain");
let error_json = serde_json::json!({
"error": "FilterChainError",
"agent": agent,
"status": status,
"agent_response": body
});
let mut error_response = Response::new(full(error_json.to_string()));
*error_response.status_mut() =
StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST);
error_response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
return Ok(error_response);
}
Err(err) => {
warn!(error = %err, "filter chain processing failed");
let mut internal_error =
Response::new(full(format!("Filter chain processing failed: {}", err)));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
}
}
}
// Normalize for upstream after input filters
if let Some(ref client_api_kind) = client_api {
let upstream_api =
provider_id.compatible_api_for_client(client_api_kind, is_streaming_request);
client_request.normalize_for_upstream(provider_id, &upstream_api);
}
// --- Phase 2: Resolve conversation state (v1/responses API) ---
let state_ctx = match resolve_conversation_state(
&mut client_request,
@ -227,6 +307,7 @@ async fn llm_chat_inner(
state_ctx,
state.state_storage.clone(),
request_id,
&state.filter_pipeline,
)
.await
}
@ -238,6 +319,7 @@ async fn llm_chat_inner(
/// All pre-validated request data extracted from the raw HTTP request.
struct PreparedRequest {
client_request: ProviderRequestType,
chat_request_bytes: Bytes,
model_from_request: String,
alias_resolved_model: String,
model_name_only: String,
@ -248,6 +330,8 @@ struct PreparedRequest {
tool_names: Option<Vec<String>>,
user_message_preview: Option<String>,
inline_routing_policy: Option<Vec<common::configuration::ModelUsagePreference>>,
client_api: Option<SupportedAPIsFromClient>,
provider_id: hermesllm::ProviderId,
}
/// Parse the body, resolve the model alias, and validate the model exists.
@ -350,14 +434,10 @@ async fn parse_and_validate_request(
if client_request.remove_metadata_key("plano_preference_config") {
debug!("removed plano_preference_config from metadata");
}
if let Some(ref client_api_kind) = client_api {
let upstream_api =
provider_id.compatible_api_for_client(client_api_kind, is_streaming_request);
client_request.normalize_for_upstream(provider_id, &upstream_api);
}
Ok(PreparedRequest {
client_request,
chat_request_bytes,
model_from_request,
alias_resolved_model,
model_name_only,
@ -368,6 +448,8 @@ async fn parse_and_validate_request(
tool_names,
user_message_preview,
inline_routing_policy,
client_api,
provider_id,
})
}
@ -501,6 +583,7 @@ async fn send_upstream(
state_ctx: ConversationStateContext,
state_storage: Option<Arc<dyn StateStorage>>,
request_id: String,
filter_pipeline: &Arc<FilterPipeline>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let span_name = if model_from_request == resolved_model {
format!("POST {} {}", request_path, resolved_model)
@ -578,8 +661,15 @@ async fn send_upstream(
messages_for_signals,
);
// Wrap with state management processor when needed
let streaming_response = if let (true, false, Some(state_store)) = (
let has_output_filter = filter_pipeline.has_output_filters();
let output_filter_request_headers = if has_output_filter {
Some(request_headers.clone())
} else {
None
};
// Pick the right processor: state-aware if needed, otherwise base metrics-only.
let processor: Box<dyn StreamProcessor> = if let (true, false, Some(state_store)) = (
state_ctx.should_manage_state,
state_ctx.original_input_items.is_empty(),
state_storage,
@ -589,7 +679,7 @@ async fn send_upstream(
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let state_processor = ResponsesStateProcessor::new(
Box::new(ResponsesStateProcessor::new(
base_processor,
state_store,
state_ctx.original_input_items,
@ -599,10 +689,22 @@ async fn send_upstream(
false,
content_encoding,
request_id,
);
create_streaming_response(byte_stream, state_processor, 16)
))
} else {
create_streaming_response(byte_stream, base_processor, 16)
Box::new(base_processor)
};
let streaming_response = if has_output_filter {
let output_chain = filter_pipeline.output.as_ref().unwrap().clone();
create_streaming_response_with_output_filter(
byte_stream,
processor,
output_chain,
output_filter_request_headers.unwrap(),
request_path.to_string(),
)
} else {
create_streaming_response(byte_stream, processor)
};
match response.body(streaming_response.body) {

View file

@ -6,7 +6,7 @@ pub mod models;
pub mod request;
pub mod response;
pub mod routing_service;
pub mod utils;
pub mod streaming;
#[cfg(test)]
mod integration_tests;

View file

@ -1,16 +1,21 @@
use bytes::Bytes;
use common::configuration::ResolvedFilterChain;
use http_body_util::combinators::BoxBody;
use http_body_util::StreamBody;
use hyper::body::Frame;
use hyper::header::HeaderMap;
use opentelemetry::trace::TraceContextExt;
use opentelemetry::KeyValue;
use std::time::Instant;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{info, warn, Instrument};
use tracing::{debug, info, warn, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use super::agents::pipeline::{PipelineError, PipelineProcessor};
const STREAM_BUFFER_SIZE: usize = 16;
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
use crate::tracing::{llm, set_service_name, signals as signal_constants};
use hermesllm::apis::openai::Message;
@ -31,6 +36,21 @@ pub trait StreamProcessor: Send + 'static {
fn on_error(&mut self, _error: &str) {}
}
impl StreamProcessor for Box<dyn StreamProcessor> {
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
(**self).process_chunk(chunk)
}
fn on_first_bytes(&mut self) {
(**self).on_first_bytes()
}
fn on_complete(&mut self) {
(**self).on_complete()
}
fn on_error(&mut self, error: &str) {
(**self).on_error(error)
}
}
/// A processor that tracks streaming metrics
pub struct ObservableStreamProcessor {
service_name: String,
@ -206,16 +226,12 @@ pub struct StreamingResponse {
pub processor_handle: tokio::task::JoinHandle<()>,
}
pub fn create_streaming_response<S, P>(
mut byte_stream: S,
mut processor: P,
buffer_size: usize,
) -> StreamingResponse
pub fn create_streaming_response<S, P>(mut byte_stream: S, mut processor: P) -> StreamingResponse
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size);
let (tx, rx) = mpsc::channel::<Bytes>(STREAM_BUFFER_SIZE);
// Capture the current span so the spawned task inherits the request context
let current_span = tracing::Span::current();
@ -277,6 +293,108 @@ where
}
}
/// Creates a streaming response that processes each raw chunk through output filters.
/// Filters receive the raw LLM response bytes and request path (any API shape; not limited to
/// chat completions). On filter error mid-stream the original chunk is passed through (headers already sent).
pub fn create_streaming_response_with_output_filter<S, P>(
mut byte_stream: S,
mut inner_processor: P,
output_chain: ResolvedFilterChain,
request_headers: HeaderMap,
request_path: String,
) -> StreamingResponse
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
let (tx, rx) = mpsc::channel::<Bytes>(STREAM_BUFFER_SIZE);
let current_span = tracing::Span::current();
let processor_handle = tokio::spawn(
async move {
let mut is_first_chunk = true;
let mut pipeline_processor = PipelineProcessor::default();
let chain = output_chain.to_agent_filter_chain("output_filter");
while let Some(item) = byte_stream.next().await {
let chunk = match item {
Ok(chunk) => chunk,
Err(err) => {
let err_msg = format!("Error receiving chunk: {:?}", err);
warn!(error = %err_msg, "stream error");
inner_processor.on_error(&err_msg);
break;
}
};
if is_first_chunk {
inner_processor.on_first_bytes();
is_first_chunk = false;
}
// Pass raw chunk bytes through the output filter chain
let processed_chunk = match pipeline_processor
.process_raw_filter_chain(
&chunk,
&chain,
&output_chain.agents,
&request_headers,
&request_path,
)
.await
{
Ok(filtered) => filtered,
Err(PipelineError::ClientError {
agent,
status,
body,
}) => {
warn!(
agent = %agent,
status = %status,
body = %body,
"output filter client error, passing through original chunk"
);
chunk
}
Err(e) => {
warn!(error = %e, "output filter error, passing through original chunk");
chunk
}
};
// Pass through inner processor for metrics/observability
match inner_processor.process_chunk(processed_chunk) {
Ok(Some(final_chunk)) => {
if tx.send(final_chunk).await.is_err() {
warn!("receiver dropped");
break;
}
}
Ok(None) => continue,
Err(err) => {
warn!("processor error: {}", err);
inner_processor.on_error(&err);
break;
}
}
}
inner_processor.on_complete();
debug!("output filter streaming completed");
}
.instrument(current_span),
);
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
StreamingResponse {
body: stream_body,
processor_handle,
}
}
/// Truncates a message to the specified maximum length, adding "..." if truncated.
pub fn truncate_message(message: &str, max_length: usize) -> String {
if message.chars().count() > max_length {

View file

@ -11,7 +11,9 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::Configuration;
use common::configuration::{
Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain,
};
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
use common::llm_providers::LlmProviders;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
@ -24,6 +26,7 @@ use hyper_util::rt::TokioIo;
use opentelemetry::global;
use opentelemetry::trace::FutureExt;
use opentelemetry_http::HeaderExtractor;
use std::collections::HashMap;
use std::sync::Arc;
use std::{env, fs};
use tokio::net::TcpListener;
@ -99,7 +102,7 @@ async fn init_app_state(
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
// Combine agents and filters into a single list
let all_agents = config
let all_agents: Vec<Agent> = config
.agents
.as_deref()
.unwrap_or_default()
@ -108,9 +111,47 @@ async fn init_app_state(
.cloned()
.collect();
let global_agent_map: HashMap<String, Agent> = all_agents
.iter()
.map(|a| (a.id.clone(), a.clone()))
.collect();
let llm_providers = LlmProviders::try_from(config.model_providers.clone())
.map_err(|e| format!("failed to create LlmProviders: {e}"))?;
let model_listener_count = config
.listeners
.iter()
.filter(|l| l.listener_type == ListenerType::Model)
.count();
if model_listener_count > 1 {
return Err(format!(
"only one model listener is allowed, found {}",
model_listener_count
)
.into());
}
let model_listener = config
.listeners
.iter()
.find(|l| l.listener_type == ListenerType::Model);
let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
filter_ids.map(|ids| {
let agents = ids
.iter()
.filter_map(|id| global_agent_map.get(id).map(|a: &Agent| (id.clone(), a.clone())))
.collect();
ResolvedFilterChain {
filter_ids: ids,
agents,
}
})
};
let filter_pipeline = Arc::new(FilterPipeline {
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
});
let overrides = config.overrides.clone().unwrap_or_default();
let routing_model_name: String = overrides
@ -174,6 +215,7 @@ async fn init_app_state(
llm_provider_url,
span_attributes,
http_client: reqwest::Client::new(),
filter_pipeline,
})
}

View file

@ -7,7 +7,7 @@ use std::io::Read;
use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::handlers::utils::StreamProcessor;
use crate::handlers::streaming::StreamProcessor;
use crate::state::{OpenAIConversationState, StateStorage};
/// Processor that wraps another processor and handles v1/responses state management