mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
Add support for output filter chain on listener
- add demo for pii redaction
This commit is contained in:
parent
8ae4901735
commit
73f15feea3
16 changed files with 1218 additions and 4 deletions
|
|
@ -196,6 +196,7 @@ mod tests {
|
|||
name: name.to_string(),
|
||||
agents: Some(agents),
|
||||
filter_chain: None,
|
||||
output_filter_chain: None,
|
||||
port: 8080,
|
||||
router: None,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ mod tests {
|
|||
name: "test-listener".to_string(),
|
||||
agents: Some(vec![agent_pipeline.clone()]),
|
||||
filter_chain: None,
|
||||
output_filter_chain: None,
|
||||
port: 8080,
|
||||
router: None,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -23,7 +23,8 @@ use super::pipeline_processor::PipelineProcessor;
|
|||
|
||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||
use crate::handlers::utils::{
|
||||
create_streaming_response, truncate_message, ObservableStreamProcessor,
|
||||
create_streaming_response, create_streaming_response_with_output_filter, truncate_message,
|
||||
ObservableStreamProcessor,
|
||||
};
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
|
|
@ -47,6 +48,8 @@ pub async fn llm_chat(
|
|||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
filter_chain: Arc<Option<Vec<String>>>,
|
||||
filter_agents: Arc<HashMap<String, Agent>>,
|
||||
output_filter_chain: Arc<Option<Vec<String>>>,
|
||||
output_filter_agents: Arc<HashMap<String, Agent>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
|
|
@ -88,6 +91,8 @@ pub async fn llm_chat(
|
|||
request_headers,
|
||||
filter_chain,
|
||||
filter_agents,
|
||||
output_filter_chain,
|
||||
output_filter_agents,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
|
|
@ -107,6 +112,8 @@ async fn llm_chat_inner(
|
|||
mut request_headers: hyper::HeaderMap,
|
||||
filter_chain: Arc<Option<Vec<String>>>,
|
||||
filter_agents: Arc<HashMap<String, Agent>>,
|
||||
output_filter_chain: Arc<Option<Vec<String>>>,
|
||||
output_filter_agents: Arc<HashMap<String, Agent>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// Set service name for LLM operations
|
||||
set_service_name(operation_component::LLM);
|
||||
|
|
@ -501,6 +508,20 @@ async fn llm_chat_inner(
|
|||
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
|
||||
});
|
||||
|
||||
// Determine if output filter chain is configured
|
||||
let has_output_filter = output_filter_chain
|
||||
.as_ref()
|
||||
.as_ref()
|
||||
.map(|fc| !fc.is_empty())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Save request headers for output filter chain (before they're consumed by upstream request)
|
||||
let output_filter_request_headers = if has_output_filter {
|
||||
Some(request_headers.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Capture start time right before sending request to upstream
|
||||
let request_start_time = std::time::Instant::now();
|
||||
let _request_start_system_time = std::time::SystemTime::now();
|
||||
|
|
@ -567,7 +588,31 @@ async fn llm_chat_inner(
|
|||
content_encoding,
|
||||
request_id,
|
||||
);
|
||||
create_streaming_response(byte_stream, state_processor, 16)
|
||||
if has_output_filter {
|
||||
let ofc = output_filter_chain.as_ref().as_ref().unwrap().clone();
|
||||
let ofa = (*output_filter_agents).clone();
|
||||
create_streaming_response_with_output_filter(
|
||||
byte_stream,
|
||||
state_processor,
|
||||
16,
|
||||
ofc,
|
||||
ofa,
|
||||
output_filter_request_headers.unwrap(),
|
||||
)
|
||||
} else {
|
||||
create_streaming_response(byte_stream, state_processor, 16)
|
||||
}
|
||||
} else if has_output_filter {
|
||||
let ofc = output_filter_chain.as_ref().as_ref().unwrap().clone();
|
||||
let ofa = (*output_filter_agents).clone();
|
||||
create_streaming_response_with_output_filter(
|
||||
byte_stream,
|
||||
base_processor,
|
||||
16,
|
||||
ofc,
|
||||
ofa,
|
||||
output_filter_request_headers.unwrap(),
|
||||
)
|
||||
} else {
|
||||
// Use base processor without state management
|
||||
create_streaming_response(byte_stream, base_processor, 16)
|
||||
|
|
|
|||
|
|
@ -1,19 +1,23 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, AgentFilterChain};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::StreamBody;
|
||||
use hyper::body::Frame;
|
||||
use hyper::header::HeaderMap;
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use opentelemetry::KeyValue;
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{info, warn, Instrument};
|
||||
use tracing::{debug, info, warn, Instrument};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
use super::pipeline_processor::PipelineProcessor;
|
||||
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
||||
use crate::tracing::{llm, set_service_name, signals as signal_constants};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
|
|
@ -277,6 +281,286 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Extract content text from an SSE chunk line (the JSON part after "data: ").
|
||||
/// Returns the content string and whether it was found.
|
||||
fn extract_sse_content(json_str: &str) -> Option<String> {
|
||||
let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
|
||||
value
|
||||
.get("choices")?
|
||||
.get(0)?
|
||||
.get("delta")?
|
||||
.get("content")?
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
/// Replace content in an SSE JSON chunk with new content.
|
||||
fn replace_sse_content(json_str: &str, new_content: &str) -> Option<String> {
|
||||
let mut value: serde_json::Value = serde_json::from_str(json_str).ok()?;
|
||||
value
|
||||
.get_mut("choices")?
|
||||
.get_mut(0)?
|
||||
.get_mut("delta")?
|
||||
.as_object_mut()?
|
||||
.insert(
|
||||
"content".to_string(),
|
||||
serde_json::Value::String(new_content.to_string()),
|
||||
);
|
||||
serde_json::to_string(&value).ok()
|
||||
}
|
||||
|
||||
/// Process an SSE chunk through the output filter chain.
|
||||
/// Parses each `data: ` line, extracts content, sends through filters, and reconstructs.
|
||||
async fn filter_sse_chunk(
|
||||
chunk_str: &str,
|
||||
pipeline_processor: &mut PipelineProcessor,
|
||||
filter_chain: &AgentFilterChain,
|
||||
filter_agents: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> String {
|
||||
let mut result = String::new();
|
||||
for line in chunk_str.split('\n') {
|
||||
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||
if json_str.trim() == "[DONE]" {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
if let Some(content) = extract_sse_content(json_str) {
|
||||
if content.is_empty() {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
// Send content through output filter chain
|
||||
let messages = vec![Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
match pipeline_processor
|
||||
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_messages) => {
|
||||
if let Some(msg) = filtered_messages.first() {
|
||||
let filtered_content = match &msg.content {
|
||||
Some(MessageContent::Text(t)) => Some(t.clone()),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(filtered_content) = filtered_content {
|
||||
if let Some(new_json) =
|
||||
replace_sse_content(json_str, &filtered_content)
|
||||
{
|
||||
result.push_str("data: ");
|
||||
result.push_str(&new_json);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: pass through original
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter chain error, passing through original chunk");
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No content in this SSE line, pass through
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
} else {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
// Remove trailing extra newline if the original didn't end with one
|
||||
if !chunk_str.ends_with('\n') && result.ends_with('\n') {
|
||||
result.pop();
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Process a non-streaming JSON response through the output filter chain.
|
||||
/// Extracts assistant message content, filters it, and reconstructs the response.
|
||||
pub async fn filter_non_streaming_response(
|
||||
response_bytes: &[u8],
|
||||
pipeline_processor: &mut PipelineProcessor,
|
||||
filter_chain: &AgentFilterChain,
|
||||
filter_agents: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Bytes {
|
||||
let response_str = match std::str::from_utf8(response_bytes) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Bytes::from(response_bytes.to_vec()),
|
||||
};
|
||||
|
||||
let mut value: serde_json::Value = match serde_json::from_str(response_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return Bytes::from(response_bytes.to_vec()),
|
||||
};
|
||||
|
||||
// Extract content from choices[0].message.content
|
||||
let content = value
|
||||
.get("choices")
|
||||
.and_then(|c| c.get(0))
|
||||
.and_then(|c| c.get("message"))
|
||||
.and_then(|m| m.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
if let Some(content) = content {
|
||||
let messages = vec![Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
match pipeline_processor
|
||||
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_messages) => {
|
||||
if let Some(msg) = filtered_messages.first() {
|
||||
let filtered_content = match &msg.content {
|
||||
Some(MessageContent::Text(t)) => Some(t.clone()),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(filtered_content) = filtered_content {
|
||||
if let Some(choices) = value.get_mut("choices") {
|
||||
if let Some(choice) = choices.get_mut(0) {
|
||||
if let Some(message) = choice.get_mut("message") {
|
||||
message.as_object_mut().unwrap().insert(
|
||||
"content".to_string(),
|
||||
serde_json::Value::String(filtered_content),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter chain error on non-streaming response");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Bytes::from(serde_json::to_string(&value).unwrap_or_else(|_| response_str.to_string()))
|
||||
}
|
||||
|
||||
/// Creates a streaming response that processes each chunk through an output filter chain.
|
||||
/// The output filter is called asynchronously for each SSE chunk's content.
|
||||
pub fn create_streaming_response_with_output_filter<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut inner_processor: P,
|
||||
buffer_size: usize,
|
||||
output_filter_chain: Vec<String>,
|
||||
output_filter_agents: HashMap<String, Agent>,
|
||||
request_headers: HeaderMap,
|
||||
) -> StreamingResponse
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size);
|
||||
let current_span = tracing::Span::current();
|
||||
|
||||
let processor_handle = tokio::spawn(
|
||||
async move {
|
||||
let mut is_first_chunk = true;
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let temp_filter_chain = AgentFilterChain {
|
||||
id: "output_filter".to_string(),
|
||||
default: None,
|
||||
description: None,
|
||||
filter_chain: Some(output_filter_chain),
|
||||
};
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let chunk = match item {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Error receiving chunk: {:?}", err);
|
||||
warn!(error = %err_msg, "stream error");
|
||||
inner_processor.on_error(&err_msg);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if is_first_chunk {
|
||||
inner_processor.on_first_bytes();
|
||||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
// Try to process through output filter chain
|
||||
let processed_chunk = if let Ok(chunk_str) = std::str::from_utf8(&chunk) {
|
||||
if chunk_str.contains("data: ") {
|
||||
let filtered = filter_sse_chunk(
|
||||
chunk_str,
|
||||
&mut pipeline_processor,
|
||||
&temp_filter_chain,
|
||||
&output_filter_agents,
|
||||
&request_headers,
|
||||
)
|
||||
.await;
|
||||
Bytes::from(filtered)
|
||||
} else {
|
||||
// Non-SSE chunk (could be non-streaming JSON response)
|
||||
let filtered = filter_non_streaming_response(
|
||||
&chunk,
|
||||
&mut pipeline_processor,
|
||||
&temp_filter_chain,
|
||||
&output_filter_agents,
|
||||
&request_headers,
|
||||
)
|
||||
.await;
|
||||
filtered
|
||||
}
|
||||
} else {
|
||||
chunk
|
||||
};
|
||||
|
||||
// Pass through inner processor for metrics/observability
|
||||
match inner_processor.process_chunk(processed_chunk) {
|
||||
Ok(Some(final_chunk)) => {
|
||||
if tx.send(final_chunk).await.is_err() {
|
||||
warn!("receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => continue,
|
||||
Err(err) => {
|
||||
warn!("processor error: {}", err);
|
||||
inner_processor.on_error(&err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inner_processor.on_complete();
|
||||
debug!("output filter streaming completed");
|
||||
}
|
||||
.instrument(current_span),
|
||||
);
|
||||
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
StreamingResponse {
|
||||
body: stream_body,
|
||||
processor_handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncates a message to the specified maximum length, adding "..." if truncated.
|
||||
pub fn truncate_message(message: &str, max_length: usize) -> String {
|
||||
if message.chars().count() > max_length {
|
||||
|
|
|
|||
|
|
@ -121,6 +121,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
})
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
let model_output_filter_chain: Arc<Option<Vec<String>>> =
|
||||
Arc::new(model_listener.and_then(|l| l.output_filter_chain.clone()));
|
||||
let model_output_filter_agents: Arc<HashMap<String, Agent>> = Arc::new(
|
||||
model_output_filter_chain
|
||||
.as_ref()
|
||||
.as_ref()
|
||||
.map(|fc| {
|
||||
fc.iter()
|
||||
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
|
@ -217,6 +230,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let agents_list = combined_agents_filters_list.clone();
|
||||
let model_filter_chain = model_filter_chain.clone();
|
||||
let model_filter_agents = model_filter_agents.clone();
|
||||
let model_output_filter_chain = model_output_filter_chain.clone();
|
||||
let model_output_filter_agents = model_output_filter_agents.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
|
@ -230,6 +245,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let agents_list = agents_list.clone();
|
||||
let model_filter_chain = model_filter_chain.clone();
|
||||
let model_filter_agents = model_filter_agents.clone();
|
||||
let model_output_filter_chain = model_output_filter_chain.clone();
|
||||
let model_output_filter_agents = model_output_filter_agents.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
|
@ -290,6 +307,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
state_storage,
|
||||
model_filter_chain,
|
||||
model_filter_agents,
|
||||
model_output_filter_chain,
|
||||
model_output_filter_agents,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue