diff --git a/demos/use_cases/rag_agent/README.md b/demos/use_cases/rag_agent/README.md index e69de29b..66102f6f 100644 --- a/demos/use_cases/rag_agent/README.md +++ b/demos/use_cases/rag_agent/README.md @@ -0,0 +1,28 @@ +# RAG Agent Query Parser + +A FastAPI service that rewrites user queries using archgw and gpt-4o-mini for better retrieval accuracy. + +## How it Works + +1. Receives a chat completion request with conversation history +2. Calls archgw's LLM gateway with gpt-4o-mini to rewrite the last user query +3. Returns the rewritten query as the assistant response + +## Setup and Running + +1. **Start archgw**: + ```bash + archgw up --foreground + ``` + +2. **Start the query parser service**: + ```bash + uv run python -m rag_agent.query_parser + ``` + +## Configuration + +```bash +# archgw LLM Gateway base URL (default: http://localhost:12000/v1) +export LLM_GATEWAY_ENDPOINT="http://localhost:12000/v1" +``` diff --git a/demos/use_cases/rag_agent/arch_config.yaml b/demos/use_cases/rag_agent/arch_config.yaml index b75bdc0c..c351fd9f 100644 --- a/demos/use_cases/rag_agent/arch_config.yaml +++ b/demos/use_cases/rag_agent/arch_config.yaml @@ -24,11 +24,6 @@ listeners: - query_rewriter - context_builder - response_generator - - name: research_agent - description: agent to research and gather information from various sources. - filter_chain: - - research_agent - - response_generator port: 8001 - name: egress_traffic @@ -38,3 +33,5 @@ listeners: llm_providers: - access_key: ${OPENAI_API_KEY} model: openai/gpt-4o + - access_key: ${OPENAI_API_KEY} + model: openai/gpt-4o-mini diff --git a/demos/use_cases/rag_agent/pyproject.toml b/demos/use_cases/rag_agent/pyproject.toml index 3ca7faea..4e608192 100644 --- a/demos/use_cases/rag_agent/pyproject.toml +++ b/demos/use_cases/rag_agent/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "pydantic>=2.11.7", "fastapi>=0.104.1", "uvicorn>=0.24.0", + "openai>=1.0.0", ] [project.scripts] diff --git a/demos/use_cases/rag_agent/src/rag_agent/api.py b/demos/use_cases/rag_agent/src/rag_agent/api.py new file mode 100644 index 00000000..292451c2 --- /dev/null +++ b/demos/use_cases/rag_agent/src/rag_agent/api.py @@ -0,0 +1,28 @@ +from pydantic import BaseModel +from typing import List, Optional, Dict, Any + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + temperature: Optional[float] = 1.0 + max_tokens: Optional[int] = None + top_p: Optional[float] = 1.0 + frequency_penalty: Optional[float] = 0.0 + presence_penalty: Optional[float] = 0.0 + stream: Optional[bool] = False + stop: Optional[List[str]] = None + + +class ChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int + model: str + choices: List[Dict[str, Any]] + usage: Dict[str, int] diff --git a/demos/use_cases/rag_agent/src/rag_agent/content_builder_agent.py b/demos/use_cases/rag_agent/src/rag_agent/content_builder_agent.py new file mode 100644 index 00000000..76998a7a --- /dev/null +++ b/demos/use_cases/rag_agent/src/rag_agent/content_builder_agent.py @@ -0,0 +1,223 @@ +from pydantic import BaseModel +from typing import List, Optional, Dict, Any +from fastapi import FastAPI, HTTPException +from openai import AsyncOpenAI +import os +import logging +import csv +from pathlib import Path + +from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse + + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# Configuration for archgw LLM gateway +LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1") +RAG_MODEL = "gpt-4o-mini" + +# Initialize OpenAI client for archgw +archgw_client = AsyncOpenAI( + base_url=LLM_GATEWAY_ENDPOINT, + api_key="EMPTY", # archgw doesn't require a real API key +) + +# Global variable to store the knowledge base +knowledge_base = [] + + +def load_knowledge_base(): + """Load the basis_of_truth.csv file into memory on startup.""" + global knowledge_base + + # Get the path to the CSV file relative to this script + current_dir = Path(__file__).parent + csv_path = current_dir / "basis_of_truth.csv" + + try: + knowledge_base = [] + with open(csv_path, "r", encoding="utf-8") as file: + csv_reader = csv.DictReader(file) + for row in csv_reader: + knowledge_base.append({"path": row["path"], "content": row["content"]}) + + logger.info(f"Loaded {len(knowledge_base)} documents from knowledge base") + + except Exception as e: + logger.error(f"Error loading knowledge base: {e}") + knowledge_base = [] + + +async def find_relevant_passages(query: str, top_k: int = 3) -> List[Dict[str, str]]: + """Use the LLM to find the most relevant passages from the knowledge base.""" + + if not knowledge_base: + logger.warning("Knowledge base is empty") + return [] + + # Create a system prompt for passage selection + system_prompt = f"""You are a retrieval assistant that selects the most relevant document passages for a given query. + + Given a user query and a list of document passages, identify the {top_k} most relevant passages that would help answer the query. + + Query: {query} + + Available passages: + """ + + # Add all passages with indices + for i, doc in enumerate(knowledge_base): + system_prompt += ( + f"\n[{i}] Path: {doc['path']}\nContent: {doc['content'][:500]}...\n" + ) + + system_prompt += f""" + + Please respond with ONLY the indices of the {top_k} most relevant passages, separated by commas (e.g., "0,3,7"). + If fewer than {top_k} passages are relevant, return only the relevant ones. + If no passages are relevant, return "NONE".""" + + try: + # Call archgw to select relevant passages + logger.info(f"Calling archgw to find relevant passages for query: '{query}'") + response = await archgw_client.chat.completions.create( + model=RAG_MODEL, + messages=[{"role": "system", "content": system_prompt}], + temperature=0.1, + max_tokens=50, + ) + + result = response.choices[0].message.content.strip() + logger.info(f"LLM selected passages: {result}") + + # Parse the indices + if result.upper() == "NONE": + return [] + + selected_passages = [] + indices = [ + int(idx.strip()) for idx in result.split(",") if idx.strip().isdigit() + ] + + for idx in indices: + if 0 <= idx < len(knowledge_base): + selected_passages.append(knowledge_base[idx]) + + logger.info(f"Selected {len(selected_passages)} relevant passages") + return selected_passages + + except Exception as e: + logger.error(f"Error finding relevant passages: {e}") + return [] + + +async def augment_query_with_context(messages: List[ChatMessage]) -> List[ChatMessage]: + """Extract user query, find relevant context, and augment the messages.""" + + # Find the last user message + last_user_message = None + last_user_index = -1 + + for i in range(len(messages) - 1, -1, -1): + if messages[i].role == "user": + last_user_message = messages[i].content + last_user_index = i + break + + if not last_user_message: + logger.warning("No user message found in conversation") + return messages + + logger.info(f"Processing user query: '{last_user_message}'") + + # Find relevant passages + relevant_passages = await find_relevant_passages(last_user_message) + + if not relevant_passages: + logger.info("No relevant passages found, returning original messages") + return messages + + # Build context from relevant passages + context_parts = [] + for i, passage in enumerate(relevant_passages): + context_parts.append( + f"Document {i+1} ({passage['path']}):\n{passage['content']}" + ) + + context = "\n\n".join(context_parts) + + # Create augmented content with original query and context + augmented_content = f"""{last_user_message} RELEVANT CONTEXT: + {context}""" + + # Create updated messages with the augmented query + updated_messages = messages.copy() + updated_messages[last_user_index] = ChatMessage( + role="user", content=augmented_content + ) + + logger.info(f"Augmented user query with {len(relevant_passages)} relevant passages") + + return updated_messages + + +class Response(BaseModel): + query: str + metadata: dict + + +# FastAPI app for REST server +app = FastAPI(title="RAG Content Builder Agent", version="1.0.0") + + +@app.post("/v1/chat/completions") +async def chat_completions(request: ChatCompletionRequest): + """Chat completions endpoint that augments user queries with relevant context from the knowledge base.""" + import time + import uuid + + logger.info( + f"Received chat completion request with {len(request.messages)} messages" + ) + + # Augment the user query with relevant context + updated_messages = await augment_query_with_context(request.messages) + + response = ChatCompletionResponse( + id=f"chatcmpl-{uuid.uuid4().hex[:8]}", + created=int(time.time()), + model=request.model, + choices=[ + { + "index": 0, + "messages": [ + {"role": msg.role, "content": msg.content} + for msg in updated_messages + ], + "finish_reason": "stop", + } + ], + usage={ + "prompt_tokens": sum(len(msg.content.split()) for msg in updated_messages), + "completion_tokens": len("Context added to user query.".split()), + "total_tokens": sum(len(msg.content.split()) for msg in updated_messages) + + len("Context added to user query.".split()), + }, + ) + + return response + + +def main(): + """Main function to initialize the knowledge base and start the server.""" + load_knowledge_base() + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) + + +if __name__ == "__main__": + main() diff --git a/demos/use_cases/rag_agent/src/rag_agent/document_store.py b/demos/use_cases/rag_agent/src/rag_agent/document_store.py deleted file mode 100644 index 93dfc228..00000000 --- a/demos/use_cases/rag_agent/src/rag_agent/document_store.py +++ /dev/null @@ -1,18 +0,0 @@ -from pydantic import BaseModel -from . import mcp - - -class QueryRequest(BaseModel): - query: str - metadata: dict | None = None - - -class QueryResponse(BaseModel): - query: str - results: list - - -@mcp.tool() -def query_rag_store(request: QueryRequest): - """Query the RAG document store.""" - return {"query": request.query, "results": []} diff --git a/demos/use_cases/rag_agent/src/rag_agent/query_parser.py b/demos/use_cases/rag_agent/src/rag_agent/query_parser.py deleted file mode 100644 index 81f87063..00000000 --- a/demos/use_cases/rag_agent/src/rag_agent/query_parser.py +++ /dev/null @@ -1,101 +0,0 @@ -from pydantic import BaseModel -from typing import List, Optional, Dict, Any -from fastapi import FastAPI, HTTPException -import uvicorn - - -# OpenAI Chat Completions API models -class ChatMessage(BaseModel): - role: str - content: str - - -class ChatCompletionRequest(BaseModel): - model: str - messages: List[ChatMessage] - temperature: Optional[float] = 1.0 - max_tokens: Optional[int] = None - top_p: Optional[float] = 1.0 - frequency_penalty: Optional[float] = 0.0 - presence_penalty: Optional[float] = 0.0 - stream: Optional[bool] = False - stop: Optional[List[str]] = None - - -class ChatCompletionResponse(BaseModel): - id: str - object: str = "chat.completion" - created: int - model: str - choices: List[Dict[str, Any]] - usage: Dict[str, int] - - -class Response(BaseModel): - query: str - metadata: dict - - -# FastAPI app for REST server -app = FastAPI(title="RAG Agent Query Parser", version="1.0.0") - - -@app.post("/v1/chat/completions") -async def chat_completions(request: ChatCompletionRequest): - """Chat completions endpoint that passes through the request as-is.""" - import time - import uuid - - # Pass-through: return the last user message as the assistant response - last_user_message = "" - for message in reversed(request.messages): - if message.role == "user": - last_user_message = message.content - break - - response = ChatCompletionResponse( - id=f"chatcmpl-{uuid.uuid4().hex[:8]}", - created=int(time.time()), - model=request.model, - choices=[ - { - "index": 0, - "message": {"role": "assistant", "content": last_user_message}, - "finish_reason": "stop", - } - ], - usage={ - "prompt_tokens": sum(len(msg.content.split()) for msg in request.messages), - "completion_tokens": len(last_user_message.split()), - "total_tokens": sum(len(msg.content.split()) for msg in request.messages) - + len(last_user_message.split()), - }, - ) - - return response - - -@app.get("/health") -async def health_check(): - """Health check endpoint.""" - return {"status": "healthy"} - - -def parse_query(query): - """Parse the user query and returns metadata extracted from query.""" - return Response(query=query, metadata={"is_valid": True}) - - -# Register MCP tool only if mcp is available -try: - from . import mcp - - if mcp is not None: - mcp.tool()(parse_query) -except (ImportError, AttributeError): - pass - - -def start_server(host: str = "localhost", port: int = 8000): - """Start the REST server.""" - uvicorn.run(app, host=host, port=port) diff --git a/demos/use_cases/rag_agent/src/rag_agent/query_rewriter_agent.py b/demos/use_cases/rag_agent/src/rag_agent/query_rewriter_agent.py new file mode 100644 index 00000000..eadc66b6 --- /dev/null +++ b/demos/use_cases/rag_agent/src/rag_agent/query_rewriter_agent.py @@ -0,0 +1,140 @@ +from pydantic import BaseModel +from typing import List, Optional, Dict, Any +from fastapi import FastAPI, HTTPException +from openai import AsyncOpenAI +import os +import logging + +from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse + + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# Configuration for archgw LLM gateway +LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1") +QUERY_REWRITE_MODEL = "gpt-4o-mini" + +# Initialize OpenAI client for archgw +archgw_client = AsyncOpenAI( + base_url=LLM_GATEWAY_ENDPOINT, + api_key="EMPTY", # archgw doesn't require a real API key +) + + +async def rewrite_query_with_archgw(messages: List[ChatMessage]) -> str: + # Prepare the system prompt for query rewriting + system_prompt = """You are a query rewriter that improves user queries for better retrieval. + + Given a conversation history, rewrite the last user message to be more specific and context-aware. + The rewritten query should: + 1. Include relevant context from previous messages + 2. Be clear and specific for information retrieval + 3. Maintain the user's intent + 4. Be concise but comprehensive + + Return only the rewritten query, nothing else.""" + + # Prepare messages for the query rewriter - just add system prompt to existing messages + rewrite_messages = [{"role": "system", "content": system_prompt}] + + # Add conversation history + for msg in messages: + rewrite_messages.append({"role": msg.role, "content": msg.content}) + + try: + # Call archgw using OpenAI client + logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to rewrite query") + response = await archgw_client.chat.completions.create( + model=QUERY_REWRITE_MODEL, + messages=rewrite_messages, + temperature=0.3, + max_tokens=200, + ) + + rewritten_query = response.choices[0].message.content.strip() + logger.info(f"Query rewritten successfully: '{rewritten_query}'") + return rewritten_query + + except Exception as e: + logger.error(f"Error rewriting query: {e}") + + # If rewriting fails, return the original last user message + logger.info("Falling back to original user message") + for message in reversed(messages): + if message.role == "user": + return message.content + return "" + + +class Response(BaseModel): + query: str + metadata: dict + + +# FastAPI app for REST server +app = FastAPI(title="RAG Agent Query Parser", version="1.0.0") + + +@app.post("/v1/chat/completions") +async def chat_completions(request: ChatCompletionRequest): + """Chat completions endpoint that rewrites the last user query using archgw.""" + import time + import uuid + + logger.info( + f"Received chat completion request with {len(request.messages)} messages" + ) + + # Call archgw to rewrite the last user query + rewritten_query = await rewrite_query_with_archgw(request.messages) + + # Create updated messages with the rewritten query + updated_messages = request.messages.copy() + + # Find and update the last user message with the rewritten query + for i in range(len(updated_messages) - 1, -1, -1): + if updated_messages[i].role == "user": + original_query = updated_messages[i].content + updated_messages[i] = ChatMessage(role="user", content=rewritten_query) + logger.info( + f"Updated user query from '{original_query}' to '{rewritten_query}'" + ) + break + + response = ChatCompletionResponse( + id=f"chatcmpl-{uuid.uuid4().hex[:8]}", + created=int(time.time()), + model=request.model, + choices=[ + { + "index": 0, + "messages": [ + {"role": msg.role, "content": msg.content} + for msg in updated_messages + ], + "finish_reason": "stop", + } + ], + usage={ + "prompt_tokens": sum(len(msg.content.split()) for msg in updated_messages), + "completion_tokens": len("Updated query for better retrieval.".split()), + "total_tokens": sum(len(msg.content.split()) for msg in updated_messages) + + len("Updated query for better retrieval.".split()), + }, + ) + + return response + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy"} + + +def parse_query(query): + """Parse the user query and returns metadata extracted from query.""" + return Response(query=query, metadata={"is_valid": True}) diff --git a/demos/use_cases/rag_agent/src/rag_agent/response_generator.py b/demos/use_cases/rag_agent/src/rag_agent/response_generator.py deleted file mode 100644 index a612c626..00000000 --- a/demos/use_cases/rag_agent/src/rag_agent/response_generator.py +++ /dev/null @@ -1,11 +0,0 @@ -from . import mcp - - -@mcp.tool() -def generate_response(query, context): - """Generate a response based on the user query and context.""" - return { - "query": query, - "context": context, - "response": "This is a generated response.", - } diff --git a/demos/use_cases/rag_agent/src/rag_agent/response_generator_agent.py b/demos/use_cases/rag_agent/src/rag_agent/response_generator_agent.py new file mode 100644 index 00000000..b3189ace --- /dev/null +++ b/demos/use_cases/rag_agent/src/rag_agent/response_generator_agent.py @@ -0,0 +1,122 @@ +from fastapi import FastAPI +from openai import AsyncOpenAI +import os +import logging +import time +import uuid + +from .api import ChatCompletionRequest, ChatCompletionResponse + +# Set up logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Configuration for archgw LLM gateway +LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1") +RESPONSE_MODEL = "gpt-4o" + +# Initialize OpenAI client for archgw +archgw_client = AsyncOpenAI( + base_url=LLM_GATEWAY_ENDPOINT, + api_key="EMPTY", # archgw doesn't require a real API key +) + +# FastAPI app for REST server +app = FastAPI(title="RAG Agent Response Generator", version="1.0.0") + + +@app.post("/v1/chat/completions") +async def chat_completions(request: ChatCompletionRequest): + """Chat completions endpoint that generates a coherent response based on all context.""" + logger.info( + f"Received chat completion request with {len(request.messages)} messages" + ) + + # Prepare the system prompt for response generation + system_prompt = """You are a helpful assistant that generates coherent, contextual responses. + + Given a conversation history, generate a helpful and relevant response based on all the context available in the messages. + Your response should: + 1. Be contextually aware of the entire conversation + 2. Address the user's needs appropriately + 3. Be helpful and informative + 4. Maintain a natural conversational tone + + Generate a complete response to assist the user.""" + + # Prepare messages for response generation + response_messages = [{"role": "system", "content": system_prompt}] + + # Add conversation history + for msg in request.messages: + response_messages.append({"role": msg.role, "content": msg.content}) + + try: + # Call archgw using OpenAI client + logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate response") + response = await archgw_client.chat.completions.create( + model=RESPONSE_MODEL, + messages=response_messages, + temperature=request.temperature or 0.7, + max_tokens=request.max_tokens or 1000, + ) + + generated_response = response.choices[0].message.content.strip() + logger.info(f"Response generated successfully") + + return ChatCompletionResponse( + id=f"chatcmpl-{uuid.uuid4().hex[:8]}", + created=int(time.time()), + model=request.model, + choices=[ + { + "index": 0, + "message": {"role": "assistant", "content": generated_response}, + "finish_reason": "stop", + } + ], + usage={ + "prompt_tokens": sum( + len(msg.content.split()) for msg in request.messages + ), + "completion_tokens": len(generated_response.split()), + "total_tokens": sum( + len(msg.content.split()) for msg in request.messages + ) + + len(generated_response.split()), + }, + ) + + except Exception as e: + logger.error(f"Error generating response: {e}") + + # Fallback response + fallback_message = "I apologize, but I'm having trouble generating a response right now. Please try again." + return ChatCompletionResponse( + id=f"chatcmpl-{uuid.uuid4().hex[:8]}", + created=int(time.time()), + model=request.model, + choices=[ + { + "index": 0, + "message": {"role": "assistant", "content": fallback_message}, + "finish_reason": "stop", + } + ], + usage={ + "prompt_tokens": sum( + len(msg.content.split()) for msg in request.messages + ), + "completion_tokens": len(fallback_message.split()), + "total_tokens": sum( + len(msg.content.split()) for msg in request.messages + ) + + len(fallback_message.split()), + }, + ) + + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return {"status": "healthy"}