mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
changes to the agents
This commit is contained in:
parent
c1e142f55f
commit
32838584cf
10 changed files with 544 additions and 135 deletions
|
|
@ -0,0 +1,28 @@
|
||||||
|
# RAG Agent Query Parser
|
||||||
|
|
||||||
|
A FastAPI service that rewrites user queries using archgw and gpt-4o-mini for better retrieval accuracy.
|
||||||
|
|
||||||
|
## How it Works
|
||||||
|
|
||||||
|
1. Receives a chat completion request with conversation history
|
||||||
|
2. Calls archgw's LLM gateway with gpt-4o-mini to rewrite the last user query
|
||||||
|
3. Returns the rewritten query as the assistant response
|
||||||
|
|
||||||
|
## Setup and Running
|
||||||
|
|
||||||
|
1. **Start archgw**:
|
||||||
|
```bash
|
||||||
|
archgw up --foreground
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Start the query parser service**:
|
||||||
|
```bash
|
||||||
|
uv run python -m rag_agent.query_parser
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# archgw LLM Gateway base URL (default: http://localhost:12000/v1)
|
||||||
|
export LLM_GATEWAY_ENDPOINT="http://localhost:12000/v1"
|
||||||
|
```
|
||||||
|
|
@ -24,11 +24,6 @@ listeners:
|
||||||
- query_rewriter
|
- query_rewriter
|
||||||
- context_builder
|
- context_builder
|
||||||
- response_generator
|
- response_generator
|
||||||
- name: research_agent
|
|
||||||
description: agent to research and gather information from various sources.
|
|
||||||
filter_chain:
|
|
||||||
- research_agent
|
|
||||||
- response_generator
|
|
||||||
port: 8001
|
port: 8001
|
||||||
|
|
||||||
- name: egress_traffic
|
- name: egress_traffic
|
||||||
|
|
@ -38,3 +33,5 @@ listeners:
|
||||||
llm_providers:
|
llm_providers:
|
||||||
- access_key: ${OPENAI_API_KEY}
|
- access_key: ${OPENAI_API_KEY}
|
||||||
model: openai/gpt-4o
|
model: openai/gpt-4o
|
||||||
|
- access_key: ${OPENAI_API_KEY}
|
||||||
|
model: openai/gpt-4o-mini
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ dependencies = [
|
||||||
"pydantic>=2.11.7",
|
"pydantic>=2.11.7",
|
||||||
"fastapi>=0.104.1",
|
"fastapi>=0.104.1",
|
||||||
"uvicorn>=0.24.0",
|
"uvicorn>=0.24.0",
|
||||||
|
"openai>=1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
|
||||||
28
demos/use_cases/rag_agent/src/rag_agent/api.py
Normal file
28
demos/use_cases/rag_agent/src/rag_agent/api.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
temperature: Optional[float] = 1.0
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
presence_penalty: Optional[float] = 0.0
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "chat.completion"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
choices: List[Dict[str, Any]]
|
||||||
|
usage: Dict[str, int]
|
||||||
223
demos/use_cases/rag_agent/src/rag_agent/content_builder_agent.py
Normal file
223
demos/use_cases/rag_agent/src/rag_agent/content_builder_agent.py
Normal file
|
|
@ -0,0 +1,223 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import csv
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse
|
||||||
|
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration for archgw LLM gateway
|
||||||
|
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
|
||||||
|
RAG_MODEL = "gpt-4o-mini"
|
||||||
|
|
||||||
|
# Initialize OpenAI client for archgw
|
||||||
|
archgw_client = AsyncOpenAI(
|
||||||
|
base_url=LLM_GATEWAY_ENDPOINT,
|
||||||
|
api_key="EMPTY", # archgw doesn't require a real API key
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global variable to store the knowledge base
|
||||||
|
knowledge_base = []
|
||||||
|
|
||||||
|
|
||||||
|
def load_knowledge_base():
|
||||||
|
"""Load the basis_of_truth.csv file into memory on startup."""
|
||||||
|
global knowledge_base
|
||||||
|
|
||||||
|
# Get the path to the CSV file relative to this script
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
csv_path = current_dir / "basis_of_truth.csv"
|
||||||
|
|
||||||
|
try:
|
||||||
|
knowledge_base = []
|
||||||
|
with open(csv_path, "r", encoding="utf-8") as file:
|
||||||
|
csv_reader = csv.DictReader(file)
|
||||||
|
for row in csv_reader:
|
||||||
|
knowledge_base.append({"path": row["path"], "content": row["content"]})
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(knowledge_base)} documents from knowledge base")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading knowledge base: {e}")
|
||||||
|
knowledge_base = []
|
||||||
|
|
||||||
|
|
||||||
|
async def find_relevant_passages(query: str, top_k: int = 3) -> List[Dict[str, str]]:
|
||||||
|
"""Use the LLM to find the most relevant passages from the knowledge base."""
|
||||||
|
|
||||||
|
if not knowledge_base:
|
||||||
|
logger.warning("Knowledge base is empty")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Create a system prompt for passage selection
|
||||||
|
system_prompt = f"""You are a retrieval assistant that selects the most relevant document passages for a given query.
|
||||||
|
|
||||||
|
Given a user query and a list of document passages, identify the {top_k} most relevant passages that would help answer the query.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Available passages:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Add all passages with indices
|
||||||
|
for i, doc in enumerate(knowledge_base):
|
||||||
|
system_prompt += (
|
||||||
|
f"\n[{i}] Path: {doc['path']}\nContent: {doc['content'][:500]}...\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt += f"""
|
||||||
|
|
||||||
|
Please respond with ONLY the indices of the {top_k} most relevant passages, separated by commas (e.g., "0,3,7").
|
||||||
|
If fewer than {top_k} passages are relevant, return only the relevant ones.
|
||||||
|
If no passages are relevant, return "NONE"."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call archgw to select relevant passages
|
||||||
|
logger.info(f"Calling archgw to find relevant passages for query: '{query}'")
|
||||||
|
response = await archgw_client.chat.completions.create(
|
||||||
|
model=RAG_MODEL,
|
||||||
|
messages=[{"role": "system", "content": system_prompt}],
|
||||||
|
temperature=0.1,
|
||||||
|
max_tokens=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.choices[0].message.content.strip()
|
||||||
|
logger.info(f"LLM selected passages: {result}")
|
||||||
|
|
||||||
|
# Parse the indices
|
||||||
|
if result.upper() == "NONE":
|
||||||
|
return []
|
||||||
|
|
||||||
|
selected_passages = []
|
||||||
|
indices = [
|
||||||
|
int(idx.strip()) for idx in result.split(",") if idx.strip().isdigit()
|
||||||
|
]
|
||||||
|
|
||||||
|
for idx in indices:
|
||||||
|
if 0 <= idx < len(knowledge_base):
|
||||||
|
selected_passages.append(knowledge_base[idx])
|
||||||
|
|
||||||
|
logger.info(f"Selected {len(selected_passages)} relevant passages")
|
||||||
|
return selected_passages
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finding relevant passages: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def augment_query_with_context(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||||
|
"""Extract user query, find relevant context, and augment the messages."""
|
||||||
|
|
||||||
|
# Find the last user message
|
||||||
|
last_user_message = None
|
||||||
|
last_user_index = -1
|
||||||
|
|
||||||
|
for i in range(len(messages) - 1, -1, -1):
|
||||||
|
if messages[i].role == "user":
|
||||||
|
last_user_message = messages[i].content
|
||||||
|
last_user_index = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if not last_user_message:
|
||||||
|
logger.warning("No user message found in conversation")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
logger.info(f"Processing user query: '{last_user_message}'")
|
||||||
|
|
||||||
|
# Find relevant passages
|
||||||
|
relevant_passages = await find_relevant_passages(last_user_message)
|
||||||
|
|
||||||
|
if not relevant_passages:
|
||||||
|
logger.info("No relevant passages found, returning original messages")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Build context from relevant passages
|
||||||
|
context_parts = []
|
||||||
|
for i, passage in enumerate(relevant_passages):
|
||||||
|
context_parts.append(
|
||||||
|
f"Document {i+1} ({passage['path']}):\n{passage['content']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
context = "\n\n".join(context_parts)
|
||||||
|
|
||||||
|
# Create augmented content with original query and context
|
||||||
|
augmented_content = f"""{last_user_message} RELEVANT CONTEXT:
|
||||||
|
{context}"""
|
||||||
|
|
||||||
|
# Create updated messages with the augmented query
|
||||||
|
updated_messages = messages.copy()
|
||||||
|
updated_messages[last_user_index] = ChatMessage(
|
||||||
|
role="user", content=augmented_content
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Augmented user query with {len(relevant_passages)} relevant passages")
|
||||||
|
|
||||||
|
return updated_messages
|
||||||
|
|
||||||
|
|
||||||
|
class Response(BaseModel):
|
||||||
|
query: str
|
||||||
|
metadata: dict
|
||||||
|
|
||||||
|
|
||||||
|
# FastAPI app for REST server
|
||||||
|
app = FastAPI(title="RAG Content Builder Agent", version="1.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat_completions(request: ChatCompletionRequest):
|
||||||
|
"""Chat completions endpoint that augments user queries with relevant context from the knowledge base."""
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Received chat completion request with {len(request.messages)} messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Augment the user query with relevant context
|
||||||
|
updated_messages = await augment_query_with_context(request.messages)
|
||||||
|
|
||||||
|
response = ChatCompletionResponse(
|
||||||
|
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||||
|
created=int(time.time()),
|
||||||
|
model=request.model,
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"messages": [
|
||||||
|
{"role": msg.role, "content": msg.content}
|
||||||
|
for msg in updated_messages
|
||||||
|
],
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": sum(len(msg.content.split()) for msg in updated_messages),
|
||||||
|
"completion_tokens": len("Context added to user query.".split()),
|
||||||
|
"total_tokens": sum(len(msg.content.split()) for msg in updated_messages)
|
||||||
|
+ len("Context added to user query.".split()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function to initialize the knowledge base and start the server."""
|
||||||
|
load_knowledge_base()
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
from . import mcp
|
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
|
||||||
query: str
|
|
||||||
metadata: dict | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class QueryResponse(BaseModel):
|
|
||||||
query: str
|
|
||||||
results: list
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
def query_rag_store(request: QueryRequest):
|
|
||||||
"""Query the RAG document store."""
|
|
||||||
return {"query": request.query, "results": []}
|
|
||||||
|
|
@ -1,101 +0,0 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import List, Optional, Dict, Any
|
|
||||||
from fastapi import FastAPI, HTTPException
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
|
|
||||||
# OpenAI Chat Completions API models
|
|
||||||
class ChatMessage(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
messages: List[ChatMessage]
|
|
||||||
temperature: Optional[float] = 1.0
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
top_p: Optional[float] = 1.0
|
|
||||||
frequency_penalty: Optional[float] = 0.0
|
|
||||||
presence_penalty: Optional[float] = 0.0
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
stop: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
object: str = "chat.completion"
|
|
||||||
created: int
|
|
||||||
model: str
|
|
||||||
choices: List[Dict[str, Any]]
|
|
||||||
usage: Dict[str, int]
|
|
||||||
|
|
||||||
|
|
||||||
class Response(BaseModel):
|
|
||||||
query: str
|
|
||||||
metadata: dict
|
|
||||||
|
|
||||||
|
|
||||||
# FastAPI app for REST server
|
|
||||||
app = FastAPI(title="RAG Agent Query Parser", version="1.0.0")
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
|
||||||
async def chat_completions(request: ChatCompletionRequest):
|
|
||||||
"""Chat completions endpoint that passes through the request as-is."""
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
# Pass-through: return the last user message as the assistant response
|
|
||||||
last_user_message = ""
|
|
||||||
for message in reversed(request.messages):
|
|
||||||
if message.role == "user":
|
|
||||||
last_user_message = message.content
|
|
||||||
break
|
|
||||||
|
|
||||||
response = ChatCompletionResponse(
|
|
||||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
|
||||||
created=int(time.time()),
|
|
||||||
model=request.model,
|
|
||||||
choices=[
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": last_user_message},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
usage={
|
|
||||||
"prompt_tokens": sum(len(msg.content.split()) for msg in request.messages),
|
|
||||||
"completion_tokens": len(last_user_message.split()),
|
|
||||||
"total_tokens": sum(len(msg.content.split()) for msg in request.messages)
|
|
||||||
+ len(last_user_message.split()),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check():
|
|
||||||
"""Health check endpoint."""
|
|
||||||
return {"status": "healthy"}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_query(query):
|
|
||||||
"""Parse the user query and returns metadata extracted from query."""
|
|
||||||
return Response(query=query, metadata={"is_valid": True})
|
|
||||||
|
|
||||||
|
|
||||||
# Register MCP tool only if mcp is available
|
|
||||||
try:
|
|
||||||
from . import mcp
|
|
||||||
|
|
||||||
if mcp is not None:
|
|
||||||
mcp.tool()(parse_query)
|
|
||||||
except (ImportError, AttributeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def start_server(host: str = "localhost", port: int = 8000):
|
|
||||||
"""Start the REST server."""
|
|
||||||
uvicorn.run(app, host=host, port=port)
|
|
||||||
140
demos/use_cases/rag_agent/src/rag_agent/query_rewriter_agent.py
Normal file
140
demos/use_cases/rag_agent/src/rag_agent/query_rewriter_agent.py
Normal file
|
|
@ -0,0 +1,140 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse
|
||||||
|
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration for archgw LLM gateway
|
||||||
|
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
|
||||||
|
QUERY_REWRITE_MODEL = "gpt-4o-mini"
|
||||||
|
|
||||||
|
# Initialize OpenAI client for archgw
|
||||||
|
archgw_client = AsyncOpenAI(
|
||||||
|
base_url=LLM_GATEWAY_ENDPOINT,
|
||||||
|
api_key="EMPTY", # archgw doesn't require a real API key
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def rewrite_query_with_archgw(messages: List[ChatMessage]) -> str:
|
||||||
|
# Prepare the system prompt for query rewriting
|
||||||
|
system_prompt = """You are a query rewriter that improves user queries for better retrieval.
|
||||||
|
|
||||||
|
Given a conversation history, rewrite the last user message to be more specific and context-aware.
|
||||||
|
The rewritten query should:
|
||||||
|
1. Include relevant context from previous messages
|
||||||
|
2. Be clear and specific for information retrieval
|
||||||
|
3. Maintain the user's intent
|
||||||
|
4. Be concise but comprehensive
|
||||||
|
|
||||||
|
Return only the rewritten query, nothing else."""
|
||||||
|
|
||||||
|
# Prepare messages for the query rewriter - just add system prompt to existing messages
|
||||||
|
rewrite_messages = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
|
# Add conversation history
|
||||||
|
for msg in messages:
|
||||||
|
rewrite_messages.append({"role": msg.role, "content": msg.content})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call archgw using OpenAI client
|
||||||
|
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to rewrite query")
|
||||||
|
response = await archgw_client.chat.completions.create(
|
||||||
|
model=QUERY_REWRITE_MODEL,
|
||||||
|
messages=rewrite_messages,
|
||||||
|
temperature=0.3,
|
||||||
|
max_tokens=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
rewritten_query = response.choices[0].message.content.strip()
|
||||||
|
logger.info(f"Query rewritten successfully: '{rewritten_query}'")
|
||||||
|
return rewritten_query
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error rewriting query: {e}")
|
||||||
|
|
||||||
|
# If rewriting fails, return the original last user message
|
||||||
|
logger.info("Falling back to original user message")
|
||||||
|
for message in reversed(messages):
|
||||||
|
if message.role == "user":
|
||||||
|
return message.content
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class Response(BaseModel):
|
||||||
|
query: str
|
||||||
|
metadata: dict
|
||||||
|
|
||||||
|
|
||||||
|
# FastAPI app for REST server
|
||||||
|
app = FastAPI(title="RAG Agent Query Parser", version="1.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat_completions(request: ChatCompletionRequest):
|
||||||
|
"""Chat completions endpoint that rewrites the last user query using archgw."""
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Received chat completion request with {len(request.messages)} messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call archgw to rewrite the last user query
|
||||||
|
rewritten_query = await rewrite_query_with_archgw(request.messages)
|
||||||
|
|
||||||
|
# Create updated messages with the rewritten query
|
||||||
|
updated_messages = request.messages.copy()
|
||||||
|
|
||||||
|
# Find and update the last user message with the rewritten query
|
||||||
|
for i in range(len(updated_messages) - 1, -1, -1):
|
||||||
|
if updated_messages[i].role == "user":
|
||||||
|
original_query = updated_messages[i].content
|
||||||
|
updated_messages[i] = ChatMessage(role="user", content=rewritten_query)
|
||||||
|
logger.info(
|
||||||
|
f"Updated user query from '{original_query}' to '{rewritten_query}'"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
response = ChatCompletionResponse(
|
||||||
|
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||||
|
created=int(time.time()),
|
||||||
|
model=request.model,
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"messages": [
|
||||||
|
{"role": msg.role, "content": msg.content}
|
||||||
|
for msg in updated_messages
|
||||||
|
],
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": sum(len(msg.content.split()) for msg in updated_messages),
|
||||||
|
"completion_tokens": len("Updated query for better retrieval.".split()),
|
||||||
|
"total_tokens": sum(len(msg.content.split()) for msg in updated_messages)
|
||||||
|
+ len("Updated query for better retrieval.".split()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_query(query):
|
||||||
|
"""Parse the user query and returns metadata extracted from query."""
|
||||||
|
return Response(query=query, metadata={"is_valid": True})
|
||||||
|
|
@ -1,11 +0,0 @@
|
||||||
from . import mcp
|
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
|
||||||
def generate_response(query, context):
|
|
||||||
"""Generate a response based on the user query and context."""
|
|
||||||
return {
|
|
||||||
"query": query,
|
|
||||||
"context": context,
|
|
||||||
"response": "This is a generated response.",
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,122 @@
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from .api import ChatCompletionRequest, ChatCompletionResponse
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Configuration for archgw LLM gateway
|
||||||
|
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
|
||||||
|
RESPONSE_MODEL = "gpt-4o"
|
||||||
|
|
||||||
|
# Initialize OpenAI client for archgw
|
||||||
|
archgw_client = AsyncOpenAI(
|
||||||
|
base_url=LLM_GATEWAY_ENDPOINT,
|
||||||
|
api_key="EMPTY", # archgw doesn't require a real API key
|
||||||
|
)
|
||||||
|
|
||||||
|
# FastAPI app for REST server
|
||||||
|
app = FastAPI(title="RAG Agent Response Generator", version="1.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat_completions(request: ChatCompletionRequest):
|
||||||
|
"""Chat completions endpoint that generates a coherent response based on all context."""
|
||||||
|
logger.info(
|
||||||
|
f"Received chat completion request with {len(request.messages)} messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare the system prompt for response generation
|
||||||
|
system_prompt = """You are a helpful assistant that generates coherent, contextual responses.
|
||||||
|
|
||||||
|
Given a conversation history, generate a helpful and relevant response based on all the context available in the messages.
|
||||||
|
Your response should:
|
||||||
|
1. Be contextually aware of the entire conversation
|
||||||
|
2. Address the user's needs appropriately
|
||||||
|
3. Be helpful and informative
|
||||||
|
4. Maintain a natural conversational tone
|
||||||
|
|
||||||
|
Generate a complete response to assist the user."""
|
||||||
|
|
||||||
|
# Prepare messages for response generation
|
||||||
|
response_messages = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
|
# Add conversation history
|
||||||
|
for msg in request.messages:
|
||||||
|
response_messages.append({"role": msg.role, "content": msg.content})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Call archgw using OpenAI client
|
||||||
|
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate response")
|
||||||
|
response = await archgw_client.chat.completions.create(
|
||||||
|
model=RESPONSE_MODEL,
|
||||||
|
messages=response_messages,
|
||||||
|
temperature=request.temperature or 0.7,
|
||||||
|
max_tokens=request.max_tokens or 1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_response = response.choices[0].message.content.strip()
|
||||||
|
logger.info(f"Response generated successfully")
|
||||||
|
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||||
|
created=int(time.time()),
|
||||||
|
model=request.model,
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": generated_response},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": sum(
|
||||||
|
len(msg.content.split()) for msg in request.messages
|
||||||
|
),
|
||||||
|
"completion_tokens": len(generated_response.split()),
|
||||||
|
"total_tokens": sum(
|
||||||
|
len(msg.content.split()) for msg in request.messages
|
||||||
|
)
|
||||||
|
+ len(generated_response.split()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating response: {e}")
|
||||||
|
|
||||||
|
# Fallback response
|
||||||
|
fallback_message = "I apologize, but I'm having trouble generating a response right now. Please try again."
|
||||||
|
return ChatCompletionResponse(
|
||||||
|
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||||
|
created=int(time.time()),
|
||||||
|
model=request.model,
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": fallback_message},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": sum(
|
||||||
|
len(msg.content.split()) for msg in request.messages
|
||||||
|
),
|
||||||
|
"completion_tokens": len(fallback_message.split()),
|
||||||
|
"total_tokens": sum(
|
||||||
|
len(msg.content.split()) for msg in request.messages
|
||||||
|
)
|
||||||
|
+ len(fallback_message.split()),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return {"status": "healthy"}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue