add MCP raw filter support: body+path args, update mcp_filter demo handlers

This commit is contained in:
Adil Hafeez 2026-03-17 13:40:31 -07:00
parent d26abbfb9c
commit b88bdb94f2
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
16 changed files with 226 additions and 116 deletions

View file

@ -85,7 +85,7 @@ properties:
type: string type: string
default: default:
type: boolean type: boolean
filter_chain: input_filters:
type: array type: array
items: items:
type: string type: string

View file

@ -332,16 +332,39 @@ async fn handle_agent_chat_inner(
"processing agent" "processing agent"
); );
// Process the filter chain // Process input filters — serialize current request as OpenAI chat completions body,
let chat_history = pipeline_processor // pass raw bytes through each filter, then extract updated messages from the result.
.process_filter_chain( let chat_history = if selected_agent
&current_messages, .input_filters
.as_ref()
.map(|f| !f.is_empty())
.unwrap_or(false)
{
let filter_body = serde_json::json!({
"model": client_request.model(),
"messages": current_messages,
});
let filter_bytes =
serde_json::to_vec(&filter_body).map_err(PipelineError::ParseError)?;
let filtered_bytes = pipeline_processor
.process_raw_filter_chain(
&filter_bytes,
selected_agent, selected_agent,
&agent_map, &agent_map,
&request_headers, &request_headers,
"/v1/chat/completions",
) )
.await?; .await?;
let filtered_body: serde_json::Value =
serde_json::from_slice(&filtered_bytes).map_err(PipelineError::ParseError)?;
serde_json::from_value(filtered_body["messages"].clone())
.map_err(PipelineError::ParseError)?
} else {
current_messages.clone()
};
// Get agent details and invoke // Get agent details and invoke
let agent = agent_map.get(&agent_name).unwrap(); let agent = agent_map.get(&agent_name).unwrap();

View file

@ -187,7 +187,7 @@ mod tests {
id: name.to_string(), id: name.to_string(),
description: Some(description.to_string()), description: Some(description.to_string()),
default: Some(is_default), default: Some(is_default),
filter_chain: Some(vec![name.to_string()]), input_filters: Some(vec![name.to_string()]),
} }
} }

View file

