mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
changes to the agents
This commit is contained in:
parent
c1e142f55f
commit
32838584cf
10 changed files with 544 additions and 135 deletions
|
|
@ -0,0 +1,28 @@
|
|||
# RAG Agent Query Parser
|
||||
|
||||
A FastAPI service that rewrites user queries using archgw and gpt-4o-mini for better retrieval accuracy.
|
||||
|
||||
## How it Works
|
||||
|
||||
1. Receives a chat completion request with conversation history
|
||||
2. Calls archgw's LLM gateway with gpt-4o-mini to rewrite the last user query
|
||||
3. Returns the rewritten query as the assistant response
|
||||
|
||||
## Setup and Running
|
||||
|
||||
1. **Start archgw**:
|
||||
```bash
|
||||
archgw up --foreground
|
||||
```
|
||||
|
||||
2. **Start the query parser service**:
|
||||
```bash
|
||||
uv run python -m rag_agent.query_parser
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
```bash
|
||||
# archgw LLM Gateway base URL (default: http://localhost:12000/v1)
|
||||
export LLM_GATEWAY_ENDPOINT="http://localhost:12000/v1"
|
||||
```
|
||||
|
|
@ -24,11 +24,6 @@ listeners:
|
|||
- query_rewriter
|
||||
- context_builder
|
||||
- response_generator
|
||||
- name: research_agent
|
||||
description: agent to research and gather information from various sources.
|
||||
filter_chain:
|
||||
- research_agent
|
||||
- response_generator
|
||||
port: 8001
|
||||
|
||||
- name: egress_traffic
|
||||
|
|
@ -38,3 +33,5 @@ listeners:
|
|||
llm_providers:
|
||||
- access_key: ${OPENAI_API_KEY}
|
||||
model: openai/gpt-4o
|
||||
- access_key: ${OPENAI_API_KEY}
|
||||
model: openai/gpt-4o-mini
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ dependencies = [
|
|||
"pydantic>=2.11.7",
|
||||
"fastapi>=0.104.1",
|
||||
"uvicorn>=0.24.0",
|
||||
"openai>=1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
28
demos/use_cases/rag_agent/src/rag_agent/api.py
Normal file
28
demos/use_cases/rag_agent/src/rag_agent/api.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = 1.0
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = 1.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[Dict[str, Any]]
|
||||
usage: Dict[str, int]
|
||||
223
demos/use_cases/rag_agent/src/rag_agent/content_builder_agent.py
Normal file
223
demos/use_cases/rag_agent/src/rag_agent/content_builder_agent.py
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse
|
||||
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Configuration for archgw LLM gateway
|
||||
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
|
||||
RAG_MODEL = "gpt-4o-mini"
|
||||
|
||||
# Initialize OpenAI client for archgw
|
||||
archgw_client = AsyncOpenAI(
|
||||
base_url=LLM_GATEWAY_ENDPOINT,
|
||||
api_key="EMPTY", # archgw doesn't require a real API key
|
||||
)
|
||||
|
||||
# Global variable to store the knowledge base
|
||||
knowledge_base = []
|
||||
|
||||
|
||||
def load_knowledge_base():
|
||||
"""Load the basis_of_truth.csv file into memory on startup."""
|
||||
global knowledge_base
|
||||
|
||||
# Get the path to the CSV file relative to this script
|
||||
current_dir = Path(__file__).parent
|
||||
csv_path = current_dir / "basis_of_truth.csv"
|
||||
|
||||
try:
|
||||
knowledge_base = []
|
||||
with open(csv_path, "r", encoding="utf-8") as file:
|
||||
csv_reader = csv.DictReader(file)
|
||||
for row in csv_reader:
|
||||
knowledge_base.append({"path": row["path"], "content": row["content"]})
|
||||
|
||||
logger.info(f"Loaded {len(knowledge_base)} documents from knowledge base")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading knowledge base: {e}")
|
||||
knowledge_base = []
|
||||
|
||||
|
||||
async def find_relevant_passages(query: str, top_k: int = 3) -> List[Dict[str, str]]:
|
||||
"""Use the LLM to find the most relevant passages from the knowledge base."""
|
||||
|
||||
if not knowledge_base:
|
||||
logger.warning("Knowledge base is empty")
|
||||
return []
|
||||
|
||||
# Create a system prompt for passage selection
|
||||
system_prompt = f"""You are a retrieval assistant that selects the most relevant document passages for a given query.
|
||||
|
||||
Given a user query and a list of document passages, identify the {top_k} most relevant passages that would help answer the query.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Available passages:
|
||||
"""
|
||||
|
||||
# Add all passages with indices
|
||||
for i, doc in enumerate(knowledge_base):
|
||||
system_prompt += (
|
||||
f"\n[{i}] Path: {doc['path']}\nContent: {doc['content'][:500]}...\n"
|
||||
)
|
||||
|
||||
system_prompt += f"""
|
||||
|
||||
Please respond with ONLY the indices of the {top_k} most relevant passages, separated by commas (e.g., "0,3,7").
|
||||
If fewer than {top_k} passages are relevant, return only the relevant ones.
|
||||
If no passages are relevant, return "NONE"."""
|
||||
|
||||
try:
|
||||
# Call archgw to select relevant passages
|
||||
logger.info(f"Calling archgw to find relevant passages for query: '{query}'")
|
||||
response = await archgw_client.chat.completions.create(
|
||||
model=RAG_MODEL,
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
temperature=0.1,
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
result = response.choices[0].message.content.strip()
|
||||
logger.info(f"LLM selected passages: {result}")
|
||||
|
||||
# Parse the indices
|
||||
if result.upper() == "NONE":
|
||||
return []
|
||||
|
||||
selected_passages = []
|
||||
indices = [
|
||||
int(idx.strip()) for idx in result.split(",") if idx.strip().isdigit()
|
||||
]
|
||||
|
||||
for idx in indices:
|
||||
if 0 <= idx < len(knowledge_base):
|
||||
selected_passages.append(knowledge_base[idx])
|
||||
|
||||
logger.info(f"Selected {len(selected_passages)} relevant passages")
|
||||
return selected_passages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding relevant passages: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def augment_query_with_context(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
"""Extract user query, find relevant context, and augment the messages."""
|
||||
|
||||
# Find the last user message
|
||||
last_user_message = None
|
||||
last_user_index = -1
|
||||
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i].role == "user":
|
||||
last_user_message = messages[i].content
|
||||
last_user_index = i
|
||||
break
|
||||
|
||||
if not last_user_message:
|
||||
logger.warning("No user message found in conversation")
|
||||
return messages
|
||||
|
||||
logger.info(f"Processing user query: '{last_user_message}'")
|
||||
|
||||
# Find relevant passages
|
||||
relevant_passages = await find_relevant_passages(last_user_message)
|
||||
|
||||
if not relevant_passages:
|
||||
logger.info("No relevant passages found, returning original messages")
|
||||
return messages
|
||||
|
||||
# Build context from relevant passages
|
||||
context_parts = []
|
||||
for i, passage in enumerate(relevant_passages):
|
||||
context_parts.append(
|
||||
f"Document {i+1} ({passage['path']}):\n{passage['content']}"
|
||||
)
|
||||
|
||||
context = "\n\n".join(context_parts)
|
||||
|
||||
# Create augmented content with original query and context
|
||||
augmented_content = f"""{last_user_message} RELEVANT CONTEXT:
|
||||
{context}"""
|
||||
|
||||
# Create updated messages with the augmented query
|
||||
updated_messages = messages.copy()
|
||||
updated_messages[last_user_index] = ChatMessage(
|
||||
role="user", content=augmented_content
|
||||
)
|
||||
|
||||
logger.info(f"Augmented user query with {len(relevant_passages)} relevant passages")
|
||||
|
||||
return updated_messages
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
query: str
|
||||
metadata: dict
|
||||
|
||||
|
||||
# FastAPI app for REST server
|
||||
app = FastAPI(title="RAG Content Builder Agent", version="1.0.0")
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest):
|
||||
"""Chat completions endpoint that augments user queries with relevant context from the knowledge base."""
|
||||
import time
|
||||
import uuid
|
||||
|
||||
logger.info(
|
||||
f"Received chat completion request with {len(request.messages)} messages"
|
||||
)
|
||||
|
||||
# Augment the user query with relevant context
|
||||
updated_messages = await augment_query_with_context(request.messages)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"messages": [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in updated_messages
|
||||
],
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
usage={
|
||||
"prompt_tokens": sum(len(msg.content.split()) for msg in updated_messages),
|
||||
"completion_tokens": len("Context added to user query.".split()),
|
||||
"total_tokens": sum(len(msg.content.split()) for msg in updated_messages)
|
||||
+ len("Context added to user query.".split()),
|
||||
},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to initialize the knowledge base and start the server."""
|
||||
load_knowledge_base()
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
from . import mcp
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
metadata: dict | None = None
|
||||
|
||||
|
||||
class QueryResponse(BaseModel):
|
||||
query: str
|
||||
results: list
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def query_rag_store(request: QueryRequest):
|
||||
"""Query the RAG document store."""
|
||||
return {"query": request.query, "results": []}
|
||||
|
|
@ -1,101 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import FastAPI, HTTPException
|
||||
import uvicorn
|
||||
|
||||
|
||||
# OpenAI Chat Completions API models
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = 1.0
|
||||
max_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = 1.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[Dict[str, Any]]
|
||||
usage: Dict[str, int]
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
query: str
|
||||
metadata: dict
|
||||
|
||||
|
||||
# FastAPI app for REST server
|
||||
app = FastAPI(title="RAG Agent Query Parser", version="1.0.0")
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest):
|
||||
"""Chat completions endpoint that passes through the request as-is."""
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# Pass-through: return the last user message as the assistant response
|
||||
last_user_message = ""
|
||||
for message in reversed(request.messages):
|
||||
if message.role == "user":
|
||||
last_user_message = message.content
|
||||
break
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": last_user_message},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
usage={
|
||||
"prompt_tokens": sum(len(msg.content.split()) for msg in request.messages),
|
||||
"completion_tokens": len(last_user_message.split()),
|
||||
"total_tokens": sum(len(msg.content.split()) for msg in request.messages)
|
||||
+ len(last_user_message.split()),
|
||||
},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
def parse_query(query):
|
||||
"""Parse the user query and returns metadata extracted from query."""
|
||||
return Response(query=query, metadata={"is_valid": True})
|
||||
|
||||
|
||||
# Register MCP tool only if mcp is available
|
||||
try:
|
||||
from . import mcp
|
||||
|
||||
if mcp is not None:
|
||||
mcp.tool()(parse_query)
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
def start_server(host: str = "localhost", port: int = 8000):
|
||||
"""Start the REST server."""
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
140
demos/use_cases/rag_agent/src/rag_agent/query_rewriter_agent.py
Normal file
140
demos/use_cases/rag_agent/src/rag_agent/query_rewriter_agent.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
|
||||
from .api import ChatMessage, ChatCompletionRequest, ChatCompletionResponse
|
||||
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Configuration for archgw LLM gateway
|
||||
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
|
||||
QUERY_REWRITE_MODEL = "gpt-4o-mini"
|
||||
|
||||
# Initialize OpenAI client for archgw
|
||||
archgw_client = AsyncOpenAI(
|
||||
base_url=LLM_GATEWAY_ENDPOINT,
|
||||
api_key="EMPTY", # archgw doesn't require a real API key
|
||||
)
|
||||
|
||||
|
||||
async def rewrite_query_with_archgw(messages: List[ChatMessage]) -> str:
|
||||
# Prepare the system prompt for query rewriting
|
||||
system_prompt = """You are a query rewriter that improves user queries for better retrieval.
|
||||
|
||||
Given a conversation history, rewrite the last user message to be more specific and context-aware.
|
||||
The rewritten query should:
|
||||
1. Include relevant context from previous messages
|
||||
2. Be clear and specific for information retrieval
|
||||
3. Maintain the user's intent
|
||||
4. Be concise but comprehensive
|
||||
|
||||
Return only the rewritten query, nothing else."""
|
||||
|
||||
# Prepare messages for the query rewriter - just add system prompt to existing messages
|
||||
rewrite_messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add conversation history
|
||||
for msg in messages:
|
||||
rewrite_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
try:
|
||||
# Call archgw using OpenAI client
|
||||
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to rewrite query")
|
||||
response = await archgw_client.chat.completions.create(
|
||||
model=QUERY_REWRITE_MODEL,
|
||||
messages=rewrite_messages,
|
||||
temperature=0.3,
|
||||
max_tokens=200,
|
||||
)
|
||||
|
||||
rewritten_query = response.choices[0].message.content.strip()
|
||||
logger.info(f"Query rewritten successfully: '{rewritten_query}'")
|
||||
return rewritten_query
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error rewriting query: {e}")
|
||||
|
||||
# If rewriting fails, return the original last user message
|
||||
logger.info("Falling back to original user message")
|
||||
for message in reversed(messages):
|
||||
if message.role == "user":
|
||||
return message.content
|
||||
return ""
|
||||
|
||||
|
||||
class Response(BaseModel):
|
||||
query: str
|
||||
metadata: dict
|
||||
|
||||
|
||||
# FastAPI app for REST server
|
||||
app = FastAPI(title="RAG Agent Query Parser", version="1.0.0")
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest):
|
||||
"""Chat completions endpoint that rewrites the last user query using archgw."""
|
||||
import time
|
||||
import uuid
|
||||
|
||||
logger.info(
|
||||
f"Received chat completion request with {len(request.messages)} messages"
|
||||
)
|
||||
|
||||
# Call archgw to rewrite the last user query
|
||||
rewritten_query = await rewrite_query_with_archgw(request.messages)
|
||||
|
||||
# Create updated messages with the rewritten query
|
||||
updated_messages = request.messages.copy()
|
||||
|
||||
# Find and update the last user message with the rewritten query
|
||||
for i in range(len(updated_messages) - 1, -1, -1):
|
||||
if updated_messages[i].role == "user":
|
||||
original_query = updated_messages[i].content
|
||||
updated_messages[i] = ChatMessage(role="user", content=rewritten_query)
|
||||
logger.info(
|
||||
f"Updated user query from '{original_query}' to '{rewritten_query}'"
|
||||
)
|
||||
break
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"messages": [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in updated_messages
|
||||
],
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
usage={
|
||||
"prompt_tokens": sum(len(msg.content.split()) for msg in updated_messages),
|
||||
"completion_tokens": len("Updated query for better retrieval.".split()),
|
||||
"total_tokens": sum(len(msg.content.split()) for msg in updated_messages)
|
||||
+ len("Updated query for better retrieval.".split()),
|
||||
},
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
def parse_query(query):
|
||||
"""Parse the user query and returns metadata extracted from query."""
|
||||
return Response(query=query, metadata={"is_valid": True})
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
from . import mcp
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
def generate_response(query, context):
|
||||
"""Generate a response based on the user query and context."""
|
||||
return {
|
||||
"query": query,
|
||||
"context": context,
|
||||
"response": "This is a generated response.",
|
||||
}
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
from fastapi import FastAPI
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from .api import ChatCompletionRequest, ChatCompletionResponse
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration for archgw LLM gateway
|
||||
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
|
||||
RESPONSE_MODEL = "gpt-4o"
|
||||
|
||||
# Initialize OpenAI client for archgw
|
||||
archgw_client = AsyncOpenAI(
|
||||
base_url=LLM_GATEWAY_ENDPOINT,
|
||||
api_key="EMPTY", # archgw doesn't require a real API key
|
||||
)
|
||||
|
||||
# FastAPI app for REST server
|
||||
app = FastAPI(title="RAG Agent Response Generator", version="1.0.0")
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest):
|
||||
"""Chat completions endpoint that generates a coherent response based on all context."""
|
||||
logger.info(
|
||||
f"Received chat completion request with {len(request.messages)} messages"
|
||||
)
|
||||
|
||||
# Prepare the system prompt for response generation
|
||||
system_prompt = """You are a helpful assistant that generates coherent, contextual responses.
|
||||
|
||||
Given a conversation history, generate a helpful and relevant response based on all the context available in the messages.
|
||||
Your response should:
|
||||
1. Be contextually aware of the entire conversation
|
||||
2. Address the user's needs appropriately
|
||||
3. Be helpful and informative
|
||||
4. Maintain a natural conversational tone
|
||||
|
||||
Generate a complete response to assist the user."""
|
||||
|
||||
# Prepare messages for response generation
|
||||
response_messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add conversation history
|
||||
for msg in request.messages:
|
||||
response_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
try:
|
||||
# Call archgw using OpenAI client
|
||||
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate response")
|
||||
response = await archgw_client.chat.completions.create(
|
||||
model=RESPONSE_MODEL,
|
||||
messages=response_messages,
|
||||
temperature=request.temperature or 0.7,
|
||||
max_tokens=request.max_tokens or 1000,
|
||||
)
|
||||
|
||||
generated_response = response.choices[0].message.content.strip()
|
||||
logger.info(f"Response generated successfully")
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": generated_response},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
usage={
|
||||
"prompt_tokens": sum(
|
||||
len(msg.content.split()) for msg in request.messages
|
||||
),
|
||||
"completion_tokens": len(generated_response.split()),
|
||||
"total_tokens": sum(
|
||||
len(msg.content.split()) for msg in request.messages
|
||||
)
|
||||
+ len(generated_response.split()),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating response: {e}")
|
||||
|
||||
# Fallback response
|
||||
fallback_message = "I apologize, but I'm having trouble generating a response right now. Please try again."
|
||||
return ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": fallback_message},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
usage={
|
||||
"prompt_tokens": sum(
|
||||
len(msg.content.split()) for msg in request.messages
|
||||
),
|
||||
"completion_tokens": len(fallback_message.split()),
|
||||
"total_tokens": sum(
|
||||
len(msg.content.split()) for msg in request.messages
|
||||
)
|
||||
+ len(fallback_message.split()),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
Loading…
Add table
Add a link
Reference in a new issue