mirror of
https://github.com/katanemo/plano.git
synced 2026-05-24 14:05:14 +02:00
pass raw bytes through input/output filter chains
This commit is contained in:
parent
80dfb41cad
commit
d26abbfb9c
11 changed files with 610 additions and 448 deletions
|
|
@ -14,10 +14,10 @@ use tokio_stream::StreamExt;
|
|||
use tracing::{debug, info, warn, Instrument};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
use super::pipeline_processor::PipelineProcessor;
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
||||
use crate::tracing::{llm, set_service_name, signals as signal_constants};
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use hermesllm::apis::openai::Message;
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
|
|
@ -281,184 +281,9 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Extract content text from an SSE chunk line (the JSON part after "data: ").
|
||||
/// Returns the content string and whether it was found.
|
||||
fn extract_sse_content(json_str: &str) -> Option<String> {
|
||||
let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
|
||||
value
|
||||
.get("choices")?
|
||||
.get(0)?
|
||||
.get("delta")?
|
||||
.get("content")?
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
/// Replace content in an SSE JSON chunk with new content.
|
||||
fn replace_sse_content(json_str: &str, new_content: &str) -> Option<String> {
|
||||
let mut value: serde_json::Value = serde_json::from_str(json_str).ok()?;
|
||||
value
|
||||
.get_mut("choices")?
|
||||
.get_mut(0)?
|
||||
.get_mut("delta")?
|
||||
.as_object_mut()?
|
||||
.insert(
|
||||
"content".to_string(),
|
||||
serde_json::Value::String(new_content.to_string()),
|
||||
);
|
||||
serde_json::to_string(&value).ok()
|
||||
}
|
||||
|
||||
/// Process an SSE chunk through the output filter chain.
|
||||
/// Parses each `data: ` line, extracts content, sends through filters, and reconstructs.
|
||||
async fn filter_sse_chunk(
|
||||
chunk_str: &str,
|
||||
pipeline_processor: &mut PipelineProcessor,
|
||||
filter_chain: &AgentFilterChain,
|
||||
filter_agents: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> String {
|
||||
let mut result = String::new();
|
||||
for line in chunk_str.split('\n') {
|
||||
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||
if json_str.trim() == "[DONE]" {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
if let Some(content) = extract_sse_content(json_str) {
|
||||
if content.is_empty() {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
// Send content through output filter chain
|
||||
let messages = vec![Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
match pipeline_processor
|
||||
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_messages) => {
|
||||
if let Some(msg) = filtered_messages.first() {
|
||||
let filtered_content = match &msg.content {
|
||||
Some(MessageContent::Text(t)) => Some(t.clone()),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(filtered_content) = filtered_content {
|
||||
if let Some(new_json) =
|
||||
replace_sse_content(json_str, &filtered_content)
|
||||
{
|
||||
result.push_str("data: ");
|
||||
result.push_str(&new_json);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: pass through original
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter chain error, passing through original chunk");
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No content in this SSE line, pass through
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
} else {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
// Remove trailing extra newline if the original didn't end with one
|
||||
if !chunk_str.ends_with('\n') && result.ends_with('\n') {
|
||||
result.pop();
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Process a non-streaming JSON response through the output filter chain.
|
||||
/// Extracts assistant message content, filters it, and reconstructs the response.
|
||||
pub async fn filter_non_streaming_response(
|
||||
response_bytes: &[u8],
|
||||
pipeline_processor: &mut PipelineProcessor,
|
||||
filter_chain: &AgentFilterChain,
|
||||
filter_agents: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Bytes {
|
||||
let response_str = match std::str::from_utf8(response_bytes) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Bytes::from(response_bytes.to_vec()),
|
||||
};
|
||||
|
||||
let mut value: serde_json::Value = match serde_json::from_str(response_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return Bytes::from(response_bytes.to_vec()),
|
||||
};
|
||||
|
||||
// Extract content from choices[0].message.content
|
||||
let content = value
|
||||
.get("choices")
|
||||
.and_then(|c| c.get(0))
|
||||
.and_then(|c| c.get("message"))
|
||||
.and_then(|m| m.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
if let Some(content) = content {
|
||||
let messages = vec![Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
match pipeline_processor
|
||||
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_messages) => {
|
||||
if let Some(msg) = filtered_messages.first() {
|
||||
let filtered_content = match &msg.content {
|
||||
Some(MessageContent::Text(t)) => Some(t.clone()),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(filtered_content) = filtered_content {
|
||||
if let Some(choices) = value.get_mut("choices") {
|
||||
if let Some(choice) = choices.get_mut(0) {
|
||||
if let Some(message) = choice.get_mut("message") {
|
||||
message.as_object_mut().unwrap().insert(
|
||||
"content".to_string(),
|
||||
serde_json::Value::String(filtered_content),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter chain error on non-streaming response");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Bytes::from(serde_json::to_string(&value).unwrap_or_else(|_| response_str.to_string()))
|
||||
}
|
||||
|
||||
/// Creates a streaming response that processes each chunk through output filters.
|
||||
/// The output filter is called asynchronously for each SSE chunk's content.
|
||||
/// Creates a streaming response that processes each raw chunk through output filters.
|
||||
/// Filters receive the raw LLM response bytes and return (possibly modified) bytes.
|
||||
/// On filter error mid-stream the original chunk is passed through (headers already sent).
|
||||
pub fn create_streaming_response_with_output_filter<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut inner_processor: P,
|
||||
|
|
@ -466,6 +291,7 @@ pub fn create_streaming_response_with_output_filter<S, P>(
|
|||
output_filters: Vec<String>,
|
||||
output_filter_agents: HashMap<String, Agent>,
|
||||
request_headers: HeaderMap,
|
||||
upstream_path: String,
|
||||
) -> StreamingResponse
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
|
|
@ -501,32 +327,35 @@ where
|
|||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
// Try to process through output filter chain
|
||||
let processed_chunk = if let Ok(chunk_str) = std::str::from_utf8(&chunk) {
|
||||
if chunk_str.contains("data: ") {
|
||||
let filtered = filter_sse_chunk(
|
||||
chunk_str,
|
||||
&mut pipeline_processor,
|
||||
&temp_filter_chain,
|
||||
&output_filter_agents,
|
||||
&request_headers,
|
||||
)
|
||||
.await;
|
||||
Bytes::from(filtered)
|
||||
} else {
|
||||
// Non-SSE chunk (could be non-streaming JSON response)
|
||||
let filtered = filter_non_streaming_response(
|
||||
&chunk,
|
||||
&mut pipeline_processor,
|
||||
&temp_filter_chain,
|
||||
&output_filter_agents,
|
||||
&request_headers,
|
||||
)
|
||||
.await;
|
||||
filtered
|
||||
// Pass raw chunk bytes through the output filter chain
|
||||
let processed_chunk = match pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&chunk,
|
||||
&temp_filter_chain,
|
||||
&output_filter_agents,
|
||||
&request_headers,
|
||||
&upstream_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered) => filtered,
|
||||
Err(PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) => {
|
||||
warn!(
|
||||
agent = %agent,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"output filter client error, passing through original chunk"
|
||||
);
|
||||
chunk
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter error, passing through original chunk");
|
||||
chunk
|
||||
}
|
||||
} else {
|
||||
chunk
|
||||
};
|
||||
|
||||
// Pass through inner processor for metrics/observability
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue