bidirectional streaming for output filter chains

Replace per-chunk HTTP requests to output filters with a single
bidirectional streaming connection per filter. This eliminates
the 50-200+ round-trips per streaming LLM response.

Filters opt in via streaming: true in config. When all output filters
support streaming, brightstaff opens one POST per filter with a streaming
request body (Body::wrap_stream) and reads the streaming response. Filters
that don't opt in fall back to the existing per-chunk behavior.

Updates the PII deanonymizer demo as the reference implementation with
request.stream() + StreamingResponse support.

Made-with: Cursor
This commit is contained in:
Adil Hafeez 2026-03-19 02:27:26 -07:00
parent 1f23c573bf
commit 42d3de8906
10 changed files with 613 additions and 133 deletions

View file

@ -43,6 +43,8 @@ properties:
- streamable-http
tool:
type: string
streaming:
type: boolean
additionalProperties: false
required:
- id

View file

@ -26,7 +26,7 @@ opentelemetry-stdout = "0.31"
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
pretty_assertions = "1.4.1"
rand = "0.9.2"
reqwest = { version = "0.12.15", features = ["stream"] }
reqwest = { version = "0.12.15", features = ["stream", "http2"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
serde_with = "3.13.0"

View file

@ -210,6 +210,7 @@ mod tests {
url: "http://localhost:8080".to_string(),
tool: None,
transport: None,
streaming: None,
}
}

View file

@ -52,6 +52,7 @@ mod tests {
url: "http://localhost:8081".to_string(),
tool: None,
transport: None,
streaming: None,
},
Agent {
id: "terminal-agent".to_string(),
@ -59,6 +60,7 @@ mod tests {
url: "http://localhost:8082".to_string(),
tool: None,
transport: None,
streaming: None,
},
];

View file

@ -10,6 +10,9 @@ use hermesllm::{ProviderRequest, ProviderRequestType};
use hyper::header::HeaderMap;
use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{debug, info, instrument, warn};
use crate::handlers::jsonrpc::{
@ -50,6 +53,18 @@ pub enum PipelineError {
},
}
/// A live streaming filter pipeline. LLM chunks go into `input_tx`;
/// processed chunks come out of `output_rx`. Each filter in the chain
/// is connected via a single bidirectional streaming HTTP connection.
#[derive(Debug)]
pub struct StreamingFilterPipeline {
pub input_tx: mpsc::Sender<Bytes>,
pub output_rx: mpsc::Receiver<Bytes>,
pub handles: Vec<tokio::task::JoinHandle<()>>,
}
const STREAMING_PIPELINE_BUFFER: usize = 16;
/// Service for processing agent pipelines
pub struct PipelineProcessor {
client: reqwest::Client,
@ -429,6 +444,130 @@ impl PipelineProcessor {
session_id
}
/// Build headers for an HTTP raw filter request (shared by per-chunk and streaming paths).
fn build_raw_filter_headers(
request_headers: &HeaderMap,
agent_id: &str,
) -> Result<HeaderMap, PipelineError> {
let mut headers = request_headers.clone();
headers.remove(hyper::header::CONTENT_LENGTH);
headers.remove(TRACE_PARENT_HEADER);
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut headers));
});
headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(agent_id)
.map_err(|_| PipelineError::AgentNotFound(agent_id.to_string()))?,
);
headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
headers.insert(
"Accept",
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
);
headers.insert(
"Content-Type",
hyper::header::HeaderValue::from_static("application/octet-stream"),
);
Ok(headers)
}
/// Set up a bidirectional streaming output filter pipeline.
///
/// Opens one streaming POST per filter (using chunked transfer encoding)
/// and chains them: the response stream of filter N feeds the request body
/// of filter N+1. Returns a pipeline where the caller pushes LLM chunks
/// into `input_tx` and reads processed chunks from `output_rx`.
pub async fn start_streaming_output_pipeline(
agents: &[&Agent],
request_headers: &HeaderMap,
request_path: &str,
) -> Result<StreamingFilterPipeline, PipelineError> {
let client = reqwest::Client::builder()
.build()
.map_err(PipelineError::RequestFailed)?;
let (input_tx, first_rx) = mpsc::channel::<Bytes>(STREAMING_PIPELINE_BUFFER);
let mut current_rx = first_rx;
let mut handles = Vec::new();
for agent in agents {
let url = format!("{}{}", agent.url, request_path);
let headers = Self::build_raw_filter_headers(request_headers, &agent.id)?;
let body_stream = ReceiverStream::new(current_rx).map(Ok::<_, std::io::Error>);
let body = reqwest::Body::wrap_stream(body_stream);
debug!(agent = %agent.id, url = %url, "opening streaming filter connection");
let response = client.post(&url).headers(headers).body(body).send().await?;
let http_status = response.status();
if !http_status.is_success() {
let error_body = response
.text()
.await
.unwrap_or_else(|_| "<unreadable>".to_string());
return Err(if http_status.is_client_error() {
PipelineError::ClientError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
} else {
PipelineError::ServerError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
});
}
let (next_tx, next_rx) = mpsc::channel::<Bytes>(STREAMING_PIPELINE_BUFFER);
let agent_id = agent.id.clone();
let handle = tokio::spawn(async move {
let mut resp_stream = response.bytes_stream();
while let Some(item) = resp_stream.next().await {
match item {
Ok(chunk) => {
if next_tx.send(chunk).await.is_err() {
debug!(agent = %agent_id, "streaming pipeline receiver dropped");
break;
}
}
Err(e) => {
warn!(agent = %agent_id, error = %e, "streaming filter response error");
break;
}
}
}
debug!(agent = %agent_id, "streaming filter stage completed");
});
handles.push(handle);
current_rx = next_rx;
}
info!(
filter_count = agents.len(),
"streaming output filter pipeline established"
);
Ok(StreamingFilterPipeline {
input_tx,
output_rx: current_rx,
handles,
})
}
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
/// Used for input and output filters where the full raw request/response is passed through.
/// No MCP protocol wrapping; agent_type is ignored.
@ -454,25 +593,7 @@ impl PipelineProcessor {
span.update_name(format!("execute_raw_filter ({})", agent.id));
});
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
agent_headers.remove(TRACE_PARENT_HEADER);
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
});
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&agent.id)
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let mut agent_headers = Self::build_raw_filter_headers(request_headers, &agent.id)?;
agent_headers.insert(
"Accept",
hyper::header::HeaderValue::from_static("application/json"),
@ -482,9 +603,6 @@ impl PipelineProcessor {
hyper::header::HeaderValue::from_static("application/json"),
);
// Append the original request path so the filter endpoint encodes the API format.
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
// -> POST http://host/anonymize/v1/chat/completions
let url = format!("{}{}", agent.url, request_path);
debug!(agent = %agent.id, url = %url, "sending raw filter request");
@ -682,6 +800,7 @@ mod tests {
tool: None,
url: server_url,
agent_type: None,
streaming: None,
};
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
@ -722,6 +841,7 @@ mod tests {
tool: None,
url: server_url,
agent_type: None,
streaming: None,
};
let body = serde_json::json!({"messages": [{"role": "user", "content": "Ping"}]});
@ -775,6 +895,7 @@ mod tests {
tool: None,
url: server_url,
agent_type: None,
streaming: None,
};
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hi"}]});
@ -793,4 +914,29 @@ mod tests {
_ => panic!("Expected client error when isError flag is set"),
}
}
#[tokio::test]
async fn test_streaming_pipeline_connection_refused() {
let agent = Agent {
id: "unreachable".to_string(),
transport: None,
tool: None,
url: "http://127.0.0.1:1".to_string(),
agent_type: Some("http".to_string()),
streaming: Some(true),
};
let headers = HeaderMap::new();
let result = PipelineProcessor::start_streaming_output_pipeline(
&[&agent],
&headers,
"/v1/chat/completions",
)
.await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
PipelineError::RequestFailed(_)
));
}
}

View file

@ -294,11 +294,13 @@ where
}
/// Creates a streaming response that processes each raw chunk through output filters.
/// Filters receive the raw LLM response bytes and request path (any API shape; not limited to
/// chat completions). On filter error mid-stream the original chunk is passed through (headers already sent).
///
/// If all filters in the chain have `streaming: true`, uses a single bidirectional
/// HTTP/2 connection per filter (no per-chunk overhead). Otherwise falls back to
/// per-chunk HTTP requests (the original behavior).
pub fn create_streaming_response_with_output_filter<S, P>(
mut byte_stream: S,
mut inner_processor: P,
byte_stream: S,
inner_processor: P,
output_chain: ResolvedFilterChain,
request_headers: HeaderMap,
request_path: String,
@ -307,84 +309,33 @@ where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
let use_streaming = output_chain.all_support_streaming();
let (tx, rx) = mpsc::channel::<Bytes>(STREAM_BUFFER_SIZE);
let current_span = tracing::Span::current();
let processor_handle = tokio::spawn(
async move {
let mut is_first_chunk = true;
let mut pipeline_processor = PipelineProcessor::default();
let chain = output_chain.to_agent_filter_chain("output_filter");
while let Some(item) = byte_stream.next().await {
let chunk = match item {
Ok(chunk) => chunk,
Err(err) => {
let err_msg = format!("Error receiving chunk: {:?}", err);
warn!(error = %err_msg, "stream error");
inner_processor.on_error(&err_msg);
break;
}
};
if is_first_chunk {
inner_processor.on_first_bytes();
is_first_chunk = false;
}
// Pass raw chunk bytes through the output filter chain
let processed_chunk = match pipeline_processor
.process_raw_filter_chain(
&chunk,
&chain,
&output_chain.agents,
&request_headers,
&request_path,
)
.await
{
Ok(filtered) => filtered,
Err(PipelineError::ClientError {
agent,
status,
body,
}) => {
warn!(
agent = %agent,
status = %status,
body = %body,
"output filter client error, passing through original chunk"
);
chunk
}
Err(e) => {
warn!(error = %e, "output filter error, passing through original chunk");
chunk
}
};
// Pass through inner processor for metrics/observability
match inner_processor.process_chunk(processed_chunk) {
Ok(Some(final_chunk)) => {
if tx.send(final_chunk).await.is_err() {
warn!("receiver dropped");
break;
}
}
Ok(None) => continue,
Err(err) => {
warn!("processor error: {}", err);
inner_processor.on_error(&err);
break;
}
}
}
inner_processor.on_complete();
debug!("output filter streaming completed");
}
.instrument(current_span),
);
let processor_handle = if use_streaming {
info!("using bidirectional streaming output filter pipeline");
spawn_streaming_output_filter(
byte_stream,
inner_processor,
output_chain,
request_headers,
request_path,
tx,
current_span,
)
} else {
debug!("using per-chunk output filter pipeline");
spawn_per_chunk_output_filter(
byte_stream,
inner_processor,
output_chain,
request_headers,
request_path,
tx,
current_span,
)
};
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
@ -395,6 +346,216 @@ where
}
}
/// Bidirectional streaming path: one HTTP/2 connection per filter for the entire
/// LLM response. Falls back to per-chunk mode if the pipeline fails to establish.
fn spawn_streaming_output_filter<S, P>(
mut byte_stream: S,
mut inner_processor: P,
output_chain: ResolvedFilterChain,
request_headers: HeaderMap,
request_path: String,
tx: mpsc::Sender<Bytes>,
current_span: tracing::Span,
) -> tokio::task::JoinHandle<()>
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
tokio::spawn(
async move {
let agents = output_chain.streaming_agents();
let pipeline = PipelineProcessor::start_streaming_output_pipeline(
&agents,
&request_headers,
&request_path,
)
.await;
let pipeline = match pipeline {
Ok(p) => p,
Err(e) => {
warn!(error = %e, "failed to establish streaming pipeline, falling back to per-chunk");
run_per_chunk_loop(
byte_stream,
inner_processor,
output_chain,
request_headers,
request_path,
tx,
)
.await;
return;
}
};
let input_tx = pipeline.input_tx;
let mut output_rx = pipeline.output_rx;
let _handles = pipeline.handles;
let mut is_first_chunk = true;
// Writer: LLM chunks → pipeline input
let writer = async {
while let Some(item) = byte_stream.next().await {
match item {
Ok(chunk) => {
if input_tx.send(chunk).await.is_err() {
debug!("streaming pipeline input closed");
break;
}
}
Err(err) => {
warn!(error = %format!("{err:?}"), "LLM stream error");
break;
}
}
}
drop(input_tx);
};
// Reader: pipeline output → client
let reader = async {
while let Some(processed) = output_rx.recv().await {
if is_first_chunk {
inner_processor.on_first_bytes();
is_first_chunk = false;
}
match inner_processor.process_chunk(processed) {
Ok(Some(final_chunk)) => {
if tx.send(final_chunk).await.is_err() {
warn!("client receiver dropped");
break;
}
}
Ok(None) => continue,
Err(err) => {
warn!("processor error: {}", err);
inner_processor.on_error(&err);
break;
}
}
}
};
tokio::join!(writer, reader);
inner_processor.on_complete();
debug!("streaming output filter pipeline completed");
}
.instrument(current_span),
)
}
/// Per-chunk path: one HTTP request per chunk per filter (original behavior).
fn spawn_per_chunk_output_filter<S, P>(
byte_stream: S,
inner_processor: P,
output_chain: ResolvedFilterChain,
request_headers: HeaderMap,
request_path: String,
tx: mpsc::Sender<Bytes>,
current_span: tracing::Span,
) -> tokio::task::JoinHandle<()>
where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
tokio::spawn(
async move {
run_per_chunk_loop(
byte_stream,
inner_processor,
output_chain,
request_headers,
request_path,
tx,
)
.await;
}
.instrument(current_span),
)
}
async fn run_per_chunk_loop<S, P>(
mut byte_stream: S,
mut inner_processor: P,
output_chain: ResolvedFilterChain,
request_headers: HeaderMap,
request_path: String,
tx: mpsc::Sender<Bytes>,
) where
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
P: StreamProcessor,
{
let mut is_first_chunk = true;
let mut pipeline_processor = PipelineProcessor::default();
let chain = output_chain.to_agent_filter_chain("output_filter");
while let Some(item) = byte_stream.next().await {
let chunk = match item {
Ok(chunk) => chunk,
Err(err) => {
let err_msg = format!("Error receiving chunk: {:?}", err);
warn!(error = %err_msg, "stream error");
inner_processor.on_error(&err_msg);
break;
}
};
if is_first_chunk {
inner_processor.on_first_bytes();
is_first_chunk = false;
}
let processed_chunk = match pipeline_processor
.process_raw_filter_chain(
&chunk,
&chain,
&output_chain.agents,
&request_headers,
&request_path,
)
.await
{
Ok(filtered) => filtered,
Err(PipelineError::ClientError {
agent,
status,
body,
}) => {
warn!(
agent = %agent,
status = %status,
body = %body,
"output filter client error, passing through original chunk"
);
chunk
}
Err(e) => {
warn!(error = %e, "output filter error, passing through original chunk");
chunk
}
};
match inner_processor.process_chunk(processed_chunk) {
Ok(Some(final_chunk)) => {
if tx.send(final_chunk).await.is_err() {
warn!("receiver dropped");
break;
}
}
Ok(None) => continue,
Err(err) => {
warn!("processor error: {}", err);
inner_processor.on_error(&err);
break;
}
}
}
inner_processor.on_complete();
debug!("output filter streaming completed");
}
/// Truncates a message to the specified maximum length, adding "..." if truncated.
pub fn truncate_message(message: &str, max_length: usize) -> String {
if message.chars().count() > max_length {

View file

@ -20,6 +20,7 @@ pub struct Agent {
pub url: String,
#[serde(rename = "type")]
pub agent_type: Option<String>,
pub streaming: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -43,6 +44,29 @@ impl ResolvedFilterChain {
self.filter_ids.is_empty()
}
/// True when every filter in the chain is an HTTP filter with `streaming: true`.
/// MCP filters and filters without the streaming flag use per-chunk mode.
pub fn all_support_streaming(&self) -> bool {
!self.filter_ids.is_empty()
&& self.filter_ids.iter().all(|id| {
self.agents
.get(id)
.map(|a| {
a.streaming.unwrap_or(false)
&& a.agent_type.as_deref().unwrap_or("mcp") != "mcp"
})
.unwrap_or(false)
})
}
/// Returns references to the ordered agents for the streaming pipeline.
pub fn streaming_agents(&self) -> Vec<&Agent> {
self.filter_ids
.iter()
.filter_map(|id| self.agents.get(id))
.collect()
}
pub fn to_agent_filter_chain(&self, id: &str) -> AgentFilterChain {
AgentFilterChain {
id: id.to_string(),
@ -542,7 +566,7 @@ mod test {
use pretty_assertions::assert_eq;
use std::fs;
use super::{IntoModels, LlmProvider, LlmProviderType};
use super::{Agent, IntoModels, LlmProvider, LlmProviderType, ResolvedFilterChain};
use crate::api::open_ai::ToolType;
#[test]
@ -663,4 +687,81 @@ mod test {
assert!(!model_ids.contains(&"arch-router".to_string()));
assert!(!model_ids.contains(&"plano-orchestrator".to_string()));
}
fn make_agent(id: &str, agent_type: Option<&str>, streaming: Option<bool>) -> Agent {
Agent {
id: id.to_string(),
url: format!("http://localhost:10501/{id}"),
agent_type: agent_type.map(String::from),
transport: None,
tool: None,
streaming,
}
}
#[test]
fn test_all_support_streaming_all_http_streaming() {
let chain = ResolvedFilterChain {
filter_ids: vec!["a".into(), "b".into()],
agents: [
("a".into(), make_agent("a", Some("http"), Some(true))),
("b".into(), make_agent("b", Some("http"), Some(true))),
]
.into(),
};
assert!(chain.all_support_streaming());
}
#[test]
fn test_all_support_streaming_one_missing_flag() {
let chain = ResolvedFilterChain {
filter_ids: vec!["a".into(), "b".into()],
agents: [
("a".into(), make_agent("a", Some("http"), Some(true))),
("b".into(), make_agent("b", Some("http"), None)),
]
.into(),
};
assert!(!chain.all_support_streaming());
}
#[test]
fn test_all_support_streaming_mcp_filter() {
let chain = ResolvedFilterChain {
filter_ids: vec!["a".into()],
agents: [("a".into(), make_agent("a", Some("mcp"), Some(true)))].into(),
};
assert!(!chain.all_support_streaming());
}
#[test]
fn test_all_support_streaming_default_type_is_mcp() {
let chain = ResolvedFilterChain {
filter_ids: vec!["a".into()],
agents: [("a".into(), make_agent("a", None, Some(true)))].into(),
};
assert!(!chain.all_support_streaming());
}
#[test]
fn test_all_support_streaming_empty_chain() {
let chain = ResolvedFilterChain::default();
assert!(!chain.all_support_streaming());
}
#[test]
fn test_streaming_agents_ordered() {
let chain = ResolvedFilterChain {
filter_ids: vec!["b".into(), "a".into()],
agents: [
("a".into(), make_agent("a", Some("http"), Some(true))),
("b".into(), make_agent("b", Some("http"), Some(true))),
]
.into(),
};
let agents = chain.streaming_agents();
assert_eq!(agents.len(), 2);
assert_eq!(agents[0].id, "b");
assert_eq!(agents[1].id, "a");
}
}

View file

@ -7,6 +7,7 @@ filters:
- id: pii_deanonymizer
url: http://localhost:10501/deanonymize
type: http
streaming: true
model_providers:
- model: openai/gpt-4o-mini

View file

@ -21,10 +21,16 @@ import logging
from typing import Any, Dict
from fastapi import FastAPI, Request
from fastapi.responses import Response
from fastapi.responses import Response, StreamingResponse
from pii import anonymize_text, anonymize_message_content
from store import get_mapping, store_mapping, deanonymize_sse, deanonymize_json
from store import (
get_mapping,
store_mapping,
deanonymize_sse,
deanonymize_sse_stream,
deanonymize_json,
)
logging.basicConfig(
level=logging.INFO,
@ -105,11 +111,36 @@ async def deanonymize(path: str, request: Request) -> Response:
/deanonymize/v1/chat/completions OpenAI chat completions
/deanonymize/v1/messages Anthropic messages
/deanonymize/v1/responses OpenAI responses API
Supports two modes:
- Bidirectional streaming: request body is streamed (Content-Type: application/octet-stream).
Reads via request.stream(), processes SSE events incrementally, returns StreamingResponse.
- Per-chunk / full body: reads entire body, processes, returns complete Response.
"""
endpoint = f"/{path}"
is_anthropic = endpoint == "/v1/messages"
request_id = request.headers.get("x-request-id", "unknown")
mapping = get_mapping(request_id)
content_type = request.headers.get("content-type", "")
is_streaming = "application/octet-stream" in content_type
if is_streaming:
if not mapping:
logger.info("request_id=%s streaming, no mapping — passthrough", request_id)
async def passthrough():
async for chunk in request.stream():
yield chunk
return StreamingResponse(passthrough(), media_type="text/event-stream")
logger.info("request_id=%s streaming deanonymize", request_id)
return StreamingResponse(
deanonymize_sse_stream(request_id, request.stream(), mapping, is_anthropic),
media_type="text/event-stream",
)
raw_body = await request.body()
if not mapping:

View file

@ -4,7 +4,7 @@ import json
import logging
import threading
import time
from typing import Dict, Optional, Tuple
from typing import AsyncIterator, Dict, Optional, Tuple
from fastapi.responses import Response
@ -59,36 +59,71 @@ def restore_streaming(request_id: str, content: str, mapping: Dict[str, str]) ->
def deanonymize_sse(
request_id: str, body_str: str, mapping: Dict[str, str], is_anthropic: bool
) -> Response:
result_lines = []
for line in body_str.split("\n"):
stripped = line.strip()
if not (stripped.startswith("data: ") and stripped[6:] != "[DONE]"):
result_lines.append(line)
continue
try:
chunk = json.loads(stripped[6:])
if is_anthropic:
# {"type": "content_block_delta", "delta": {"type": "text_delta", "text": "..."}}
if chunk.get("type") == "content_block_delta":
delta = chunk.get("delta", {})
if delta.get("type") == "text_delta" and delta.get("text"):
delta["text"] = restore_streaming(
request_id, delta["text"], mapping
)
else:
# {"choices": [{"delta": {"content": "..."}}]}
for choice in chunk.get("choices", []):
delta = choice.get("delta", {})
if delta.get("content"):
delta["content"] = restore_streaming(
request_id, delta["content"], mapping
)
result_lines.append("data: " + json.dumps(chunk))
except json.JSONDecodeError:
result_lines.append(line)
result_lines = [
_process_sse_line(request_id, line, mapping, is_anthropic)
for line in body_str.split("\n")
]
return Response(content="\n".join(result_lines), media_type="text/plain")
def _process_sse_line(
request_id: str, line: str, mapping: Dict[str, str], is_anthropic: bool
) -> str:
"""Process a single SSE line, restoring PII in data payloads."""
stripped = line.strip()
if not (stripped.startswith("data: ") and stripped[6:] != "[DONE]"):
return line
try:
chunk = json.loads(stripped[6:])
if is_anthropic:
if chunk.get("type") == "content_block_delta":
delta = chunk.get("delta", {})
if delta.get("type") == "text_delta" and delta.get("text"):
delta["text"] = restore_streaming(
request_id, delta["text"], mapping
)
else:
for choice in chunk.get("choices", []):
delta = choice.get("delta", {})
if delta.get("content"):
delta["content"] = restore_streaming(
request_id, delta["content"], mapping
)
return "data: " + json.dumps(chunk)
except json.JSONDecodeError:
return line
async def deanonymize_sse_stream(
request_id: str,
byte_stream: AsyncIterator[bytes],
mapping: Dict[str, str],
is_anthropic: bool,
):
"""Async generator that reads SSE events from a streaming request body,
de-anonymizes them, and yields processed events as they become complete.
Buffers partial data and splits on SSE event boundaries (blank lines).
"""
buffer = ""
async for raw_chunk in byte_stream:
buffer += raw_chunk.decode("utf-8", errors="replace")
# Yield each complete SSE event (delimited by double newline)
while "\n\n" in buffer:
event, buffer = buffer.split("\n\n", 1)
processed_lines = [
_process_sse_line(request_id, line, mapping, is_anthropic)
for line in event.split("\n")
]
yield "\n".join(processed_lines) + "\n\n"
# Flush any trailing data
if buffer.strip():
processed_lines = [
_process_sse_line(request_id, line, mapping, is_anthropic)
for line in buffer.split("\n")
]
yield "\n".join(processed_lines)
def deanonymize_json(
request_id: str,
raw_body: bytes,