plano/demos/use_cases/orchestrating_agents/main.py

115 lines
3.2 KiB
Python

import logging
import json
from typing import List, Dict, Any
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import openai
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("uvicorn.error")
app = FastAPI()
class Message(BaseModel):
role: str
content: str
class ChatCompletionsRequest(BaseModel):
messages: List[Message]
model: str
metadata: Dict[str, Any] = {}
stream: bool = False
openai_client = openai.OpenAI(
api_key="None", # archgw picks the API key from the config file
base_url="http://host.docker.internal:12000/v1",
)
def call_openai(messages: List[Dict[str, str]], stream: bool, model: str):
logger.info(f"llm agent model: {model}")
completion = openai_client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
)
if stream:
def stream():
for line in completion:
if line.choices and len(line.choices) > 0 and line.choices[0].delta:
chunk_response_str = json.dumps(line.model_dump())
yield "data: " + chunk_response_str + "\n\n"
yield "data: [DONE]" + "\n\n"
return StreamingResponse(stream(), media_type="text/event-stream")
else:
return completion
class Agent:
def __init__(self, role: str, instructions: str, model: str = ""):
self.model = model
self.system_prompt = f"You are a {role}.\n{instructions}"
def handle(self, req: ChatCompletionsRequest):
messages = [{"role": "system", "content": self.get_system_prompt()}] + [
message.model_dump() for message in req.messages
]
model = req.model
if self.model:
model = self.model
return call_openai(messages, req.stream, model)
def get_system_prompt(self) -> str:
return self.system_prompt
# Define your agents
AGENTS = {
"sales_agent": Agent(
role="sales agent",
instructions=(
"Always answer in a sentence or less.\n"
"Follow the following routine with the user:\n"
"1. Engage\n"
"2. Quote ridiculous price\n"
"3. Reveal caveat if user agrees."
),
model="gpt-4o-mini",
),
"issues_and_repairs": Agent(
role="issues and repairs agent",
instructions="Propose a solution, offer refund if necessary.",
model="gpt-4o",
),
"escalate_to_human": Agent(
role="human escalation agent",
instructions="Escalate issues to a human.",
# skipping model name here as arch gateway will pick the default model from the config file
),
"unknown_agent": Agent(
role="general assistant", instructions="Assist the user in general queries."
),
}
@app.post("/v1/chat/completions")
def completion_api(req: ChatCompletionsRequest, request: Request):
agent_name = req.metadata.get("agent-name", "unknown_agent")
agent = AGENTS.get(agent_name)
logger.info(f"Routing to agent: {agent_name}")
return agent.handle(req)
@app.get("/healthz")
async def healthz():
return {"status": "ok"}