diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index 4fba9024..1895a819 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -36,6 +36,7 @@ properties: type: string enum: - mcp + - rest transport: type: string enum: diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index c0cd1cef..c1bf5ca4 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -4,7 +4,7 @@ use common::configuration::{Agent, AgentFilterChain}; use common::consts::{ ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER, }; -use common::traces::{SpanBuilder, SpanKind, generate_random_span_id}; +use common::traces::{generate_random_span_id, SpanBuilder, SpanKind}; use hermesllm::apis::openai::Message; use hermesllm::{ProviderRequest, ProviderRequestType}; use hyper::header::HeaderMap; @@ -216,10 +216,11 @@ impl PipelineProcessor { let tool_name = agent.tool.as_deref().unwrap_or(&agent.id); info!( - "executing filter: {}/{}, url: {}, conversation length: {}", + "executing filter: {}/{}, url: {}, type: {}, conversation length: {}", agent_name, tool_name, agent.url, + agent.agent_type.as_deref().unwrap_or("mcp"), chat_history.len() ); @@ -229,16 +230,29 @@ impl PipelineProcessor { // Generate filter span ID before execution so MCP spans can use it as parent let filter_span_id = generate_random_span_id(); - chat_history_updated = self - .execute_filter( - &chat_history_updated, - agent, - request_headers, - trace_collector, - trace_id.clone(), - filter_span_id.clone(), - ) - .await?; + if agent.agent_type.as_deref().unwrap_or("mcp") == "mcp" { + chat_history_updated = self + .execute_mcp_filter( + &chat_history_updated, + agent, + request_headers, + trace_collector, + trace_id.clone(), + filter_span_id.clone(), + ) + .await?; + } else { + chat_history_updated = self + .execute_rest_filter( + &chat_history_updated, + agent, + request_headers, + trace_collector, + trace_id.clone(), + filter_span_id.clone(), + ) + .await?; + } let end_time = SystemTime::now(); let elapsed = start_instant.elapsed(); @@ -412,7 +426,7 @@ impl PipelineProcessor { } /// Send request to a specific agent and return the response content - async fn execute_filter( + async fn execute_mcp_filter( &mut self, messages: &[Message], agent: &Agent, @@ -426,11 +440,7 @@ impl PipelineProcessor { session_id.clone() } else { let session_id = self - .get_new_session_id( - &agent.id, - trace_id.clone(), - filter_span_id.clone(), - ) + .get_new_session_id(&agent.id, trace_id.clone(), filter_span_id.clone()) .await; self.agent_id_session_map .insert(agent.id.clone(), session_id.clone()); @@ -450,19 +460,20 @@ impl PipelineProcessor { let mcp_span_id = generate_random_span_id(); // Build headers - let agent_headers = - self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id), trace_id.clone(), mcp_span_id.clone())?; + let agent_headers = self.build_mcp_headers( + request_headers, + &agent.id, + Some(&mcp_session_id), + trace_id.clone(), + mcp_span_id.clone(), + )?; // Send request with tracing let start_time = SystemTime::now(); let start_instant = Instant::now(); let response = self - .send_mcp_request( - &json_rpc_request, - agent_headers, - &agent.id, - ) + .send_mcp_request(&json_rpc_request, agent_headers, &agent.id) .await?; let http_status = response.status(); let response_bytes = response.bytes().await?; @@ -604,7 +615,13 @@ impl PipelineProcessor { let notification_body = serde_json::to_string(&initialized_notification)?; debug!("Sending initialized notification for agent {}", agent_id); - let headers = self.build_mcp_headers(&HeaderMap::new(), agent_id, Some(session_id), trace_id.clone(), parent_span_id.clone())?; + let headers = self.build_mcp_headers( + &HeaderMap::new(), + agent_id, + Some(session_id), + trace_id.clone(), + parent_span_id.clone(), + )?; let response = self .client @@ -632,7 +649,13 @@ impl PipelineProcessor { let initialize_request = self.build_initialize_request(); let headers = self - .build_mcp_headers(&HeaderMap::new(), agent_id, None, trace_id.clone(), parent_span_id.clone()) + .build_mcp_headers( + &HeaderMap::new(), + agent_id, + None, + trace_id.clone(), + parent_span_id.clone(), + ) .expect("Failed to build headers for initialization"); let response = self @@ -667,6 +690,129 @@ impl PipelineProcessor { session_id } + /// Execute a REST-based filter agent + async fn execute_rest_filter( + &mut self, + messages: &[Message], + agent: &Agent, + request_headers: &HeaderMap, + trace_collector: Option<&std::sync::Arc>, + trace_id: String, + filter_span_id: String, + ) -> Result, PipelineError> { + let tool_name = agent.tool.as_deref().unwrap_or(&agent.id); + + // Generate span ID for this REST call (child of filter span) + let rest_span_id = generate_random_span_id(); + + // Build headers + let trace_parent = format!("00-{}-{}-01", trace_id, rest_span_id); + let mut agent_headers = request_headers.clone(); + agent_headers.remove(hyper::header::CONTENT_LENGTH); + + agent_headers.remove(TRACE_PARENT_HEADER); + agent_headers.insert( + TRACE_PARENT_HEADER, + hyper::header::HeaderValue::from_str(&trace_parent).unwrap(), + ); + + agent_headers.insert( + ARCH_UPSTREAM_HOST_HEADER, + hyper::header::HeaderValue::from_str(&agent.id) + .map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?, + ); + + agent_headers.insert( + ENVOY_RETRY_HEADER, + hyper::header::HeaderValue::from_str("3").unwrap(), + ); + + agent_headers.insert( + "Accept", + hyper::header::HeaderValue::from_static("application/json"), + ); + + agent_headers.insert( + "Content-Type", + hyper::header::HeaderValue::from_static("application/json"), + ); + + // Send request with tracing + let start_time = SystemTime::now(); + let start_instant = Instant::now(); + + debug!( + "Sending REST request to agent {} at URL: {}", + agent.id, agent.url + ); + + // Send messages array directly as request body + let response = self + .client + .post(&agent.url) + .headers(agent_headers) + .json(&messages) + .send() + .await?; + + let http_status = response.status(); + let response_bytes = response.bytes().await?; + + let end_time = SystemTime::now(); + let elapsed = start_instant.elapsed(); + + // Record REST call span + if let Some(collector) = trace_collector { + let mut attrs = HashMap::new(); + attrs.insert("rest.tool_name", tool_name.to_string()); + attrs.insert("rest.url", agent.url.clone()); + attrs.insert("http.status_code", http_status.as_u16().to_string()); + + self.record_mcp_span( + collector, + "rest_call", + &agent.id, + start_time, + end_time, + elapsed, + Some(attrs), + trace_id.clone(), + filter_span_id.clone(), + Some(rest_span_id), + ); + } + + // Handle HTTP errors + if !http_status.is_success() { + let error_body = String::from_utf8_lossy(&response_bytes).to_string(); + return Err(if http_status.is_client_error() { + PipelineError::ClientError { + agent: agent.id.clone(), + status: http_status.as_u16(), + body: error_body, + } + } else { + PipelineError::ServerError { + agent: agent.id.clone(), + status: http_status.as_u16(), + body: error_body, + } + }); + } + + info!( + "Response from REST agent {}: {}", + agent.id, + String::from_utf8_lossy(&response_bytes) + ); + + // Parse response - expecting array of messages directly + let messages: Vec = + serde_json::from_slice(&response_bytes).map_err(PipelineError::ParseError)?; + + Ok(messages) + } + /// Send request to terminal agent and return the raw response for streaming pub async fn invoke_agent( &self, @@ -757,7 +903,15 @@ mod tests { let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]); let result = processor - .process_filter_chain(&messages, &pipeline, &agent_map, &request_headers, None, String::new(), String::new()) + .process_filter_chain( + &messages, + &pipeline, + &agent_map, + &request_headers, + None, + String::new(), + String::new(), + ) .await; assert!(result.is_err()); @@ -791,7 +945,14 @@ mod tests { let request_headers = HeaderMap::new(); let result = processor - .execute_filter(&messages, &agent, &request_headers, None, "trace-123".to_string(), "span-123".to_string()) + .execute_mcp_filter( + &messages, + &agent, + &request_headers, + None, + "trace-123".to_string(), + "span-123".to_string(), + ) .await; match result { @@ -830,7 +991,14 @@ mod tests { let request_headers = HeaderMap::new(); let result = processor - .execute_filter(&messages, &agent, &request_headers, None, "trace-456".to_string(), "span-456".to_string()) + .execute_mcp_filter( + &messages, + &agent, + &request_headers, + None, + "trace-456".to_string(), + "span-456".to_string(), + ) .await; match result { @@ -882,7 +1050,14 @@ mod tests { let request_headers = HeaderMap::new(); let result = processor - .execute_filter(&messages, &agent, &request_headers, None, "trace-789".to_string(), "span-789".to_string()) + .execute_mcp_filter( + &messages, + &agent, + &request_headers, + None, + "trace-789".to_string(), + "span-789".to_string(), + ) .await; match result { diff --git a/demos/use_cases/mcp_filter/arch_config.yaml b/demos/use_cases/mcp_filter/arch_config.yaml index e8da6164..a006e969 100644 --- a/demos/use_cases/mcp_filter/arch_config.yaml +++ b/demos/use_cases/mcp_filter/arch_config.yaml @@ -6,8 +6,9 @@ agents: filters: - id: query_rewriter - url: http://host.docker.internal:10501 - # type: mcp # default is mcp + url: http://host.docker.internal:10500 + type: rest + # type: rest or mcp, mcp is default # transport: streamable-http # default is streamable-http # tool: query_rewriter # default name is the filter id - id: context_builder diff --git a/demos/use_cases/mcp_filter/src/rag_agent/__init__.py b/demos/use_cases/mcp_filter/src/rag_agent/__init__.py index 08f8e21f..aa601877 100644 --- a/demos/use_cases/mcp_filter/src/rag_agent/__init__.py +++ b/demos/use_cases/mcp_filter/src/rag_agent/__init__.py @@ -54,18 +54,26 @@ def main(host, port, agent, transport, agent_name, rest_server, rest_port): mcp_name = agent_name or default_name if rest_server: - # Only response_generator supports REST server mode - if agent != "response_generator": + # REST server mode - supported for query_rewriter and response_generator + if agent == "response_generator": + print(f"Starting REST server on {host}:{rest_port} for agent: {agent}") + from rag_agent.rag_agent import start_server + + start_server(host=host, port=rest_port) + return + elif agent == "query_rewriter": + print(f"Starting REST server on {host}:{rest_port} for agent: {agent}") + from rag_agent.query_rewriter import start_server + + start_server(host=host, port=rest_port) + return + else: print(f"Error: Agent '{agent}' does not support REST server mode.") - print(f"REST server is only supported for: response_generator") + print( + f"REST server is only supported for: query_rewriter, response_generator" + ) print(f"Remove --rest-server flag to start {agent} as an MCP server.") return - - print(f"Starting REST server on {host}:{rest_port} for agent: {agent}") - from rag_agent.rag_agent import start_server - - start_server(host=host, port=rest_port) - return else: # Only query_rewriter and context_builder support MCP if agent not in ["query_rewriter", "context_builder"]: diff --git a/demos/use_cases/mcp_filter/src/rag_agent/context_builder.py b/demos/use_cases/mcp_filter/src/rag_agent/context_builder.py index 2fa6e307..5512fcc0 100644 --- a/demos/use_cases/mcp_filter/src/rag_agent/context_builder.py +++ b/demos/use_cases/mcp_filter/src/rag_agent/context_builder.py @@ -184,7 +184,6 @@ async def augment_query_with_context( load_knowledge_base() -@mcp.tool() async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]: """MCP tool that augments user queries with relevant context from the knowledge base.""" logger.info(f"Received chat completion request with {len(messages)} messages") @@ -203,3 +202,8 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]: # Return as dict to minimize text serialization return [{"role": msg.role, "content": msg.content} for msg in updated_messages] + + +# Register MCP tool only if mcp is available +if mcp is not None: + mcp.tool()(context_builder) diff --git a/demos/use_cases/mcp_filter/src/rag_agent/query_rewriter.py b/demos/use_cases/mcp_filter/src/rag_agent/query_rewriter.py index 89e5b200..81362e1f 100644 --- a/demos/use_cases/mcp_filter/src/rag_agent/query_rewriter.py +++ b/demos/use_cases/mcp_filter/src/rag_agent/query_rewriter.py @@ -1,11 +1,14 @@ import asyncio import json +import time from typing import List, Optional, Dict, Any +import uuid +from fastapi import FastAPI, Depends, Request from openai import AsyncOpenAI import os import logging -from .api import ChatMessage +from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage from . import mcp from fastmcp.server.dependencies import get_http_headers @@ -16,7 +19,6 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) - # Configuration for archgw LLM gateway LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1") QUERY_REWRITE_MODEL = "gpt-4o-mini" @@ -27,6 +29,8 @@ archgw_client = AsyncOpenAI( api_key="EMPTY", # archgw doesn't require a real API key ) +app = FastAPI() + async def rewrite_query_with_archgw( messages: List[ChatMessage], traceparent_header: str @@ -79,15 +83,11 @@ async def rewrite_query_with_archgw( return "" -@mcp.tool() async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]: """Chat completions endpoint that rewrites the last user query using archgw. Returns a dict with a 'messages' key containing the updated message list. """ - import time - import uuid - logger.info(f"Received chat completion request with {len(messages)} messages") # Get traceparent header from HTTP request using FastMCP's dependency function @@ -117,3 +117,42 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]: # Return as dict to minimize text serialization return [{"role": msg.role, "content": msg.content} for msg in updated_messages] + + +# Register MCP tool only if mcp is available +if mcp is not None: + mcp.tool()(query_rewriter) + + +@app.post("/") +async def chat_completions_endpoint( + request_messages: List[ChatMessage], request: Request +) -> List[ChatMessage]: + """FastAPI endpoint for chat completions with query rewriting.""" + logger.info( + f"Received /v1/chat/completions request with {len(request_messages)} messages" + ) + + # Extract traceparent header + traceparent_header = request.headers.get("traceparent") + if traceparent_header: + logger.info(f"Received traceparent header: {traceparent_header}") + else: + logger.info("No traceparent header found") + + # Call the query rewriter tool + updated_messages_data = await query_rewriter(request_messages) + + # Convert back to ChatMessage objects + updated_messages = [ChatMessage(**msg) for msg in updated_messages_data] + + logger.info("Returning rewritten chat completion response") + return updated_messages + + +def start_server(host: str = "0.0.0.0", port: int = 10501): + """Start the FastAPI server for query rewriter.""" + import uvicorn + + logger.info(f"Starting Query Rewriter REST server on {host}:{port}") + uvicorn.run(app, host=host, port=port) diff --git a/demos/use_cases/mcp_filter/start_agents.sh b/demos/use_cases/mcp_filter/start_agents.sh index e044fdda..2c2f446e 100644 --- a/demos/use_cases/mcp_filter/start_agents.sh +++ b/demos/use_cases/mcp_filter/start_agents.sh @@ -26,11 +26,16 @@ trap cleanup EXIT # WAIT_FOR_PIDS+=($!) -log "Starting query_parser agent on port 10501..." +log "Starting query_rewriter agent on port 10500/http..." +uv run python -m rag_agent --rest-server --host 0.0.0.0 --rest-port 10500 --agent query_rewriter & +WAIT_FOR_PIDS+=($!) + + +log "Starting query_parser agent on port 10501/mcp..." uv run python -m rag_agent --host 0.0.0.0 --port 10501 --agent query_rewriter & WAIT_FOR_PIDS+=($!) -log "Starting context_builder agent on port 10502..." +log "Starting context_builder agent on port 10502/mcp..." uv run python -m rag_agent --host 0.0.0.0 --port 10502 --agent context_builder & WAIT_FOR_PIDS+=($!) diff --git a/demos/use_cases/mcp_filter/test.rest b/demos/use_cases/mcp_filter/test.rest index 13d773b1..daf93b92 100644 --- a/demos/use_cases/mcp_filter/test.rest +++ b/demos/use_cases/mcp_filter/test.rest @@ -52,19 +52,16 @@ Content-Type: application/json "stream": true } -### send request to context builder agent -POST http://localhost:10501/v1/chat/completions +### send request to query_rewriter agent +POST http://localhost:10500/ Content-Type: application/json -{ - "model": "gpt-4o-mini", - "messages": [ - { - "role": "user", - "content": "What is the guaranteed uptime percentage for TechCorp's cloud services?" - } - ] -} +[ + { + "role": "user", + "content": "What is the guaranteed uptime percentage for TechCorp's cloud services?" + } +] ### test fast-llm POST http://localhost:12000/v1/chat/completions