add MCP raw filter support: body+path args, update mcp_filter demo handlers

This commit is contained in:
Adil Hafeez 2026-03-17 13:40:31 -07:00
parent d26abbfb9c
commit b88bdb94f2
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
16 changed files with 226 additions and 116 deletions

View file

@ -332,15 +332,38 @@ async fn handle_agent_chat_inner(
"processing agent"
);
// Process the filter chain
let chat_history = pipeline_processor
.process_filter_chain(
&current_messages,
selected_agent,
&agent_map,
&request_headers,
)
.await?;
// Process input filters — serialize current request as OpenAI chat completions body,
// pass raw bytes through each filter, then extract updated messages from the result.
let chat_history = if selected_agent
.input_filters
.as_ref()
.map(|f| !f.is_empty())
.unwrap_or(false)
{
let filter_body = serde_json::json!({
"model": client_request.model(),
"messages": current_messages,
});
let filter_bytes =
serde_json::to_vec(&filter_body).map_err(PipelineError::ParseError)?;
let filtered_bytes = pipeline_processor
.process_raw_filter_chain(
&filter_bytes,
selected_agent,
&agent_map,
&request_headers,
"/v1/chat/completions",
)
.await?;
let filtered_body: serde_json::Value =
serde_json::from_slice(&filtered_bytes).map_err(PipelineError::ParseError)?;
serde_json::from_value(filtered_body["messages"].clone())
.map_err(PipelineError::ParseError)?
} else {
current_messages.clone()
};
// Get agent details and invoke
let agent = agent_map.get(&agent_name).unwrap();

View file

@ -187,7 +187,7 @@ mod tests {
id: name.to_string(),
description: Some(description.to_string()),
default: Some(is_default),
filter_chain: Some(vec![name.to_string()]),
input_filters: Some(vec![name.to_string()]),
}
}

View file

