mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
pass raw bytes through input/output filter chains
This commit is contained in:
parent
80dfb41cad
commit
d26abbfb9c
11 changed files with 610 additions and 448 deletions
|
|
@ -268,12 +268,13 @@ async fn llm_chat_inner(
|
|||
}
|
||||
|
||||
// === Input filters processing for model listener ===
|
||||
// Filters receive the raw request bytes and return (possibly modified) raw bytes.
|
||||
// The returned bytes are re-parsed into a ProviderRequestType to continue the request.
|
||||
{
|
||||
if let Some(ref fc) = *input_filters {
|
||||
if !fc.is_empty() {
|
||||
debug!(input_filters = ?fc, "processing model listener input filters");
|
||||
|
||||
// Create a temporary AgentFilterChain to reuse PipelineProcessor
|
||||
let temp_filter_chain = AgentFilterChain {
|
||||
id: "model_listener".to_string(),
|
||||
default: None,
|
||||
|
|
@ -282,23 +283,34 @@ async fn llm_chat_inner(
|
|||
};
|
||||
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let messages = client_request.get_messages();
|
||||
match pipeline_processor
|
||||
.process_filter_chain(
|
||||
&messages,
|
||||
.process_raw_filter_chain(
|
||||
&chat_request_bytes,
|
||||
&temp_filter_chain,
|
||||
&input_filter_agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_messages) => {
|
||||
client_request.set_messages(&filtered_messages);
|
||||
info!(
|
||||
original_count = messages.len(),
|
||||
filtered_count = filtered_messages.len(),
|
||||
"filter chain processed successfully"
|
||||
);
|
||||
Ok(filtered_bytes) => {
|
||||
match ProviderRequestType::try_from((
|
||||
&filtered_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(updated_request) => {
|
||||
client_request = updated_request;
|
||||
info!("input filter chain processed successfully");
|
||||
}
|
||||
Err(parse_err) => {
|
||||
warn!(error = %parse_err, "input filter returned invalid request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Input filter returned invalid request: {}",
|
||||
parse_err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(super::pipeline_processor::PipelineError::ClientError {
|
||||
agent,
|
||||
|
|
@ -508,21 +520,25 @@ async fn llm_chat_inner(
|
|||
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
|
||||
});
|
||||
|
||||
// Output filters are only supported for /v1/chat/completions — the SSE content
|
||||
// extraction logic is specific to that API shape (choices[].delta.content).
|
||||
let output_filters_configured = output_filters
|
||||
.as_ref()
|
||||
.as_ref()
|
||||
.map(|fc| !fc.is_empty())
|
||||
.unwrap_or(false);
|
||||
let has_output_filter = output_filters_configured
|
||||
&& request_path == common::consts::CHAT_COMPLETIONS_PATH;
|
||||
if output_filters_configured && !has_output_filter {
|
||||
warn!(
|
||||
path = %request_path,
|
||||
"output filters are configured but only supported for /v1/chat/completions, skipping"
|
||||
);
|
||||
}
|
||||
let has_output_filter = output_filters_configured;
|
||||
|
||||
// Extract the upstream API path (e.g. "/v1/messages" from "https://api.anthropic.com/v1/messages").
|
||||
// Output filters are called at <agent.url><upstream_api_path> so they know the exact byte format.
|
||||
let upstream_api_path = {
|
||||
let after_scheme = full_qualified_llm_provider_url
|
||||
.find("://")
|
||||
.map(|i| &full_qualified_llm_provider_url[i + 3..])
|
||||
.unwrap_or(&full_qualified_llm_provider_url);
|
||||
after_scheme
|
||||
.find('/')
|
||||
.map(|i| after_scheme[i..].to_string())
|
||||
.unwrap_or_else(|| "/".to_string())
|
||||
};
|
||||
|
||||
// Save request headers for output filters (before they're consumed by upstream request)
|
||||
let output_filter_request_headers = if has_output_filter {
|
||||
|
|
@ -607,6 +623,7 @@ async fn llm_chat_inner(
|
|||
ofc,
|
||||
ofa,
|
||||
output_filter_request_headers.unwrap(),
|
||||
upstream_api_path.clone(),
|
||||
)
|
||||
} else {
|
||||
create_streaming_response(byte_stream, state_processor, 16)
|
||||
|
|
@ -621,6 +638,7 @@ async fn llm_chat_inner(
|
|||
ofc,
|
||||
ofa,
|
||||
output_filter_request_headers.unwrap(),
|
||||
upstream_api_path,
|
||||
)
|
||||
} else {
|
||||
// Use base processor without state management
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, AgentFilterChain};
|
||||
use common::consts::{
|
||||
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
|
||||
|
|
@ -605,6 +606,139 @@ impl PipelineProcessor {
|
|||
Ok(messages)
|
||||
}
|
||||
|
||||
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
|
||||
/// Used for input and output filters where the full raw request/response is passed through.
|
||||
/// No MCP protocol wrapping; agent_type is ignored.
|
||||
#[instrument(
|
||||
skip(self, raw_bytes, agent, request_headers),
|
||||
fields(
|
||||
agent_id = %agent.id,
|
||||
agent_url = %agent.url,
|
||||
filter_name = %agent.id,
|
||||
bytes_len = raw_bytes.len()
|
||||
)
|
||||
)]
|
||||
async fn execute_raw_filter(
|
||||
&mut self,
|
||||
raw_bytes: &[u8],
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
request_path: &str,
|
||||
) -> Result<Bytes, PipelineError> {
|
||||
set_service_name(operation_component::AGENT_FILTER);
|
||||
use opentelemetry::trace::get_active_span;
|
||||
get_active_span(|span| {
|
||||
span.update_name(format!("execute_raw_filter ({})", agent.id));
|
||||
});
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
|
||||
});
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
);
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
agent_headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
agent_headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
// Append the original request path so the filter endpoint encodes the API format.
|
||||
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
|
||||
// -> POST http://host/anonymize/v1/chat/completions
|
||||
let url = format!("{}{}", agent.url, request_path);
|
||||
debug!(agent = %agent.id, url = %url, "sending raw filter request");
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.headers(agent_headers)
|
||||
.body(raw_bytes.to_vec())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
if !http_status.is_success() {
|
||||
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||
return Err(if http_status.is_client_error() {
|
||||
PipelineError::ClientError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
} else {
|
||||
PipelineError::ServerError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
debug!(agent = %agent.id, bytes_len = response_bytes.len(), "raw filter response received");
|
||||
Ok(response_bytes)
|
||||
}
|
||||
|
||||
/// Process a chain of raw-bytes filters sequentially.
|
||||
/// Input: raw request or response bytes. Output: filtered bytes.
|
||||
/// Each agent receives the output of the previous one.
|
||||
pub async fn process_raw_filter_chain(
|
||||
&mut self,
|
||||
raw_bytes: &[u8],
|
||||
agent_filter_chain: &AgentFilterChain,
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
request_path: &str,
|
||||
) -> Result<Bytes, PipelineError> {
|
||||
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
|
||||
Some(fc) if !fc.is_empty() => fc,
|
||||
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
|
||||
};
|
||||
|
||||
let mut current_bytes = Bytes::copy_from_slice(raw_bytes);
|
||||
|
||||
for agent_name in filter_chain {
|
||||
debug!(agent = %agent_name, "processing raw filter agent");
|
||||
|
||||
let agent = agent_map
|
||||
.get(agent_name)
|
||||
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
info!(
|
||||
agent = %agent_name,
|
||||
url = %agent.url,
|
||||
bytes_len = current_bytes.len(),
|
||||
"executing raw filter"
|
||||
);
|
||||
|
||||
current_bytes = self
|
||||
.execute_raw_filter(¤t_bytes, agent, request_headers, request_path)
|
||||
.await?;
|
||||
|
||||
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
|
||||
}
|
||||
|
||||
Ok(current_bytes)
|
||||
}
|
||||
|
||||
/// Send request to terminal agent and return the raw response for streaming
|
||||
/// Note: The caller is responsible for creating the plano(agent) span that wraps
|
||||
/// both this call and the subsequent response consumption.
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@ use tokio_stream::StreamExt;
|
|||
use tracing::{debug, info, warn, Instrument};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
use super::pipeline_processor::PipelineProcessor;
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
||||
use crate::tracing::{llm, set_service_name, signals as signal_constants};
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use hermesllm::apis::openai::Message;
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
|
|
@ -281,184 +281,9 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Extract content text from an SSE chunk line (the JSON part after "data: ").
|
||||
/// Returns the content string and whether it was found.
|
||||
fn extract_sse_content(json_str: &str) -> Option<String> {
|
||||
let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
|
||||
value
|
||||
.get("choices")?
|
||||
.get(0)?
|
||||
.get("delta")?
|
||||
.get("content")?
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
/// Replace content in an SSE JSON chunk with new content.
|
||||
fn replace_sse_content(json_str: &str, new_content: &str) -> Option<String> {
|
||||
let mut value: serde_json::Value = serde_json::from_str(json_str).ok()?;
|
||||
value
|
||||
.get_mut("choices")?
|
||||
.get_mut(0)?
|
||||
.get_mut("delta")?
|
||||
.as_object_mut()?
|
||||
.insert(
|
||||
"content".to_string(),
|
||||
serde_json::Value::String(new_content.to_string()),
|
||||
);
|
||||
serde_json::to_string(&value).ok()
|
||||
}
|
||||
|
||||
/// Process an SSE chunk through the output filter chain.
|
||||
/// Parses each `data: ` line, extracts content, sends through filters, and reconstructs.
|
||||
async fn filter_sse_chunk(
|
||||
chunk_str: &str,
|
||||
pipeline_processor: &mut PipelineProcessor,
|
||||
filter_chain: &AgentFilterChain,
|
||||
filter_agents: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> String {
|
||||
let mut result = String::new();
|
||||
for line in chunk_str.split('\n') {
|
||||
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||
if json_str.trim() == "[DONE]" {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
if let Some(content) = extract_sse_content(json_str) {
|
||||
if content.is_empty() {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
// Send content through output filter chain
|
||||
let messages = vec![Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
match pipeline_processor
|
||||
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_messages) => {
|
||||
if let Some(msg) = filtered_messages.first() {
|
||||
let filtered_content = match &msg.content {
|
||||
Some(MessageContent::Text(t)) => Some(t.clone()),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(filtered_content) = filtered_content {
|
||||
if let Some(new_json) =
|
||||
replace_sse_content(json_str, &filtered_content)
|
||||
{
|
||||
result.push_str("data: ");
|
||||
result.push_str(&new_json);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: pass through original
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter chain error, passing through original chunk");
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No content in this SSE line, pass through
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
} else {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
// Remove trailing extra newline if the original didn't end with one
|
||||
if !chunk_str.ends_with('\n') && result.ends_with('\n') {
|
||||
result.pop();
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Process a non-streaming JSON response through the output filter chain.
|
||||
/// Extracts assistant message content, filters it, and reconstructs the response.
|
||||
pub async fn filter_non_streaming_response(
|
||||
response_bytes: &[u8],
|
||||
pipeline_processor: &mut PipelineProcessor,
|
||||
filter_chain: &AgentFilterChain,
|
||||
filter_agents: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Bytes {
|
||||
let response_str = match std::str::from_utf8(response_bytes) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Bytes::from(response_bytes.to_vec()),
|
||||
};
|
||||
|
||||
let mut value: serde_json::Value = match serde_json::from_str(response_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return Bytes::from(response_bytes.to_vec()),
|
||||
};
|
||||
|
||||
// Extract content from choices[0].message.content
|
||||
let content = value
|
||||
.get("choices")
|
||||
.and_then(|c| c.get(0))
|
||||
.and_then(|c| c.get("message"))
|
||||
.and_then(|m| m.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
if let Some(content) = content {
|
||||
let messages = vec![Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
match pipeline_processor
|
||||
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_messages) => {
|
||||
if let Some(msg) = filtered_messages.first() {
|
||||
let filtered_content = match &msg.content {
|
||||
Some(MessageContent::Text(t)) => Some(t.clone()),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(filtered_content) = filtered_content {
|
||||
if let Some(choices) = value.get_mut("choices") {
|
||||
if let Some(choice) = choices.get_mut(0) {
|
||||
if let Some(message) = choice.get_mut("message") {
|
||||
message.as_object_mut().unwrap().insert(
|
||||
"content".to_string(),
|
||||
serde_json::Value::String(filtered_content),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter chain error on non-streaming response");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Bytes::from(serde_json::to_string(&value).unwrap_or_else(|_| response_str.to_string()))
|
||||
}
|
||||
|
||||
/// Creates a streaming response that processes each chunk through output filters.
|
||||
/// The output filter is called asynchronously for each SSE chunk's content.
|
||||
/// Creates a streaming response that processes each raw chunk through output filters.
|
||||
/// Filters receive the raw LLM response bytes and return (possibly modified) bytes.
|
||||
/// On filter error mid-stream the original chunk is passed through (headers already sent).
|
||||
pub fn create_streaming_response_with_output_filter<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut inner_processor: P,
|
||||
|
|
@ -466,6 +291,7 @@ pub fn create_streaming_response_with_output_filter<S, P>(
|
|||
output_filters: Vec<String>,
|
||||
output_filter_agents: HashMap<String, Agent>,
|
||||
request_headers: HeaderMap,
|
||||
upstream_path: String,
|
||||
) -> StreamingResponse
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
|
|
@ -501,32 +327,35 @@ where
|
|||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
// Try to process through output filter chain
|
||||
let processed_chunk = if let Ok(chunk_str) = std::str::from_utf8(&chunk) {
|
||||
if chunk_str.contains("data: ") {
|
||||
let filtered = filter_sse_chunk(
|
||||
chunk_str,
|
||||
&mut pipeline_processor,
|
||||
&temp_filter_chain,
|
||||
&output_filter_agents,
|
||||
&request_headers,
|
||||
)
|
||||
.await;
|
||||
Bytes::from(filtered)
|
||||
} else {
|
||||
// Non-SSE chunk (could be non-streaming JSON response)
|
||||
let filtered = filter_non_streaming_response(
|
||||
&chunk,
|
||||
&mut pipeline_processor,
|
||||
&temp_filter_chain,
|
||||
&output_filter_agents,
|
||||
&request_headers,
|
||||
)
|
||||
.await;
|
||||
filtered
|
||||
// Pass raw chunk bytes through the output filter chain
|
||||
let processed_chunk = match pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&chunk,
|
||||
&temp_filter_chain,
|
||||
&output_filter_agents,
|
||||
&request_headers,
|
||||
&upstream_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered) => filtered,
|
||||
Err(PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) => {
|
||||
warn!(
|
||||
agent = %agent,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"output filter client error, passing through original chunk"
|
||||
);
|
||||
chunk
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter error, passing through original chunk");
|
||||
chunk
|
||||
}
|
||||
} else {
|
||||
chunk
|
||||
};
|
||||
|
||||
// Pass through inner processor for metrics/observability
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue