mirror of
https://github.com/katanemo/plano.git
synced 2026-05-11 16:52:41 +02:00
bidirectional streaming for output filter chains
Replace per-chunk HTTP requests to output filters with a single bidirectional streaming connection per filter. This eliminates the 50-200+ round-trips per streaming LLM response. Filters opt in via streaming: true in config. When all output filters support streaming, brightstaff opens one POST per filter with a streaming request body (Body::wrap_stream) and reads the streaming response. Filters that don't opt in fall back to the existing per-chunk behavior. Updates the PII deanonymizer demo as the reference implementation with request.stream() + StreamingResponse support. Made-with: Cursor
This commit is contained in:
parent
1f23c573bf
commit
42d3de8906
10 changed files with 613 additions and 133 deletions
|
|
@ -26,7 +26,7 @@ opentelemetry-stdout = "0.31"
|
|||
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] }
|
||||
pretty_assertions = "1.4.1"
|
||||
rand = "0.9.2"
|
||||
reqwest = { version = "0.12.15", features = ["stream"] }
|
||||
reqwest = { version = "0.12.15", features = ["stream", "http2"] }
|
||||
serde = { version = "1.0.219", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
serde_with = "3.13.0"
|
||||
|
|
|
|||
|
|
@ -210,6 +210,7 @@ mod tests {
|
|||
url: "http://localhost:8080".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
streaming: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ mod tests {
|
|||
url: "http://localhost:8081".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
streaming: None,
|
||||
},
|
||||
Agent {
|
||||
id: "terminal-agent".to_string(),
|
||||
|
|
@ -59,6 +60,7 @@ mod tests {
|
|||
url: "http://localhost:8082".to_string(),
|
||||
tool: None,
|
||||
transport: None,
|
||||
streaming: None,
|
||||
},
|
||||
];
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,9 @@ use hermesllm::{ProviderRequest, ProviderRequestType};
|
|||
use hyper::header::HeaderMap;
|
||||
use opentelemetry::global;
|
||||
use opentelemetry_http::HeaderInjector;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{debug, info, instrument, warn};
|
||||
|
||||
use crate::handlers::jsonrpc::{
|
||||
|
|
@ -50,6 +53,18 @@ pub enum PipelineError {
|
|||
},
|
||||
}
|
||||
|
||||
/// A live streaming filter pipeline. LLM chunks go into `input_tx`;
|
||||
/// processed chunks come out of `output_rx`. Each filter in the chain
|
||||
/// is connected via a single bidirectional streaming HTTP connection.
|
||||
#[derive(Debug)]
|
||||
pub struct StreamingFilterPipeline {
|
||||
pub input_tx: mpsc::Sender<Bytes>,
|
||||
pub output_rx: mpsc::Receiver<Bytes>,
|
||||
pub handles: Vec<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
const STREAMING_PIPELINE_BUFFER: usize = 16;
|
||||
|
||||
/// Service for processing agent pipelines
|
||||
pub struct PipelineProcessor {
|
||||
client: reqwest::Client,
|
||||
|
|
@ -429,6 +444,130 @@ impl PipelineProcessor {
|
|||
session_id
|
||||
}
|
||||
|
||||
/// Build headers for an HTTP raw filter request (shared by per-chunk and streaming paths).
|
||||
fn build_raw_filter_headers(
|
||||
request_headers: &HeaderMap,
|
||||
agent_id: &str,
|
||||
) -> Result<HeaderMap, PipelineError> {
|
||||
let mut headers = request_headers.clone();
|
||||
headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
headers.remove(TRACE_PARENT_HEADER);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut headers));
|
||||
});
|
||||
|
||||
headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(agent_id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent_id.to_string()))?,
|
||||
);
|
||||
headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
|
||||
);
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/octet-stream"),
|
||||
);
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Set up a bidirectional streaming output filter pipeline.
|
||||
///
|
||||
/// Opens one streaming POST per filter (using chunked transfer encoding)
|
||||
/// and chains them: the response stream of filter N feeds the request body
|
||||
/// of filter N+1. Returns a pipeline where the caller pushes LLM chunks
|
||||
/// into `input_tx` and reads processed chunks from `output_rx`.
|
||||
pub async fn start_streaming_output_pipeline(
|
||||
agents: &[&Agent],
|
||||
request_headers: &HeaderMap,
|
||||
request_path: &str,
|
||||
) -> Result<StreamingFilterPipeline, PipelineError> {
|
||||
let client = reqwest::Client::builder()
|
||||
.build()
|
||||
.map_err(PipelineError::RequestFailed)?;
|
||||
|
||||
let (input_tx, first_rx) = mpsc::channel::<Bytes>(STREAMING_PIPELINE_BUFFER);
|
||||
let mut current_rx = first_rx;
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for agent in agents {
|
||||
let url = format!("{}{}", agent.url, request_path);
|
||||
let headers = Self::build_raw_filter_headers(request_headers, &agent.id)?;
|
||||
|
||||
let body_stream = ReceiverStream::new(current_rx).map(Ok::<_, std::io::Error>);
|
||||
let body = reqwest::Body::wrap_stream(body_stream);
|
||||
|
||||
debug!(agent = %agent.id, url = %url, "opening streaming filter connection");
|
||||
|
||||
let response = client.post(&url).headers(headers).body(body).send().await?;
|
||||
|
||||
let http_status = response.status();
|
||||
if !http_status.is_success() {
|
||||
let error_body = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "<unreadable>".to_string());
|
||||
return Err(if http_status.is_client_error() {
|
||||
PipelineError::ClientError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
} else {
|
||||
PipelineError::ServerError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let (next_tx, next_rx) = mpsc::channel::<Bytes>(STREAMING_PIPELINE_BUFFER);
|
||||
let agent_id = agent.id.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let mut resp_stream = response.bytes_stream();
|
||||
while let Some(item) = resp_stream.next().await {
|
||||
match item {
|
||||
Ok(chunk) => {
|
||||
if next_tx.send(chunk).await.is_err() {
|
||||
debug!(agent = %agent_id, "streaming pipeline receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(agent = %agent_id, error = %e, "streaming filter response error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!(agent = %agent_id, "streaming filter stage completed");
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
current_rx = next_rx;
|
||||
}
|
||||
|
||||
info!(
|
||||
filter_count = agents.len(),
|
||||
"streaming output filter pipeline established"
|
||||
);
|
||||
Ok(StreamingFilterPipeline {
|
||||
input_tx,
|
||||
output_rx: current_rx,
|
||||
handles,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
|
||||
/// Used for input and output filters where the full raw request/response is passed through.
|
||||
/// No MCP protocol wrapping; agent_type is ignored.
|
||||
|
|
@ -454,25 +593,7 @@ impl PipelineProcessor {
|
|||
span.update_name(format!("execute_raw_filter ({})", agent.id));
|
||||
});
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
|
||||
});
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
);
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
let mut agent_headers = Self::build_raw_filter_headers(request_headers, &agent.id)?;
|
||||
agent_headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
|
|
@ -482,9 +603,6 @@ impl PipelineProcessor {
|
|||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
// Append the original request path so the filter endpoint encodes the API format.
|
||||
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
|
||||
// -> POST http://host/anonymize/v1/chat/completions
|
||||
let url = format!("{}{}", agent.url, request_path);
|
||||
debug!(agent = %agent.id, url = %url, "sending raw filter request");
|
||||
|
||||
|
|
@ -682,6 +800,7 @@ mod tests {
|
|||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
streaming: None,
|
||||
};
|
||||
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
|
||||
|
|
@ -722,6 +841,7 @@ mod tests {
|
|||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
streaming: None,
|
||||
};
|
||||
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Ping"}]});
|
||||
|
|
@ -775,6 +895,7 @@ mod tests {
|
|||
tool: None,
|
||||
url: server_url,
|
||||
agent_type: None,
|
||||
streaming: None,
|
||||
};
|
||||
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hi"}]});
|
||||
|
|
@ -793,4 +914,29 @@ mod tests {
|
|||
_ => panic!("Expected client error when isError flag is set"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_pipeline_connection_refused() {
|
||||
let agent = Agent {
|
||||
id: "unreachable".to_string(),
|
||||
transport: None,
|
||||
tool: None,
|
||||
url: "http://127.0.0.1:1".to_string(),
|
||||
agent_type: Some("http".to_string()),
|
||||
streaming: Some(true),
|
||||
};
|
||||
let headers = HeaderMap::new();
|
||||
let result = PipelineProcessor::start_streaming_output_pipeline(
|
||||
&[&agent],
|
||||
&headers,
|
||||
"/v1/chat/completions",
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
PipelineError::RequestFailed(_)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -294,11 +294,13 @@ where
|
|||
}
|
||||
|
||||
/// Creates a streaming response that processes each raw chunk through output filters.
|
||||
/// Filters receive the raw LLM response bytes and request path (any API shape; not limited to
|
||||
/// chat completions). On filter error mid-stream the original chunk is passed through (headers already sent).
|
||||
///
|
||||
/// If all filters in the chain have `streaming: true`, uses a single bidirectional
|
||||
/// HTTP/2 connection per filter (no per-chunk overhead). Otherwise falls back to
|
||||
/// per-chunk HTTP requests (the original behavior).
|
||||
pub fn create_streaming_response_with_output_filter<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut inner_processor: P,
|
||||
byte_stream: S,
|
||||
inner_processor: P,
|
||||
output_chain: ResolvedFilterChain,
|
||||
request_headers: HeaderMap,
|
||||
request_path: String,
|
||||
|
|
@ -307,84 +309,33 @@ where
|
|||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
let use_streaming = output_chain.all_support_streaming();
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(STREAM_BUFFER_SIZE);
|
||||
let current_span = tracing::Span::current();
|
||||
|
||||
let processor_handle = tokio::spawn(
|
||||
async move {
|
||||
let mut is_first_chunk = true;
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let chain = output_chain.to_agent_filter_chain("output_filter");
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let chunk = match item {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Error receiving chunk: {:?}", err);
|
||||
warn!(error = %err_msg, "stream error");
|
||||
inner_processor.on_error(&err_msg);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if is_first_chunk {
|
||||
inner_processor.on_first_bytes();
|
||||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
// Pass raw chunk bytes through the output filter chain
|
||||
let processed_chunk = match pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&chunk,
|
||||
&chain,
|
||||
&output_chain.agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered) => filtered,
|
||||
Err(PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) => {
|
||||
warn!(
|
||||
agent = %agent,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"output filter client error, passing through original chunk"
|
||||
);
|
||||
chunk
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter error, passing through original chunk");
|
||||
chunk
|
||||
}
|
||||
};
|
||||
|
||||
// Pass through inner processor for metrics/observability
|
||||
match inner_processor.process_chunk(processed_chunk) {
|
||||
Ok(Some(final_chunk)) => {
|
||||
if tx.send(final_chunk).await.is_err() {
|
||||
warn!("receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => continue,
|
||||
Err(err) => {
|
||||
warn!("processor error: {}", err);
|
||||
inner_processor.on_error(&err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inner_processor.on_complete();
|
||||
debug!("output filter streaming completed");
|
||||
}
|
||||
.instrument(current_span),
|
||||
);
|
||||
let processor_handle = if use_streaming {
|
||||
info!("using bidirectional streaming output filter pipeline");
|
||||
spawn_streaming_output_filter(
|
||||
byte_stream,
|
||||
inner_processor,
|
||||
output_chain,
|
||||
request_headers,
|
||||
request_path,
|
||||
tx,
|
||||
current_span,
|
||||
)
|
||||
} else {
|
||||
debug!("using per-chunk output filter pipeline");
|
||||
spawn_per_chunk_output_filter(
|
||||
byte_stream,
|
||||
inner_processor,
|
||||
output_chain,
|
||||
request_headers,
|
||||
request_path,
|
||||
tx,
|
||||
current_span,
|
||||
)
|
||||
};
|
||||
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
|
@ -395,6 +346,216 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Bidirectional streaming path: one HTTP/2 connection per filter for the entire
|
||||
/// LLM response. Falls back to per-chunk mode if the pipeline fails to establish.
|
||||
fn spawn_streaming_output_filter<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut inner_processor: P,
|
||||
output_chain: ResolvedFilterChain,
|
||||
request_headers: HeaderMap,
|
||||
request_path: String,
|
||||
tx: mpsc::Sender<Bytes>,
|
||||
current_span: tracing::Span,
|
||||
) -> tokio::task::JoinHandle<()>
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let agents = output_chain.streaming_agents();
|
||||
let pipeline = PipelineProcessor::start_streaming_output_pipeline(
|
||||
&agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await;
|
||||
|
||||
let pipeline = match pipeline {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
warn!(error = %e, "failed to establish streaming pipeline, falling back to per-chunk");
|
||||
run_per_chunk_loop(
|
||||
byte_stream,
|
||||
inner_processor,
|
||||
output_chain,
|
||||
request_headers,
|
||||
request_path,
|
||||
tx,
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let input_tx = pipeline.input_tx;
|
||||
let mut output_rx = pipeline.output_rx;
|
||||
let _handles = pipeline.handles;
|
||||
let mut is_first_chunk = true;
|
||||
|
||||
// Writer: LLM chunks → pipeline input
|
||||
let writer = async {
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
match item {
|
||||
Ok(chunk) => {
|
||||
if input_tx.send(chunk).await.is_err() {
|
||||
debug!("streaming pipeline input closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %format!("{err:?}"), "LLM stream error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
drop(input_tx);
|
||||
};
|
||||
|
||||
// Reader: pipeline output → client
|
||||
let reader = async {
|
||||
while let Some(processed) = output_rx.recv().await {
|
||||
if is_first_chunk {
|
||||
inner_processor.on_first_bytes();
|
||||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
match inner_processor.process_chunk(processed) {
|
||||
Ok(Some(final_chunk)) => {
|
||||
if tx.send(final_chunk).await.is_err() {
|
||||
warn!("client receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => continue,
|
||||
Err(err) => {
|
||||
warn!("processor error: {}", err);
|
||||
inner_processor.on_error(&err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
tokio::join!(writer, reader);
|
||||
inner_processor.on_complete();
|
||||
debug!("streaming output filter pipeline completed");
|
||||
}
|
||||
.instrument(current_span),
|
||||
)
|
||||
}
|
||||
|
||||
/// Per-chunk path: one HTTP request per chunk per filter (original behavior).
|
||||
fn spawn_per_chunk_output_filter<S, P>(
|
||||
byte_stream: S,
|
||||
inner_processor: P,
|
||||
output_chain: ResolvedFilterChain,
|
||||
request_headers: HeaderMap,
|
||||
request_path: String,
|
||||
tx: mpsc::Sender<Bytes>,
|
||||
current_span: tracing::Span,
|
||||
) -> tokio::task::JoinHandle<()>
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
tokio::spawn(
|
||||
async move {
|
||||
run_per_chunk_loop(
|
||||
byte_stream,
|
||||
inner_processor,
|
||||
output_chain,
|
||||
request_headers,
|
||||
request_path,
|
||||
tx,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
.instrument(current_span),
|
||||
)
|
||||
}
|
||||
|
||||
async fn run_per_chunk_loop<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut inner_processor: P,
|
||||
output_chain: ResolvedFilterChain,
|
||||
request_headers: HeaderMap,
|
||||
request_path: String,
|
||||
tx: mpsc::Sender<Bytes>,
|
||||
) where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
let mut is_first_chunk = true;
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let chain = output_chain.to_agent_filter_chain("output_filter");
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let chunk = match item {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Error receiving chunk: {:?}", err);
|
||||
warn!(error = %err_msg, "stream error");
|
||||
inner_processor.on_error(&err_msg);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if is_first_chunk {
|
||||
inner_processor.on_first_bytes();
|
||||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
let processed_chunk = match pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&chunk,
|
||||
&chain,
|
||||
&output_chain.agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered) => filtered,
|
||||
Err(PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) => {
|
||||
warn!(
|
||||
agent = %agent,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"output filter client error, passing through original chunk"
|
||||
);
|
||||
chunk
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter error, passing through original chunk");
|
||||
chunk
|
||||
}
|
||||
};
|
||||
|
||||
match inner_processor.process_chunk(processed_chunk) {
|
||||
Ok(Some(final_chunk)) => {
|
||||
if tx.send(final_chunk).await.is_err() {
|
||||
warn!("receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => continue,
|
||||
Err(err) => {
|
||||
warn!("processor error: {}", err);
|
||||
inner_processor.on_error(&err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inner_processor.on_complete();
|
||||
debug!("output filter streaming completed");
|
||||
}
|
||||
|
||||
/// Truncates a message to the specified maximum length, adding "..." if truncated.
|
||||
pub fn truncate_message(message: &str, max_length: usize) -> String {
|
||||
if message.chars().count() > max_length {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ pub struct Agent {
|
|||
pub url: String,
|
||||
#[serde(rename = "type")]
|
||||
pub agent_type: Option<String>,
|
||||
pub streaming: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -43,6 +44,29 @@ impl ResolvedFilterChain {
|
|||
self.filter_ids.is_empty()
|
||||
}
|
||||
|
||||
/// True when every filter in the chain is an HTTP filter with `streaming: true`.
|
||||
/// MCP filters and filters without the streaming flag use per-chunk mode.
|
||||
pub fn all_support_streaming(&self) -> bool {
|
||||
!self.filter_ids.is_empty()
|
||||
&& self.filter_ids.iter().all(|id| {
|
||||
self.agents
|
||||
.get(id)
|
||||
.map(|a| {
|
||||
a.streaming.unwrap_or(false)
|
||||
&& a.agent_type.as_deref().unwrap_or("mcp") != "mcp"
|
||||
})
|
||||
.unwrap_or(false)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns references to the ordered agents for the streaming pipeline.
|
||||
pub fn streaming_agents(&self) -> Vec<&Agent> {
|
||||
self.filter_ids
|
||||
.iter()
|
||||
.filter_map(|id| self.agents.get(id))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn to_agent_filter_chain(&self, id: &str) -> AgentFilterChain {
|
||||
AgentFilterChain {
|
||||
id: id.to_string(),
|
||||
|
|
@ -542,7 +566,7 @@ mod test {
|
|||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
|
||||
use super::{IntoModels, LlmProvider, LlmProviderType};
|
||||
use super::{Agent, IntoModels, LlmProvider, LlmProviderType, ResolvedFilterChain};
|
||||
use crate::api::open_ai::ToolType;
|
||||
|
||||
#[test]
|
||||
|
|
@ -663,4 +687,81 @@ mod test {
|
|||
assert!(!model_ids.contains(&"arch-router".to_string()));
|
||||
assert!(!model_ids.contains(&"plano-orchestrator".to_string()));
|
||||
}
|
||||
|
||||
fn make_agent(id: &str, agent_type: Option<&str>, streaming: Option<bool>) -> Agent {
|
||||
Agent {
|
||||
id: id.to_string(),
|
||||
url: format!("http://localhost:10501/{id}"),
|
||||
agent_type: agent_type.map(String::from),
|
||||
transport: None,
|
||||
tool: None,
|
||||
streaming,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_support_streaming_all_http_streaming() {
|
||||
let chain = ResolvedFilterChain {
|
||||
filter_ids: vec!["a".into(), "b".into()],
|
||||
agents: [
|
||||
("a".into(), make_agent("a", Some("http"), Some(true))),
|
||||
("b".into(), make_agent("b", Some("http"), Some(true))),
|
||||
]
|
||||
.into(),
|
||||
};
|
||||
assert!(chain.all_support_streaming());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_support_streaming_one_missing_flag() {
|
||||
let chain = ResolvedFilterChain {
|
||||
filter_ids: vec!["a".into(), "b".into()],
|
||||
agents: [
|
||||
("a".into(), make_agent("a", Some("http"), Some(true))),
|
||||
("b".into(), make_agent("b", Some("http"), None)),
|
||||
]
|
||||
.into(),
|
||||
};
|
||||
assert!(!chain.all_support_streaming());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_support_streaming_mcp_filter() {
|
||||
let chain = ResolvedFilterChain {
|
||||
filter_ids: vec!["a".into()],
|
||||
agents: [("a".into(), make_agent("a", Some("mcp"), Some(true)))].into(),
|
||||
};
|
||||
assert!(!chain.all_support_streaming());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_support_streaming_default_type_is_mcp() {
|
||||
let chain = ResolvedFilterChain {
|
||||
filter_ids: vec!["a".into()],
|
||||
agents: [("a".into(), make_agent("a", None, Some(true)))].into(),
|
||||
};
|
||||
assert!(!chain.all_support_streaming());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_support_streaming_empty_chain() {
|
||||
let chain = ResolvedFilterChain::default();
|
||||
assert!(!chain.all_support_streaming());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_agents_ordered() {
|
||||
let chain = ResolvedFilterChain {
|
||||
filter_ids: vec!["b".into(), "a".into()],
|
||||
agents: [
|
||||
("a".into(), make_agent("a", Some("http"), Some(true))),
|
||||
("b".into(), make_agent("b", Some("http"), Some(true))),
|
||||
]
|
||||
.into(),
|
||||
};
|
||||
let agents = chain.streaming_agents();
|
||||
assert_eq!(agents.len(), 2);
|
||||
assert_eq!(agents[0].id, "b");
|
||||
assert_eq!(agents[1].id, "a");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue