From ab185b241ab530f33059aca7e4f2d6434c87607a Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 16 Dec 2025 12:11:42 -0800 Subject: [PATCH] cleanup demo --- demos/use_cases/rag_agent/arch_config.yaml | 14 +--- .../rag_agent/src/rag_agent/__init__.py | 39 ++++++----- .../src/rag_agent/context_builder.py | 64 ++----------------- .../rag_agent/src/rag_agent/query_rewriter.py | 57 +---------------- .../rag_agent/src/rag_agent/rag_agent.py | 27 +------- 5 files changed, 36 insertions(+), 165 deletions(-) diff --git a/demos/use_cases/rag_agent/arch_config.yaml b/demos/use_cases/rag_agent/arch_config.yaml index d4b1c935..d74fb39d 100644 --- a/demos/use_cases/rag_agent/arch_config.yaml +++ b/demos/use_cases/rag_agent/arch_config.yaml @@ -3,15 +3,13 @@ version: v0.3.0 agents: - id: rag_agent url: http://host.docker.internal:10505 - - id: travel_agent - url: http://host.docker.internal:10401 agent_filters: - id: query_rewriter url: http://host.docker.internal:10501 - # type: mcp - # transport: streamable-http - # tool: query_rewriter + # type: mcp # default is mcp + # transport: streamable-http # default is streamable-http + # tool: query_rewriter # default name is the filter id - id: context_builder url: http://host.docker.internal:10502 @@ -39,11 +37,5 @@ listeners: filter_chain: - query_rewriter - context_builder - - # - id: travel_agent - # description: virtual assistant for travel bookings and recommendations - # filter_chain: - # - input_guards - tracing: random_sampling: 100 diff --git a/demos/use_cases/rag_agent/src/rag_agent/__init__.py b/demos/use_cases/rag_agent/src/rag_agent/__init__.py index 23af756b..08f8e21f 100644 --- a/demos/use_cases/rag_agent/src/rag_agent/__init__.py +++ b/demos/use_cases/rag_agent/src/rag_agent/__init__.py @@ -33,7 +33,7 @@ mcp = None ) @click.option("--rest-port", "rest_port", default=8000, help="Port for REST server") def main(host, port, agent, transport, agent_name, rest_server, rest_port): - """Start a RAG agent as an MCP server.""" + """Start a RAG agent as an MCP server or REST server.""" # Map friendly names to agent modules agent_map = { @@ -45,29 +45,38 @@ def main(host, port, agent, transport, agent_name, rest_server, rest_port): ), } - module_name, default_name = agent_map[agent] - mcp_name = agent_name or default_name - - global mcp - mcp = FastMCP(mcp_name, host=host, port=port) - if agent not in agent_map: print(f"Error: Unknown agent '{agent}'") print(f"Available agents: {', '.join(agent_map.keys())}") return + module_name, default_name = agent_map[agent] + mcp_name = agent_name or default_name + if rest_server: + # Only response_generator supports REST server mode + if agent != "response_generator": + print(f"Error: Agent '{agent}' does not support REST server mode.") + print(f"REST server is only supported for: response_generator") + print(f"Remove --rest-server flag to start {agent} as an MCP server.") + return + print(f"Starting REST server on {host}:{rest_port} for agent: {agent}") + from rag_agent.rag_agent import start_server - if agent == "response_generator": - from rag_agent.rag_agent import start_server - - start_server(host=host, port=rest_port) - return - else: - print("Please specify an agent to start with --agent option.") - return + start_server(host=host, port=rest_port) + return else: + # Only query_rewriter and context_builder support MCP + if agent not in ["query_rewriter", "context_builder"]: + print(f"Error: Agent '{agent}' does not support MCP mode.") + print(f"MCP is only supported for: query_rewriter, context_builder") + print(f"Use --rest-server flag to start {agent} as a REST server.") + return + + global mcp + mcp = FastMCP(mcp_name, host=host, port=port) + print(f"Starting MCP server: {mcp_name}") print(f" Agent: {agent}") print(f" Transport: {transport}") diff --git a/demos/use_cases/rag_agent/src/rag_agent/context_builder.py b/demos/use_cases/rag_agent/src/rag_agent/context_builder.py index 15170674..2fa6e307 100644 --- a/demos/use_cases/rag_agent/src/rag_agent/context_builder.py +++ b/demos/use_cases/rag_agent/src/rag_agent/context_builder.py @@ -1,15 +1,12 @@ import json -from pydantic import BaseModel from typing import List, Optional, Dict, Any -from fastapi import FastAPI, HTTPException, Request from openai import AsyncOpenAI import os import logging import csv from pathlib import Path -import uvicorn -from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse +from .api import ChatMessage from . import mcp from fastmcp.server.dependencies import get_http_headers @@ -183,25 +180,16 @@ async def augment_query_with_context( return updated_messages -class Response(BaseModel): - query: str - metadata: dict - - -# FastAPI app for REST server -app = FastAPI(title="RAG Content Builder Agent", version="1.0.0") +# Load knowledge base on module import +load_knowledge_base() @mcp.tool() -@app.post("/v1/chat/completions") async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]: - """chat completions endpoint that augments user queries with relevant context from the knowledge base.""" - import time - import uuid - + """MCP tool that augments user queries with relevant context from the knowledge base.""" logger.info(f"Received chat completion request with {len(messages)} messages") - # Get traceparent header from HTTP request using FastMCP's dependency function + # Get traceparent header from MCP request headers = get_http_headers() traceparent_header = headers.get("traceparent") @@ -215,45 +203,3 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]: # Return as dict to minimize text serialization return [{"role": msg.role, "content": msg.content} for msg in updated_messages] - - -def main(): - """Main function to initialize the knowledge base and start the server.""" - load_knowledge_base() - - uvicorn.run(app, host="0.0.0.0", port=8000) - - -if __name__ == "__main__": - main() - - -def start_server(host: str = "localhost", port: int = 8000): - """Start the REST server.""" - load_knowledge_base() - # Rename the uvicorn.error logger - uvicorn.run( - app, - host=host, - port=port, - log_config={ - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "default": { - "format": "%(asctime)s - [CONTEXT_BUILDER] - %(levelname)s - %(message)s", - }, - }, - "handlers": { - "default": { - "formatter": "default", - "class": "logging.StreamHandler", - "stream": "ext://sys.stdout", - }, - }, - "root": { - "level": "INFO", - "handlers": ["default"], - }, - }, - ) diff --git a/demos/use_cases/rag_agent/src/rag_agent/query_rewriter.py b/demos/use_cases/rag_agent/src/rag_agent/query_rewriter.py index 75f2a4d1..89e5b200 100644 --- a/demos/use_cases/rag_agent/src/rag_agent/query_rewriter.py +++ b/demos/use_cases/rag_agent/src/rag_agent/query_rewriter.py @@ -1,20 +1,14 @@ import asyncio import json -from pydantic import BaseModel from typing import List, Optional, Dict, Any -from fastapi import FastAPI, HTTPException, Request from openai import AsyncOpenAI import os import logging -import uvicorn -from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse +from .api import ChatMessage from . import mcp from fastmcp.server.dependencies import get_http_headers -from fastmcp.dependencies import CurrentContext -from fastmcp.server.context import Context - # Set up logging logging.basicConfig( level=logging.INFO, @@ -85,15 +79,6 @@ async def rewrite_query_with_archgw( return "" -class Response(BaseModel): - query: str - metadata: dict - - -# FastAPI app for REST server -app = FastAPI(title="RAG Agent Query Parser", version="1.0.0") - - @mcp.tool() async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]: """Chat completions endpoint that rewrites the last user query using archgw. @@ -132,43 +117,3 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]: # Return as dict to minimize text serialization return [{"role": msg.role, "content": msg.content} for msg in updated_messages] - - -@app.get("/health") -async def health_check(): - """Health check endpoint.""" - return {"status": "healthy"} - - -def parse_query(query): - """Parse the user query and returns metadata extracted from query.""" - return Response(query=query, metadata={"is_valid": True}) - - -def start_server(host: str = "localhost", port: int = 8000): - """Start the REST server.""" - uvicorn.run( - app, - host=host, - port=port, - log_config={ - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "default": { - "format": "%(asctime)s - [QUERY_REWRITER] - %(levelname)s - %(message)s", - }, - }, - "handlers": { - "default": { - "formatter": "default", - "class": "logging.StreamHandler", - "stream": "ext://sys.stdout", - }, - }, - "root": { - "level": "INFO", - "handlers": ["default"], - }, - }, - ) diff --git a/demos/use_cases/rag_agent/src/rag_agent/rag_agent.py b/demos/use_cases/rag_agent/src/rag_agent/rag_agent.py index e22cca03..b590248a 100644 --- a/demos/use_cases/rag_agent/src/rag_agent/rag_agent.py +++ b/demos/use_cases/rag_agent/src/rag_agent/rag_agent.py @@ -15,9 +15,6 @@ from .api import ( ChatCompletionStreamResponse, ) -from . import mcp -from fastmcp.server.dependencies import get_http_headers - # Set up logging logging.basicConfig( level=logging.INFO, @@ -63,15 +60,14 @@ def prepare_response_messages(request_body: ChatCompletionRequest): @app.post("/v1/chat/completions") -async def chat_completion_http(request_body: ChatCompletionRequest): +async def chat_completion_http(request: Request, request_body: ChatCompletionRequest): """HTTP endpoint for chat completions with streaming support.""" logger.info( f"Received chat completion request with {len(request_body.messages)} messages" ) - # Get traceparent header from HTTP request using FastMCP's dependency function - headers = get_http_headers() - traceparent_header = headers.get("traceparent") + # Get traceparent header from HTTP request + traceparent_header = request.headers.get("traceparent") if traceparent_header: logger.info(f"Received traceparent header: {traceparent_header}") @@ -91,23 +87,6 @@ async def chat_completion_http(request_body: ChatCompletionRequest): return await non_streaming_chat_completions(request_body, traceparent_header) -@mcp.tool(name="invoke") -async def chat_completion(request_body: ChatCompletionRequest): - """Chat completions endpoint that generates a coherent response based on all context. - - For MCP calls, streaming is collected and returned as a complete response. - """ - logger.info( - f"[MCP] Received chat completion request with {len(request_body.messages)} messages" - ) - - # For MCP, always use non-streaming to return a complete response - response = await non_streaming_chat_completions( - request_body, traceparent_header=None - ) - return response - - async def stream_chat_completions( request_body: ChatCompletionRequest, traceparent_header: str = None ):