diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index a0ff4499..047204ca 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -111,7 +111,13 @@ mod integration_tests { let headers = HeaderMap::new(); let result = pipeline_processor - .process_filter_chain(&request.messages, &test_pipeline, &agent_map, &headers) + .process_filter_chain( + &request.messages, + &test_pipeline, + &agent_map, + &headers, + None, + ) .await; println!("Pipeline processing result: {:?}", result); diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index 8d4c1fe8..00225ee8 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -158,139 +158,62 @@ impl PipelineProcessor { Ok(chat_history_updated) } - /// Send request to a specific agent and return the response content - async fn execute_filter( - &mut self, - messages: &[Message], - agent: &Agent, + /// Build common MCP headers for requests + fn build_mcp_headers( + &self, request_headers: &HeaderMap, - ) -> Result, PipelineError> { - let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) { - session_id.clone() - } else { - let session_id = self.get_new_session_id(&agent.id).await; - self.agent_id_session_map - .insert(agent.id.clone(), session_id.clone()); - session_id - }; - - // let mut request = original_request.clone(); - // request.messages = messages.to_vec(); - - let tool_name = agent.tool.as_deref().unwrap_or(&agent.id); - - let arguments = serde_json::json!({ - "messages": messages - }); - - let params = serde_json::json!({ - "name": tool_name, - "arguments": arguments - }); - - let json_rpc_request = JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: JsonRpcId::String(Uuid::new_v4().to_string()), - method: "tools/call".to_string(), - params: Some(serde_json::from_value(params)?), - }; - - let request_body = serde_json::to_string(&json_rpc_request)?; - info!("Sending request to agent {}", agent.id); - info!("Request body: {}", request_body); - - // Pretty print for debugging - let pretty_body = serde_json::to_string_pretty(&json_rpc_request)?; - info!("Request body (pretty):\n{}", pretty_body); - - let mut agent_headers = request_headers.clone(); - info!( - "Using MCP session ID {} for agent {}", - mcp_session_id, agent.id - ); - - // Log all headers being sent - info!("Headers being sent:"); - for (key, value) in agent_headers.iter() { - info!(" {}: {:?}", key, value); - } - - agent_headers.insert( - "mcp-session-id", - hyper::header::HeaderValue::from_str(&mcp_session_id).unwrap(), - ); - agent_headers.remove(hyper::header::CONTENT_LENGTH); - agent_headers.insert( + agent_id: &str, + session_id: Option<&str>, + ) -> Result { + let mut headers = request_headers.clone(); + headers.remove(hyper::header::CONTENT_LENGTH); + + headers.insert( ARCH_UPSTREAM_HOST_HEADER, - hyper::header::HeaderValue::from_str(&agent.id) - .map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?, + hyper::header::HeaderValue::from_str(agent_id) + .map_err(|_| PipelineError::AgentNotFound(agent_id.to_string()))?, ); - agent_headers.insert( + headers.insert( ENVOY_RETRY_HEADER, hyper::header::HeaderValue::from_str("3").unwrap(), ); - agent_headers.insert( + headers.insert( "Accept", hyper::header::HeaderValue::from_static("application/json, text/event-stream"), ); - agent_headers.insert( + headers.insert( "Content-Type", hyper::header::HeaderValue::from_static("application/json"), ); - info!("Final headers being sent:"); - for (key, value) in agent_headers.iter() { - info!(" {}: {:?}", key, value); + if let Some(sid) = session_id { + headers.insert( + "mcp-session-id", + hyper::header::HeaderValue::from_str(sid).unwrap(), + ); } - let response = self - .client - .post(format!("{}/mcp", self.url)) - .headers(agent_headers) - .body(request_body) - .send() - .await?; + Ok(headers) + } - let http_status = response.status(); - let response_bytes = response.bytes().await?; - - if !http_status.is_success() { - let error_body = String::from_utf8_lossy(&response_bytes).to_string(); - - if http_status.is_client_error() { - // 4xx errors - cascade back to developer - return Err(PipelineError::ClientError { - agent: agent.id.clone(), - status: http_status.as_u16(), - body: error_body, - }); - } else if http_status.is_server_error() { - // 5xx errors - server/agent error - return Err(PipelineError::ServerError { - agent: agent.id.clone(), - status: http_status.as_u16(), - body: error_body, - }); - } - } - - info!( - "response bytes in str: {}", - String::from_utf8_lossy(&response_bytes) - ); - - let response_str = String::from_utf8_lossy(&response_bytes); + /// Parse SSE formatted response and extract JSON-RPC data + fn parse_sse_response(&self, response_bytes: &[u8], agent_id: &str) -> Result { + let response_str = String::from_utf8_lossy(response_bytes); let lines: Vec<&str> = response_str.lines().collect(); // Validate SSE format: first line should be "event: message" if lines.is_empty() || lines[0] != "event: message" { - warn!("Invalid SSE response format from agent {}: expected 'event: message' as first line, got: {:?}", agent.id, lines.first()); + warn!( + "Invalid SSE response format from agent {}: expected 'event: message' as first line, got: {:?}", + agent_id, + lines.first() + ); return Err(PipelineError::NoContentInResponse(format!( "Invalid SSE response format from agent {}: expected 'event: message' as first line", - agent.id + agent_id ))); } @@ -304,38 +227,137 @@ impl PipelineProcessor { if data_lines.len() != 1 { warn!( "Expected exactly one 'data:' line from agent {}, found {}", - agent.id, + agent_id, data_lines.len() ); return Err(PipelineError::NoContentInResponse(format!( "Expected exactly one 'data:' line from agent {}, found {}", - agent.id, + agent_id, data_lines.len() ))); } - let data_chunk = &data_lines[0][6..]; // Skip "data: " prefix + // Skip "data: " prefix + Ok(data_lines[0][6..].to_string()) + } - let response: JsonRpcResponse = serde_json::from_str(data_chunk)?; + /// Send an MCP request and return the response + async fn send_mcp_request( + &self, + json_rpc_request: &JsonRpcRequest, + headers: HeaderMap, + agent_id: &str, + ) -> Result { + let request_body = serde_json::to_string(json_rpc_request)?; + + debug!("Sending MCP request to agent {}: {}", agent_id, request_body); + + let response = self + .client + .post(format!("{}/mcp", self.url)) + .headers(headers) + .body(request_body) + .send() + .await?; + + Ok(response) + } + + /// Build a tools/call JSON-RPC request + fn build_tool_call_request( + &self, + tool_name: &str, + messages: &[Message], + ) -> Result { + let arguments = serde_json::json!({ + "messages": messages + }); + + let params = serde_json::json!({ + "name": tool_name, + "arguments": arguments + }); + + Ok(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: JsonRpcId::String(Uuid::new_v4().to_string()), + method: "tools/call".to_string(), + params: Some(serde_json::from_value(params)?), + }) + } + + /// Send request to a specific agent and return the response content + async fn execute_filter( + &mut self, + messages: &[Message], + agent: &Agent, + request_headers: &HeaderMap, + ) -> Result, PipelineError> { + // Get or create MCP session + let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) { + session_id.clone() + } else { + let session_id = self.get_new_session_id(&agent.id).await; + self.agent_id_session_map + .insert(agent.id.clone(), session_id.clone()); + session_id + }; + + info!("Using MCP session ID {} for agent {}", mcp_session_id, agent.id); + + // Build JSON-RPC request + let tool_name = agent.tool.as_deref().unwrap_or(&agent.id); + let json_rpc_request = self.build_tool_call_request(tool_name, messages)?; + + // Build headers + let agent_headers = self.build_mcp_headers( + request_headers, + &agent.id, + Some(&mcp_session_id), + )?; + + // Send request + let response = self.send_mcp_request(&json_rpc_request, agent_headers, &agent.id).await?; + let http_status = response.status(); + let response_bytes = response.bytes().await?; + + // Handle HTTP errors + if !http_status.is_success() { + let error_body = String::from_utf8_lossy(&response_bytes).to_string(); + return Err(if http_status.is_client_error() { + PipelineError::ClientError { + agent: agent.id.clone(), + status: http_status.as_u16(), + body: error_body, + } + } else { + PipelineError::ServerError { + agent: agent.id.clone(), + status: http_status.as_u16(), + body: error_body, + } + }); + } + + info!("Response from agent {}: {}", agent.id, String::from_utf8_lossy(&response_bytes)); + + // Parse SSE response + let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?; + let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?; let response_result = response .result .ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?; - // check if error field is set in response result - let mcp_error = response_result - .get("isError") - .and_then(|v| v.as_bool()) - .unwrap_or(false); - - if mcp_error { + // Check if error field is set in response result + if response_result.get("isError").and_then(|v| v.as_bool()).unwrap_or(false) { let error_message = response_result .get("content") .and_then(|v| v.as_array()) - .and_then(|arr| arr.get(0)) + .and_then(|arr| arr.first()) .and_then(|v| v.get("text")) .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - .unwrap_or("unknown_error".to_string()); + .unwrap_or("unknown_error") + .to_string(); return Err(PipelineError::ClientError { agent: agent.id.clone(), @@ -344,11 +366,10 @@ impl PipelineProcessor { }); } + // Extract structured content and parse messages let response_json = response_result .get("structuredContent") .ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?; - // Parse the response as JSON to extract the content - // let response_json: serde_json::Value = serde_json::from_slice(&response_bytes)?; let messages: Vec = response_json .get("result") @@ -362,8 +383,9 @@ impl PipelineProcessor { Ok(messages) } - async fn get_new_session_id(&self, agent_id: &str) -> String { - let initialize_request = JsonRpcRequest { + /// Build an initialize JSON-RPC request + fn build_initialize_request(&self) -> JsonRpcRequest { + JsonRpcRequest { jsonrpc: "2.0".to_string(), id: JsonRpcId::Number(1), method: "initialize".to_string(), @@ -383,26 +405,47 @@ impl PipelineProcessor { ); params }), + } + } + + /// Send initialized notification after session creation + async fn send_initialized_notification(&self, agent_id: &str, session_id: &str) -> Result<(), PipelineError> { + let initialized_notification = JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, }; - let request_body = serde_json::to_string(&initialize_request).unwrap(); - - info!("Initializing MCP session for agent {}", agent_id); - info!("Initialize request body: {}", request_body); + let notification_body = serde_json::to_string(&initialized_notification)?; + debug!("Sending initialized notification for agent {}", agent_id); + let headers = self.build_mcp_headers(&HeaderMap::new(), agent_id, Some(session_id))?; + let response = self .client .post(format!("{}/mcp", self.url)) - .header("Content-Type", "application/json") - .header("Accept", "application/json, text/event-stream") - .header(ARCH_UPSTREAM_HOST_HEADER, agent_id) - .body(request_body) + .headers(headers) + .body(notification_body) .send() + .await?; + + info!("Initialized notification response status: {}", response.status()); + Ok(()) + } + + async fn get_new_session_id(&self, agent_id: &str) -> String { + info!("Initializing MCP session for agent {}", agent_id); + + let initialize_request = self.build_initialize_request(); + let headers = self.build_mcp_headers(&HeaderMap::new(), agent_id, None) + .expect("Failed to build headers for initialization"); + + let response = self + .send_mcp_request(&initialize_request, headers, agent_id) .await .expect("Failed to initialize MCP session"); info!("Initialize response status: {}", response.status()); - info!("Initialize response headers: {:?}", response.headers()); let session_id = response .headers() @@ -411,39 +454,13 @@ impl PipelineProcessor { .expect("No mcp-session-id in response") .to_string(); - info!( - "Created new MCP session for agent {}: {}", - agent_id, session_id - ); + info!("Created new MCP session for agent {}: {}", agent_id, session_id); - // Send initialized notification (without id field per JSON-RPC 2.0 spec) - let initialized_notification = JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/initialized".to_string(), - params: None, - }; - - let notification_body = serde_json::to_string(&initialized_notification).unwrap(); - - info!("Sending initialized notification: {}", notification_body); - - let notif_response = self - .client - .post(format!("{}/mcp", self.url)) - .header("Content-Type", "application/json") - .header("Accept", "application/json, text/event-stream") - .header("mcp-session-id", &session_id) - .header(ARCH_UPSTREAM_HOST_HEADER, agent_id) - .body(notification_body) - .send() + // Send initialized notification + self.send_initialized_notification(agent_id, &session_id) .await .expect("Failed to send initialized notification"); - info!( - "Initialized notification response status: {}", - notif_response.status() - ); - session_id } @@ -490,6 +507,7 @@ impl PipelineProcessor { mod tests { use super::*; use hermesllm::apis::openai::{Message, MessageContent, Role}; + use mockito::Server; use std::collections::HashMap; fn create_test_message(role: Role, content: &str) -> Message { @@ -538,4 +556,134 @@ mod tests { assert!(result.is_err()); matches!(result.unwrap_err(), PipelineError::AgentNotFound(_)); } + + #[tokio::test] + async fn test_execute_filter_http_status_error() { + let mut server = Server::new_async().await; + let _m = server + .mock("POST", "/mcp") + .with_status(500) + .with_body("boom") + .create(); + + let server_url = server.url(); + let mut processor = PipelineProcessor::new(server_url.clone()); + processor + .agent_id_session_map + .insert("agent-1".to_string(), "session-1".to_string()); + + let agent = Agent { + id: "agent-1".to_string(), + transport: None, + tool: None, + url: server_url, + kind: None, + }; + + let messages = vec![create_test_message(Role::User, "Hello")]; + let request_headers = HeaderMap::new(); + + let result = processor + .execute_filter(&messages, &agent, &request_headers) + .await; + + match result { + Err(PipelineError::ServerError { status, body, .. }) => { + assert_eq!(status, 500); + assert_eq!(body, "boom"); + } + _ => panic!("Expected server error for 500 status"), + } + } + + #[tokio::test] + async fn test_execute_filter_http_client_error() { + let mut server = Server::new_async().await; + let _m = server + .mock("POST", "/mcp") + .with_status(400) + .with_body("bad request") + .create(); + + let server_url = server.url(); + let mut processor = PipelineProcessor::new(server_url.clone()); + processor + .agent_id_session_map + .insert("agent-3".to_string(), "session-3".to_string()); + + let agent = Agent { + id: "agent-3".to_string(), + transport: None, + tool: None, + url: server_url, + kind: None, + }; + + let messages = vec![create_test_message(Role::User, "Ping")]; + let request_headers = HeaderMap::new(); + + let result = processor + .execute_filter(&messages, &agent, &request_headers) + .await; + + match result { + Err(PipelineError::ClientError { status, body, .. }) => { + assert_eq!(status, 400); + assert_eq!(body, "bad request"); + } + _ => panic!("Expected client error for 400 status"), + } + } + + #[tokio::test] + async fn test_execute_filter_mcp_error_flag() { + let rpc_body = serde_json::json!({ + "jsonrpc": "2.0", + "id": "1", + "result": { + "isError": true, + "content": [ + { "text": "bad tool call" } + ] + } + }); + + let sse_body = format!("event: message\ndata: {}\n\n", rpc_body.to_string()); + + let mut server = Server::new_async().await; + let _m = server + .mock("POST", "/mcp") + .with_status(200) + .with_body(sse_body) + .create(); + + let server_url = server.url(); + let mut processor = PipelineProcessor::new(server_url.clone()); + processor + .agent_id_session_map + .insert("agent-2".to_string(), "session-2".to_string()); + + let agent = Agent { + id: "agent-2".to_string(), + transport: None, + tool: None, + url: server_url, + kind: None, + }; + + let messages = vec![create_test_message(Role::User, "Hi")]; + let request_headers = HeaderMap::new(); + + let result = processor + .execute_filter(&messages, &agent, &request_headers) + .await; + + match result { + Err(PipelineError::ClientError { status, body, .. }) => { + assert_eq!(status, 200); + assert_eq!(body, "bad tool call"); + } + _ => panic!("Expected client error when isError flag is set"), + } + } }