diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index e82bb8b2..bfc8f00e 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -116,17 +116,20 @@ mod tests { }; let headers = HeaderMap::new(); + let request_bytes = serde_json::to_vec(&request).expect("failed to serialize request"); let result = pipeline_processor - .process_filter_chain(&request.messages, &test_pipeline, &agent_map, &headers) + .process_raw_filter_chain(&request_bytes, &test_pipeline, &agent_map, &headers, "/v1/chat/completions") .await; println!("Pipeline processing result: {:?}", result); assert!(result.is_ok()); - let processed_messages = result.unwrap(); - // With empty filter chain, should return the original messages unchanged - assert_eq!(processed_messages.len(), 1); - if let Some(MessageContent::Text(content)) = &processed_messages[0].content { + let processed_bytes = result.unwrap(); + // With empty filter chain, should return the original bytes unchanged + let processed_request: ChatCompletionsRequest = + serde_json::from_slice(&processed_bytes).expect("failed to deserialize response"); + assert_eq!(processed_request.messages.len(), 1); + if let Some(MessageContent::Text(content)) = &processed_request.messages[0].content { assert_eq!(content, "Hello world!"); } else { panic!("Expected text content"); diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index c01d7bdc..4cb8531f 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -36,8 +36,6 @@ pub enum PipelineError { NoResultInResponse(String), #[error("No structured content in response from agent '{0}'")] NoStructuredContentInResponse(String), - #[error("No messages in response from agent '{0}'")] - NoMessagesInResponse(String), #[error("Client error from agent '{agent}' (HTTP {status}): {body}")] ClientError { agent: String, @@ -80,68 +78,6 @@ impl PipelineProcessor { } } - // /// Process the filter chain of agents (all except the terminal agent) - // #[instrument( - // skip(self, chat_history, agent_filter_chain, agent_map, request_headers), - // fields( - // filter_count = agent_filter_chain.input_filters.as_ref().map(|fc| fc.len()).unwrap_or(0), - // message_count = chat_history.len() - // ) - // )] - #[allow(clippy::too_many_arguments)] - pub async fn process_filter_chain( - &mut self, - chat_history: &[Message], - agent_filter_chain: &AgentFilterChain, - agent_map: &HashMap, - request_headers: &HeaderMap, - ) -> Result, PipelineError> { - let mut chat_history_updated = chat_history.to_vec(); - - // If filter_chain is None or empty, proceed without filtering - let filter_chain = match agent_filter_chain.input_filters.as_ref() { - Some(fc) if !fc.is_empty() => fc, - _ => return Ok(chat_history_updated), - }; - - for agent_name in filter_chain { - debug!(agent = %agent_name, "processing filter agent"); - - let agent = agent_map - .get(agent_name) - .ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?; - - let tool_name = agent.tool.as_deref().unwrap_or(&agent.id); - - info!( - agent = %agent_name, - tool = %tool_name, - url = %agent.url, - agent_type = %agent.agent_type.as_deref().unwrap_or("mcp"), - conversation_len = chat_history.len(), - "executing filter" - ); - - if agent.agent_type.as_deref().unwrap_or("mcp") == "mcp" { - chat_history_updated = self - .execute_mcp_filter(&chat_history_updated, agent, request_headers) - .await?; - } else { - chat_history_updated = self - .execute_http_filter(&chat_history_updated, agent, request_headers) - .await?; - } - - info!( - agent = %agent_name, - updated_len = chat_history_updated.len(), - "filter completed" - ); - } - - Ok(chat_history_updated) - } - /// Build common MCP headers for requests fn build_mcp_headers( &self, @@ -262,27 +198,6 @@ impl PipelineProcessor { Ok(response) } - /// Build a tools/call JSON-RPC request - fn build_tool_call_request( - &self, - tool_name: &str, - messages: &[Message], - ) -> Result { - let mut arguments = HashMap::new(); - arguments.insert("messages".to_string(), serde_json::to_value(messages)?); - - let mut params = HashMap::new(); - params.insert("name".to_string(), serde_json::to_value(tool_name)?); - params.insert("arguments".to_string(), serde_json::to_value(arguments)?); - - Ok(JsonRpcRequest { - jsonrpc: JSON_RPC_VERSION.to_string(), - id: JsonRpcId::String(Uuid::new_v4().to_string()), - method: TOOL_CALL_METHOD.to_string(), - params: Some(params), - }) - } - /// Build a tools/call JSON-RPC request with a full body dict and path hint. /// Used by execute_mcp_filter_raw so MCP tools receive the same contract as HTTP filters. fn build_tool_call_request_with_body( @@ -307,130 +222,7 @@ impl PipelineProcessor { }) } - /// Send request to a specific agent and return the response content - #[instrument( - skip(self, messages, agent, request_headers), - fields( - agent_id = %agent.id, - filter_name = %agent.id, - message_count = messages.len() - ) - )] - async fn execute_mcp_filter( - &mut self, - messages: &[Message], - agent: &Agent, - request_headers: &HeaderMap, - ) -> Result, PipelineError> { - // Set service name for this filter span - set_service_name(operation_component::AGENT_FILTER); - - // Update current span name to include filter name - use opentelemetry::trace::get_active_span; - get_active_span(|span| { - span.update_name(format!("execute_mcp_filter ({})", agent.id)); - }); - - // Get or create MCP session - let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) { - session_id.clone() - } else { - let session_id = self.get_new_session_id(&agent.id, request_headers).await; - self.agent_id_session_map - .insert(agent.id.clone(), session_id.clone()); - session_id - }; - - info!( - "Using MCP session ID {} for agent {}", - mcp_session_id, agent.id - ); - - // Build JSON-RPC request - let tool_name = agent.tool.as_deref().unwrap_or(&agent.id); - let json_rpc_request = self.build_tool_call_request(tool_name, messages)?; - - // Build headers - let agent_headers = - self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id))?; - - let response = self - .send_mcp_request(&json_rpc_request, &agent_headers, &agent.id) - .await?; - let http_status = response.status(); - let response_bytes = response.bytes().await?; - - // Handle HTTP errors - if !http_status.is_success() { - let error_body = String::from_utf8_lossy(&response_bytes).to_string(); - return Err(if http_status.is_client_error() { - PipelineError::ClientError { - agent: agent.id.clone(), - status: http_status.as_u16(), - body: error_body, - } - } else { - PipelineError::ServerError { - agent: agent.id.clone(), - status: http_status.as_u16(), - body: error_body, - } - }); - } - - info!( - "Response from agent {}: {}", - agent.id, - String::from_utf8_lossy(&response_bytes) - ); - - // Parse SSE response - let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?; - let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?; - let response_result = response - .result - .ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?; - - // Check if error field is set in response result - if response_result - .get("isError") - .and_then(|v| v.as_bool()) - .unwrap_or(false) - { - let error_message = response_result - .get("content") - .and_then(|v| v.as_array()) - .and_then(|arr| arr.first()) - .and_then(|v| v.get("text")) - .and_then(|v| v.as_str()) - .unwrap_or("unknown_error") - .to_string(); - - return Err(PipelineError::ClientError { - agent: agent.id.clone(), - status: hyper::StatusCode::BAD_REQUEST.as_u16(), - body: error_message, - }); - } - - // Extract structured content and parse messages - let response_json = response_result - .get("structuredContent") - .ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?; - - let messages: Vec = response_json - .get("result") - .and_then(|v| v.as_array()) - .ok_or_else(|| PipelineError::NoMessagesInResponse(agent.id.clone()))? - .iter() - .map(|msg_value| serde_json::from_value(msg_value.clone())) - .collect::, _>>() - .map_err(PipelineError::ParseError)?; - - Ok(messages) - } - - /// Like execute_mcp_filter but passes the full raw body dict + path hint as MCP tool arguments. + /// Like execute_mcp_filter_raw but passes the full raw body dict + path hint as MCP tool arguments. /// The MCP tool receives (body: dict, path: str) and returns the modified body dict. async fn execute_mcp_filter_raw( &mut self, @@ -519,11 +311,24 @@ impl PipelineProcessor { }); } - let result = response_result + // FastMCP puts structured Pydantic return values in structuredContent.result, + // but plain dicts land in content[0].text as a JSON string. Try both. + let result = if let Some(structured) = response_result .get("structuredContent") .and_then(|v| v.get("result")) .cloned() - .ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?; + { + structured + } else { + let text = response_result + .get("content") + .and_then(|v| v.as_array()) + .and_then(|arr| arr.first()) + .and_then(|v| v.get("text")) + .and_then(|v| v.as_str()) + .ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?; + serde_json::from_str(text).map_err(PipelineError::ParseError)? + }; Ok(Bytes::from( serde_json::to_vec(&result).map_err(PipelineError::ParseError)?, @@ -624,112 +429,6 @@ impl PipelineProcessor { session_id } - /// Execute a HTTP-based filter agent - #[instrument( - skip(self, messages, agent, request_headers), - fields( - agent_id = %agent.id, - agent_url = %agent.url, - filter_name = %agent.id, - message_count = messages.len() - ) - )] - async fn execute_http_filter( - &mut self, - messages: &[Message], - agent: &Agent, - request_headers: &HeaderMap, - ) -> Result, PipelineError> { - // Set service name for this filter span - set_service_name(operation_component::AGENT_FILTER); - - // Update current span name to include filter name - use opentelemetry::trace::get_active_span; - get_active_span(|span| { - span.update_name(format!("execute_http_filter ({})", agent.id)); - }); - - // Build headers - let mut agent_headers = request_headers.clone(); - agent_headers.remove(hyper::header::CONTENT_LENGTH); - - // Inject OpenTelemetry trace context automatically - agent_headers.remove(TRACE_PARENT_HEADER); - global::get_text_map_propagator(|propagator| { - let cx = - tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); - propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers)); - }); - - agent_headers.insert( - ARCH_UPSTREAM_HOST_HEADER, - hyper::header::HeaderValue::from_str(&agent.id) - .map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?, - ); - - agent_headers.insert( - ENVOY_RETRY_HEADER, - hyper::header::HeaderValue::from_str("3").unwrap(), - ); - - agent_headers.insert( - "Accept", - hyper::header::HeaderValue::from_static("application/json"), - ); - - agent_headers.insert( - "Content-Type", - hyper::header::HeaderValue::from_static("application/json"), - ); - - debug!( - "Sending HTTP request to agent {} at URL: {}", - agent.id, agent.url - ); - - // Send messages array directly as request body - let response = self - .client - .post(&agent.url) - .headers(agent_headers) - .json(&messages) - .send() - .await?; - - let http_status = response.status(); - let response_bytes = response.bytes().await?; - - // Handle HTTP errors - if !http_status.is_success() { - let error_body = String::from_utf8_lossy(&response_bytes).to_string(); - return Err(if http_status.is_client_error() { - PipelineError::ClientError { - agent: agent.id.clone(), - status: http_status.as_u16(), - body: error_body, - } - } else { - PipelineError::ServerError { - agent: agent.id.clone(), - status: http_status.as_u16(), - body: error_body, - } - }); - } - - debug!( - "Response from HTTP agent {}: {}", - agent.id, - String::from_utf8_lossy(&response_bytes) - ); - - // Parse response - expecting array of messages directly - let messages: Vec = - serde_json::from_slice(&response_bytes).map_err(PipelineError::ParseError)?; - - Ok(messages) - } - /// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back. /// Used for input and output filters where the full raw request/response is passed through. /// No MCP protocol wrapping; agent_type is ignored. @@ -925,20 +624,9 @@ impl PipelineProcessor { #[cfg(test)] mod tests { use super::*; - use hermesllm::apis::openai::{Message, MessageContent, Role}; use mockito::Server; use std::collections::HashMap; - fn create_test_message(role: Role, content: &str) -> Message { - Message { - role, - content: Some(MessageContent::Text(content.to_string())), - name: None, - tool_calls: None, - tool_call_id: None, - } - } - fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain { AgentFilterChain { id: "test-agent".to_string(), @@ -954,12 +642,19 @@ mod tests { let agent_map = HashMap::new(); let request_headers = HeaderMap::new(); - let messages = vec![create_test_message(Role::User, "Hello")]; + let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]}); + let raw_bytes = serde_json::to_vec(&body).unwrap(); let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]); let result = processor - .process_filter_chain(&messages, &pipeline, &agent_map, &request_headers) + .process_raw_filter_chain( + &raw_bytes, + &pipeline, + &agent_map, + &request_headers, + "/v1/chat/completions", + ) .await; assert!(result.is_err()); @@ -989,11 +684,12 @@ mod tests { agent_type: None, }; - let messages = vec![create_test_message(Role::User, "Hello")]; + let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]}); + let raw_bytes = serde_json::to_vec(&body).unwrap(); let request_headers = HeaderMap::new(); let result = processor - .execute_mcp_filter(&messages, &agent, &request_headers) + .execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions") .await; match result { @@ -1028,11 +724,12 @@ mod tests { agent_type: None, }; - let messages = vec![create_test_message(Role::User, "Ping")]; + let body = serde_json::json!({"messages": [{"role": "user", "content": "Ping"}]}); + let raw_bytes = serde_json::to_vec(&body).unwrap(); let request_headers = HeaderMap::new(); let result = processor - .execute_mcp_filter(&messages, &agent, &request_headers) + .execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions") .await; match result { @@ -1080,11 +777,12 @@ mod tests { agent_type: None, }; - let messages = vec![create_test_message(Role::User, "Hi")]; + let body = serde_json::json!({"messages": [{"role": "user", "content": "Hi"}]}); + let raw_bytes = serde_json::to_vec(&body).unwrap(); let request_headers = HeaderMap::new(); let result = processor - .execute_mcp_filter(&messages, &agent, &request_headers) + .execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions") .await; match result { diff --git a/demos/filter_chains/mcp_filter/src/rag_agent/input_guards.py b/demos/filter_chains/mcp_filter/src/rag_agent/input_guards.py index 4067e143..c4677cbb 100644 --- a/demos/filter_chains/mcp_filter/src/rag_agent/input_guards.py +++ b/demos/filter_chains/mcp_filter/src/rag_agent/input_guards.py @@ -1,8 +1,5 @@ -import asyncio import json -import time -from typing import List, Optional, Dict, Any -import uuid +from typing import Optional, Dict, Any from fastmcp.exceptions import ToolError from openai import AsyncOpenAI import os diff --git a/demos/filter_chains/mcp_filter/src/rag_agent/query_rewriter.py b/demos/filter_chains/mcp_filter/src/rag_agent/query_rewriter.py index 8481f3a7..d175d123 100644 --- a/demos/filter_chains/mcp_filter/src/rag_agent/query_rewriter.py +++ b/demos/filter_chains/mcp_filter/src/rag_agent/query_rewriter.py @@ -1,8 +1,4 @@ -import asyncio -import json -import time -from typing import List, Optional, Dict, Any -import uuid +from typing import List, Optional from openai import AsyncOpenAI import os import logging