mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
cleanup demo
This commit is contained in:
parent
58cf6b29bc
commit
ab185b241a
5 changed files with 36 additions and 165 deletions
|
|
@ -3,15 +3,13 @@ version: v0.3.0
|
|||
agents:
|
||||
- id: rag_agent
|
||||
url: http://host.docker.internal:10505
|
||||
- id: travel_agent
|
||||
url: http://host.docker.internal:10401
|
||||
|
||||
agent_filters:
|
||||
- id: query_rewriter
|
||||
url: http://host.docker.internal:10501
|
||||
# type: mcp
|
||||
# transport: streamable-http
|
||||
# tool: query_rewriter
|
||||
# type: mcp # default is mcp
|
||||
# transport: streamable-http # default is streamable-http
|
||||
# tool: query_rewriter # default name is the filter id
|
||||
- id: context_builder
|
||||
url: http://host.docker.internal:10502
|
||||
|
||||
|
|
@ -39,11 +37,5 @@ listeners:
|
|||
filter_chain:
|
||||
- query_rewriter
|
||||
- context_builder
|
||||
|
||||
# - id: travel_agent
|
||||
# description: virtual assistant for travel bookings and recommendations
|
||||
# filter_chain:
|
||||
# - input_guards
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ mcp = None
|
|||
)
|
||||
@click.option("--rest-port", "rest_port", default=8000, help="Port for REST server")
|
||||
def main(host, port, agent, transport, agent_name, rest_server, rest_port):
|
||||
"""Start a RAG agent as an MCP server."""
|
||||
"""Start a RAG agent as an MCP server or REST server."""
|
||||
|
||||
# Map friendly names to agent modules
|
||||
agent_map = {
|
||||
|
|
@ -45,29 +45,38 @@ def main(host, port, agent, transport, agent_name, rest_server, rest_port):
|
|||
),
|
||||
}
|
||||
|
||||
module_name, default_name = agent_map[agent]
|
||||
mcp_name = agent_name or default_name
|
||||
|
||||
global mcp
|
||||
mcp = FastMCP(mcp_name, host=host, port=port)
|
||||
|
||||
if agent not in agent_map:
|
||||
print(f"Error: Unknown agent '{agent}'")
|
||||
print(f"Available agents: {', '.join(agent_map.keys())}")
|
||||
return
|
||||
|
||||
module_name, default_name = agent_map[agent]
|
||||
mcp_name = agent_name or default_name
|
||||
|
||||
if rest_server:
|
||||
# Only response_generator supports REST server mode
|
||||
if agent != "response_generator":
|
||||
print(f"Error: Agent '{agent}' does not support REST server mode.")
|
||||
print(f"REST server is only supported for: response_generator")
|
||||
print(f"Remove --rest-server flag to start {agent} as an MCP server.")
|
||||
return
|
||||
|
||||
print(f"Starting REST server on {host}:{rest_port} for agent: {agent}")
|
||||
from rag_agent.rag_agent import start_server
|
||||
|
||||
if agent == "response_generator":
|
||||
from rag_agent.rag_agent import start_server
|
||||
|
||||
start_server(host=host, port=rest_port)
|
||||
return
|
||||
else:
|
||||
print("Please specify an agent to start with --agent option.")
|
||||
return
|
||||
start_server(host=host, port=rest_port)
|
||||
return
|
||||
else:
|
||||
# Only query_rewriter and context_builder support MCP
|
||||
if agent not in ["query_rewriter", "context_builder"]:
|
||||
print(f"Error: Agent '{agent}' does not support MCP mode.")
|
||||
print(f"MCP is only supported for: query_rewriter, context_builder")
|
||||
print(f"Use --rest-server flag to start {agent} as a REST server.")
|
||||
return
|
||||
|
||||
global mcp
|
||||
mcp = FastMCP(mcp_name, host=host, port=port)
|
||||
|
||||
print(f"Starting MCP server: {mcp_name}")
|
||||
print(f" Agent: {agent}")
|
||||
print(f" Transport: {transport}")
|
||||
|
|
|
|||
|
|
@ -1,15 +1,12 @@
|
|||
import json
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
import csv
|
||||
from pathlib import Path
|
||||
import uvicorn
|
||||
|
||||
from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse
|
||||
from .api import ChatMessage
|
||||
from . import mcp
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
|
|
@ -183,25 +180,16 @@ async def augment_query_with_context(
|
|||
return updated_messages
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
query: str
|
||||
metadata: dict
|
||||
|
||||
|
||||
# FastAPI app for REST server
|
||||
app = FastAPI(title="RAG Content Builder Agent", version="1.0.0")
|
||||
# Load knowledge base on module import
|
||||
load_knowledge_base()
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@app.post("/v1/chat/completions")
|
||||
async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
"""chat completions endpoint that augments user queries with relevant context from the knowledge base."""
|
||||
import time
|
||||
import uuid
|
||||
|
||||
"""MCP tool that augments user queries with relevant context from the knowledge base."""
|
||||
logger.info(f"Received chat completion request with {len(messages)} messages")
|
||||
|
||||
# Get traceparent header from HTTP request using FastMCP's dependency function
|
||||
# Get traceparent header from MCP request
|
||||
headers = get_http_headers()
|
||||
traceparent_header = headers.get("traceparent")
|
||||
|
||||
|
|
@ -215,45 +203,3 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
|
|||
|
||||
# Return as dict to minimize text serialization
|
||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to initialize the knowledge base and start the server."""
|
||||
load_knowledge_base()
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
def start_server(host: str = "localhost", port: int = 8000):
|
||||
"""Start the REST server."""
|
||||
load_knowledge_base()
|
||||
# Rename the uvicorn.error logger
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_config={
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": "%(asctime)s - [CONTEXT_BUILDER] - %(levelname)s - %(message)s",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"default": {
|
||||
"formatter": "default",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
},
|
||||
"root": {
|
||||
"level": "INFO",
|
||||
"handlers": ["default"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,20 +1,14 @@
|
|||
import asyncio
|
||||
import json
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
import uvicorn
|
||||
|
||||
from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse
|
||||
from .api import ChatMessage
|
||||
from . import mcp
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
from fastmcp.dependencies import CurrentContext
|
||||
from fastmcp.server.context import Context
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -85,15 +79,6 @@ async def rewrite_query_with_archgw(
|
|||
return ""
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
query: str
|
||||
metadata: dict
|
||||
|
||||
|
||||
# FastAPI app for REST server
|
||||
app = FastAPI(title="RAG Agent Query Parser", version="1.0.0")
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
"""Chat completions endpoint that rewrites the last user query using archgw.
|
||||
|
|
@ -132,43 +117,3 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
|
|||
|
||||
# Return as dict to minimize text serialization
|
||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
def parse_query(query):
|
||||
"""Parse the user query and returns metadata extracted from query."""
|
||||
return Response(query=query, metadata={"is_valid": True})
|
||||
|
||||
|
||||
def start_server(host: str = "localhost", port: int = 8000):
|
||||
"""Start the REST server."""
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_config={
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"format": "%(asctime)s - [QUERY_REWRITER] - %(levelname)s - %(message)s",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"default": {
|
||||
"formatter": "default",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
},
|
||||
"root": {
|
||||
"level": "INFO",
|
||||
"handlers": ["default"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,9 +15,6 @@ from .api import (
|
|||
ChatCompletionStreamResponse,
|
||||
)
|
||||
|
||||
from . import mcp
|
||||
from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -63,15 +60,14 @@ def prepare_response_messages(request_body: ChatCompletionRequest):
|
|||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion_http(request_body: ChatCompletionRequest):
|
||||
async def chat_completion_http(request: Request, request_body: ChatCompletionRequest):
|
||||
"""HTTP endpoint for chat completions with streaming support."""
|
||||
logger.info(
|
||||
f"Received chat completion request with {len(request_body.messages)} messages"
|
||||
)
|
||||
|
||||
# Get traceparent header from HTTP request using FastMCP's dependency function
|
||||
headers = get_http_headers()
|
||||
traceparent_header = headers.get("traceparent")
|
||||
# Get traceparent header from HTTP request
|
||||
traceparent_header = request.headers.get("traceparent")
|
||||
|
||||
if traceparent_header:
|
||||
logger.info(f"Received traceparent header: {traceparent_header}")
|
||||
|
|
@ -91,23 +87,6 @@ async def chat_completion_http(request_body: ChatCompletionRequest):
|
|||
return await non_streaming_chat_completions(request_body, traceparent_header)
|
||||
|
||||
|
||||
@mcp.tool(name="invoke")
|
||||
async def chat_completion(request_body: ChatCompletionRequest):
|
||||
"""Chat completions endpoint that generates a coherent response based on all context.
|
||||
|
||||
For MCP calls, streaming is collected and returned as a complete response.
|
||||
"""
|
||||
logger.info(
|
||||
f"[MCP] Received chat completion request with {len(request_body.messages)} messages"
|
||||
)
|
||||
|
||||
# For MCP, always use non-streaming to return a complete response
|
||||
response = await non_streaming_chat_completions(
|
||||
request_body, traceparent_header=None
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def stream_chat_completions(
|
||||
request_body: ChatCompletionRequest, traceparent_header: str = None
|
||||
):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue