pending changes

This commit is contained in:
Adil Hafeez 2025-12-15 18:17:15 -08:00
parent afffa11e91
commit 358fa856c4
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
21 changed files with 1195 additions and 403 deletions

View file

@ -2,23 +2,21 @@ version: v0.3.0
agents:
- id: rag_agent
url: mcp://host.docker.internal:10501
# only sse is supported
# transport: sse or stdio
# optional tool name, defaults to "invoke"
# tool: invoke
url: mcp://host.docker.internal:10505
- id: travel_agent
url: mcp://host.docker.internal:10502
transport: streamable-http
tool: invoke
url: mcp://host.docker.internal:10401
agent_filters:
- id: query_rewriter
url: mcp://host.docker.internal:10500
# tool is optional, defaults to id
# tool: query_rewriter
transport: streamable-http
tool: query_rewriter
url: mcp://host.docker.internal:10501
- id: context_builder
url: mcp://host.docker.internal:10500
- id: input_guards
url: mcp://host.docker.internal:10500
transport: streamable-http
tool: context_builder
url: mcp://host.docker.internal:10502
model_providers:
- model: openai/gpt-4o-mini
@ -35,20 +33,20 @@ model_aliases:
listeners:
- type: agent
name: agent_1
port: 8001
router: arch_agent_router
agents:
- id: rag_agent
description: virtual assistant for retrieval augmented generation tasks
filter_chain:
- input_guards
- query_rewriter
- context_builder
- id: travel_agent
description: virtual assistant for travel bookings and recommendations
filter_chain:
- input_guards
# - id: travel_agent
# description: virtual assistant for travel bookings and recommendations
# filter_chain:
# - input_guards
tracing:
random_sampling: 100

View file

@ -0,0 +1,86 @@
### Initialize MCP Session (SSE)
POST http://localhost:10501/mcp
Content-Type: application/json
Accept: application/json, text/event-stream
{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"capabilities":{},"protocolVersion":"2024-11-05","clientInfo":{"name":"test","version":"1.0.0"}}}
### Send Initialized Notification
POST http://localhost:10501/mcp
Content-Type: application/json
Accept: application/json, text/event-stream
mcp-session-id: e4ec1ae904e14e06b7d194da10e5f74c
{
"jsonrpc": "2.0",
"method": "notifications/initialized"
}
### List Tools
POST http://localhost:10501/mcp
Content-Type: application/json
Accept: application/json, text/event-stream
mcp-session-id: eb10a691b36e4547b6c93c5dc5b47e11
{
"jsonrpc": "2.0",
"id": "list-tools-1",
"method": "tools/list"
}
### Call Query Rewriter Tool
POST http://localhost:10501/mcp
Content-Type: application/json
Accept: application/json, text/event-stream
mcp-session-id: 6b95ff75825a402b90eb3ea07e23fbce
{
"jsonrpc": "2.0",
"id": "3d3b886a-6216-4a26-a422-7a972529c0e7",
"method": "tools/call",
"params": {
"arguments": {
"messages": [
{
"content": "What is the guaranteed uptime percentage for TechCorp's cloud services?",
"role": "user"
}
]
},
"name": "query_rewriter"
}
}
### another test
# Content-Type: application/json
# Accept: application/json, text/event-stream
# mcp-session-id: ed7a81a1d39549ecaadb867a6b2daf1e
POST http://localhost:10501/mcp
content-type: application/json
mcp-session-id: e4ec1ae904e14e06b7d194da10e5f74c
accept: application/json, text/event-stream
{"jsonrpc":"2.0","id":"4bb1043a-2953-4bcd-b801-f270b0ae8c39","method":"tools/call","params":{"arguments":{"messages":[{"content":"What is the guaranteed uptime percentage for TechCorp's cloud services?","role":"user"}]},"name":"query_rewriter"}}
### stream test
POST http://localhost:10501/mcp
content-type: application/json
mcp-session-id: 60be9fb816304cb6b9ecdb91d89cd91f
accept: application/json, text/event-stream
{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": "long_job",
"arguments": {
"n": 3
}
}
}

View file

@ -7,7 +7,7 @@ requires-python = ">=3.10"
dependencies = [
"click>=8.2.1",
"mcp>=1.13.1",
"fastmcp>=2.12.2",
"fastmcp>=2.14",
"pydantic>=2.11.7",
"fastapi>=0.104.1",
"uvicorn>=0.24.0",

View file

@ -1,50 +1,88 @@
import click
from mcp.server.fastmcp import FastMCP
from fastmcp import FastMCP
mcp = None
@click.command()
@click.option("--transport", "transport", default="sse", help="Transport type: stdio or sse")
@click.option(
"--transport",
"transport",
default="streamable-http",
help="Transport type: stdio or sse",
)
@click.option("--host", "host", default="localhost", help="Host to bind MCP server to")
@click.option("--port", "port", type=int, default=10500, help="Port for MCP server")
@click.option("--agent", "agent", required=True, help="Agent name: query_rewriter, context_builder, or response_generator")
@click.option("--name", "agent_name", default=None, help="Custom MCP server name (defaults to agent type)")
def main(host, port, agent, transport, agent_name):
@click.option(
"--agent",
"agent",
required=True,
help="Agent name: query_rewriter, context_builder, or response_generator",
)
@click.option(
"--name",
"agent_name",
default=None,
help="Custom MCP server name (defaults to agent type)",
)
@click.option(
"--rest-server",
"rest_server",
is_flag=True,
help="Start REST server instead of MCP server",
)
@click.option("--rest-port", "rest_port", default=8000, help="Port for REST server")
def main(host, port, agent, transport, agent_name, rest_server, rest_port):
"""Start a RAG agent as an MCP server."""
# Map friendly names to agent modules
agent_map = {
"query_rewriter": ("rag_agent.query_rewriter", "Query Rewriter Agent"),
"context_builder": ("rag_agent.context_builder_agent", "Context Builder Agent"),
"response_generator": ("rag_agent.response_generator", "Response Generator Agent"),
"context_builder": ("rag_agent.context_builder", "Context Builder Agent"),
"response_generator": (
"rag_agent.rag_agent",
"Response Generator Agent",
),
}
module_name, default_name = agent_map[agent]
mcp_name = agent_name or default_name
global mcp
mcp = FastMCP(mcp_name, host=host, port=port)
if agent not in agent_map:
print(f"Error: Unknown agent '{agent}'")
print(f"Available agents: {', '.join(agent_map.keys())}")
return
module_name, default_name = agent_map[agent]
mcp_name = agent_name or default_name
print(f"Starting MCP server: {mcp_name}")
print(f" Agent: {agent}")
print(f" Transport: {transport}")
print(f" Host: {host}")
print(f" Port: {port}")
global mcp
mcp = FastMCP(mcp_name, host=host, port=port)
# Import the agent module to register its tools
import importlib
importlib.import_module(module_name)
print(f"Agent '{agent}' loaded successfully")
print(f"MCP server ready on {transport}://{host}:{port}")
mcp.run(transport=transport)
if rest_server:
print(f"Starting REST server on {host}:{rest_port} for agent: {agent}")
if agent == "response_generator":
from rag_agent.rag_agent import start_server
start_server(host=host, port=rest_port)
return
else:
print("Please specify an agent to start with --agent option.")
return
else:
print(f"Starting MCP server: {mcp_name}")
print(f" Agent: {agent}")
print(f" Transport: {transport}")
print(f" Host: {host}")
print(f" Port: {port}")
# Import the agent module to register its tools
import importlib
importlib.import_module(module_name)
print(f"Agent '{agent}' loaded successfully")
print(f"MCP server ready on {transport}://{host}:{port}")
mcp.run(transport=transport)
if __name__ == "__main__":

View file

@ -191,54 +191,30 @@ class Response(BaseModel):
# FastAPI app for REST server
app = FastAPI(title="RAG Content Builder Agent", version="1.0.0")
@mcp.tool()
@app.post("/v1/chat/completions")
async def context_builder(
request_body: ChatCompletionRequest
) -> ChatCompletionResponse:
""" chat completions endpoint that augments user queries with relevant context from the knowledge base."""
async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
"""chat completions endpoint that augments user queries with relevant context from the knowledge base."""
import time
import uuid
logger.info(
f"Received chat completion request with {len(request_body.messages)} messages"
)
logger.info(f"Received chat completion request with {len(messages)} messages")
# Get traceparent header from HTTP request using FastMCP's dependency function
headers = get_http_headers()
traceparent_header = headers.get("traceparent")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
else:
logger.info("No traceparent header found")
# Augment the user query with relevant context
updated_messages = await augment_query_with_context(
request_body.messages, traceparent_header
)
messages_history_json = json.dumps([msg.dict() for msg in updated_messages])
updated_messages = await augment_query_with_context(messages, traceparent_header)
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request_body.model,
choices=[
{
"index": 0,
"message": {"role": "user", "content": messages_history_json},
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": sum(len(msg.content.split()) for msg in updated_messages),
"completion_tokens": len("Context added to user query.".split()),
"total_tokens": sum(len(msg.content.split()) for msg in updated_messages)
+ len("Context added to user query.".split()),
},
)
return response
# Return as dict to minimize text serialization
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
def main():

View file

@ -1,3 +1,4 @@
import asyncio
import json
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
@ -11,6 +12,9 @@ from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse
from . import mcp
from fastmcp.server.dependencies import get_http_headers
from fastmcp.dependencies import CurrentContext
from fastmcp.server.context import Context
# Set up logging
logging.basicConfig(
level=logging.INFO,
@ -29,10 +33,11 @@ archgw_client = AsyncOpenAI(
api_key="EMPTY", # archgw doesn't require a real API key
)
async def rewrite_query_with_archgw(
messages: List[ChatMessage], traceparent_header: str
) -> str:
""" Rewrite the user query using LLM for better retrieval. """
"""Rewrite the user query using LLM for better retrieval."""
system_prompt = """You are a query rewriter that improves user queries for better retrieval.
Given a conversation history, rewrite the last user message to be more specific and context-aware.
@ -89,33 +94,31 @@ class Response(BaseModel):
app = FastAPI(title="RAG Agent Query Parser", version="1.0.0")
@app.post("/v1/chat/completions")
@mcp.tool()
async def query_rewriter(request_body: ChatCompletionRequest):
"""Chat completions endpoint that rewrites the last user query using archgw."""
async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
"""Chat completions endpoint that rewrites the last user query using archgw.
Returns a dict with a 'messages' key containing the updated message list.
"""
import time
import uuid
logger.info(
f"Received chat completion request with {len(request_body.messages)} messages"
)
logger.info(f"Received chat completion request with {len(messages)} messages")
# Get traceparent header from HTTP request using FastMCP's dependency function
headers = get_http_headers()
traceparent_header = headers.get("traceparent")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
else:
logger.info("No traceparent header found")
# Call archgw to rewrite the last user query
rewritten_query = await rewrite_query_with_archgw(
request_body.messages, traceparent_header
)
rewritten_query = await rewrite_query_with_archgw(messages, traceparent_header)
# Create updated messages with the rewritten query
updated_messages = request_body.messages.copy()
updated_messages = messages.copy()
# Find and update the last user message with the rewritten query
for i in range(len(updated_messages) - 1, -1, -1):
@ -127,28 +130,8 @@ async def query_rewriter(request_body: ChatCompletionRequest):
)
break
messages_history_json = json.dumps([msg.dict() for msg in updated_messages])
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request_body.model,
choices=[
{
"index": 0,
"message": {"role": "user", "content": messages_history_json},
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": sum(len(msg.content.split()) for msg in updated_messages),
"completion_tokens": len("Updated query for better retrieval.".split()),
"total_tokens": sum(len(msg.content.split()) for msg in updated_messages)
+ len("Updated query for better retrieval.".split()),
},
)
return response
# Return as dict to minimize text serialization
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
@app.get("/health")

View file

@ -63,9 +63,8 @@ def prepare_response_messages(request_body: ChatCompletionRequest):
@app.post("/v1/chat/completions")
@mcp.tool(name="invoke")
async def chat_completion(request_body: ChatCompletionRequest):
"""Chat completions endpoint that generates a coherent response based on all context."""
async def chat_completion_http(request_body: ChatCompletionRequest):
"""HTTP endpoint for chat completions with streaming support."""
logger.info(
f"Received chat completion request with {len(request_body.messages)} messages"
)
@ -73,7 +72,7 @@ async def chat_completion(request_body: ChatCompletionRequest):
# Get traceparent header from HTTP request using FastMCP's dependency function
headers = get_http_headers()
traceparent_header = headers.get("traceparent")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
else:
@ -92,6 +91,23 @@ async def chat_completion(request_body: ChatCompletionRequest):
return await non_streaming_chat_completions(request_body, traceparent_header)
@mcp.tool(name="invoke")
async def chat_completion(request_body: ChatCompletionRequest):
"""Chat completions endpoint that generates a coherent response based on all context.
For MCP calls, streaming is collected and returned as a complete response.
"""
logger.info(
f"[MCP] Received chat completion request with {len(request_body.messages)} messages"
)
# For MCP, always use non-streaming to return a complete response
response = await non_streaming_chat_completions(
request_body, traceparent_header=None
)
return response
async def stream_chat_completions(
request_body: ChatCompletionRequest, traceparent_header: str = None
):

View file

@ -21,16 +21,25 @@ cleanup() {
trap cleanup EXIT
log "Starting query_parser agent on port 10500..."
uv run python -m rag_agent --rest-server --host 0.0.0.0 --rest-port 10500 --agent query_parser &
# log "Starting input guards filter on port 10500..."
# uv run python -m rag_agent --host 0.0.0.0 --port 10500 --agent input_guards &
# WAIT_FOR_PIDS+=($!)
log "Starting query_parser agent on port 10501..."
uv run python -m rag_agent --host 0.0.0.0 --port 10501 --agent query_rewriter &
WAIT_FOR_PIDS+=($!)
log "Starting context_builder agent on port 10501..."
uv run python -m rag_agent --rest-server --host 0.0.0.0 --rest-port 10501 --agent context_builder &
log "Starting context_builder agent on port 10502..."
uv run python -m rag_agent --host 0.0.0.0 --port 10502 --agent context_builder &
WAIT_FOR_PIDS+=($!)
log "Starting response_generator agent on port 10502..."
uv run python -m rag_agent --rest-server --host 0.0.0.0 --rest-port 10502 --agent response_generator &
# log "Starting response_generator agent on port 10400..."
# uv run python -m rag_agent --host 0.0.0.0 --port 10400 --agent response_generator &
# WAIT_FOR_PIDS+=($!)
log "Starting response_generator agent on port 10505..."
uv run python -m rag_agent --rest-server --host 0.0.0.0 --rest-port 10505 --agent response_generator &
WAIT_FOR_PIDS+=($!)
for PID in "${WAIT_FOR_PIDS[@]}"; do

View file

@ -49,7 +49,7 @@ Content-Type: application/json
"content": "What is the guaranteed uptime percentage for TechCorp's cloud services?"
}
],
"stream": false
"stream": true
}
### send request to context builder agent

File diff suppressed because it is too large Load diff