add MCP raw filter support: body+path args, update mcp_filter demo handlers

This commit is contained in:
Adil Hafeez 2026-03-17 13:40:31 -07:00
parent d26abbfb9c
commit b88bdb94f2
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
16 changed files with 226 additions and 116 deletions

View file

@ -42,7 +42,7 @@ listeners:
agents:
- id: rag_agent
description: virtual assistant for retrieval augmented generation tasks
filter_chain:
input_filters:
- input_guards
- query_rewriter
- context_builder

View file

@ -195,11 +195,11 @@ async def augment_query_with_context(
load_knowledge_base()
@app.post("/")
async def context_builder(
messages: List[ChatMessage], request: Request
) -> List[ChatMessage]:
@app.post("/{path:path}")
async def context_builder(path: str, request: Request) -> dict:
"""MCP tool that augments user queries with relevant context from the knowledge base."""
body = await request.json()
messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received chat completion request with {len(messages)} messages")
# Get traceparent header from MCP request
@ -219,8 +219,7 @@ async def context_builder(
messages, traceparent_header, request_id
)
# Return as dict to minimize text serialization
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
return {**body, "messages": [{"role": msg.role, "content": msg.content} for msg in updated_messages]}
# Register MCP tool only if mcp is available

View file

@ -126,14 +126,14 @@ Respond in JSON format:
# @mcp.tool
@app.post("/")
async def input_guards(
messages: List[ChatMessage], request: Request
) -> List[ChatMessage]:
@app.post("/{path:path}")
async def input_guards(path: str, request: Request) -> dict:
"""Input guard that validates queries are within TechCorp's domain.
If the query is out of scope, replaces the user message with a rejection notice.
"""
body = await request.json()
messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received request with {len(messages)} messages")
# Get traceparent header from HTTP request using FastMCP's dependency function
@ -164,7 +164,7 @@ async def input_guards(
)
logger.info("Query validation passed - forwarding to next filter")
return messages
return body
@app.get("/health")

View file

@ -81,11 +81,11 @@ Return only the rewritten query, nothing else."""
return ""
@app.post("/")
async def query_rewriter_http(
messages: List[ChatMessage], request: Request
) -> List[ChatMessage]:
@app.post("/{path:path}")
async def query_rewriter_http(path: str, request: Request) -> dict:
"""HTTP filter endpoint used by Plano (type: http)."""
body = await request.json()
messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received request with {len(messages)} messages")
traceparent_header = request.headers.get("traceparent")
@ -99,25 +99,20 @@ async def query_rewriter_http(
rewritten_query = await rewrite_query_with_plano(
messages, traceparent_header, request_id
)
# Create updated messages with the rewritten query
updated_messages = messages.copy()
# Find and update the last user message with the rewritten query
updated_messages = [m.model_dump() for m in messages]
for i in range(len(updated_messages) - 1, -1, -1):
if updated_messages[i].role == "user":
original_query = updated_messages[i].content
updated_messages[i] = ChatMessage(role="user", content=rewritten_query)
if updated_messages[i]["role"] == "user":
original_query = updated_messages[i]["content"]
updated_messages[i]["content"] = rewritten_query
logger.info(
f"Updated user query from '{original_query}' to '{rewritten_query}'"
)
break
updated_messages_data = [
{"role": msg.role, "content": msg.content} for msg in updated_messages
]
updated_messages = [ChatMessage(**msg) for msg in updated_messages_data]
logger.info("Returning rewritten chat completion response")
return updated_messages
return {**body, "messages": updated_messages}
@app.get("/health")

View file

@ -39,7 +39,7 @@ listeners:
agents:
- id: rag_agent
description: virtual assistant for retrieval augmented generation tasks
filter_chain:
input_filters:
- input_guards
- query_rewriter
- context_builder

View file

@ -195,9 +195,14 @@ async def augment_query_with_context(
load_knowledge_base()
async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
"""MCP tool that augments user queries with relevant context from the knowledge base."""
logger.info(f"Received chat completion request with {len(messages)} messages")
async def context_builder(body: dict, path: str) -> dict:
"""MCP tool that augments user queries with relevant context from the knowledge base.
Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
Returns the body with the last user message augmented with retrieved context.
"""
messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received request with {len(messages)} messages at path {path}")
# Get traceparent header from MCP request
headers = get_http_headers()
@ -215,8 +220,7 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
messages, traceparent_header, request_id
)
# Return as dict to minimize text serialization
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
return {**body, "messages": [{"role": msg.role, "content": msg.content} for msg in updated_messages]}
# Register MCP tool only if mcp is available

View file

@ -3,13 +3,12 @@ import json
import time
from typing import List, Optional, Dict, Any
import uuid
from fastapi import FastAPI, Depends, Request
from fastmcp.exceptions import ToolError
from openai import AsyncOpenAI
import os
import logging
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage
from .api import ChatMessage
from . import mcp
from fastmcp.server.dependencies import get_http_headers
@ -30,8 +29,6 @@ plano_client = AsyncOpenAI(
api_key="EMPTY", # Plano doesn't require a real API key
)
app = FastAPI()
async def validate_query_scope(
messages: List[ChatMessage],
@ -127,12 +124,14 @@ Respond in JSON format:
@mcp.tool
async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]:
async def input_guards(body: dict, path: str) -> dict:
"""Input guard that validates queries are within TechCorp's domain.
If the query is out of scope, replaces the user message with a rejection notice.
Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
If the query is out of scope, raises a ToolError to block the request.
"""
logger.info(f"Received request with {len(messages)} messages")
messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received request with {len(messages)} messages at path {path}")
# Get traceparent header from HTTP request using FastMCP's dependency function
headers = get_http_headers()
@ -153,9 +152,8 @@ async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]:
reason = validation_result.get("reason", "Query is outside TechCorp's domain")
logger.warning(f"Query rejected: {reason}")
# Throw ToolError
error_message = f"I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. {reason}\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support."
raise ToolError(error_message)
logger.info("Query validation passed - forwarding to next filter")
return messages
return body

View file

@ -3,12 +3,11 @@ import json
import time
from typing import List, Optional, Dict, Any
import uuid
from fastapi import FastAPI, Depends, Request
from openai import AsyncOpenAI
import os
import logging
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage
from .api import ChatMessage
from . import mcp
from fastmcp.server.dependencies import get_http_headers
@ -29,9 +28,6 @@ plano_client = AsyncOpenAI(
api_key="EMPTY", # Plano doesn't require a real API key
)
app = FastAPI()
async def rewrite_query_with_plano(
messages: List[ChatMessage],
traceparent_header: str,
@ -87,12 +83,14 @@ async def rewrite_query_with_plano(
return ""
async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
"""Chat completions endpoint that rewrites the last user query using Plano.
async def query_rewriter(body: dict, path: str) -> dict:
"""Rewrites the last user query in the request body using Plano.
Returns a dict with a 'messages' key containing the updated message list.
Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
Returns the body with the last user message rewritten for better retrieval.
"""
logger.info(f"Received chat completion request with {len(messages)} messages")
messages = [ChatMessage(**m) for m in body.get("messages", [])]
logger.info(f"Received request with {len(messages)} messages at path {path}")
# Get traceparent header from HTTP request using FastMCP's dependency function
headers = get_http_headers()
@ -109,57 +107,20 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
messages, traceparent_header, request_id
)
# Create updated messages with the rewritten query
updated_messages = messages.copy()
# Find and update the last user message with the rewritten query
updated_messages = [m.model_dump() for m in messages]
for i in range(len(updated_messages) - 1, -1, -1):
if updated_messages[i].role == "user":
original_query = updated_messages[i].content
updated_messages[i] = ChatMessage(role="user", content=rewritten_query)
if updated_messages[i]["role"] == "user":
logger.info(
f"Updated user query from '{original_query}' to '{rewritten_query}'"
f"Updated user query from '{updated_messages[i]['content']}' to '{rewritten_query}'"
)
updated_messages[i]["content"] = rewritten_query
break
# Return as dict to minimize text serialization
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
logger.info("Returning rewritten chat completion response")
return {**body, "messages": updated_messages}
# Register MCP tool only if mcp is available
if mcp is not None:
mcp.tool()(query_rewriter)
@app.post("/")
async def chat_completions_endpoint(
request_messages: List[ChatMessage], request: Request
) -> List[ChatMessage]:
"""FastAPI endpoint for chat completions with query rewriting."""
logger.info(
f"Received /v1/chat/completions request with {len(request_messages)} messages"
)
# Extract traceparent header
traceparent_header = request.headers.get("traceparent")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
else:
logger.info("No traceparent header found")
# Call the query rewriter tool
updated_messages_data = await query_rewriter(request_messages)
# Convert back to ChatMessage objects
updated_messages = [ChatMessage(**msg) for msg in updated_messages_data]
logger.info("Returning rewritten chat completion response")
return updated_messages
def start_server(host: str = "0.0.0.0", port: int = 10501):
"""Start the FastAPI server for query rewriter."""
import uvicorn
logger.info(f"Starting Query Rewriter REST server on {host}:{port}")
uvicorn.run(app, host=host, port=port)