add support for type=rest/mcp

This commit is contained in:
Adil Hafeez 2025-12-22 15:36:00 -08:00
parent fc3045cb03
commit 6cfd630a05
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
8 changed files with 292 additions and 62 deletions

View file

@ -36,6 +36,7 @@ properties:
type: string
enum:
- mcp
- rest
transport:
type: string
enum:

View file

@ -4,7 +4,7 @@ use common::configuration::{Agent, AgentFilterChain};
use common::consts::{
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
};
use common::traces::{SpanBuilder, SpanKind, generate_random_span_id};
use common::traces::{generate_random_span_id, SpanBuilder, SpanKind};
use hermesllm::apis::openai::Message;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hyper::header::HeaderMap;
@ -216,10 +216,11 @@ impl PipelineProcessor {
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
info!(
"executing filter: {}/{}, url: {}, conversation length: {}",
"executing filter: {}/{}, url: {}, type: {}, conversation length: {}",
agent_name,
tool_name,
agent.url,
agent.agent_type.as_deref().unwrap_or("mcp"),
chat_history.len()
);
@ -229,16 +230,29 @@ impl PipelineProcessor {
// Generate filter span ID before execution so MCP spans can use it as parent
let filter_span_id = generate_random_span_id();
chat_history_updated = self
.execute_filter(
&chat_history_updated,
agent,
request_headers,
trace_collector,
trace_id.clone(),
filter_span_id.clone(),
)
.await?;
if agent.agent_type.as_deref().unwrap_or("mcp") == "mcp" {
chat_history_updated = self
.execute_mcp_filter(
&chat_history_updated,
agent,
request_headers,
trace_collector,
trace_id.clone(),
filter_span_id.clone(),
)
.await?;
} else {
chat_history_updated = self
.execute_rest_filter(
&chat_history_updated,
agent,
request_headers,
trace_collector,
trace_id.clone(),
filter_span_id.clone(),
)
.await?;
}
let end_time = SystemTime::now();
let elapsed = start_instant.elapsed();
@ -412,7 +426,7 @@ impl PipelineProcessor {
}
/// Send request to a specific agent and return the response content
async fn execute_filter(
async fn execute_mcp_filter(
&mut self,
messages: &[Message],
agent: &Agent,
@ -426,11 +440,7 @@ impl PipelineProcessor {
session_id.clone()
} else {
let session_id = self
.get_new_session_id(
&agent.id,
trace_id.clone(),
filter_span_id.clone(),
)
.get_new_session_id(&agent.id, trace_id.clone(), filter_span_id.clone())
.await;
self.agent_id_session_map
.insert(agent.id.clone(), session_id.clone());
@ -450,19 +460,20 @@ impl PipelineProcessor {
let mcp_span_id = generate_random_span_id();
// Build headers
let agent_headers =
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id), trace_id.clone(), mcp_span_id.clone())?;
let agent_headers = self.build_mcp_headers(
request_headers,
&agent.id,
Some(&mcp_session_id),
trace_id.clone(),
mcp_span_id.clone(),
)?;
// Send request with tracing
let start_time = SystemTime::now();
let start_instant = Instant::now();
let response = self
.send_mcp_request(
&json_rpc_request,
agent_headers,
&agent.id,
)
.send_mcp_request(&json_rpc_request, agent_headers, &agent.id)
.await?;
let http_status = response.status();
let response_bytes = response.bytes().await?;
@ -604,7 +615,13 @@ impl PipelineProcessor {
let notification_body = serde_json::to_string(&initialized_notification)?;
debug!("Sending initialized notification for agent {}", agent_id);
let headers = self.build_mcp_headers(&HeaderMap::new(), agent_id, Some(session_id), trace_id.clone(), parent_span_id.clone())?;
let headers = self.build_mcp_headers(
&HeaderMap::new(),
agent_id,
Some(session_id),
trace_id.clone(),
parent_span_id.clone(),
)?;
let response = self
.client
@ -632,7 +649,13 @@ impl PipelineProcessor {
let initialize_request = self.build_initialize_request();
let headers = self
.build_mcp_headers(&HeaderMap::new(), agent_id, None, trace_id.clone(), parent_span_id.clone())
.build_mcp_headers(
&HeaderMap::new(),
agent_id,
None,
trace_id.clone(),
parent_span_id.clone(),
)
.expect("Failed to build headers for initialization");
let response = self
@ -667,6 +690,129 @@ impl PipelineProcessor {
session_id
}
/// Execute a REST-based filter agent
async fn execute_rest_filter(
&mut self,
messages: &[Message],
agent: &Agent,
request_headers: &HeaderMap,
trace_collector: Option<&std::sync::Arc<common::traces::TraceCollector>>,
trace_id: String,
filter_span_id: String,
) -> Result<Vec<Message>, PipelineError> {
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
// Generate span ID for this REST call (child of filter span)
let rest_span_id = generate_random_span_id();
// Build headers
let trace_parent = format!("00-{}-{}-01", trace_id, rest_span_id);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
agent_headers.remove(TRACE_PARENT_HEADER);
agent_headers.insert(
TRACE_PARENT_HEADER,
hyper::header::HeaderValue::from_str(&trace_parent).unwrap(),
);
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&agent.id)
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
agent_headers.insert(
"Accept",
hyper::header::HeaderValue::from_static("application/json"),
);
agent_headers.insert(
"Content-Type",
hyper::header::HeaderValue::from_static("application/json"),
);
// Send request with tracing
let start_time = SystemTime::now();
let start_instant = Instant::now();
debug!(
"Sending REST request to agent {} at URL: {}",
agent.id, agent.url
);
// Send messages array directly as request body
let response = self
.client
.post(&agent.url)
.headers(agent_headers)
.json(&messages)
.send()
.await?;
let http_status = response.status();
let response_bytes = response.bytes().await?;
let end_time = SystemTime::now();
let elapsed = start_instant.elapsed();
// Record REST call span
if let Some(collector) = trace_collector {
let mut attrs = HashMap::new();
attrs.insert("rest.tool_name", tool_name.to_string());
attrs.insert("rest.url", agent.url.clone());
attrs.insert("http.status_code", http_status.as_u16().to_string());
self.record_mcp_span(
collector,
"rest_call",
&agent.id,
start_time,
end_time,
elapsed,
Some(attrs),
trace_id.clone(),
filter_span_id.clone(),
Some(rest_span_id),
);
}
// Handle HTTP errors
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
PipelineError::ClientError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
} else {
PipelineError::ServerError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
});
}
info!(
"Response from REST agent {}: {}",
agent.id,
String::from_utf8_lossy(&response_bytes)
);
// Parse response - expecting array of messages directly
let messages: Vec<Message> =
serde_json::from_slice(&response_bytes).map_err(PipelineError::ParseError)?;
Ok(messages)
}
/// Send request to terminal agent and return the raw response for streaming
pub async fn invoke_agent(
&self,
@ -757,7 +903,15 @@ mod tests {
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
let result = processor
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers, None, String::new(), String::new())
.process_filter_chain(
&messages,
&pipeline,
&agent_map,
&request_headers,
None,
String::new(),
String::new(),
)
.await;
assert!(result.is_err());
@ -791,7 +945,14 @@ mod tests {
let request_headers = HeaderMap::new();
let result = processor
.execute_filter(&messages, &agent, &request_headers, None, "trace-123".to_string(), "span-123".to_string())
.execute_mcp_filter(
&messages,
&agent,
&request_headers,
None,
"trace-123".to_string(),
"span-123".to_string(),
)
.await;
match result {
@ -830,7 +991,14 @@ mod tests {
let request_headers = HeaderMap::new();
let result = processor
.execute_filter(&messages, &agent, &request_headers, None, "trace-456".to_string(), "span-456".to_string())
.execute_mcp_filter(
&messages,
&agent,
&request_headers,
None,
"trace-456".to_string(),
"span-456".to_string(),
)
.await;
match result {
@ -882,7 +1050,14 @@ mod tests {
let request_headers = HeaderMap::new();
let result = processor
.execute_filter(&messages, &agent, &request_headers, None, "trace-789".to_string(), "span-789".to_string())
.execute_mcp_filter(
&messages,
&agent,
&request_headers,
None,
"trace-789".to_string(),
"span-789".to_string(),
)
.await;
match result {

View file

@ -6,8 +6,9 @@ agents:
filters:
- id: query_rewriter
url: http://host.docker.internal:10501
# type: mcp # default is mcp
url: http://host.docker.internal:10500
type: rest
# type: rest or mcp, mcp is default
# transport: streamable-http # default is streamable-http
# tool: query_rewriter # default name is the filter id
- id: context_builder

View file

@ -54,18 +54,26 @@ def main(host, port, agent, transport, agent_name, rest_server, rest_port):
mcp_name = agent_name or default_name
if rest_server:
# Only response_generator supports REST server mode
if agent != "response_generator":
# REST server mode - supported for query_rewriter and response_generator
if agent == "response_generator":
print(f"Starting REST server on {host}:{rest_port} for agent: {agent}")
from rag_agent.rag_agent import start_server
start_server(host=host, port=rest_port)
return
elif agent == "query_rewriter":
print(f"Starting REST server on {host}:{rest_port} for agent: {agent}")
from rag_agent.query_rewriter import start_server
start_server(host=host, port=rest_port)
return
else:
print(f"Error: Agent '{agent}' does not support REST server mode.")
print(f"REST server is only supported for: response_generator")
print(
f"REST server is only supported for: query_rewriter, response_generator"
)
print(f"Remove --rest-server flag to start {agent} as an MCP server.")
return
print(f"Starting REST server on {host}:{rest_port} for agent: {agent}")
from rag_agent.rag_agent import start_server
start_server(host=host, port=rest_port)
return
else:
# Only query_rewriter and context_builder support MCP
if agent not in ["query_rewriter", "context_builder"]:

View file

@ -184,7 +184,6 @@ async def augment_query_with_context(
load_knowledge_base()
@mcp.tool()
async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
"""MCP tool that augments user queries with relevant context from the knowledge base."""
logger.info(f"Received chat completion request with {len(messages)} messages")
@ -203,3 +202,8 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
# Return as dict to minimize text serialization
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
# Register MCP tool only if mcp is available
if mcp is not None:
mcp.tool()(context_builder)

View file

@ -1,11 +1,14 @@
import asyncio
import json
import time
from typing import List, Optional, Dict, Any
import uuid
from fastapi import FastAPI, Depends, Request
from openai import AsyncOpenAI
import os
import logging
from .api import ChatMessage
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage
from . import mcp
from fastmcp.server.dependencies import get_http_headers
@ -16,7 +19,6 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
# Configuration for archgw LLM gateway
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
QUERY_REWRITE_MODEL = "gpt-4o-mini"
@ -27,6 +29,8 @@ archgw_client = AsyncOpenAI(
api_key="EMPTY", # archgw doesn't require a real API key
)
app = FastAPI()
async def rewrite_query_with_archgw(
messages: List[ChatMessage], traceparent_header: str
@ -79,15 +83,11 @@ async def rewrite_query_with_archgw(
return ""
@mcp.tool()
async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
"""Chat completions endpoint that rewrites the last user query using archgw.
Returns a dict with a 'messages' key containing the updated message list.
"""
import time
import uuid
logger.info(f"Received chat completion request with {len(messages)} messages")
# Get traceparent header from HTTP request using FastMCP's dependency function
@ -117,3 +117,42 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
# Return as dict to minimize text serialization
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
# Register MCP tool only if mcp is available
if mcp is not None:
mcp.tool()(query_rewriter)
@app.post("/")
async def chat_completions_endpoint(
request_messages: List[ChatMessage], request: Request
) -> List[ChatMessage]:
"""FastAPI endpoint for chat completions with query rewriting."""
logger.info(
f"Received /v1/chat/completions request with {len(request_messages)} messages"
)
# Extract traceparent header
traceparent_header = request.headers.get("traceparent")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
else:
logger.info("No traceparent header found")
# Call the query rewriter tool
updated_messages_data = await query_rewriter(request_messages)
# Convert back to ChatMessage objects
updated_messages = [ChatMessage(**msg) for msg in updated_messages_data]
logger.info("Returning rewritten chat completion response")
return updated_messages
def start_server(host: str = "0.0.0.0", port: int = 10501):
"""Start the FastAPI server for query rewriter."""
import uvicorn
logger.info(f"Starting Query Rewriter REST server on {host}:{port}")
uvicorn.run(app, host=host, port=port)

View file

@ -26,11 +26,16 @@ trap cleanup EXIT
# WAIT_FOR_PIDS+=($!)
log "Starting query_parser agent on port 10501..."
log "Starting query_rewriter agent on port 10500/http..."
uv run python -m rag_agent --rest-server --host 0.0.0.0 --rest-port 10500 --agent query_rewriter &
WAIT_FOR_PIDS+=($!)
log "Starting query_parser agent on port 10501/mcp..."
uv run python -m rag_agent --host 0.0.0.0 --port 10501 --agent query_rewriter &
WAIT_FOR_PIDS+=($!)
log "Starting context_builder agent on port 10502..."
log "Starting context_builder agent on port 10502/mcp..."
uv run python -m rag_agent --host 0.0.0.0 --port 10502 --agent context_builder &
WAIT_FOR_PIDS+=($!)

View file

@ -52,19 +52,16 @@ Content-Type: application/json
"stream": true
}
### send request to context builder agent
POST http://localhost:10501/v1/chat/completions
### send request to query_rewriter agent
POST http://localhost:10500/
Content-Type: application/json
{
"model": "gpt-4o-mini",
"messages": [
{
"role": "user",
"content": "What is the guaranteed uptime percentage for TechCorp's cloud services?"
}
]
}
[
{
"role": "user",
"content": "What is the guaranteed uptime percentage for TechCorp's cloud services?"
}
]
### test fast-llm
POST http://localhost:12000/v1/chat/completions