mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
bidirectional streaming for output filter chains
Replace per-chunk HTTP requests to output filters with a single bidirectional streaming connection per filter. This eliminates the 50-200+ round-trips per streaming LLM response. Filters opt in via streaming: true in config. When all output filters support streaming, brightstaff opens one POST per filter with a streaming request body (Body::wrap_stream) and reads the streaming response. Filters that don't opt in fall back to the existing per-chunk behavior. Updates the PII deanonymizer demo as the reference implementation with request.stream() + StreamingResponse support. Made-with: Cursor
This commit is contained in:
parent
1f23c573bf
commit
42d3de8906
10 changed files with 613 additions and 133 deletions
|
|
@ -43,6 +43,8 @@ properties:
|
|||
- streamable-http
|
||||
tool:
|
||||
type: string
|
||||
streaming:
|
||||
type: boolean
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ opentelemetry-stdout = "0.31"
|
|||
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
|
||||
pretty_assertions = "1.4.1"
|
||||
rand = "0.9.2"
|
||||
reqwest = { version = "0.12.15", features = ["stream"] }
|
||||
reqwest = { version = "0.12.15", features = ["stream", "http2"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
serde_with = "3.13.0"
|
||||
|
|
|
|||
|
|
@ -210,6 +210,7 @@ mod tests {
|
|||
url: "http://localhost:8080".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
streaming: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ mod tests {
|
|||
url: "http://localhost:8081".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
streaming: None,
|
||||
},
|
||||
Agent {
|
||||
id: "terminal-agent".to_string(),
|
||||
|
|
@ -59,6 +60,7 @@ mod tests {
|
|||
url: "http://localhost:8082".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
streaming: None,
|
||||
},
|
||||
];
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ use hermesllm::{ProviderRequest, ProviderRequestType};
|
|||
use hyper::header::HeaderMap;
|
||||
use opentelemetry::global;
|
||||
use opentelemetry_http::HeaderInjector;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
|
||||
use crate::handlers::jsonrpc::{
|
||||
|
|
@ -50,6 +53,18 @@ pub enum PipelineError {
|
|||
},
|
||||
}
|
||||
|
||||
/// A live streaming filter pipeline. LLM chunks go into `input_tx`;
|
||||
/// processed chunks come out of `output_rx`. Each filter in the chain
|
||||
/// is connected via a single bidirectional streaming HTTP connection.
|
||||
#[derive(Debug)]
|
||||
pub struct StreamingFilterPipeline {
|
||||
pub input_tx: mpsc::Sender<Bytes>,
|
||||
pub output_rx: mpsc::Receiver<Bytes>,
|
||||
pub handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
const STREAMING_PIPELINE_BUFFER: usize = 16;
|
||||
|
||||
/// Service for processing agent pipelines
|
||||
pub struct PipelineProcessor {
|
||||
client: reqwest::Client,
|
||||
|
|
@ -429,6 +444,130 @@ impl PipelineProcessor {
|
|||
session_id
|
||||
}
|
||||
|
||||
/// Build headers for an HTTP raw filter request (shared by per-chunk and streaming paths).
|
||||
fn build_raw_filter_headers(
|
||||
request_headers: &HeaderMap,
|
||||
agent_id: &str,
|
||||
) -> Result<HeaderMap, PipelineError> {
|
||||
let mut headers = request_headers.clone();
|
||||
headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
headers.remove(TRACE_PARENT_HEADER);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut headers));
|
||||
});
|
||||
|
||||
headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(agent_id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent_id.to_string()))?,
|
||||
);
|
||||
headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
|
||||
);
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Set up a bidirectional streaming output filter pipeline.
|
||||
///
|
||||
/// Opens one streaming POST per filter (using chunked transfer encoding)
|
||||
/// and chains them: the response stream of filter N feeds the request body
|
||||
/// of filter N+1. Returns a pipeline where the caller pushes LLM chunks
|
||||
/// into `input_tx` and reads processed chunks from `output_rx`.
|
||||
pub async fn start_streaming_output_pipeline(
|
||||
agents: &[&Agent],
|
||||
request_headers: &HeaderMap,
|
||||
request_path: &str,
|
||||
) -> Result<StreamingFilterPipeline, PipelineError> {
|
||||
let client = reqwest::Client::builder()
|
||||
.build()
|
||||
.map_err(PipelineError::RequestFailed)?;
|
||||
|
||||
let (input_tx, first_rx) = mpsc::channel::<Bytes>(STREAMING_PIPELINE_BUFFER);
|
||||
let mut current_rx = first_rx;
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for agent in agents {
|
||||
let url = format!("{}{}", agent.url, request_path);
|
||||
let headers = Self::build_raw_filter_headers(request_headers, &agent.id)?;
|
||||
|
||||
let body_stream = ReceiverStream::new(current_rx).map(Ok::<_, std::io::Error>);
|
||||
let body = reqwest::Body::wrap_stream(body_stream);
|
||||
|
||||
debug!(agent = %agent.id, url = %url, "opening streaming filter connection");
|
||||
|
||||
let response = client.post(&url).headers(headers).body(body).send().await?;
|
||||
|
||||
let http_status = response.status();
|
||||
if !http_status.is_success() {
|
||||
let error_body = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "<unreadable>".to_string());
|
||||
return Err(if http_status.is_client_error() {
|
||||
PipelineError::ClientError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
} else {
|
||||
PipelineError::ServerError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let (next_tx, next_rx) = mpsc::channel::<Bytes>(STREAMING_PIPELINE_BUFFER);
|
||||
let agent_id = agent.id.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let mut resp_stream = response.bytes_stream();
|
||||
while let Some(item) = resp_stream.next().await {
|
||||
match item {
|
||||
Ok(chunk) => {
|
||||
if next_tx.send(chunk).await.is_err() {
|
||||
debug!(agent = %agent_id, "streaming pipeline receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(agent = %agent_id, error = %e, "streaming filter response error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!(agent = %agent_id, "streaming filter stage completed");
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
current_rx = next_rx;
|
||||
}
|
||||
|
||||
info!(
|
||||
filter_count = agents.len(),
|
||||
"streaming output filter pipeline established"
|
||||
);
|
||||
Ok(StreamingFilterPipeline {
|
||||
input_tx,
|
||||
output_rx: current_rx,
|
||||
handles,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
|
||||
/// Used for input and output filters where the full raw request/response is passed through.
|
||||
/// No MCP protocol wrapping; agent_type is ignored.
|
||||
|
|
@ -454,25 +593,7 @@ impl PipelineProcessor {
|
|||
span.update_name(format!("execute_raw_filter ({})", agent.id));
|
||||
});
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
|
||||
});
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
);
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
let mut agent_headers = Self::build_raw_filter_headers(request_headers, &agent.id)?;
|
||||
agent_headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
|
|
@ -482,9 +603,6 @@ impl PipelineProcessor {
|
|||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
// Append the original request path so the filter endpoint encodes the API format.
|
||||
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
|
||||
// -> POST http://host/anonymize/v1/chat/completions
|
||||
let url = format!("{}{}", agent.url, request_path);
|
||||
debug!(agent = %agent.id, url = %url, "sending raw filter request");
|
||||
|
||||
|
|
@ -682,6 +800,7 @@ mod tests {
|
|||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
streaming: None,
|
||||
};
|
||||
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
|
||||
|
|
@ -722,6 +841,7 @@ mod tests {
|
|||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
streaming: None,
|
||||
};
|
||||
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Ping"}]});
|
||||
|
|
@ -775,6 +895,7 @@ mod tests {
|
|||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
streaming: None,
|
||||
};
|
||||
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hi"}]});
|
||||
|
|
@ -793,4 +914,29 @@ mod tests {
|
|||
_ => panic!("Expected client error when isError flag is set"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_pipeline_connection_refused() {
|
||||
let agent = Agent {
|
||||
id: "unreachable".to_string(),
|
||||
transport: None,
|
||||
tool: None,
|
||||
url: "http://127.0.0.1:1".to_string(),
|
||||
agent_type: Some("http".to_string()),
|
||||
streaming: Some(true),
|
||||
};
|
||||
let headers = HeaderMap::new();
|
||||
let result = PipelineProcessor::start_streaming_output_pipeline(
|
||||
&[&agent],
|
||||
&headers,
|
||||
"/v1/chat/completions",
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
PipelineError::RequestFailed(_)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -294,11 +294,13 @@ where
|
|||
}
|
||||
|
||||
/// Creates a streaming response that processes each raw chunk through output filters.
|
||||
/// Filters receive the raw LLM response bytes and request path (any API shape; not limited to
|
||||
/// chat completions). On filter error mid-stream the original chunk is passed through (headers already sent).
|
||||
///
|
||||
/// If all filters in the chain have `streaming: true`, uses a single bidirectional
|
||||
/// HTTP/2 connection per filter (no per-chunk overhead). Otherwise falls back to
|
||||
/// per-chunk HTTP requests (the original behavior).
|
||||
pub fn create_streaming_response_with_output_filter<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut inner_processor: P,
|
||||
byte_stream: S,
|
||||
inner_processor: P,
|
||||
output_chain: ResolvedFilterChain,
|
||||
request_headers: HeaderMap,
|
||||
request_path: String,
|
||||
|
|
@ -307,84 +309,33 @@ where
|
|||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
let use_streaming = output_chain.all_support_streaming();
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(STREAM_BUFFER_SIZE);
|
||||
let current_span = tracing::Span::current();
|
||||
|
||||
let processor_handle = tokio::spawn(
|
||||
async move {
|
||||
let mut is_first_chunk = true;
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let chain = output_chain.to_agent_filter_chain("output_filter");
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let chunk = match item {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Error receiving chunk: {:?}", err);
|
||||
warn!(error = %err_msg, "stream error");
|
||||
inner_processor.on_error(&err_msg);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if is_first_chunk {
|
||||
inner_processor.on_first_bytes();
|
||||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
// Pass raw chunk bytes through the output filter chain
|
||||
let processed_chunk = match pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&chunk,
|
||||
&chain,
|
||||
&output_chain.agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered) => filtered,
|
||||
Err(PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) => {
|
||||
warn!(
|
||||
agent = %agent,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"output filter client error, passing through original chunk"
|
||||
);
|
||||
chunk
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter error, passing through original chunk");
|
||||
chunk
|
||||
}
|
||||
};
|
||||
|
||||
// Pass through inner processor for metrics/observability
|
||||
match inner_processor.process_chunk(processed_chunk) {
|
||||
Ok(Some(final_chunk)) => {
|
||||
if tx.send(final_chunk).await.is_err() {
|
||||
warn!("receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => continue,
|
||||
Err(err) => {
|
||||
warn!("processor error: {}", err);
|
||||
inner_processor.on_error(&err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inner_processor.on_complete();
|
||||
debug!("output filter streaming completed");
|
||||
}
|
||||
.instrument(current_span),
|
||||
);
|
||||
let processor_handle = if use_streaming {
|
||||
info!("using bidirectional streaming output filter pipeline");
|
||||
spawn_streaming_output_filter(
|
||||
byte_stream,
|
||||
inner_processor,
|
||||
output_chain,
|
||||
request_headers,
|
||||
request_path,
|
||||
tx,
|
||||
current_span,
|
||||
)
|
||||
} else {
|
||||
debug!("using per-chunk output filter pipeline");
|
||||
spawn_per_chunk_output_filter(
|
||||
byte_stream,
|
||||
inner_processor,
|
||||
output_chain,
|
||||
request_headers,
|
||||
request_path,
|
||||
tx,
|
||||
current_span,
|
||||
)
|
||||
};
|
||||
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
|
@ -395,6 +346,216 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Bidirectional streaming path: one HTTP/2 connection per filter for the entire
|
||||
/// LLM response. Falls back to per-chunk mode if the pipeline fails to establish.
|
||||
fn spawn_streaming_output_filter<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut inner_processor: P,
|
||||
output_chain: ResolvedFilterChain,
|
||||
request_headers: HeaderMap,
|
||||
request_path: String,
|
||||
tx: mpsc::Sender<Bytes>,
|
||||
current_span: tracing::Span,
|
||||
) -> tokio::task::JoinHandle<()>
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let agents = output_chain.streaming_agents();
|
||||
let pipeline = PipelineProcessor::start_streaming_output_pipeline(
|
||||
&agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await;
|
||||
|
||||
let pipeline = match pipeline {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
warn!(error = %e, "failed to establish streaming pipeline, falling back to per-chunk");
|
||||
run_per_chunk_loop(
|
||||
byte_stream,
|
||||
inner_processor,
|
||||
output_chain,
|
||||
request_headers,
|
||||
request_path,
|
||||
tx,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let input_tx = pipeline.input_tx;
|
||||
let mut output_rx = pipeline.output_rx;
|
||||
let _handles = pipeline.handles;
|
||||
let mut is_first_chunk = true;
|
||||
|
||||
// Writer: LLM chunks → pipeline input
|
||||
let writer = async {
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
match item {
|
||||
Ok(chunk) => {
|
||||
if input_tx.send(chunk).await.is_err() {
|
||||
debug!("streaming pipeline input closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %format!("{err:?}"), "LLM stream error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
drop(input_tx);
|
||||
};
|
||||
|
||||
// Reader: pipeline output → client
|
||||
let reader = async {
|
||||
while let Some(processed) = output_rx.recv().await {
|
||||
if is_first_chunk {
|
||||
inner_processor.on_first_bytes();
|
||||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
match inner_processor.process_chunk(processed) {
|
||||
Ok(Some(final_chunk)) => {
|
||||
if tx.send(final_chunk).await.is_err() {
|
||||
warn!("client receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => continue,
|
||||
Err(err) => {
|
||||
warn!("processor error: {}", err);
|
||||
inner_processor.on_error(&err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
tokio::join!(writer, reader);
|
||||
inner_processor.on_complete();
|
||||
debug!("streaming output filter pipeline completed");
|
||||
}
|
||||
.instrument(current_span),
|
||||
)
|
||||
}
|
||||
|
||||
/// Per-chunk path: one HTTP request per chunk per filter (original behavior).
|
||||
fn spawn_per_chunk_output_filter<S, P>(
|
||||
byte_stream: S,
|
||||
inner_processor: P,
|
||||
output_chain: ResolvedFilterChain,
|
||||
request_headers: HeaderMap,
|
||||
request_path: String,
|
||||
tx: mpsc::Sender<Bytes>,
|
||||
current_span: tracing::Span,
|
||||
) -> tokio::task::JoinHandle<()>
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
tokio::spawn(
|
||||
async move {
|
||||
run_per_chunk_loop(
|
||||
byte_stream,
|
||||
inner_processor,
|
||||
output_chain,
|
||||
request_headers,
|
||||
request_path,
|
||||
tx,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
.instrument(current_span),
|
||||
)
|
||||
}
|
||||
|
||||
async fn run_per_chunk_loop<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut inner_processor: P,
|
||||
output_chain: ResolvedFilterChain,
|
||||
request_headers: HeaderMap,
|
||||
request_path: String,
|
||||
tx: mpsc::Sender<Bytes>,
|
||||
) where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
let mut is_first_chunk = true;
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let chain = output_chain.to_agent_filter_chain("output_filter");
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let chunk = match item {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Error receiving chunk: {:?}", err);
|
||||
warn!(error = %err_msg, "stream error");
|
||||
inner_processor.on_error(&err_msg);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if is_first_chunk {
|
||||
inner_processor.on_first_bytes();
|
||||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
let processed_chunk = match pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&chunk,
|
||||
&chain,
|
||||
&output_chain.agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered) => filtered,
|
||||
Err(PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) => {
|
||||
warn!(
|
||||
agent = %agent,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"output filter client error, passing through original chunk"
|
||||
);
|
||||
chunk
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter error, passing through original chunk");
|
||||
chunk
|
||||
}
|
||||
};
|
||||
|
||||
match inner_processor.process_chunk(processed_chunk) {
|
||||
Ok(Some(final_chunk)) => {
|
||||
if tx.send(final_chunk).await.is_err() {
|
||||
warn!("receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => continue,
|
||||
Err(err) => {
|
||||
warn!("processor error: {}", err);
|
||||
inner_processor.on_error(&err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inner_processor.on_complete();
|
||||
debug!("output filter streaming completed");
|
||||
}
|
||||
|
||||
/// Truncates a message to the specified maximum length, adding "..." if truncated.
|
||||
pub fn truncate_message(message: &str, max_length: usize) -> String {
|
||||
if message.chars().count() > max_length {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ pub struct Agent {
|
|||
pub url: String,
|
||||
#[serde(rename = "type")]
|
||||
pub agent_type: Option<String>,
|
||||
pub streaming: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -43,6 +44,29 @@ impl ResolvedFilterChain {
|
|||
self.filter_ids.is_empty()
|
||||
}
|
||||
|
||||
/// True when every filter in the chain is an HTTP filter with `streaming: true`.
|
||||
/// MCP filters and filters without the streaming flag use per-chunk mode.
|
||||
pub fn all_support_streaming(&self) -> bool {
|
||||
!self.filter_ids.is_empty()
|
||||
&& self.filter_ids.iter().all(|id| {
|
||||
self.agents
|
||||
.get(id)
|
||||
.map(|a| {
|
||||
a.streaming.unwrap_or(false)
|
||||
&& a.agent_type.as_deref().unwrap_or("mcp") != "mcp"
|
||||
})
|
||||
.unwrap_or(false)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns references to the ordered agents for the streaming pipeline.
|
||||
pub fn streaming_agents(&self) -> Vec<&Agent> {
|
||||
self.filter_ids
|
||||
.iter()
|
||||
.filter_map(|id| self.agents.get(id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn to_agent_filter_chain(&self, id: &str) -> AgentFilterChain {
|
||||
AgentFilterChain {
|
||||
id: id.to_string(),
|
||||
|
|
@ -542,7 +566,7 @@ mod test {
|
|||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
|
||||
use super::{IntoModels, LlmProvider, LlmProviderType};
|
||||
use super::{Agent, IntoModels, LlmProvider, LlmProviderType, ResolvedFilterChain};
|
||||
use crate::api::open_ai::ToolType;
|
||||
|
||||
#[test]
|
||||
|
|
@ -663,4 +687,81 @@ mod test {
|
|||
assert!(!model_ids.contains(&"arch-router".to_string()));
|
||||
assert!(!model_ids.contains(&"plano-orchestrator".to_string()));
|
||||
}
|
||||
|
||||
fn make_agent(id: &str, agent_type: Option<&str>, streaming: Option<bool>) -> Agent {
|
||||
Agent {
|
||||
id: id.to_string(),
|
||||
url: format!("http://localhost:10501/{id}"),
|
||||
agent_type: agent_type.map(String::from),
|
||||
transport: None,
|
||||
tool: None,
|
||||
streaming,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_support_streaming_all_http_streaming() {
|
||||
let chain = ResolvedFilterChain {
|
||||
filter_ids: vec!["a".into(), "b".into()],
|
||||
agents: [
|
||||
("a".into(), make_agent("a", Some("http"), Some(true))),
|
||||
("b".into(), make_agent("b", Some("http"), Some(true))),
|
||||
]
|
||||
.into(),
|
||||
};
|
||||
assert!(chain.all_support_streaming());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_support_streaming_one_missing_flag() {
|
||||
let chain = ResolvedFilterChain {
|
||||
filter_ids: vec!["a".into(), "b".into()],
|
||||
agents: [
|
||||
("a".into(), make_agent("a", Some("http"), Some(true))),
|
||||
("b".into(), make_agent("b", Some("http"), None)),
|
||||
]
|
||||
.into(),
|
||||
};
|
||||
assert!(!chain.all_support_streaming());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_support_streaming_mcp_filter() {
|
||||
let chain = ResolvedFilterChain {
|
||||
filter_ids: vec!["a".into()],
|
||||
agents: [("a".into(), make_agent("a", Some("mcp"), Some(true)))].into(),
|
||||
};
|
||||
assert!(!chain.all_support_streaming());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_support_streaming_default_type_is_mcp() {
|
||||
let chain = ResolvedFilterChain {
|
||||
filter_ids: vec!["a".into()],
|
||||
agents: [("a".into(), make_agent("a", None, Some(true)))].into(),
|
||||
};
|
||||
assert!(!chain.all_support_streaming());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_support_streaming_empty_chain() {
|
||||
let chain = ResolvedFilterChain::default();
|
||||
assert!(!chain.all_support_streaming());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_agents_ordered() {
|
||||
let chain = ResolvedFilterChain {
|
||||
filter_ids: vec!["b".into(), "a".into()],
|
||||
agents: [
|
||||
("a".into(), make_agent("a", Some("http"), Some(true))),
|
||||
("b".into(), make_agent("b", Some("http"), Some(true))),
|
||||
]
|
||||
.into(),
|
||||
};
|
||||
let agents = chain.streaming_agents();
|
||||
assert_eq!(agents.len(), 2);
|
||||
assert_eq!(agents[0].id, "b");
|
||||
assert_eq!(agents[1].id, "a");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ filters:
|
|||
- id: pii_deanonymizer
|
||||
url: http://localhost:10501/deanonymize
|
||||
type: http
|
||||
streaming: true
|
||||
|
||||
model_providers:
|
||||
- model: openai/gpt-4o-mini
|
||||
|
|
|
|||
|
|
@ -21,10 +21,16 @@ import logging
|
|||
from typing import Any, Dict
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import Response
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
||||
from pii import anonymize_text, anonymize_message_content
|
||||
from store import get_mapping, store_mapping, deanonymize_sse, deanonymize_json
|
||||
from store import (
|
||||
get_mapping,
|
||||
store_mapping,
|
||||
deanonymize_sse,
|
||||
deanonymize_sse_stream,
|
||||
deanonymize_json,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -105,11 +111,36 @@ async def deanonymize(path: str, request: Request) -> Response:
|
|||
/deanonymize/v1/chat/completions — OpenAI chat completions
|
||||
/deanonymize/v1/messages — Anthropic messages
|
||||
/deanonymize/v1/responses — OpenAI responses API
|
||||
|
||||
Supports two modes:
|
||||
- Bidirectional streaming: request body is streamed (Content-Type: application/octet-stream).
|
||||
Reads via request.stream(), processes SSE events incrementally, returns StreamingResponse.
|
||||
- Per-chunk / full body: reads entire body, processes, returns complete Response.
|
||||
"""
|
||||
endpoint = f"/{path}"
|
||||
is_anthropic = endpoint == "/v1/messages"
|
||||
request_id = request.headers.get("x-request-id", "unknown")
|
||||
mapping = get_mapping(request_id)
|
||||
|
||||
content_type = request.headers.get("content-type", "")
|
||||
is_streaming = "application/octet-stream" in content_type
|
||||
|
||||
if is_streaming:
|
||||
if not mapping:
|
||||
logger.info("request_id=%s streaming, no mapping — passthrough", request_id)
|
||||
|
||||
async def passthrough():
|
||||
async for chunk in request.stream():
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(passthrough(), media_type="text/event-stream")
|
||||
|
||||
logger.info("request_id=%s streaming deanonymize", request_id)
|
||||
return StreamingResponse(
|
||||
deanonymize_sse_stream(request_id, request.stream(), mapping, is_anthropic),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
raw_body = await request.body()
|
||||
|
||||
if not mapping:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import json
|
|||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import AsyncIterator, Dict, Optional, Tuple
|
||||
|
||||
from fastapi.responses import Response
|
||||
|
||||
|
|
@ -59,36 +59,71 @@ def restore_streaming(request_id: str, content: str, mapping: Dict[str, str]) ->
|
|||
def deanonymize_sse(
|
||||
request_id: str, body_str: str, mapping: Dict[str, str], is_anthropic: bool
|
||||
) -> Response:
|
||||
result_lines = []
|
||||
for line in body_str.split("\n"):
|
||||
stripped = line.strip()
|
||||
if not (stripped.startswith("data: ") and stripped[6:] != "[DONE]"):
|
||||
result_lines.append(line)
|
||||
continue
|
||||
try:
|
||||
chunk = json.loads(stripped[6:])
|
||||
if is_anthropic:
|
||||
# {"type": "content_block_delta", "delta": {"type": "text_delta", "text": "..."}}
|
||||
if chunk.get("type") == "content_block_delta":
|
||||
delta = chunk.get("delta", {})
|
||||
if delta.get("type") == "text_delta" and delta.get("text"):
|
||||
delta["text"] = restore_streaming(
|
||||
request_id, delta["text"], mapping
|
||||
)
|
||||
else:
|
||||
# {"choices": [{"delta": {"content": "..."}}]}
|
||||
for choice in chunk.get("choices", []):
|
||||
delta = choice.get("delta", {})
|
||||
if delta.get("content"):
|
||||
delta["content"] = restore_streaming(
|
||||
request_id, delta["content"], mapping
|
||||
)
|
||||
result_lines.append("data: " + json.dumps(chunk))
|
||||
except json.JSONDecodeError:
|
||||
result_lines.append(line)
|
||||
result_lines = [
|
||||
_process_sse_line(request_id, line, mapping, is_anthropic)
|
||||
for line in body_str.split("\n")
|
||||
]
|
||||
return Response(content="\n".join(result_lines), media_type="text/plain")
|
||||
|
||||
|
||||
def _process_sse_line(
|
||||
request_id: str, line: str, mapping: Dict[str, str], is_anthropic: bool
|
||||
) -> str:
|
||||
"""Process a single SSE line, restoring PII in data payloads."""
|
||||
stripped = line.strip()
|
||||
if not (stripped.startswith("data: ") and stripped[6:] != "[DONE]"):
|
||||
return line
|
||||
try:
|
||||
chunk = json.loads(stripped[6:])
|
||||
if is_anthropic:
|
||||
if chunk.get("type") == "content_block_delta":
|
||||
delta = chunk.get("delta", {})
|
||||
if delta.get("type") == "text_delta" and delta.get("text"):
|
||||
delta["text"] = restore_streaming(
|
||||
request_id, delta["text"], mapping
|
||||
)
|
||||
else:
|
||||
for choice in chunk.get("choices", []):
|
||||
delta = choice.get("delta", {})
|
||||
if delta.get("content"):
|
||||
delta["content"] = restore_streaming(
|
||||
request_id, delta["content"], mapping
|
||||
)
|
||||
return "data: " + json.dumps(chunk)
|
||||
except json.JSONDecodeError:
|
||||
return line
|
||||
|
||||
|
||||
async def deanonymize_sse_stream(
|
||||
request_id: str,
|
||||
byte_stream: AsyncIterator[bytes],
|
||||
mapping: Dict[str, str],
|
||||
is_anthropic: bool,
|
||||
):
|
||||
"""Async generator that reads SSE events from a streaming request body,
|
||||
de-anonymizes them, and yields processed events as they become complete.
|
||||
Buffers partial data and splits on SSE event boundaries (blank lines).
|
||||
"""
|
||||
buffer = ""
|
||||
async for raw_chunk in byte_stream:
|
||||
buffer += raw_chunk.decode("utf-8", errors="replace")
|
||||
# Yield each complete SSE event (delimited by double newline)
|
||||
while "\n\n" in buffer:
|
||||
event, buffer = buffer.split("\n\n", 1)
|
||||
processed_lines = [
|
||||
_process_sse_line(request_id, line, mapping, is_anthropic)
|
||||
for line in event.split("\n")
|
||||
]
|
||||
yield "\n".join(processed_lines) + "\n\n"
|
||||
# Flush any trailing data
|
||||
if buffer.strip():
|
||||
processed_lines = [
|
||||
_process_sse_line(request_id, line, mapping, is_anthropic)
|
||||
for line in buffer.split("\n")
|
||||
]
|
||||
yield "\n".join(processed_lines)
|
||||
|
||||
|
||||
def deanonymize_json(
|
||||
request_id: str,
|
||||
raw_body: bytes,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue