mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
add MCP raw filter support: body+path args, update mcp_filter demo handlers
This commit is contained in:
parent
d26abbfb9c
commit
b88bdb94f2
16 changed files with 226 additions and 116 deletions
|
|
@ -42,7 +42,7 @@ listeners:
|
|||
agents:
|
||||
- id: rag_agent
|
||||
description: virtual assistant for retrieval augmented generation tasks
|
||||
filter_chain:
|
||||
input_filters:
|
||||
- input_guards
|
||||
- query_rewriter
|
||||
- context_builder
|
||||
|
|
|
|||
|
|
@ -195,11 +195,11 @@ async def augment_query_with_context(
|
|||
load_knowledge_base()
|
||||
|
||||
|
||||
@app.post("/")
|
||||
async def context_builder(
|
||||
messages: List[ChatMessage], request: Request
|
||||
) -> List[ChatMessage]:
|
||||
@app.post("/{path:path}")
|
||||
async def context_builder(path: str, request: Request) -> dict:
|
||||
"""MCP tool that augments user queries with relevant context from the knowledge base."""
|
||||
body = await request.json()
|
||||
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||
logger.info(f"Received chat completion request with {len(messages)} messages")
|
||||
|
||||
# Get traceparent header from MCP request
|
||||
|
|
@ -219,8 +219,7 @@ async def context_builder(
|
|||
messages, traceparent_header, request_id
|
||||
)
|
||||
|
||||
# Return as dict to minimize text serialization
|
||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
||||
return {**body, "messages": [{"role": msg.role, "content": msg.content} for msg in updated_messages]}
|
||||
|
||||
|
||||
# Register MCP tool only if mcp is available
|
||||
|
|
|
|||
|
|
@ -126,14 +126,14 @@ Respond in JSON format:
|
|||
|
||||
|
||||
# @mcp.tool
|
||||
@app.post("/")
|
||||
async def input_guards(
|
||||
messages: List[ChatMessage], request: Request
|
||||
) -> List[ChatMessage]:
|
||||
@app.post("/{path:path}")
|
||||
async def input_guards(path: str, request: Request) -> dict:
|
||||
"""Input guard that validates queries are within TechCorp's domain.
|
||||
|
||||
If the query is out of scope, replaces the user message with a rejection notice.
|
||||
"""
|
||||
body = await request.json()
|
||||
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||
logger.info(f"Received request with {len(messages)} messages")
|
||||
|
||||
# Get traceparent header from HTTP request using FastMCP's dependency function
|
||||
|
|
@ -164,7 +164,7 @@ async def input_guards(
|
|||
)
|
||||
|
||||
logger.info("Query validation passed - forwarding to next filter")
|
||||
return messages
|
||||
return body
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
|
|
|||
|
|
@ -81,11 +81,11 @@ Return only the rewritten query, nothing else."""
|
|||
return ""
|
||||
|
||||
|
||||
@app.post("/")
|
||||
async def query_rewriter_http(
|
||||
messages: List[ChatMessage], request: Request
|
||||
) -> List[ChatMessage]:
|
||||
@app.post("/{path:path}")
|
||||
async def query_rewriter_http(path: str, request: Request) -> dict:
|
||||
"""HTTP filter endpoint used by Plano (type: http)."""
|
||||
body = await request.json()
|
||||
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||
logger.info(f"Received request with {len(messages)} messages")
|
||||
|
||||
traceparent_header = request.headers.get("traceparent")
|
||||
|
|
@ -99,25 +99,20 @@ async def query_rewriter_http(
|
|||
rewritten_query = await rewrite_query_with_plano(
|
||||
messages, traceparent_header, request_id
|
||||
)
|
||||
# Create updated messages with the rewritten query
|
||||
updated_messages = messages.copy()
|
||||
|
||||
# Find and update the last user message with the rewritten query
|
||||
updated_messages = [m.model_dump() for m in messages]
|
||||
for i in range(len(updated_messages) - 1, -1, -1):
|
||||
if updated_messages[i].role == "user":
|
||||
original_query = updated_messages[i].content
|
||||
updated_messages[i] = ChatMessage(role="user", content=rewritten_query)
|
||||
if updated_messages[i]["role"] == "user":
|
||||
original_query = updated_messages[i]["content"]
|
||||
updated_messages[i]["content"] = rewritten_query
|
||||
logger.info(
|
||||
f"Updated user query from '{original_query}' to '{rewritten_query}'"
|
||||
)
|
||||
break
|
||||
updated_messages_data = [
|
||||
{"role": msg.role, "content": msg.content} for msg in updated_messages
|
||||
]
|
||||
updated_messages = [ChatMessage(**msg) for msg in updated_messages_data]
|
||||
|
||||
logger.info("Returning rewritten chat completion response")
|
||||
return updated_messages
|
||||
return {**body, "messages": updated_messages}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ listeners:
|
|||
agents:
|
||||
- id: rag_agent
|
||||
description: virtual assistant for retrieval augmented generation tasks
|
||||
filter_chain:
|
||||
input_filters:
|
||||
- input_guards
|
||||
- query_rewriter
|
||||
- context_builder
|
||||
|
|
|
|||
|
|
@ -195,9 +195,14 @@ async def augment_query_with_context(
|
|||
load_knowledge_base()
|
||||
|
||||
|
||||
async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
"""MCP tool that augments user queries with relevant context from the knowledge base."""
|
||||
logger.info(f"Received chat completion request with {len(messages)} messages")
|
||||
async def context_builder(body: dict, path: str) -> dict:
|
||||
"""MCP tool that augments user queries with relevant context from the knowledge base.
|
||||
|
||||
Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
|
||||
Returns the body with the last user message augmented with retrieved context.
|
||||
"""
|
||||
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||
logger.info(f"Received request with {len(messages)} messages at path {path}")
|
||||
|
||||
# Get traceparent header from MCP request
|
||||
headers = get_http_headers()
|
||||
|
|
@ -215,8 +220,7 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
|
|||
messages, traceparent_header, request_id
|
||||
)
|
||||
|
||||
# Return as dict to minimize text serialization
|
||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
||||
return {**body, "messages": [{"role": msg.role, "content": msg.content} for msg in updated_messages]}
|
||||
|
||||
|
||||
# Register MCP tool only if mcp is available
|
||||
|
|
|
|||
|
|
@ -3,13 +3,12 @@ import json
|
|||
import time
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
from fastapi import FastAPI, Depends, Request
|
||||
from fastmcp.exceptions import ToolError
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
|
||||
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage
|
||||
from .api import ChatMessage
|
||||
from . import mcp
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
|
|
@ -30,8 +29,6 @@ plano_client = AsyncOpenAI(
|
|||
api_key="EMPTY", # Plano doesn't require a real API key
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
async def validate_query_scope(
|
||||
messages: List[ChatMessage],
|
||||
|
|
@ -127,12 +124,14 @@ Respond in JSON format:
|
|||
|
||||
|
||||
@mcp.tool
|
||||
async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
async def input_guards(body: dict, path: str) -> dict:
|
||||
"""Input guard that validates queries are within TechCorp's domain.
|
||||
|
||||
If the query is out of scope, replaces the user message with a rejection notice.
|
||||
Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
|
||||
If the query is out of scope, raises a ToolError to block the request.
|
||||
"""
|
||||
logger.info(f"Received request with {len(messages)} messages")
|
||||
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||
logger.info(f"Received request with {len(messages)} messages at path {path}")
|
||||
|
||||
# Get traceparent header from HTTP request using FastMCP's dependency function
|
||||
headers = get_http_headers()
|
||||
|
|
@ -153,9 +152,8 @@ async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]:
|
|||
reason = validation_result.get("reason", "Query is outside TechCorp's domain")
|
||||
logger.warning(f"Query rejected: {reason}")
|
||||
|
||||
# Throw ToolError
|
||||
error_message = f"I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. {reason}\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support."
|
||||
raise ToolError(error_message)
|
||||
|
||||
logger.info("Query validation passed - forwarding to next filter")
|
||||
return messages
|
||||
return body
|
||||
|
|
|
|||
|
|
@ -3,12 +3,11 @@ import json
|
|||
import time
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
from fastapi import FastAPI, Depends, Request
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
|
||||
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage
|
||||
from .api import ChatMessage
|
||||
from . import mcp
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
|
|
@ -29,9 +28,6 @@ plano_client = AsyncOpenAI(
|
|||
api_key="EMPTY", # Plano doesn't require a real API key
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
async def rewrite_query_with_plano(
|
||||
messages: List[ChatMessage],
|
||||
traceparent_header: str,
|
||||
|
|
@ -87,12 +83,14 @@ async def rewrite_query_with_plano(
|
|||
return ""
|
||||
|
||||
|
||||
async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
"""Chat completions endpoint that rewrites the last user query using Plano.
|
||||
async def query_rewriter(body: dict, path: str) -> dict:
|
||||
"""Rewrites the last user query in the request body using Plano.
|
||||
|
||||
Returns a dict with a 'messages' key containing the updated message list.
|
||||
Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
|
||||
Returns the body with the last user message rewritten for better retrieval.
|
||||
"""
|
||||
logger.info(f"Received chat completion request with {len(messages)} messages")
|
||||
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||
logger.info(f"Received request with {len(messages)} messages at path {path}")
|
||||
|
||||
# Get traceparent header from HTTP request using FastMCP's dependency function
|
||||
headers = get_http_headers()
|
||||
|
|
@ -109,57 +107,20 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
|
|||
messages, traceparent_header, request_id
|
||||
)
|
||||
|
||||
# Create updated messages with the rewritten query
|
||||
updated_messages = messages.copy()
|
||||
|
||||
# Find and update the last user message with the rewritten query
|
||||
updated_messages = [m.model_dump() for m in messages]
|
||||
for i in range(len(updated_messages) - 1, -1, -1):
|
||||
if updated_messages[i].role == "user":
|
||||
original_query = updated_messages[i].content
|
||||
updated_messages[i] = ChatMessage(role="user", content=rewritten_query)
|
||||
if updated_messages[i]["role"] == "user":
|
||||
logger.info(
|
||||
f"Updated user query from '{original_query}' to '{rewritten_query}'"
|
||||
f"Updated user query from '{updated_messages[i]['content']}' to '{rewritten_query}'"
|
||||
)
|
||||
updated_messages[i]["content"] = rewritten_query
|
||||
break
|
||||
|
||||
# Return as dict to minimize text serialization
|
||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
||||
logger.info("Returning rewritten chat completion response")
|
||||
return {**body, "messages": updated_messages}
|
||||
|
||||
|
||||
# Register MCP tool only if mcp is available
|
||||
if mcp is not None:
|
||||
mcp.tool()(query_rewriter)
|
||||
|
||||
|
||||
@app.post("/")
|
||||
async def chat_completions_endpoint(
|
||||
request_messages: List[ChatMessage], request: Request
|
||||
) -> List[ChatMessage]:
|
||||
"""FastAPI endpoint for chat completions with query rewriting."""
|
||||
logger.info(
|
||||
f"Received /v1/chat/completions request with {len(request_messages)} messages"
|
||||
)
|
||||
|
||||
# Extract traceparent header
|
||||
traceparent_header = request.headers.get("traceparent")
|
||||
if traceparent_header:
|
||||
logger.info(f"Received traceparent header: {traceparent_header}")
|
||||
else:
|
||||
logger.info("No traceparent header found")
|
||||
|
||||
# Call the query rewriter tool
|
||||
updated_messages_data = await query_rewriter(request_messages)
|
||||
|
||||
# Convert back to ChatMessage objects
|
||||
updated_messages = [ChatMessage(**msg) for msg in updated_messages_data]
|
||||
|
||||
logger.info("Returning rewritten chat completion response")
|
||||
return updated_messages
|
||||
|
||||
|
||||
def start_server(host: str = "0.0.0.0", port: int = 10501):
|
||||
"""Start the FastAPI server for query rewriter."""
|
||||
import uvicorn
|
||||
|
||||
logger.info(f"Starting Query Rewriter REST server on {host}:{port}")
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue