mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add support for type=rest/mcp
This commit is contained in:
parent
fc3045cb03
commit
6cfd630a05
8 changed files with 292 additions and 62 deletions
|
|
@ -36,6 +36,7 @@ properties:
|
|||
type: string
|
||||
enum:
|
||||
- mcp
|
||||
- rest
|
||||
transport:
|
||||
type: string
|
||||
enum:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ use common::configuration::{Agent, AgentFilterChain};
|
|||
use common::consts::{
|
||||
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
use common::traces::{SpanBuilder, SpanKind, generate_random_span_id};
|
||||
use common::traces::{generate_random_span_id, SpanBuilder, SpanKind};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hyper::header::HeaderMap;
|
||||
|
|
@ -216,10 +216,11 @@ impl PipelineProcessor {
|
|||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
|
||||
info!(
|
||||
"executing filter: {}/{}, url: {}, conversation length: {}",
|
||||
"executing filter: {}/{}, url: {}, type: {}, conversation length: {}",
|
||||
agent_name,
|
||||
tool_name,
|
||||
agent.url,
|
||||
agent.agent_type.as_deref().unwrap_or("mcp"),
|
||||
chat_history.len()
|
||||
);
|
||||
|
||||
|
|
@ -229,16 +230,29 @@ impl PipelineProcessor {
|
|||
// Generate filter span ID before execution so MCP spans can use it as parent
|
||||
let filter_span_id = generate_random_span_id();
|
||||
|
||||
chat_history_updated = self
|
||||
.execute_filter(
|
||||
&chat_history_updated,
|
||||
agent,
|
||||
request_headers,
|
||||
trace_collector,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
if agent.agent_type.as_deref().unwrap_or("mcp") == "mcp" {
|
||||
chat_history_updated = self
|
||||
.execute_mcp_filter(
|
||||
&chat_history_updated,
|
||||
agent,
|
||||
request_headers,
|
||||
trace_collector,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
chat_history_updated = self
|
||||
.execute_rest_filter(
|
||||
&chat_history_updated,
|
||||
agent,
|
||||
request_headers,
|
||||
trace_collector,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let end_time = SystemTime::now();
|
||||
let elapsed = start_instant.elapsed();
|
||||
|
|
@ -412,7 +426,7 @@ impl PipelineProcessor {
|
|||
}
|
||||
|
||||
/// Send request to a specific agent and return the response content
|
||||
async fn execute_filter(
|
||||
async fn execute_mcp_filter(
|
||||
&mut self,
|
||||
messages: &[Message],
|
||||
agent: &Agent,
|
||||
|
|
@ -426,11 +440,7 @@ impl PipelineProcessor {
|
|||
session_id.clone()
|
||||
} else {
|
||||
let session_id = self
|
||||
.get_new_session_id(
|
||||
&agent.id,
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
)
|
||||
.get_new_session_id(&agent.id, trace_id.clone(), filter_span_id.clone())
|
||||
.await;
|
||||
self.agent_id_session_map
|
||||
.insert(agent.id.clone(), session_id.clone());
|
||||
|
|
@ -450,19 +460,20 @@ impl PipelineProcessor {
|
|||
let mcp_span_id = generate_random_span_id();
|
||||
|
||||
// Build headers
|
||||
let agent_headers =
|
||||
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id), trace_id.clone(), mcp_span_id.clone())?;
|
||||
let agent_headers = self.build_mcp_headers(
|
||||
request_headers,
|
||||
&agent.id,
|
||||
Some(&mcp_session_id),
|
||||
trace_id.clone(),
|
||||
mcp_span_id.clone(),
|
||||
)?;
|
||||
|
||||
// Send request with tracing
|
||||
let start_time = SystemTime::now();
|
||||
let start_instant = Instant::now();
|
||||
|
||||
let response = self
|
||||
.send_mcp_request(
|
||||
&json_rpc_request,
|
||||
agent_headers,
|
||||
&agent.id,
|
||||
)
|
||||
.send_mcp_request(&json_rpc_request, agent_headers, &agent.id)
|
||||
.await?;
|
||||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
|
@ -604,7 +615,13 @@ impl PipelineProcessor {
|
|||
let notification_body = serde_json::to_string(&initialized_notification)?;
|
||||
debug!("Sending initialized notification for agent {}", agent_id);
|
||||
|
||||
let headers = self.build_mcp_headers(&HeaderMap::new(), agent_id, Some(session_id), trace_id.clone(), parent_span_id.clone())?;
|
||||
let headers = self.build_mcp_headers(
|
||||
&HeaderMap::new(),
|
||||
agent_id,
|
||||
Some(session_id),
|
||||
trace_id.clone(),
|
||||
parent_span_id.clone(),
|
||||
)?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
|
|
@ -632,7 +649,13 @@ impl PipelineProcessor {
|
|||
|
||||
let initialize_request = self.build_initialize_request();
|
||||
let headers = self
|
||||
.build_mcp_headers(&HeaderMap::new(), agent_id, None, trace_id.clone(), parent_span_id.clone())
|
||||
.build_mcp_headers(
|
||||
&HeaderMap::new(),
|
||||
agent_id,
|
||||
None,
|
||||
trace_id.clone(),
|
||||
parent_span_id.clone(),
|
||||
)
|
||||
.expect("Failed to build headers for initialization");
|
||||
|
||||
let response = self
|
||||
|
|
@ -667,6 +690,129 @@ impl PipelineProcessor {
|
|||
session_id
|
||||
}
|
||||
|
||||
/// Execute a REST-based filter agent
|
||||
async fn execute_rest_filter(
|
||||
&mut self,
|
||||
messages: &[Message],
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
trace_collector: Option<&std::sync::Arc<common::traces::TraceCollector>>,
|
||||
trace_id: String,
|
||||
filter_span_id: String,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
|
||||
// Generate span ID for this REST call (child of filter span)
|
||||
let rest_span_id = generate_random_span_id();
|
||||
|
||||
// Build headers
|
||||
let trace_parent = format!("00-{}-{}-01", trace_id, rest_span_id);
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
agent_headers.insert(
|
||||
TRACE_PARENT_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
// Send request with tracing
|
||||
let start_time = SystemTime::now();
|
||||
let start_instant = Instant::now();
|
||||
|
||||
debug!(
|
||||
"Sending REST request to agent {} at URL: {}",
|
||||
agent.id, agent.url
|
||||
);
|
||||
|
||||
// Send messages array directly as request body
|
||||
let response = self
|
||||
.client
|
||||
.post(&agent.url)
|
||||
.headers(agent_headers)
|
||||
.json(&messages)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
let end_time = SystemTime::now();
|
||||
let elapsed = start_instant.elapsed();
|
||||
|
||||
// Record REST call span
|
||||
if let Some(collector) = trace_collector {
|
||||
let mut attrs = HashMap::new();
|
||||
attrs.insert("rest.tool_name", tool_name.to_string());
|
||||
attrs.insert("rest.url", agent.url.clone());
|
||||
attrs.insert("http.status_code", http_status.as_u16().to_string());
|
||||
|
||||
self.record_mcp_span(
|
||||
collector,
|
||||
"rest_call",
|
||||
&agent.id,
|
||||
start_time,
|
||||
end_time,
|
||||
elapsed,
|
||||
Some(attrs),
|
||||
trace_id.clone(),
|
||||
filter_span_id.clone(),
|
||||
Some(rest_span_id),
|
||||
);
|
||||
}
|
||||
|
||||
// Handle HTTP errors
|
||||
if !http_status.is_success() {
|
||||
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||
return Err(if http_status.is_client_error() {
|
||||
PipelineError::ClientError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
} else {
|
||||
PipelineError::ServerError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
info!(
|
||||
"Response from REST agent {}: {}",
|
||||
agent.id,
|
||||
String::from_utf8_lossy(&response_bytes)
|
||||
);
|
||||
|
||||
// Parse response - expecting array of messages directly
|
||||
let messages: Vec<Message> =
|
||||
serde_json::from_slice(&response_bytes).map_err(PipelineError::ParseError)?;
|
||||
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
/// Send request to terminal agent and return the raw response for streaming
|
||||
pub async fn invoke_agent(
|
||||
&self,
|
||||
|
|
@ -757,7 +903,15 @@ mod tests {
|
|||
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
|
||||
|
||||
let result = processor
|
||||
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers, None, String::new(), String::new())
|
||||
.process_filter_chain(
|
||||
&messages,
|
||||
&pipeline,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
None,
|
||||
String::new(),
|
||||
String::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
|
|
@ -791,7 +945,14 @@ mod tests {
|
|||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-123".to_string(), "span-123".to_string())
|
||||
.execute_mcp_filter(
|
||||
&messages,
|
||||
&agent,
|
||||
&request_headers,
|
||||
None,
|
||||
"trace-123".to_string(),
|
||||
"span-123".to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
@ -830,7 +991,14 @@ mod tests {
|
|||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-456".to_string(), "span-456".to_string())
|
||||
.execute_mcp_filter(
|
||||
&messages,
|
||||
&agent,
|
||||
&request_headers,
|
||||
None,
|
||||
"trace-456".to_string(),
|
||||
"span-456".to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
@ -882,7 +1050,14 @@ mod tests {
|
|||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_filter(&messages, &agent, &request_headers, None, "trace-789".to_string(), "span-789".to_string())
|
||||
.execute_mcp_filter(
|
||||
&messages,
|
||||
&agent,
|
||||
&request_headers,
|
||||
None,
|
||||
"trace-789".to_string(),
|
||||
"span-789".to_string(),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
|
|||
|
|
@ -6,8 +6,9 @@ agents:
|
|||
|
||||
filters:
|
||||
- id: query_rewriter
|
||||
url: http://host.docker.internal:10501
|
||||
# type: mcp # default is mcp
|
||||
url: http://host.docker.internal:10500
|
||||
type: rest
|
||||
# type: rest or mcp, mcp is default
|
||||
# transport: streamable-http # default is streamable-http
|
||||
# tool: query_rewriter # default name is the filter id
|
||||
- id: context_builder
|
||||
|
|
|
|||
|
|
@ -54,18 +54,26 @@ def main(host, port, agent, transport, agent_name, rest_server, rest_port):
|
|||
mcp_name = agent_name or default_name
|
||||
|
||||
if rest_server:
|
||||
# Only response_generator supports REST server mode
|
||||
if agent != "response_generator":
|
||||
# REST server mode - supported for query_rewriter and response_generator
|
||||
if agent == "response_generator":
|
||||
print(f"Starting REST server on {host}:{rest_port} for agent: {agent}")
|
||||
from rag_agent.rag_agent import start_server
|
||||
|
||||
start_server(host=host, port=rest_port)
|
||||
return
|
||||
elif agent == "query_rewriter":
|
||||
print(f"Starting REST server on {host}:{rest_port} for agent: {agent}")
|
||||
from rag_agent.query_rewriter import start_server
|
||||
|
||||
start_server(host=host, port=rest_port)
|
||||
return
|
||||
else:
|
||||
print(f"Error: Agent '{agent}' does not support REST server mode.")
|
||||
print(f"REST server is only supported for: response_generator")
|
||||
print(
|
||||
f"REST server is only supported for: query_rewriter, response_generator"
|
||||
)
|
||||
print(f"Remove --rest-server flag to start {agent} as an MCP server.")
|
||||
return
|
||||
|
||||
print(f"Starting REST server on {host}:{rest_port} for agent: {agent}")
|
||||
from rag_agent.rag_agent import start_server
|
||||
|
||||
start_server(host=host, port=rest_port)
|
||||
return
|
||||
else:
|
||||
# Only query_rewriter and context_builder support MCP
|
||||
if agent not in ["query_rewriter", "context_builder"]:
|
||||
|
|
|
|||
|
|
@ -184,7 +184,6 @@ async def augment_query_with_context(
|
|||
load_knowledge_base()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
"""MCP tool that augments user queries with relevant context from the knowledge base."""
|
||||
logger.info(f"Received chat completion request with {len(messages)} messages")
|
||||
|
|
@ -203,3 +202,8 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
|
|||
|
||||
# Return as dict to minimize text serialization
|
||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
||||
|
||||
|
||||
# Register MCP tool only if mcp is available
|
||||
if mcp is not None:
|
||||
mcp.tool()(context_builder)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
from fastapi import FastAPI, Depends, Request
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
|
||||
from .api import ChatMessage
|
||||
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage
|
||||
from . import mcp
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
|
|
@ -16,7 +19,6 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Configuration for archgw LLM gateway
|
||||
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
|
||||
QUERY_REWRITE_MODEL = "gpt-4o-mini"
|
||||
|
|
@ -27,6 +29,8 @@ archgw_client = AsyncOpenAI(
|
|||
api_key="EMPTY", # archgw doesn't require a real API key
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
async def rewrite_query_with_archgw(
|
||||
messages: List[ChatMessage], traceparent_header: str
|
||||
|
|
@ -79,15 +83,11 @@ async def rewrite_query_with_archgw(
|
|||
return ""
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
"""Chat completions endpoint that rewrites the last user query using archgw.
|
||||
|
||||
Returns a dict with a 'messages' key containing the updated message list.
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
|
||||
logger.info(f"Received chat completion request with {len(messages)} messages")
|
||||
|
||||
# Get traceparent header from HTTP request using FastMCP's dependency function
|
||||
|
|
@ -117,3 +117,42 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
|
|||
|
||||
# Return as dict to minimize text serialization
|
||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
||||
|
||||
|
||||
# Register MCP tool only if mcp is available
|
||||
if mcp is not None:
|
||||
mcp.tool()(query_rewriter)
|
||||
|
||||
|
||||
@app.post("/")
|
||||
async def chat_completions_endpoint(
|
||||
request_messages: List[ChatMessage], request: Request
|
||||
) -> List[ChatMessage]:
|
||||
"""FastAPI endpoint for chat completions with query rewriting."""
|
||||
logger.info(
|
||||
f"Received /v1/chat/completions request with {len(request_messages)} messages"
|
||||
)
|
||||
|
||||
# Extract traceparent header
|
||||
traceparent_header = request.headers.get("traceparent")
|
||||
if traceparent_header:
|
||||
logger.info(f"Received traceparent header: {traceparent_header}")
|
||||
else:
|
||||
logger.info("No traceparent header found")
|
||||
|
||||
# Call the query rewriter tool
|
||||
updated_messages_data = await query_rewriter(request_messages)
|
||||
|
||||
# Convert back to ChatMessage objects
|
||||
updated_messages = [ChatMessage(**msg) for msg in updated_messages_data]
|
||||
|
||||
logger.info("Returning rewritten chat completion response")
|
||||
return updated_messages
|
||||
|
||||
|
||||
def start_server(host: str = "0.0.0.0", port: int = 10501):
|
||||
"""Start the FastAPI server for query rewriter."""
|
||||
import uvicorn
|
||||
|
||||
logger.info(f"Starting Query Rewriter REST server on {host}:{port}")
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
|
|
|||
|
|
@ -26,11 +26,16 @@ trap cleanup EXIT
|
|||
# WAIT_FOR_PIDS+=($!)
|
||||
|
||||
|
||||
log "Starting query_parser agent on port 10501..."
|
||||
log "Starting query_rewriter agent on port 10500/http..."
|
||||
uv run python -m rag_agent --rest-server --host 0.0.0.0 --rest-port 10500 --agent query_rewriter &
|
||||
WAIT_FOR_PIDS+=($!)
|
||||
|
||||
|
||||
log "Starting query_parser agent on port 10501/mcp..."
|
||||
uv run python -m rag_agent --host 0.0.0.0 --port 10501 --agent query_rewriter &
|
||||
WAIT_FOR_PIDS+=($!)
|
||||
|
||||
log "Starting context_builder agent on port 10502..."
|
||||
log "Starting context_builder agent on port 10502/mcp..."
|
||||
uv run python -m rag_agent --host 0.0.0.0 --port 10502 --agent context_builder &
|
||||
WAIT_FOR_PIDS+=($!)
|
||||
|
||||
|
|
|
|||
|
|
@ -52,19 +52,16 @@ Content-Type: application/json
|
|||
"stream": true
|
||||
}
|
||||
|
||||
### send request to context builder agent
|
||||
POST http://localhost:10501/v1/chat/completions
|
||||
### send request to query_rewriter agent
|
||||
POST http://localhost:10500/
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the guaranteed uptime percentage for TechCorp's cloud services?"
|
||||
}
|
||||
]
|
||||
}
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the guaranteed uptime percentage for TechCorp's cloud services?"
|
||||
}
|
||||
]
|
||||
|
||||
### test fast-llm
|
||||
POST http://localhost:12000/v1/chat/completions
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue