changes to the agents

This commit is contained in:
Salman Paracha 2025-09-11 13:25:09 -07:00
parent c1e142f55f
commit 32838584cf
10 changed files with 544 additions and 135 deletions

View file

@ -0,0 +1,28 @@
# RAG Agent Query Parser
A FastAPI service that rewrites user queries using archgw and gpt-4o-mini for better retrieval accuracy.
## How it Works
1. Receives a chat completion request with conversation history
2. Calls archgw's LLM gateway with gpt-4o-mini to rewrite the last user query
3. Returns the rewritten query as the assistant response
## Setup and Running
1. **Start archgw**:
```bash
archgw up --foreground
```
2. **Start the query parser service**:
```bash
uv run python -m rag_agent.query_parser
```
## Configuration
```bash
# archgw LLM Gateway base URL (default: http://localhost:12000/v1)
export LLM_GATEWAY_ENDPOINT="http://localhost:12000/v1"
```

View file

@ -24,11 +24,6 @@ listeners:
- query_rewriter
- context_builder
- response_generator
- name: research_agent
description: agent to research and gather information from various sources.
filter_chain:
- research_agent
- response_generator
port: 8001
- name: egress_traffic
@ -38,3 +33,5 @@ listeners:
llm_providers:
- access_key: ${OPENAI_API_KEY}
model: openai/gpt-4o
- access_key: ${OPENAI_API_KEY}
model: openai/gpt-4o-mini

View file

@ -11,6 +11,7 @@ dependencies = [
"pydantic>=2.11.7",
"fastapi>=0.104.1",
"uvicorn>=0.24.0",
"openai>=1.0.0",
]
[project.scripts]

View file

@ -0,0 +1,28 @@
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 1.0
max_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0
stream: Optional[bool] = False
stop: Optional[List[str]] = None
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, int]

View file

@ -0,0 +1,223 @@
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from fastapi import FastAPI, HTTPException
from openai import AsyncOpenAI
import os
import logging
import csv
from pathlib import Path
from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration for archgw LLM gateway
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
RAG_MODEL = "gpt-4o-mini"
# Initialize OpenAI client for archgw
archgw_client = AsyncOpenAI(
base_url=LLM_GATEWAY_ENDPOINT,
api_key="EMPTY", # archgw doesn't require a real API key
)
# Global variable to store the knowledge base
knowledge_base = []
def load_knowledge_base():
"""Load the basis_of_truth.csv file into memory on startup."""
global knowledge_base
# Get the path to the CSV file relative to this script
current_dir = Path(__file__).parent
csv_path = current_dir / "basis_of_truth.csv"
try:
knowledge_base = []
with open(csv_path, "r", encoding="utf-8") as file:
csv_reader = csv.DictReader(file)
for row in csv_reader:
knowledge_base.append({"path": row["path"], "content": row["content"]})
logger.info(f"Loaded {len(knowledge_base)} documents from knowledge base")
except Exception as e:
logger.error(f"Error loading knowledge base: {e}")
knowledge_base = []
async def find_relevant_passages(query: str, top_k: int = 3) -> List[Dict[str, str]]:
"""Use the LLM to find the most relevant passages from the knowledge base."""
if not knowledge_base:
logger.warning("Knowledge base is empty")
return []
# Create a system prompt for passage selection
system_prompt = f"""You are a retrieval assistant that selects the most relevant document passages for a given query.
Given a user query and a list of document passages, identify the {top_k} most relevant passages that would help answer the query.
Query: {query}
Available passages:
"""
# Add all passages with indices
for i, doc in enumerate(knowledge_base):
system_prompt += (
f"\n[{i}] Path: {doc['path']}\nContent: {doc['content'][:500]}...\n"
)
system_prompt += f"""
Please respond with ONLY the indices of the {top_k} most relevant passages, separated by commas (e.g., "0,3,7").
If fewer than {top_k} passages are relevant, return only the relevant ones.
If no passages are relevant, return "NONE"."""
try:
# Call archgw to select relevant passages
logger.info(f"Calling archgw to find relevant passages for query: '{query}'")
response = await archgw_client.chat.completions.create(
model=RAG_MODEL,
messages=[{"role": "system", "content": system_prompt}],
temperature=0.1,
max_tokens=50,
)
result = response.choices[0].message.content.strip()
logger.info(f"LLM selected passages: {result}")
# Parse the indices
if result.upper() == "NONE":
return []
selected_passages = []
indices = [
int(idx.strip()) for idx in result.split(",") if idx.strip().isdigit()
]
for idx in indices:
if 0 <= idx < len(knowledge_base):
selected_passages.append(knowledge_base[idx])
logger.info(f"Selected {len(selected_passages)} relevant passages")
return selected_passages
except Exception as e:
logger.error(f"Error finding relevant passages: {e}")
return []
async def augment_query_with_context(messages: List[ChatMessage]) -> List[ChatMessage]:
"""Extract user query, find relevant context, and augment the messages."""
# Find the last user message
last_user_message = None
last_user_index = -1
for i in range(len(messages) - 1, -1, -1):
if messages[i].role == "user":
last_user_message = messages[i].content
last_user_index = i
break
if not last_user_message:
logger.warning("No user message found in conversation")
return messages
logger.info(f"Processing user query: '{last_user_message}'")
# Find relevant passages
relevant_passages = await find_relevant_passages(last_user_message)
if not relevant_passages:
logger.info("No relevant passages found, returning original messages")
return messages
# Build context from relevant passages
context_parts = []
for i, passage in enumerate(relevant_passages):
context_parts.append(
f"Document {i+1} ({passage['path']}):\n{passage['content']}"
)
context = "\n\n".join(context_parts)
# Create augmented content with original query and context
augmented_content = f"""{last_user_message} RELEVANT CONTEXT:
{context}"""
# Create updated messages with the augmented query
updated_messages = messages.copy()
updated_messages[last_user_index] = ChatMessage(
role="user", content=augmented_content
)
logger.info(f"Augmented user query with {len(relevant_passages)} relevant passages")
return updated_messages
class Response(BaseModel):
query: str
metadata: dict
# FastAPI app for REST server
app = FastAPI(title="RAG Content Builder Agent", version="1.0.0")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""Chat completions endpoint that augments user queries with relevant context from the knowledge base."""
import time
import uuid
logger.info(
f"Received chat completion request with {len(request.messages)} messages"
)
# Augment the user query with relevant context
updated_messages = await augment_query_with_context(request.messages)
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"messages": [
{"role": msg.role, "content": msg.content}
for msg in updated_messages
],
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": sum(len(msg.content.split()) for msg in updated_messages),
"completion_tokens": len("Context added to user query.".split()),
"total_tokens": sum(len(msg.content.split()) for msg in updated_messages)
+ len("Context added to user query.".split()),
},
)
return response
def main():
"""Main function to initialize the knowledge base and start the server."""
load_knowledge_base()
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()

View file

@ -1,18 +0,0 @@
from pydantic import BaseModel
from . import mcp
class QueryRequest(BaseModel):
query: str
metadata: dict | None = None
class QueryResponse(BaseModel):
query: str
results: list
@mcp.tool()
def query_rag_store(request: QueryRequest):
"""Query the RAG document store."""
return {"query": request.query, "results": []}

View file

@ -1,101 +0,0 @@
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from fastapi import FastAPI, HTTPException
import uvicorn
# OpenAI Chat Completions API models
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = 1.0
max_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0
stream: Optional[bool] = False
stop: Optional[List[str]] = None
class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, int]
class Response(BaseModel):
query: str
metadata: dict
# FastAPI app for REST server
app = FastAPI(title="RAG Agent Query Parser", version="1.0.0")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""Chat completions endpoint that passes through the request as-is."""
import time
import uuid
# Pass-through: return the last user message as the assistant response
last_user_message = ""
for message in reversed(request.messages):
if message.role == "user":
last_user_message = message.content
break
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": last_user_message},
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": sum(len(msg.content.split()) for msg in request.messages),
"completion_tokens": len(last_user_message.split()),
"total_tokens": sum(len(msg.content.split()) for msg in request.messages)
+ len(last_user_message.split()),
},
)
return response
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
def parse_query(query):
"""Parse the user query and returns metadata extracted from query."""
return Response(query=query, metadata={"is_valid": True})
# Register MCP tool only if mcp is available
try:
from . import mcp
if mcp is not None:
mcp.tool()(parse_query)
except (ImportError, AttributeError):
pass
def start_server(host: str = "localhost", port: int = 8000):
"""Start the REST server."""
uvicorn.run(app, host=host, port=port)

View file

@ -0,0 +1,140 @@
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from fastapi import FastAPI, HTTPException
from openai import AsyncOpenAI
import os
import logging
from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration for archgw LLM gateway
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
QUERY_REWRITE_MODEL = "gpt-4o-mini"
# Initialize OpenAI client for archgw
archgw_client = AsyncOpenAI(
base_url=LLM_GATEWAY_ENDPOINT,
api_key="EMPTY", # archgw doesn't require a real API key
)
async def rewrite_query_with_archgw(messages: List[ChatMessage]) -> str:
# Prepare the system prompt for query rewriting
system_prompt = """You are a query rewriter that improves user queries for better retrieval.
Given a conversation history, rewrite the last user message to be more specific and context-aware.
The rewritten query should:
1. Include relevant context from previous messages
2. Be clear and specific for information retrieval
3. Maintain the user's intent
4. Be concise but comprehensive
Return only the rewritten query, nothing else."""
# Prepare messages for the query rewriter - just add system prompt to existing messages
rewrite_messages = [{"role": "system", "content": system_prompt}]
# Add conversation history
for msg in messages:
rewrite_messages.append({"role": msg.role, "content": msg.content})
try:
# Call archgw using OpenAI client
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to rewrite query")
response = await archgw_client.chat.completions.create(
model=QUERY_REWRITE_MODEL,
messages=rewrite_messages,
temperature=0.3,
max_tokens=200,
)
rewritten_query = response.choices[0].message.content.strip()
logger.info(f"Query rewritten successfully: '{rewritten_query}'")
return rewritten_query
except Exception as e:
logger.error(f"Error rewriting query: {e}")
# If rewriting fails, return the original last user message
logger.info("Falling back to original user message")
for message in reversed(messages):
if message.role == "user":
return message.content
return ""
class Response(BaseModel):
query: str
metadata: dict
# FastAPI app for REST server
app = FastAPI(title="RAG Agent Query Parser", version="1.0.0")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""Chat completions endpoint that rewrites the last user query using archgw."""
import time
import uuid
logger.info(
f"Received chat completion request with {len(request.messages)} messages"
)
# Call archgw to rewrite the last user query
rewritten_query = await rewrite_query_with_archgw(request.messages)
# Create updated messages with the rewritten query
updated_messages = request.messages.copy()
# Find and update the last user message with the rewritten query
for i in range(len(updated_messages) - 1, -1, -1):
if updated_messages[i].role == "user":
original_query = updated_messages[i].content
updated_messages[i] = ChatMessage(role="user", content=rewritten_query)
logger.info(
f"Updated user query from '{original_query}' to '{rewritten_query}'"
)
break
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"messages": [
{"role": msg.role, "content": msg.content}
for msg in updated_messages
],
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": sum(len(msg.content.split()) for msg in updated_messages),
"completion_tokens": len("Updated query for better retrieval.".split()),
"total_tokens": sum(len(msg.content.split()) for msg in updated_messages)
+ len("Updated query for better retrieval.".split()),
},
)
return response
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
def parse_query(query):
"""Parse the user query and returns metadata extracted from query."""
return Response(query=query, metadata={"is_valid": True})

View file

@ -1,11 +0,0 @@
from . import mcp
@mcp.tool()
def generate_response(query, context):
"""Generate a response based on the user query and context."""
return {
"query": query,
"context": context,
"response": "This is a generated response.",
}

View file

@ -0,0 +1,122 @@
from fastapi import FastAPI
from openai import AsyncOpenAI
import os
import logging
import time
import uuid
from .api import ChatCompletionRequest, ChatCompletionResponse
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configuration for archgw LLM gateway
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
RESPONSE_MODEL = "gpt-4o"
# Initialize OpenAI client for archgw
archgw_client = AsyncOpenAI(
base_url=LLM_GATEWAY_ENDPOINT,
api_key="EMPTY", # archgw doesn't require a real API key
)
# FastAPI app for REST server
app = FastAPI(title="RAG Agent Response Generator", version="1.0.0")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""Chat completions endpoint that generates a coherent response based on all context."""
logger.info(
f"Received chat completion request with {len(request.messages)} messages"
)
# Prepare the system prompt for response generation
system_prompt = """You are a helpful assistant that generates coherent, contextual responses.
Given a conversation history, generate a helpful and relevant response based on all the context available in the messages.
Your response should:
1. Be contextually aware of the entire conversation
2. Address the user's needs appropriately
3. Be helpful and informative
4. Maintain a natural conversational tone
Generate a complete response to assist the user."""
# Prepare messages for response generation
response_messages = [{"role": "system", "content": system_prompt}]
# Add conversation history
for msg in request.messages:
response_messages.append({"role": msg.role, "content": msg.content})
try:
# Call archgw using OpenAI client
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate response")
response = await archgw_client.chat.completions.create(
model=RESPONSE_MODEL,
messages=response_messages,
temperature=request.temperature or 0.7,
max_tokens=request.max_tokens or 1000,
)
generated_response = response.choices[0].message.content.strip()
logger.info(f"Response generated successfully")
return ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": generated_response},
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": sum(
len(msg.content.split()) for msg in request.messages
),
"completion_tokens": len(generated_response.split()),
"total_tokens": sum(
len(msg.content.split()) for msg in request.messages
)
+ len(generated_response.split()),
},
)
except Exception as e:
logger.error(f"Error generating response: {e}")
# Fallback response
fallback_message = "I apologize, but I'm having trouble generating a response right now. Please try again."
return ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": fallback_message},
"finish_reason": "stop",
}
],
usage={
"prompt_tokens": sum(
len(msg.content.split()) for msg in request.messages
),
"completion_tokens": len(fallback_message.split()),
"total_tokens": sum(
len(msg.content.split()) for msg in request.messages
)
+ len(fallback_message.split()),
},
)
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}