diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index e0a6eef1..5190fecf 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -85,7 +85,7 @@ properties: type: string default: type: boolean - filter_chain: + input_filters: type: array items: type: string diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index 513e0ef2..44bb3235 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -332,15 +332,38 @@ async fn handle_agent_chat_inner( "processing agent" ); - // Process the filter chain - let chat_history = pipeline_processor - .process_filter_chain( - ¤t_messages, - selected_agent, - &agent_map, - &request_headers, - ) - .await?; + // Process input filters — serialize current request as OpenAI chat completions body, + // pass raw bytes through each filter, then extract updated messages from the result. + let chat_history = if selected_agent + .input_filters + .as_ref() + .map(|f| !f.is_empty()) + .unwrap_or(false) + { + let filter_body = serde_json::json!({ + "model": client_request.model(), + "messages": current_messages, + }); + let filter_bytes = + serde_json::to_vec(&filter_body).map_err(PipelineError::ParseError)?; + + let filtered_bytes = pipeline_processor + .process_raw_filter_chain( + &filter_bytes, + selected_agent, + &agent_map, + &request_headers, + "/v1/chat/completions", + ) + .await?; + + let filtered_body: serde_json::Value = + serde_json::from_slice(&filtered_bytes).map_err(PipelineError::ParseError)?; + serde_json::from_value(filtered_body["messages"].clone()) + .map_err(PipelineError::ParseError)? + } else { + current_messages.clone() + }; // Get agent details and invoke let agent = agent_map.get(&agent_name).unwrap(); diff --git a/crates/brightstaff/src/handlers/agent_selector.rs b/crates/brightstaff/src/handlers/agent_selector.rs index 0c9b018e..2341e156 100644 --- a/crates/brightstaff/src/handlers/agent_selector.rs +++ b/crates/brightstaff/src/handlers/agent_selector.rs @@ -187,7 +187,7 @@ mod tests { id: name.to_string(), description: Some(description.to_string()), default: Some(is_default), - filter_chain: Some(vec![name.to_string()]), + input_filters: Some(vec![name.to_string()]), } } diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index c3153d3d..e82bb8b2 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -64,7 +64,7 @@ mod tests { let agent_pipeline = AgentFilterChain { id: "terminal-agent".to_string(), - filter_chain: Some(vec![ + input_filters: Some(vec![ "filter-agent".to_string(), "terminal-agent".to_string(), ]), @@ -110,7 +110,7 @@ mod tests { // Create a pipeline with empty filter chain to avoid network calls let test_pipeline = AgentFilterChain { id: "terminal-agent".to_string(), - filter_chain: Some(vec![]), // Empty filter chain - no network calls needed + input_filters: Some(vec![]), // Empty filter chain - no network calls needed description: None, default: None, }; diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index 487a7b22..9382a66d 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -279,7 +279,7 @@ async fn llm_chat_inner( id: "model_listener".to_string(), default: None, description: None, - filter_chain: Some(fc.clone()), + input_filters: Some(fc.clone()), }; let mut pipeline_processor = PipelineProcessor::default(); diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index 776a88f5..c01d7bdc 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -84,7 +84,7 @@ impl PipelineProcessor { // #[instrument( // skip(self, chat_history, agent_filter_chain, agent_map, request_headers), // fields( - // filter_count = agent_filter_chain.filter_chain.as_ref().map(|fc| fc.len()).unwrap_or(0), + // filter_count = agent_filter_chain.input_filters.as_ref().map(|fc| fc.len()).unwrap_or(0), // message_count = chat_history.len() // ) // )] @@ -99,7 +99,7 @@ impl PipelineProcessor { let mut chat_history_updated = chat_history.to_vec(); // If filter_chain is None or empty, proceed without filtering - let filter_chain = match agent_filter_chain.filter_chain.as_ref() { + let filter_chain = match agent_filter_chain.input_filters.as_ref() { Some(fc) if !fc.is_empty() => fc, _ => return Ok(chat_history_updated), }; @@ -283,6 +283,30 @@ impl PipelineProcessor { }) } + /// Build a tools/call JSON-RPC request with a full body dict and path hint. + /// Used by execute_mcp_filter_raw so MCP tools receive the same contract as HTTP filters. + fn build_tool_call_request_with_body( + &self, + tool_name: &str, + body: &serde_json::Value, + path: &str, + ) -> Result { + let mut arguments = HashMap::new(); + arguments.insert("body".to_string(), serde_json::to_value(body)?); + arguments.insert("path".to_string(), serde_json::to_value(path)?); + + let mut params = HashMap::new(); + params.insert("name".to_string(), serde_json::to_value(tool_name)?); + params.insert("arguments".to_string(), serde_json::to_value(arguments)?); + + Ok(JsonRpcRequest { + jsonrpc: JSON_RPC_VERSION.to_string(), + id: JsonRpcId::String(Uuid::new_v4().to_string()), + method: TOOL_CALL_METHOD.to_string(), + params: Some(params), + }) + } + /// Send request to a specific agent and return the response content #[instrument( skip(self, messages, agent, request_headers), @@ -406,6 +430,106 @@ impl PipelineProcessor { Ok(messages) } + /// Like execute_mcp_filter but passes the full raw body dict + path hint as MCP tool arguments. + /// The MCP tool receives (body: dict, path: str) and returns the modified body dict. + async fn execute_mcp_filter_raw( + &mut self, + raw_bytes: &[u8], + agent: &Agent, + request_headers: &HeaderMap, + request_path: &str, + ) -> Result { + set_service_name(operation_component::AGENT_FILTER); + use opentelemetry::trace::get_active_span; + get_active_span(|span| { + span.update_name(format!("execute_mcp_filter_raw ({})", agent.id)); + }); + + let body: serde_json::Value = + serde_json::from_slice(raw_bytes).map_err(PipelineError::ParseError)?; + + let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) { + session_id.clone() + } else { + let session_id = self.get_new_session_id(&agent.id, request_headers).await; + self.agent_id_session_map + .insert(agent.id.clone(), session_id.clone()); + session_id + }; + + info!( + "Using MCP session ID {} for agent {}", + mcp_session_id, agent.id + ); + + let tool_name = agent.tool.as_deref().unwrap_or(&agent.id); + let json_rpc_request = + self.build_tool_call_request_with_body(tool_name, &body, request_path)?; + + let agent_headers = + self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id))?; + + let response = self + .send_mcp_request(&json_rpc_request, &agent_headers, &agent.id) + .await?; + let http_status = response.status(); + let response_bytes = response.bytes().await?; + + if !http_status.is_success() { + let error_body = String::from_utf8_lossy(&response_bytes).to_string(); + return Err(if http_status.is_client_error() { + PipelineError::ClientError { + agent: agent.id.clone(), + status: http_status.as_u16(), + body: error_body, + } + } else { + PipelineError::ServerError { + agent: agent.id.clone(), + status: http_status.as_u16(), + body: error_body, + } + }); + } + + let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?; + let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?; + let response_result = response + .result + .ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?; + + if response_result + .get("isError") + .and_then(|v| v.as_bool()) + .unwrap_or(false) + { + let error_message = response_result + .get("content") + .and_then(|v| v.as_array()) + .and_then(|arr| arr.first()) + .and_then(|v| v.get("text")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown_error") + .to_string(); + + return Err(PipelineError::ClientError { + agent: agent.id.clone(), + status: hyper::StatusCode::BAD_REQUEST.as_u16(), + body: error_message, + }); + } + + let result = response_result + .get("structuredContent") + .and_then(|v| v.get("result")) + .cloned() + .ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?; + + Ok(Bytes::from( + serde_json::to_vec(&result).map_err(PipelineError::ParseError)?, + )) + } + /// Build an initialize JSON-RPC request fn build_initialize_request(&self) -> JsonRpcRequest { JsonRpcRequest { @@ -708,7 +832,7 @@ impl PipelineProcessor { request_headers: &HeaderMap, request_path: &str, ) -> Result { - let filter_chain = match agent_filter_chain.filter_chain.as_ref() { + let filter_chain = match agent_filter_chain.input_filters.as_ref() { Some(fc) if !fc.is_empty() => fc, _ => return Ok(Bytes::copy_from_slice(raw_bytes)), }; @@ -722,16 +846,22 @@ impl PipelineProcessor { .get(agent_name) .ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?; + let agent_type = agent.agent_type.as_deref().unwrap_or("mcp"); info!( agent = %agent_name, url = %agent.url, + agent_type = %agent_type, bytes_len = current_bytes.len(), "executing raw filter" ); - current_bytes = self - .execute_raw_filter(¤t_bytes, agent, request_headers, request_path) - .await?; + current_bytes = if agent_type == "mcp" { + self.execute_mcp_filter_raw(¤t_bytes, agent, request_headers, request_path) + .await? + } else { + self.execute_raw_filter(¤t_bytes, agent, request_headers, request_path) + .await? + }; info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed"); } @@ -812,7 +942,7 @@ mod tests { fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain { AgentFilterChain { id: "test-agent".to_string(), - filter_chain: Some(agents.iter().map(|s| s.to_string()).collect()), + input_filters: Some(agents.iter().map(|s| s.to_string()).collect()), description: None, default: None, } diff --git a/crates/brightstaff/src/handlers/utils.rs b/crates/brightstaff/src/handlers/utils.rs index 18a4adca..2abd1b69 100644 --- a/crates/brightstaff/src/handlers/utils.rs +++ b/crates/brightstaff/src/handlers/utils.rs @@ -308,7 +308,7 @@ where id: "output_filter".to_string(), default: None, description: None, - filter_chain: Some(output_filters), + input_filters: Some(output_filters), }; while let Some(item) = byte_stream.next().await { diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 30187dd8..6bdaa01e 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -27,7 +27,7 @@ pub struct AgentFilterChain { pub id: String, pub default: Option, pub description: Option, - pub filter_chain: Option>, + pub input_filters: Option>, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] diff --git a/demos/filter_chains/http_filter/config.yaml b/demos/filter_chains/http_filter/config.yaml index 014a141a..7d7ba1bb 100644 --- a/demos/filter_chains/http_filter/config.yaml +++ b/demos/filter_chains/http_filter/config.yaml @@ -42,7 +42,7 @@ listeners: agents: - id: rag_agent description: virtual assistant for retrieval augmented generation tasks - filter_chain: + input_filters: - input_guards - query_rewriter - context_builder diff --git a/demos/filter_chains/http_filter/src/rag_agent/context_builder.py b/demos/filter_chains/http_filter/src/rag_agent/context_builder.py index 5da1c43d..75b73ef0 100644 --- a/demos/filter_chains/http_filter/src/rag_agent/context_builder.py +++ b/demos/filter_chains/http_filter/src/rag_agent/context_builder.py @@ -195,11 +195,11 @@ async def augment_query_with_context( load_knowledge_base() -@app.post("/") -async def context_builder( - messages: List[ChatMessage], request: Request -) -> List[ChatMessage]: +@app.post("/{path:path}") +async def context_builder(path: str, request: Request) -> dict: """MCP tool that augments user queries with relevant context from the knowledge base.""" + body = await request.json() + messages = [ChatMessage(**m) for m in body.get("messages", [])] logger.info(f"Received chat completion request with {len(messages)} messages") # Get traceparent header from MCP request @@ -219,8 +219,7 @@ async def context_builder( messages, traceparent_header, request_id ) - # Return as dict to minimize text serialization - return [{"role": msg.role, "content": msg.content} for msg in updated_messages] + return {**body, "messages": [{"role": msg.role, "content": msg.content} for msg in updated_messages]} # Register MCP tool only if mcp is available diff --git a/demos/filter_chains/http_filter/src/rag_agent/input_guards.py b/demos/filter_chains/http_filter/src/rag_agent/input_guards.py index 3b2414ad..98d481d6 100644 --- a/demos/filter_chains/http_filter/src/rag_agent/input_guards.py +++ b/demos/filter_chains/http_filter/src/rag_agent/input_guards.py @@ -126,14 +126,14 @@ Respond in JSON format: # @mcp.tool -@app.post("/") -async def input_guards( - messages: List[ChatMessage], request: Request -) -> List[ChatMessage]: +@app.post("/{path:path}") +async def input_guards(path: str, request: Request) -> dict: """Input guard that validates queries are within TechCorp's domain. If the query is out of scope, replaces the user message with a rejection notice. """ + body = await request.json() + messages = [ChatMessage(**m) for m in body.get("messages", [])] logger.info(f"Received request with {len(messages)} messages") # Get traceparent header from HTTP request using FastMCP's dependency function @@ -164,7 +164,7 @@ async def input_guards( ) logger.info("Query validation passed - forwarding to next filter") - return messages + return body @app.get("/health") diff --git a/demos/filter_chains/http_filter/src/rag_agent/query_rewriter.py b/demos/filter_chains/http_filter/src/rag_agent/query_rewriter.py index 3e05836c..43824473 100644 --- a/demos/filter_chains/http_filter/src/rag_agent/query_rewriter.py +++ b/demos/filter_chains/http_filter/src/rag_agent/query_rewriter.py @@ -81,11 +81,11 @@ Return only the rewritten query, nothing else.""" return "" -@app.post("/") -async def query_rewriter_http( - messages: List[ChatMessage], request: Request -) -> List[ChatMessage]: +@app.post("/{path:path}") +async def query_rewriter_http(path: str, request: Request) -> dict: """HTTP filter endpoint used by Plano (type: http).""" + body = await request.json() + messages = [ChatMessage(**m) for m in body.get("messages", [])] logger.info(f"Received request with {len(messages)} messages") traceparent_header = request.headers.get("traceparent") @@ -99,25 +99,20 @@ async def query_rewriter_http( rewritten_query = await rewrite_query_with_plano( messages, traceparent_header, request_id ) - # Create updated messages with the rewritten query - updated_messages = messages.copy() # Find and update the last user message with the rewritten query + updated_messages = [m.model_dump() for m in messages] for i in range(len(updated_messages) - 1, -1, -1): - if updated_messages[i].role == "user": - original_query = updated_messages[i].content - updated_messages[i] = ChatMessage(role="user", content=rewritten_query) + if updated_messages[i]["role"] == "user": + original_query = updated_messages[i]["content"] + updated_messages[i]["content"] = rewritten_query logger.info( f"Updated user query from '{original_query}' to '{rewritten_query}'" ) break - updated_messages_data = [ - {"role": msg.role, "content": msg.content} for msg in updated_messages - ] - updated_messages = [ChatMessage(**msg) for msg in updated_messages_data] logger.info("Returning rewritten chat completion response") - return updated_messages + return {**body, "messages": updated_messages} @app.get("/health") diff --git a/demos/filter_chains/mcp_filter/config.yaml b/demos/filter_chains/mcp_filter/config.yaml index e07a49dc..1270a2f3 100644 --- a/demos/filter_chains/mcp_filter/config.yaml +++ b/demos/filter_chains/mcp_filter/config.yaml @@ -39,7 +39,7 @@ listeners: agents: - id: rag_agent description: virtual assistant for retrieval augmented generation tasks - filter_chain: + input_filters: - input_guards - query_rewriter - context_builder diff --git a/demos/filter_chains/mcp_filter/src/rag_agent/context_builder.py b/demos/filter_chains/mcp_filter/src/rag_agent/context_builder.py index e50bb76c..df60f70c 100644 --- a/demos/filter_chains/mcp_filter/src/rag_agent/context_builder.py +++ b/demos/filter_chains/mcp_filter/src/rag_agent/context_builder.py @@ -195,9 +195,14 @@ async def augment_query_with_context( load_knowledge_base() -async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]: - """MCP tool that augments user queries with relevant context from the knowledge base.""" - logger.info(f"Received chat completion request with {len(messages)} messages") +async def context_builder(body: dict, path: str) -> dict: + """MCP tool that augments user queries with relevant context from the knowledge base. + + Receives the full request body dict and the API path hint (e.g. /v1/chat/completions). + Returns the body with the last user message augmented with retrieved context. + """ + messages = [ChatMessage(**m) for m in body.get("messages", [])] + logger.info(f"Received request with {len(messages)} messages at path {path}") # Get traceparent header from MCP request headers = get_http_headers() @@ -215,8 +220,7 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]: messages, traceparent_header, request_id ) - # Return as dict to minimize text serialization - return [{"role": msg.role, "content": msg.content} for msg in updated_messages] + return {**body, "messages": [{"role": msg.role, "content": msg.content} for msg in updated_messages]} # Register MCP tool only if mcp is available diff --git a/demos/filter_chains/mcp_filter/src/rag_agent/input_guards.py b/demos/filter_chains/mcp_filter/src/rag_agent/input_guards.py index 607ff035..4067e143 100644 --- a/demos/filter_chains/mcp_filter/src/rag_agent/input_guards.py +++ b/demos/filter_chains/mcp_filter/src/rag_agent/input_guards.py @@ -3,13 +3,12 @@ import json import time from typing import List, Optional, Dict, Any import uuid -from fastapi import FastAPI, Depends, Request from fastmcp.exceptions import ToolError from openai import AsyncOpenAI import os import logging -from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage +from .api import ChatMessage from . import mcp from fastmcp.server.dependencies import get_http_headers @@ -30,8 +29,6 @@ plano_client = AsyncOpenAI( api_key="EMPTY", # Plano doesn't require a real API key ) -app = FastAPI() - async def validate_query_scope( messages: List[ChatMessage], @@ -127,12 +124,14 @@ Respond in JSON format: @mcp.tool -async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]: +async def input_guards(body: dict, path: str) -> dict: """Input guard that validates queries are within TechCorp's domain. - If the query is out of scope, replaces the user message with a rejection notice. + Receives the full request body dict and the API path hint (e.g. /v1/chat/completions). + If the query is out of scope, raises a ToolError to block the request. """ - logger.info(f"Received request with {len(messages)} messages") + messages = [ChatMessage(**m) for m in body.get("messages", [])] + logger.info(f"Received request with {len(messages)} messages at path {path}") # Get traceparent header from HTTP request using FastMCP's dependency function headers = get_http_headers() @@ -153,9 +152,8 @@ async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]: reason = validation_result.get("reason", "Query is outside TechCorp's domain") logger.warning(f"Query rejected: {reason}") - # Throw ToolError error_message = f"I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. {reason}\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support." raise ToolError(error_message) logger.info("Query validation passed - forwarding to next filter") - return messages + return body diff --git a/demos/filter_chains/mcp_filter/src/rag_agent/query_rewriter.py b/demos/filter_chains/mcp_filter/src/rag_agent/query_rewriter.py index 5939170e..8481f3a7 100644 --- a/demos/filter_chains/mcp_filter/src/rag_agent/query_rewriter.py +++ b/demos/filter_chains/mcp_filter/src/rag_agent/query_rewriter.py @@ -3,12 +3,11 @@ import json import time from typing import List, Optional, Dict, Any import uuid -from fastapi import FastAPI, Depends, Request from openai import AsyncOpenAI import os import logging -from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage +from .api import ChatMessage from . import mcp from fastmcp.server.dependencies import get_http_headers @@ -29,9 +28,6 @@ plano_client = AsyncOpenAI( api_key="EMPTY", # Plano doesn't require a real API key ) -app = FastAPI() - - async def rewrite_query_with_plano( messages: List[ChatMessage], traceparent_header: str, @@ -87,12 +83,14 @@ async def rewrite_query_with_plano( return "" -async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]: - """Chat completions endpoint that rewrites the last user query using Plano. +async def query_rewriter(body: dict, path: str) -> dict: + """Rewrites the last user query in the request body using Plano. - Returns a dict with a 'messages' key containing the updated message list. + Receives the full request body dict and the API path hint (e.g. /v1/chat/completions). + Returns the body with the last user message rewritten for better retrieval. """ - logger.info(f"Received chat completion request with {len(messages)} messages") + messages = [ChatMessage(**m) for m in body.get("messages", [])] + logger.info(f"Received request with {len(messages)} messages at path {path}") # Get traceparent header from HTTP request using FastMCP's dependency function headers = get_http_headers() @@ -109,57 +107,20 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]: messages, traceparent_header, request_id ) - # Create updated messages with the rewritten query - updated_messages = messages.copy() - # Find and update the last user message with the rewritten query + updated_messages = [m.model_dump() for m in messages] for i in range(len(updated_messages) - 1, -1, -1): - if updated_messages[i].role == "user": - original_query = updated_messages[i].content - updated_messages[i] = ChatMessage(role="user", content=rewritten_query) + if updated_messages[i]["role"] == "user": logger.info( - f"Updated user query from '{original_query}' to '{rewritten_query}'" + f"Updated user query from '{updated_messages[i]['content']}' to '{rewritten_query}'" ) + updated_messages[i]["content"] = rewritten_query break - # Return as dict to minimize text serialization - return [{"role": msg.role, "content": msg.content} for msg in updated_messages] + logger.info("Returning rewritten chat completion response") + return {**body, "messages": updated_messages} # Register MCP tool only if mcp is available if mcp is not None: mcp.tool()(query_rewriter) - - -@app.post("/") -async def chat_completions_endpoint( - request_messages: List[ChatMessage], request: Request -) -> List[ChatMessage]: - """FastAPI endpoint for chat completions with query rewriting.""" - logger.info( - f"Received /v1/chat/completions request with {len(request_messages)} messages" - ) - - # Extract traceparent header - traceparent_header = request.headers.get("traceparent") - if traceparent_header: - logger.info(f"Received traceparent header: {traceparent_header}") - else: - logger.info("No traceparent header found") - - # Call the query rewriter tool - updated_messages_data = await query_rewriter(request_messages) - - # Convert back to ChatMessage objects - updated_messages = [ChatMessage(**msg) for msg in updated_messages_data] - - logger.info("Returning rewritten chat completion response") - return updated_messages - - -def start_server(host: str = "0.0.0.0", port: int = 10501): - """Start the FastAPI server for query rewriter.""" - import uvicorn - - logger.info(f"Starting Query Rewriter REST server on {host}:{port}") - uvicorn.run(app, host=host, port=port)