@ -64,7 +64,7 @@ mod tests {
let agent_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: Some(vec![
input_filters: Some(vec![
"filter-agent".to_string(),
"terminal-agent".to_string(),
]),
@ -110,7 +110,7 @@ mod tests {
// Create a pipeline with empty filter chain to avoid network calls
let test_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: Some(vec![]), // Empty filter chain - no network calls needed
input_filters: Some(vec![]), // Empty filter chain - no network calls needed
description: None,
default: None,
};

View file

@ -279,7 +279,7 @@ async fn llm_chat_inner(
id: "model_listener".to_string(),
default: None,
description: None,
filter_chain: Some(fc.clone()),
input_filters: Some(fc.clone()),
};
let mut pipeline_processor = PipelineProcessor::default();

View file

@ -84,7 +84,7 @@ impl PipelineProcessor {
// #[instrument(
// skip(self, chat_history, agent_filter_chain, agent_map, request_headers),
// fields(
// filter_count = agent_filter_chain.filter_chain.as_ref().map(|fc| fc.len()).unwrap_or(0),
// filter_count = agent_filter_chain.input_filters.as_ref().map(|fc| fc.len()).unwrap_or(0),
// message_count = chat_history.len()
// )
// )]
@ -99,7 +99,7 @@ impl PipelineProcessor {
let mut chat_history_updated = chat_history.to_vec();
// If filter_chain is None or empty, proceed without filtering
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
let filter_chain = match agent_filter_chain.input_filters.as_ref() {
Some(fc) if !fc.is_empty() => fc,
_ => return Ok(chat_history_updated),
};
@ -283,6 +283,30 @@ impl PipelineProcessor {
})
}
/// Build a tools/call JSON-RPC request with a full body dict and path hint.
/// Used by execute_mcp_filter_raw so MCP tools receive the same contract as HTTP filters.
fn build_tool_call_request_with_body(
&self,
tool_name: &str,
body: &serde_json::Value,
path: &str,
) -> Result<JsonRpcRequest, PipelineError> {
let mut arguments = HashMap::new();
arguments.insert("body".to_string(), serde_json::to_value(body)?);
arguments.insert("path".to_string(), serde_json::to_value(path)?);
let mut params = HashMap::new();
params.insert("name".to_string(), serde_json::to_value(tool_name)?);
params.insert("arguments".to_string(), serde_json::to_value(arguments)?);
Ok(JsonRpcRequest {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: JsonRpcId::String(Uuid::new_v4().to_string()),
method: TOOL_CALL_METHOD.to_string(),
params: Some(params),
})
}
/// Send request to a specific agent and return the response content
#[instrument(
skip(self, messages, agent, request_headers),
@ -406,6 +430,106 @@ impl PipelineProcessor {
Ok(messages)
}
/// Like execute_mcp_filter but passes the full raw body dict + path hint as MCP tool arguments.
/// The MCP tool receives (body: dict, path: str) and returns the modified body dict.
async fn execute_mcp_filter_raw(
&mut self,
raw_bytes: &[u8],
agent: &Agent,
request_headers: &HeaderMap,
request_path: &str,
) -> Result<Bytes, PipelineError> {
set_service_name(operation_component::AGENT_FILTER);
use opentelemetry::trace::get_active_span;
get_active_span(|span| {
span.update_name(format!("execute_mcp_filter_raw ({})", agent.id));
});
let body: serde_json::Value =
serde_json::from_slice(raw_bytes).map_err(PipelineError::ParseError)?;
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
session_id.clone()
} else {
let session_id = self.get_new_session_id(&agent.id, request_headers).await;
self.agent_id_session_map
.insert(agent.id.clone(), session_id.clone());
session_id
};
info!(
"Using MCP session ID {} for agent {}",
mcp_session_id, agent.id
);
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
let json_rpc_request =
self.build_tool_call_request_with_body(tool_name, &body, request_path)?;
let agent_headers =
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id))?;
let response = self
.send_mcp_request(&json_rpc_request, &agent_headers, &agent.id)
.await?;
let http_status = response.status();
let response_bytes = response.bytes().await?;
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
PipelineError::ClientError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
} else {
PipelineError::ServerError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
});
}
let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?;
let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?;
let response_result = response
.result
.ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?;
if response_result
.get("isError")
.and_then(|v| v.as_bool())
.unwrap_or(false)
{
let error_message = response_result
.get("content")
.and_then(|v| v.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.get("text"))
.and_then(|v| v.as_str())
.unwrap_or("unknown_error")
.to_string();
return Err(PipelineError::ClientError {
agent: agent.id.clone(),
status: hyper::StatusCode::BAD_REQUEST.as_u16(),
body: error_message,
});
}
let result = response_result
.get("structuredContent")
.and_then(|v| v.get("result"))
.cloned()
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
Ok(Bytes::from(
serde_json::to_vec(&result).map_err(PipelineError::ParseError)?,
))
}
/// Build an initialize JSON-RPC request
fn build_initialize_request(&self) -> JsonRpcRequest {
JsonRpcRequest {
@ -708,7 +832,7 @@ impl PipelineProcessor {
request_headers: &HeaderMap,
request_path: &str,
) -> Result<Bytes, PipelineError> {
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
let filter_chain = match agent_filter_chain.input_filters.as_ref() {
Some(fc) if !fc.is_empty() => fc,
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
};
@ -722,16 +846,22 @@ impl PipelineProcessor {
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
let agent_type = agent.agent_type.as_deref().unwrap_or("mcp");
info!(
agent = %agent_name,
url = %agent.url,
agent_type = %agent_type,
bytes_len = current_bytes.len(),
"executing raw filter"
);
current_bytes = self
.execute_raw_filter(&current_bytes, agent, request_headers, request_path)
.await?;
current_bytes = if agent_type == "mcp" {
self.execute_mcp_filter_raw(&current_bytes, agent, request_headers, request_path)
.await?
} else {
self.execute_raw_filter(&current_bytes, agent, request_headers, request_path)
.await?
};
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
}
@ -812,7 +942,7 @@ mod tests {
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
AgentFilterChain {
id: "test-agent".to_string(),
filter_chain: Some(agents.iter().map(|s| s.to_string()).collect()),
input_filters: Some(agents.iter().map(|s| s.to_string()).collect()),
description: None,
default: None,
}

View file

@ -308,7 +308,7 @@ where
id: "output_filter".to_string(),
default: None,
description: None,
filter_chain: Some(output_filters),
input_filters: Some(output_filters),
};
while let Some(item) = byte_stream.next().await {

View file

@ -27,7 +27,7 @@ pub struct AgentFilterChain {
pub id: String,
pub default: Option<bool>,
pub description: Option<String>,
pub filter_chain: Option<Vec<String>>,
pub input_filters: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]