@ -64,7 +64,7 @@ mod tests {
let agent_pipeline = AgentFilterChain { let agent_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(), id: "terminal-agent".to_string(),
filter_chain: Some(vec![ input_filters: Some(vec![
"filter-agent".to_string(), "filter-agent".to_string(),
"terminal-agent".to_string(), "terminal-agent".to_string(),
]), ]),
@ -110,7 +110,7 @@ mod tests {
// Create a pipeline with empty filter chain to avoid network calls // Create a pipeline with empty filter chain to avoid network calls
let test_pipeline = AgentFilterChain { let test_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(), id: "terminal-agent".to_string(),
filter_chain: Some(vec![]), // Empty filter chain - no network calls needed input_filters: Some(vec![]), // Empty filter chain - no network calls needed
description: None, description: None,
default: None, default: None,
}; };

View file

@ -279,7 +279,7 @@ async fn llm_chat_inner(
id: "model_listener".to_string(), id: "model_listener".to_string(),
default: None, default: None,
description: None, description: None,
filter_chain: Some(fc.clone()), input_filters: Some(fc.clone()),
}; };
let mut pipeline_processor = PipelineProcessor::default(); let mut pipeline_processor = PipelineProcessor::default();

View file

@ -84,7 +84,7 @@ impl PipelineProcessor {
// #[instrument( // #[instrument(
// skip(self, chat_history, agent_filter_chain, agent_map, request_headers), // skip(self, chat_history, agent_filter_chain, agent_map, request_headers),
// fields( // fields(
// filter_count = agent_filter_chain.filter_chain.as_ref().map(|fc| fc.len()).unwrap_or(0), // filter_count = agent_filter_chain.input_filters.as_ref().map(|fc| fc.len()).unwrap_or(0),
// message_count = chat_history.len() // message_count = chat_history.len()
// ) // )
// )] // )]
@ -99,7 +99,7 @@ impl PipelineProcessor {
let mut chat_history_updated = chat_history.to_vec(); let mut chat_history_updated = chat_history.to_vec();
// If filter_chain is None or empty, proceed without filtering // If filter_chain is None or empty, proceed without filtering
let filter_chain = match agent_filter_chain.filter_chain.as_ref() { let filter_chain = match agent_filter_chain.input_filters.as_ref() {
Some(fc) if !fc.is_empty() => fc, Some(fc) if !fc.is_empty() => fc,
_ => return Ok(chat_history_updated), _ => return Ok(chat_history_updated),
}; };
@ -283,6 +283,30 @@ impl PipelineProcessor {
}) })
} }
/// Build a tools/call JSON-RPC request with a full body dict and path hint.
/// Used by execute_mcp_filter_raw so MCP tools receive the same contract as HTTP filters.
fn build_tool_call_request_with_body(
&self,
tool_name: &str,
body: &serde_json::Value,
path: &str,
) -> Result<JsonRpcRequest, PipelineError> {
let mut arguments = HashMap::new();
arguments.insert("body".to_string(), serde_json::to_value(body)?);
arguments.insert("path".to_string(), serde_json::to_value(path)?);
let mut params = HashMap::new();
params.insert("name".to_string(), serde_json::to_value(tool_name)?);
params.insert("arguments".to_string(), serde_json::to_value(arguments)?);
Ok(JsonRpcRequest {
jsonrpc: JSON_RPC_VERSION.to_string(),
id: JsonRpcId::String(Uuid::new_v4().to_string()),
method: TOOL_CALL_METHOD.to_string(),
params: Some(params),
})
}
/// Send request to a specific agent and return the response content /// Send request to a specific agent and return the response content
#[instrument( #[instrument(
skip(self, messages, agent, request_headers), skip(self, messages, agent, request_headers),
@ -406,6 +430,106 @@ impl PipelineProcessor {
Ok(messages) Ok(messages)
} }
/// Like execute_mcp_filter but passes the full raw body dict + path hint as MCP tool arguments.
/// The MCP tool receives (body: dict, path: str) and returns the modified body dict.
async fn execute_mcp_filter_raw(
&mut self,
raw_bytes: &[u8],
agent: &Agent,
request_headers: &HeaderMap,
request_path: &str,
) -> Result<Bytes, PipelineError> {
set_service_name(operation_component::AGENT_FILTER);
use opentelemetry::trace::get_active_span;
get_active_span(|span| {
span.update_name(format!("execute_mcp_filter_raw ({})", agent.id));
});
let body: serde_json::Value =
serde_json::from_slice(raw_bytes).map_err(PipelineError::ParseError)?;
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
session_id.clone()
} else {
let session_id = self.get_new_session_id(&agent.id, request_headers).await;
self.agent_id_session_map
.insert(agent.id.clone(), session_id.clone());
session_id
};
info!(
"Using MCP session ID {} for agent {}",
mcp_session_id, agent.id
);
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
let json_rpc_request =
self.build_tool_call_request_with_body(tool_name, &body, request_path)?;
let agent_headers =
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id))?;
let response = self
.send_mcp_request(&json_rpc_request, &agent_headers, &agent.id)
.await?;
let http_status = response.status();
let response_bytes = response.bytes().await?;
if !http_status.is_success() {
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
return Err(if http_status.is_client_error() {
PipelineError::ClientError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
} else {
PipelineError::ServerError {
agent: agent.id.clone(),
status: http_status.as_u16(),
body: error_body,
}
});
}
let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?;
let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?;
let response_result = response
.result
.ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?;
if response_result
.get("isError")
.and_then(|v| v.as_bool())
.unwrap_or(false)
{
let error_message = response_result
.get("content")
.and_then(|v| v.as_array())
.and_then(|arr| arr.first())
.and_then(|v| v.get("text"))
.and_then(|v| v.as_str())
.unwrap_or("unknown_error")
.to_string();
return Err(PipelineError::ClientError {
agent: agent.id.clone(),
status: hyper::StatusCode::BAD_REQUEST.as_u16(),
body: error_message,
});
}
let result = response_result
.get("structuredContent")
.and_then(|v| v.get("result"))
.cloned()
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
Ok(Bytes::from(
serde_json::to_vec(&result).map_err(PipelineError::ParseError)?,
))
}
/// Build an initialize JSON-RPC request /// Build an initialize JSON-RPC request
fn build_initialize_request(&self) -> JsonRpcRequest { fn build_initialize_request(&self) -> JsonRpcRequest {
JsonRpcRequest { JsonRpcRequest {
@ -708,7 +832,7 @@ impl PipelineProcessor {
request_headers: &HeaderMap, request_headers: &HeaderMap,
request_path: &str, request_path: &str,
) -> Result<Bytes, PipelineError> { ) -> Result<Bytes, PipelineError> {
let filter_chain = match agent_filter_chain.filter_chain.as_ref() { let filter_chain = match agent_filter_chain.input_filters.as_ref() {
Some(fc) if !fc.is_empty() => fc, Some(fc) if !fc.is_empty() => fc,
_ => return Ok(Bytes::copy_from_slice(raw_bytes)), _ => return Ok(Bytes::copy_from_slice(raw_bytes)),
}; };
@ -722,16 +846,22 @@ impl PipelineProcessor {
.get(agent_name) .get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?; .ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
let agent_type = agent.agent_type.as_deref().unwrap_or("mcp");
info!( info!(
agent = %agent_name, agent = %agent_name,
url = %agent.url, url = %agent.url,
agent_type = %agent_type,
bytes_len = current_bytes.len(), bytes_len = current_bytes.len(),
"executing raw filter" "executing raw filter"
); );
current_bytes = self current_bytes = if agent_type == "mcp" {
.execute_raw_filter(&current_bytes, agent, request_headers, request_path) self.execute_mcp_filter_raw(&current_bytes, agent, request_headers, request_path)
.await?; .await?
} else {
self.execute_raw_filter(&current_bytes, agent, request_headers, request_path)
.await?
};
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed"); info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
} }
@ -812,7 +942,7 @@ mod tests {
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain { fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
AgentFilterChain { AgentFilterChain {
id: "test-agent".to_string(), id: "test-agent".to_string(),
filter_chain: Some(agents.iter().map(|s| s.to_string()).collect()), input_filters: Some(agents.iter().map(|s| s.to_string()).collect()),
description: None, description: None,
default: None, default: None,
} }

View file

@ -308,7 +308,7 @@ where
id: "output_filter".to_string(), id: "output_filter".to_string(),
default: None, default: None,
description: None, description: None,
filter_chain: Some(output_filters), input_filters: Some(output_filters),
}; };
while let Some(item) = byte_stream.next().await { while let Some(item) = byte_stream.next().await {

View file

@ -27,7 +27,7 @@ pub struct AgentFilterChain {
pub id: String, pub id: String,
pub default: Option<bool>, pub default: Option<bool>,
pub description: Option<String>, pub description: Option<String>,
pub filter_chain: Option<Vec<String>>, pub input_filters: Option<Vec<String>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]

View file

@ -42,7 +42,7 @@ listeners:
agents: agents:
- id: rag_agent - id: rag_agent
description: virtual assistant for retrieval augmented generation tasks description: virtual assistant for retrieval augmented generation tasks
filter_chain: input_filters:
- input_guards - input_guards
- query_rewriter - query_rewriter
- context_builder - context_builder

View file

@ -195,11 +195,11 @@ async def augment_query_with_context(
load_knowledge_base() load_knowledge_base()
@app.post("/") @app.post("/{path:path}")
async def context_builder( async def context_builder(path: str, request: Request) -> dict:
messages: List[ChatMessage], request: Request
) -> List[ChatMessage]:
"""MCP tool that augments user queries with relevant context from the knowledge base.""" """MCP tool that augments user queries with relevant context from the knowledge base."""
body = await request.json()
messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received chat completion request with {len(messages)} messages") logger.info(f"Received chat completion request with {len(messages)} messages")
# Get traceparent header from MCP request # Get traceparent header from MCP request
@ -219,8 +219,7 @@ async def context_builder(
messages, traceparent_header, request_id messages, traceparent_header, request_id
) )
# Return as dict to minimize text serialization return {**body, "messages": [{"role": msg.role, "content": msg.content} for msg in updated_messages]}
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
# Register MCP tool only if mcp is available # Register MCP tool only if mcp is available

View file

@ -126,14 +126,14 @@ Respond in JSON format:
# @mcp.tool # @mcp.tool
@app.post("/") @app.post("/{path:path}")
async def input_guards( async def input_guards(path: str, request: Request) -> dict:
messages: List[ChatMessage], request: Request
) -> List[ChatMessage]:
"""Input guard that validates queries are within TechCorp's domain. """Input guard that validates queries are within TechCorp's domain.
If the query is out of scope, replaces the user message with a rejection notice. If the query is out of scope, replaces the user message with a rejection notice.
""" """
body = await request.json()
messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received request with {len(messages)} messages") logger.info(f"Received request with {len(messages)} messages")
# Get traceparent header from HTTP request using FastMCP's dependency function # Get traceparent header from HTTP request using FastMCP's dependency function
@ -164,7 +164,7 @@ async def input_guards(
) )
logger.info("Query validation passed - forwarding to next filter") logger.info("Query validation passed - forwarding to next filter")
return messages return body
@app.get("/health") @app.get("/health")

View file

@ -81,11 +81,11 @@ Return only the rewritten query, nothing else."""
return "" return ""
@app.post("/") @app.post("/{path:path}")
async def query_rewriter_http( async def query_rewriter_http(path: str, request: Request) -> dict:
messages: List[ChatMessage], request: Request
) -> List[ChatMessage]:
"""HTTP filter endpoint used by Plano (type: http).""" """HTTP filter endpoint used by Plano (type: http)."""
body = await request.json()
messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received request with {len(messages)} messages") logger.info(f"Received request with {len(messages)} messages")
traceparent_header = request.headers.get("traceparent") traceparent_header = request.headers.get("traceparent")
@ -99,25 +99,20 @@ async def query_rewriter_http(
rewritten_query = await rewrite_query_with_plano( rewritten_query = await rewrite_query_with_plano(
messages, traceparent_header, request_id messages, traceparent_header, request_id
) )
# Create updated messages with the rewritten query
updated_messages = messages.copy()
# Find and update the last user message with the rewritten query # Find and update the last user message with the rewritten query
updated_messages = [m.model_dump() for m in messages]
for i in range(len(updated_messages) - 1, -1, -1): for i in range(len(updated_messages) - 1, -1, -1):
if updated_messages[i].role == "user": if updated_messages[i]["role"] == "user":
original_query = updated_messages[i].content original_query = updated_messages[i]["content"]
updated_messages[i] = ChatMessage(role="user", content=rewritten_query) updated_messages[i]["content"] = rewritten_query
logger.info( logger.info(
f"Updated user query from '{original_query}' to '{rewritten_query}'" f"Updated user query from '{original_query}' to '{rewritten_query}'"
) )
break break
updated_messages_data = [
{"role": msg.role, "content": msg.content} for msg in updated_messages
]
updated_messages = [ChatMessage(**msg) for msg in updated_messages_data]
logger.info("Returning rewritten chat completion response") logger.info("Returning rewritten chat completion response")
return updated_messages return {**body, "messages": updated_messages}
@app.get("/health") @app.get("/health")

View file

@ -39,7 +39,7 @@ listeners:
agents: agents:
- id: rag_agent - id: rag_agent
description: virtual assistant for retrieval augmented generation tasks description: virtual assistant for retrieval augmented generation tasks
filter_chain: input_filters:
- input_guards - input_guards
- query_rewriter - query_rewriter
- context_builder - context_builder

View file

@ -195,9 +195,14 @@ async def augment_query_with_context(
load_knowledge_base() load_knowledge_base()
async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]: async def context_builder(body: dict, path: str) -> dict:
"""MCP tool that augments user queries with relevant context from the knowledge base.""" """MCP tool that augments user queries with relevant context from the knowledge base.
logger.info(f"Received chat completion request with {len(messages)} messages")
Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
Returns the body with the last user message augmented with retrieved context.
"""
messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received request with {len(messages)} messages at path {path}")
# Get traceparent header from MCP request # Get traceparent header from MCP request
headers = get_http_headers() headers = get_http_headers()
@ -215,8 +220,7 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
messages, traceparent_header, request_id messages, traceparent_header, request_id
) )
# Return as dict to minimize text serialization return {**body, "messages": [{"role": msg.role, "content": msg.content} for msg in updated_messages]}
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
# Register MCP tool only if mcp is available # Register MCP tool only if mcp is available

View file

@ -3,13 +3,12 @@ import json
import time import time
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
import uuid import uuid
from fastapi import FastAPI, Depends, Request
from fastmcp.exceptions import ToolError from fastmcp.exceptions import ToolError
from openai import AsyncOpenAI from openai import AsyncOpenAI
import os import os
import logging import logging
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage from .api import ChatMessage
from . import mcp from . import mcp
from fastmcp.server.dependencies import get_http_headers from fastmcp.server.dependencies import get_http_headers
@ -30,8 +29,6 @@ plano_client = AsyncOpenAI(
api_key="EMPTY", # Plano doesn't require a real API key api_key="EMPTY", # Plano doesn't require a real API key
) )
app = FastAPI()
async def validate_query_scope( async def validate_query_scope(
messages: List[ChatMessage], messages: List[ChatMessage],
@ -127,12 +124,14 @@ Respond in JSON format:
@mcp.tool @mcp.tool
async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]: async def input_guards(body: dict, path: str) -> dict:
"""Input guard that validates queries are within TechCorp's domain. """Input guard that validates queries are within TechCorp's domain.
If the query is out of scope, replaces the user message with a rejection notice. Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
If the query is out of scope, raises a ToolError to block the request.
""" """
logger.info(f"Received request with {len(messages)} messages") messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received request with {len(messages)} messages at path {path}")
# Get traceparent header from HTTP request using FastMCP's dependency function # Get traceparent header from HTTP request using FastMCP's dependency function
headers = get_http_headers() headers = get_http_headers()
@ -153,9 +152,8 @@ async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]:
reason = validation_result.get("reason", "Query is outside TechCorp's domain") reason = validation_result.get("reason", "Query is outside TechCorp's domain")
logger.warning(f"Query rejected: {reason}") logger.warning(f"Query rejected: {reason}")
# Throw ToolError
error_message = f"I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. {reason}\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support." error_message = f"I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. {reason}\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support."
raise ToolError(error_message) raise ToolError(error_message)
logger.info("Query validation passed - forwarding to next filter") logger.info("Query validation passed - forwarding to next filter")
return messages return body

View file

@ -3,12 +3,11 @@ import json
import time import time
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
import uuid import uuid
from fastapi import FastAPI, Depends, Request
from openai import AsyncOpenAI from openai import AsyncOpenAI
import os import os
import logging import logging
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage from .api import ChatMessage
from . import mcp from . import mcp
from fastmcp.server.dependencies import get_http_headers from fastmcp.server.dependencies import get_http_headers
@ -29,9 +28,6 @@ plano_client = AsyncOpenAI(
api_key="EMPTY", # Plano doesn't require a real API key api_key="EMPTY", # Plano doesn't require a real API key
) )
app = FastAPI()
async def rewrite_query_with_plano( async def rewrite_query_with_plano(
messages: List[ChatMessage], messages: List[ChatMessage],
traceparent_header: str, traceparent_header: str,
@ -87,12 +83,14 @@ async def rewrite_query_with_plano(
return "" return ""
async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]: async def query_rewriter(body: dict, path: str) -> dict:
"""Chat completions endpoint that rewrites the last user query using Plano. """Rewrites the last user query in the request body using Plano.
Returns a dict with a 'messages' key containing the updated message list. Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
Returns the body with the last user message rewritten for better retrieval.
""" """
logger.info(f"Received chat completion request with {len(messages)} messages") messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received request with {len(messages)} messages at path {path}")
# Get traceparent header from HTTP request using FastMCP's dependency function # Get traceparent header from HTTP request using FastMCP's dependency function
headers = get_http_headers() headers = get_http_headers()
@ -109,57 +107,20 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
messages, traceparent_header, request_id messages, traceparent_header, request_id
) )
# Create updated messages with the rewritten query
updated_messages = messages.copy()
# Find and update the last user message with the rewritten query # Find and update the last user message with the rewritten query
updated_messages = [m.model_dump() for m in messages]
for i in range(len(updated_messages) - 1, -1, -1): for i in range(len(updated_messages) - 1, -1, -1):
if updated_messages[i].role == "user": if updated_messages[i]["role"] == "user":
original_query = updated_messages[i].content
updated_messages[i] = ChatMessage(role="user", content=rewritten_query)
logger.info( logger.info(
f"Updated user query from '{original_query}' to '{rewritten_query}'" f"Updated user query from '{updated_messages[i]['content']}' to '{rewritten_query}'"
) )
updated_messages[i]["content"] = rewritten_query
break break
# Return as dict to minimize text serialization logger.info("Returning rewritten chat completion response")
return [{"role": msg.role, "content": msg.content} for msg in updated_messages] return {**body, "messages": updated_messages}
# Register MCP tool only if mcp is available # Register MCP tool only if mcp is available
if mcp is not None: if mcp is not None:
mcp.tool()(query_rewriter) mcp.tool()(query_rewriter)
@app.post("/")
async def chat_completions_endpoint(
request_messages: List[ChatMessage], request: Request
) -> List[ChatMessage]:
"""FastAPI endpoint for chat completions with query rewriting."""
logger.info(
f"Received /v1/chat/completions request with {len(request_messages)} messages"
)
# Extract traceparent header
traceparent_header = request.headers.get("traceparent")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
else:
logger.info("No traceparent header found")
# Call the query rewriter tool
updated_messages_data = await query_rewriter(request_messages)
# Convert back to ChatMessage objects
updated_messages = [ChatMessage(**msg) for msg in updated_messages_data]
logger.info("Returning rewritten chat completion response")
return updated_messages
def start_server(host: str = "0.0.0.0", port: int = 10501):
"""Start the FastAPI server for query rewriter."""
import uvicorn
logger.info(f"Starting Query Rewriter REST server on {host}:{port}")
uvicorn.run(app, host=host, port=port)