use bytes::Bytes; use common::configuration::{Agent, AgentFilterChain}; use http_body_util::combinators::BoxBody; use http_body_util::StreamBody; use hyper::body::Frame; use hyper::header::HeaderMap; use opentelemetry::trace::TraceContextExt; use opentelemetry::KeyValue; use std::collections::HashMap; use std::time::Instant; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; use tracing::{debug, info, warn, Instrument}; use tracing_opentelemetry::OpenTelemetrySpanExt; use super::pipeline_processor::PipelineProcessor; use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER}; use crate::tracing::{llm, set_service_name, signals as signal_constants}; use hermesllm::apis::openai::{Message, MessageContent, Role}; /// Trait for processing streaming chunks /// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging) pub trait StreamProcessor: Send + 'static { /// Process an incoming chunk of bytes fn process_chunk(&mut self, chunk: Bytes) -> Result, String>; /// Called when the first bytes are received (for time-to-first-token tracking) fn on_first_bytes(&mut self) {} /// Called when streaming completes successfully fn on_complete(&mut self) {} /// Called when streaming encounters an error fn on_error(&mut self, _error: &str) {} } /// A processor that tracks streaming metrics pub struct ObservableStreamProcessor { service_name: String, operation_name: String, total_bytes: usize, chunk_count: usize, start_time: Instant, time_to_first_token: Option, messages: Option>, } impl ObservableStreamProcessor { /// Create a new passthrough processor /// /// # Arguments /// * `service_name` - The service name for this span (e.g., "plano(llm)") /// This will be set as the `service.name.override` attribute on the current span, /// allowing the ServiceNameOverrideExporter to route spans to different services. /// * `operation_name` - The current span operation name (e.g., "POST /v1/chat/completions gpt-4") /// Used to append the flag marker when concerning signals are detected. /// * `start_time` - When the request started (for duration calculation) /// * `messages` - Optional conversation messages for signal analysis pub fn new( service_name: impl Into, operation_name: impl Into, start_time: Instant, messages: Option>, ) -> Self { let service_name = service_name.into(); // Set the service name override on the current span for OpenTelemetry export // This allows the ServiceNameOverrideExporter to route this span to the correct service set_service_name(&service_name); Self { service_name, operation_name: operation_name.into(), total_bytes: 0, chunk_count: 0, start_time, time_to_first_token: None, messages, } } } impl StreamProcessor for ObservableStreamProcessor { fn process_chunk(&mut self, chunk: Bytes) -> Result, String> { self.total_bytes += chunk.len(); self.chunk_count += 1; Ok(Some(chunk)) } fn on_first_bytes(&mut self) { // Record time to first token (only for streaming) if self.time_to_first_token.is_none() { self.time_to_first_token = Some(self.start_time.elapsed().as_millis()); } } fn on_complete(&mut self) { // Record time-to-first-token as an OTel span attribute + event (streaming only) if let Some(ttft) = self.time_to_first_token { let span = tracing::Span::current(); let otel_context = span.context(); let otel_span = otel_context.span(); otel_span.set_attribute(KeyValue::new(llm::TIME_TO_FIRST_TOKEN_MS, ttft as i64)); otel_span.add_event( llm::TIME_TO_FIRST_TOKEN_MS, vec![KeyValue::new(llm::TIME_TO_FIRST_TOKEN_MS, ttft as i64)], ); } // Analyze signals if messages are available and record as span attributes if let Some(ref messages) = self.messages { let analyzer: Box = Box::new(TextBasedSignalAnalyzer::new()); let report = analyzer.analyze(messages); // Get the current OTel span to set signal attributes let span = tracing::Span::current(); let otel_context = span.context(); let otel_span = otel_context.span(); // Add overall quality otel_span.set_attribute(KeyValue::new( signal_constants::QUALITY, format!("{:?}", report.overall_quality), )); // Add repair/follow-up metrics if concerning if report.follow_up.is_concerning || report.follow_up.repair_count > 0 { otel_span.set_attribute(KeyValue::new( signal_constants::REPAIR_COUNT, report.follow_up.repair_count as i64, )); otel_span.set_attribute(KeyValue::new( signal_constants::REPAIR_RATIO, format!("{:.3}", report.follow_up.repair_ratio), )); } // Add frustration metrics if report.frustration.has_frustration { otel_span.set_attribute(KeyValue::new( signal_constants::FRUSTRATION_COUNT, report.frustration.frustration_count as i64, )); otel_span.set_attribute(KeyValue::new( signal_constants::FRUSTRATION_SEVERITY, report.frustration.severity as i64, )); } // Add repetition metrics if report.repetition.has_looping { otel_span.set_attribute(KeyValue::new( signal_constants::REPETITION_COUNT, report.repetition.repetition_count as i64, )); } // Add escalation metrics if report.escalation.escalation_requested { otel_span .set_attribute(KeyValue::new(signal_constants::ESCALATION_REQUESTED, true)); } // Add positive feedback metrics if report.positive_feedback.has_positive_feedback { otel_span.set_attribute(KeyValue::new( signal_constants::POSITIVE_FEEDBACK_COUNT, report.positive_feedback.positive_count as i64, )); } // Flag the span name if any concerning signal is detected let should_flag = report.frustration.has_frustration || report.repetition.has_looping || report.escalation.escalation_requested || matches!( report.overall_quality, InteractionQuality::Poor | InteractionQuality::Severe ); if should_flag { otel_span.update_name(format!("{} {}", self.operation_name, FLAG_MARKER)); } } info!( service = %self.service_name, total_bytes = self.total_bytes, chunk_count = self.chunk_count, duration_ms = self.start_time.elapsed().as_millis(), time_to_first_token_ms = ?self.time_to_first_token, "streaming completed" ); } fn on_error(&mut self, error_msg: &str) { warn!( service = %self.service_name, error = error_msg, duration_ms = self.start_time.elapsed().as_millis(), "stream error" ); } } /// Result of creating a streaming response pub struct StreamingResponse { pub body: BoxBody, pub processor_handle: tokio::task::JoinHandle<()>, } pub fn create_streaming_response( mut byte_stream: S, mut processor: P, buffer_size: usize, ) -> StreamingResponse where S: StreamExt> + Send + Unpin + 'static, P: StreamProcessor, { let (tx, rx) = mpsc::channel::(buffer_size); // Capture the current span so the spawned task inherits the request context let current_span = tracing::Span::current(); // Spawn a task to process and forward chunks let processor_handle = tokio::spawn( async move { let mut is_first_chunk = true; while let Some(item) = byte_stream.next().await { let chunk = match item { Ok(chunk) => chunk, Err(err) => { let err_msg = format!("Error receiving chunk: {:?}", err); warn!(error = %err_msg, "stream error"); processor.on_error(&err_msg); break; } }; // Call on_first_bytes for the first chunk if is_first_chunk { processor.on_first_bytes(); is_first_chunk = false; } // Process the chunk match processor.process_chunk(chunk) { Ok(Some(processed_chunk)) => { if tx.send(processed_chunk).await.is_err() { warn!("receiver dropped"); break; } } Ok(None) => { // Skip this chunk continue; } Err(err) => { warn!("processor error: {}", err); processor.on_error(&err); break; } } } processor.on_complete(); } .instrument(current_span), ); // Convert channel receiver to HTTP stream let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk))); let stream_body = BoxBody::new(StreamBody::new(stream)); StreamingResponse { body: stream_body, processor_handle, } } /// Extract content text from an SSE chunk line (the JSON part after "data: "). /// Returns the content string and whether it was found. fn extract_sse_content(json_str: &str) -> Option { let value: serde_json::Value = serde_json::from_str(json_str).ok()?; value .get("choices")? .get(0)? .get("delta")? .get("content")? .as_str() .map(|s| s.to_string()) } /// Replace content in an SSE JSON chunk with new content. fn replace_sse_content(json_str: &str, new_content: &str) -> Option { let mut value: serde_json::Value = serde_json::from_str(json_str).ok()?; value .get_mut("choices")? .get_mut(0)? .get_mut("delta")? .as_object_mut()? .insert( "content".to_string(), serde_json::Value::String(new_content.to_string()), ); serde_json::to_string(&value).ok() } /// Process an SSE chunk through the output filter chain. /// Parses each `data: ` line, extracts content, sends through filters, and reconstructs. async fn filter_sse_chunk( chunk_str: &str, pipeline_processor: &mut PipelineProcessor, filter_chain: &AgentFilterChain, filter_agents: &HashMap, request_headers: &HeaderMap, ) -> String { let mut result = String::new(); for line in chunk_str.split('\n') { if let Some(json_str) = line.strip_prefix("data: ") { if json_str.trim() == "[DONE]" { result.push_str(line); result.push('\n'); continue; } if let Some(content) = extract_sse_content(json_str) { if content.is_empty() { result.push_str(line); result.push('\n'); continue; } // Send content through output filter chain let messages = vec![Message { role: Role::Assistant, content: Some(MessageContent::Text(content)), name: None, tool_calls: None, tool_call_id: None, }]; match pipeline_processor .process_filter_chain(&messages, filter_chain, filter_agents, request_headers) .await { Ok(filtered_messages) => { if let Some(msg) = filtered_messages.first() { let filtered_content = match &msg.content { Some(MessageContent::Text(t)) => Some(t.clone()), _ => None, }; if let Some(filtered_content) = filtered_content { if let Some(new_json) = replace_sse_content(json_str, &filtered_content) { result.push_str("data: "); result.push_str(&new_json); result.push('\n'); continue; } } } // Fallback: pass through original result.push_str(line); result.push('\n'); } Err(e) => { warn!(error = %e, "output filter chain error, passing through original chunk"); result.push_str(line); result.push('\n'); } } } else { // No content in this SSE line, pass through result.push_str(line); result.push('\n'); } } else { result.push_str(line); result.push('\n'); } } // Remove trailing extra newline if the original didn't end with one if !chunk_str.ends_with('\n') && result.ends_with('\n') { result.pop(); } result } /// Process a non-streaming JSON response through the output filter chain. /// Extracts assistant message content, filters it, and reconstructs the response. pub async fn filter_non_streaming_response( response_bytes: &[u8], pipeline_processor: &mut PipelineProcessor, filter_chain: &AgentFilterChain, filter_agents: &HashMap, request_headers: &HeaderMap, ) -> Bytes { let response_str = match std::str::from_utf8(response_bytes) { Ok(s) => s, Err(_) => return Bytes::from(response_bytes.to_vec()), }; let mut value: serde_json::Value = match serde_json::from_str(response_str) { Ok(v) => v, Err(_) => return Bytes::from(response_bytes.to_vec()), }; // Extract content from choices[0].message.content let content = value .get("choices") .and_then(|c| c.get(0)) .and_then(|c| c.get("message")) .and_then(|m| m.get("content")) .and_then(|c| c.as_str()) .map(|s| s.to_string()); if let Some(content) = content { let messages = vec![Message { role: Role::Assistant, content: Some(MessageContent::Text(content)), name: None, tool_calls: None, tool_call_id: None, }]; match pipeline_processor .process_filter_chain(&messages, filter_chain, filter_agents, request_headers) .await { Ok(filtered_messages) => { if let Some(msg) = filtered_messages.first() { let filtered_content = match &msg.content { Some(MessageContent::Text(t)) => Some(t.clone()), _ => None, }; if let Some(filtered_content) = filtered_content { if let Some(choices) = value.get_mut("choices") { if let Some(choice) = choices.get_mut(0) { if let Some(message) = choice.get_mut("message") { message.as_object_mut().unwrap().insert( "content".to_string(), serde_json::Value::String(filtered_content), ); } } } } } } Err(e) => { warn!(error = %e, "output filter chain error on non-streaming response"); } } } Bytes::from(serde_json::to_string(&value).unwrap_or_else(|_| response_str.to_string())) } /// Creates a streaming response that processes each chunk through output filters. /// The output filter is called asynchronously for each SSE chunk's content. pub fn create_streaming_response_with_output_filter( mut byte_stream: S, mut inner_processor: P, buffer_size: usize, output_filters: Vec, output_filter_agents: HashMap, request_headers: HeaderMap, ) -> StreamingResponse where S: StreamExt> + Send + Unpin + 'static, P: StreamProcessor, { let (tx, rx) = mpsc::channel::(buffer_size); let current_span = tracing::Span::current(); let processor_handle = tokio::spawn( async move { let mut is_first_chunk = true; let mut pipeline_processor = PipelineProcessor::default(); let temp_filter_chain = AgentFilterChain { id: "output_filter".to_string(), default: None, description: None, filter_chain: Some(output_filters), }; while let Some(item) = byte_stream.next().await { let chunk = match item { Ok(chunk) => chunk, Err(err) => { let err_msg = format!("Error receiving chunk: {:?}", err); warn!(error = %err_msg, "stream error"); inner_processor.on_error(&err_msg); break; } }; if is_first_chunk { inner_processor.on_first_bytes(); is_first_chunk = false; } // Try to process through output filter chain let processed_chunk = if let Ok(chunk_str) = std::str::from_utf8(&chunk) { if chunk_str.contains("data: ") { let filtered = filter_sse_chunk( chunk_str, &mut pipeline_processor, &temp_filter_chain, &output_filter_agents, &request_headers, ) .await; Bytes::from(filtered) } else { // Non-SSE chunk (could be non-streaming JSON response) let filtered = filter_non_streaming_response( &chunk, &mut pipeline_processor, &temp_filter_chain, &output_filter_agents, &request_headers, ) .await; filtered } } else { chunk }; // Pass through inner processor for metrics/observability match inner_processor.process_chunk(processed_chunk) { Ok(Some(final_chunk)) => { if tx.send(final_chunk).await.is_err() { warn!("receiver dropped"); break; } } Ok(None) => continue, Err(err) => { warn!("processor error: {}", err); inner_processor.on_error(&err); break; } } } inner_processor.on_complete(); debug!("output filter streaming completed"); } .instrument(current_span), ); let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk))); let stream_body = BoxBody::new(StreamBody::new(stream)); StreamingResponse { body: stream_body, processor_handle, } } /// Truncates a message to the specified maximum length, adding "..." if truncated. pub fn truncate_message(message: &str, max_length: usize) -> String { if message.chars().count() > max_length { let truncated: String = message.chars().take(max_length).collect(); format!("{}...", truncated) } else { message.to_string() } }