mirror of
https://github.com/katanemo/plano.git
synced 2026-05-11 00:32:42 +02:00
Musa/demo fix (#676)
* fix demo with travel agent * Update .gitignore * remove sse chunk rendering
This commit is contained in:
parent
745b36fdef
commit
b45c7aba86
3 changed files with 212 additions and 278 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -147,3 +147,5 @@ apps/*/dist/
|
||||||
.vercel
|
.vercel
|
||||||
|
|
||||||
*.logs
|
*.logs
|
||||||
|
|
||||||
|
.cursor/
|
||||||
|
|
|
||||||
|
|
@ -4,104 +4,86 @@ from fastapi.responses import StreamingResponse
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from opentelemetry.propagate import extract, inject
|
from opentelemetry.propagate import extract, inject
|
||||||
|
|
||||||
# Set up logging
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s - [FLIGHT_AGENT] - %(levelname)s - %(message)s",
|
format="%(asctime)s - [FLIGHT_AGENT] - %(levelname)s - %(message)s",
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Configuration
|
|
||||||
LLM_GATEWAY_ENDPOINT = os.getenv(
|
LLM_GATEWAY_ENDPOINT = os.getenv(
|
||||||
"LLM_GATEWAY_ENDPOINT", "http://host.docker.internal:12000/v1"
|
"LLM_GATEWAY_ENDPOINT", "http://host.docker.internal:12000/v1"
|
||||||
)
|
)
|
||||||
FLIGHT_MODEL = "openai/gpt-4o"
|
FLIGHT_MODEL = "openai/gpt-4o"
|
||||||
EXTRACTION_MODEL = "openai/gpt-4o-mini"
|
EXTRACTION_MODEL = "openai/gpt-4o-mini"
|
||||||
|
|
||||||
# FlightAware AeroAPI configuration
|
|
||||||
AEROAPI_BASE_URL = "https://aeroapi.flightaware.com/aeroapi"
|
AEROAPI_BASE_URL = "https://aeroapi.flightaware.com/aeroapi"
|
||||||
AEROAPI_KEY = os.getenv("AEROAPI_KEY")
|
AEROAPI_KEY = os.getenv("AEROAPI_KEY")
|
||||||
|
|
||||||
# HTTP client for API calls
|
|
||||||
http_client = httpx.AsyncClient(timeout=30.0)
|
http_client = httpx.AsyncClient(timeout=30.0)
|
||||||
|
openai_client = AsyncOpenAI(base_url=LLM_GATEWAY_ENDPOINT, api_key="EMPTY")
|
||||||
|
|
||||||
# Initialize OpenAI client
|
SYSTEM_PROMPT = """You are a travel planning assistant specializing in flight information. You support both direct flights AND multi-leg connecting flights.
|
||||||
openai_client_via_plano = AsyncOpenAI(
|
|
||||||
base_url=LLM_GATEWAY_ENDPOINT,
|
|
||||||
api_key="EMPTY",
|
|
||||||
)
|
|
||||||
|
|
||||||
# System prompt for flight agent
|
Flight data fields:
|
||||||
SYSTEM_PROMPT = """You are a travel planning assistant specializing in flight information in a multi-agent system. You will receive flight data in JSON format with these fields:
|
- airline: Full airline name (e.g., "Delta Air Lines")
|
||||||
|
- flight_number: Flight identifier (e.g., "DL123")
|
||||||
- "airline": Full airline name (e.g., "Delta Air Lines")
|
- departure_time/arrival_time: ISO 8601 timestamps
|
||||||
- "flight_number": Flight identifier (e.g., "DL123")
|
- origin/destination: Airport IATA codes
|
||||||
- "departure_time": ISO 8601 timestamp for scheduled departure (e.g., "2025-12-24T23:00:00Z")
|
- aircraft_type: Aircraft model code (e.g., "B739")
|
||||||
- "arrival_time": ISO 8601 timestamp for scheduled arrival (e.g., "2025-12-25T04:40:00Z")
|
- status: Flight status (e.g., "Scheduled", "Delayed")
|
||||||
- "origin": Origin airport IATA code (e.g., "ATL")
|
- terminal_origin/gate_origin: Departure terminal and gate (may be null)
|
||||||
- "destination": Destination airport IATA code (e.g., "SEA")
|
|
||||||
- "aircraft_type": Aircraft model code (e.g., "A21N", "B739")
|
|
||||||
- "status": Flight status (e.g., "Scheduled", "Delayed")
|
|
||||||
- "terminal_origin": Departure terminal (may be null)
|
|
||||||
- "gate_origin": Departure gate (may be null)
|
|
||||||
|
|
||||||
Your task:
|
Your task:
|
||||||
1. Read the JSON flight data carefully
|
1. Present flights clearly with airline, flight number, readable times, airports, and aircraft
|
||||||
2. Present each flight clearly with: airline, flight number, departure/arrival times (convert from ISO format to readable time), airports, and aircraft type
|
2. Organize chronologically by departure time
|
||||||
3. Organize flights chronologically by departure time
|
3. Convert ISO timestamps to readable format (e.g., "11:00 AM")
|
||||||
4. Convert ISO timestamps to readable format (e.g., "11:00 PM" or "23:00")
|
4. Include terminal/gate info when available
|
||||||
5. Include terminal/gate info when available
|
5. For multi-leg flights: present each leg separately with connection timing
|
||||||
6. Use natural, conversational language
|
|
||||||
|
|
||||||
Important: If the conversation includes information from other agents (like weather details), acknowledge and build upon that context naturally. Your primary focus is flights, but maintain awareness of the full conversation.
|
Multi-agent context: If the conversation includes information from other sources, incorporate it naturally into your response."""
|
||||||
|
|
||||||
Remember: All the data you need is in the JSON. Use it directly."""
|
ROUTE_EXTRACTION_PROMPT = """Extract flight route and travel date. Support direct AND multi-leg flights.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
1. Patterns: "flight from X to Y", "X to Y to Z", "fly from X through Y to Z"
|
||||||
|
2. For multi-leg (e.g., "Seattle to Dubai to Lahore"), extract ALL cities in order
|
||||||
|
3. Extract dates: "tomorrow", "next week", "December 25", "12/25", "on Monday"
|
||||||
|
4. Use conversation context for missing details
|
||||||
|
|
||||||
|
Output format: {"cities": ["City1", "City2", ...], "date": "YYYY-MM-DD" or null}
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- "Flight from Seattle to Atlanta tomorrow" → {"cities": ["Seattle", "Atlanta"], "date": "2026-01-07"}
|
||||||
|
- "Seattle to Dubai to Lahore" → {"cities": ["Seattle", "Dubai", "Lahore"], "date": null}
|
||||||
|
- "Flights from LA through Chicago to NYC" → {"cities": ["LA", "Chicago", "NYC"], "date": null}
|
||||||
|
|
||||||
|
Today is January 6, 2026. Extract flight route:"""
|
||||||
|
|
||||||
|
|
||||||
async def extract_flight_route(messages: list, request: Request) -> dict:
|
async def extract_flight_route(messages: list, request: Request) -> dict:
|
||||||
"""Extract origin, destination, and date from conversation using LLM."""
|
|
||||||
|
|
||||||
extraction_prompt = """Extract flight origin, destination cities, and travel date from the conversation.
|
|
||||||
|
|
||||||
Rules:
|
|
||||||
1. Look for patterns: "flight from X to Y", "flights to Y", "fly from X"
|
|
||||||
2. Extract dates like "tomorrow", "next week", "December 25", "12/25", "on Monday"
|
|
||||||
3. Use conversation context to fill in missing details
|
|
||||||
4. Return JSON: {"origin": "City" or null, "destination": "City" or null, "date": "YYYY-MM-DD" or null}
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
- "Flight from Seattle to Atlanta tomorrow" → {"origin": "Seattle", "destination": "Atlanta", "date": "2025-12-24"}
|
|
||||||
- "What flights go to New York?" → {"origin": null, "destination": "New York", "date": null}
|
|
||||||
- "Flights to Miami on Christmas" → {"origin": null, "destination": "Miami", "date": "2025-12-25"}
|
|
||||||
- "Show me flights from LA to NYC next Monday" → {"origin": "LA", "destination": "NYC", "date": "2025-12-30"}
|
|
||||||
|
|
||||||
Today is December 23, 2025. Extract flight route and date:"""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ctx = extract(request.headers)
|
ctx = extract(request.headers)
|
||||||
extra_headers = {}
|
extra_headers = {}
|
||||||
inject(extra_headers, context=ctx)
|
inject(extra_headers, context=ctx)
|
||||||
|
|
||||||
response = await openai_client_via_plano.chat.completions.create(
|
response = await openai_client.chat.completions.create(
|
||||||
model=EXTRACTION_MODEL,
|
model=EXTRACTION_MODEL,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": extraction_prompt},
|
{"role": "system", "content": ROUTE_EXTRACTION_PROMPT},
|
||||||
*[
|
*[
|
||||||
{"role": msg.get("role"), "content": msg.get("content")}
|
{"role": m.get("role"), "content": m.get("content")}
|
||||||
for msg in messages[-5:]
|
for m in messages[-5:]
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
extra_headers=extra_headers if extra_headers else None,
|
extra_headers=extra_headers or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = response.choices[0].message.content.strip()
|
result = response.choices[0].message.content.strip()
|
||||||
|
|
@ -111,18 +93,19 @@ async def extract_flight_route(messages: list, request: Request) -> dict:
|
||||||
result = result.split("```")[1].split("```")[0].strip()
|
result = result.split("```")[1].split("```")[0].strip()
|
||||||
|
|
||||||
route = json.loads(result)
|
route = json.loads(result)
|
||||||
return {
|
cities = route.get("cities", [])
|
||||||
"origin": route.get("origin"),
|
|
||||||
"destination": route.get("destination"),
|
if not cities and (route.get("origin") or route.get("destination")):
|
||||||
"date": route.get("date"),
|
cities = [c for c in [route.get("origin"), route.get("destination")] if c]
|
||||||
}
|
|
||||||
|
return {"cities": cities, "date": route.get("date")}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error extracting flight route: {e}")
|
logger.error(f"Error extracting flight route: {e}")
|
||||||
return {"origin": None, "destination": None, "date": None}
|
return {"cities": [], "date": None}
|
||||||
|
|
||||||
|
|
||||||
async def resolve_airport_code(city_name: str, request: Request) -> Optional[str]:
|
async def resolve_airport_code(city_name: str, request: Request) -> Optional[str]:
|
||||||
"""Convert city name to airport code using LLM."""
|
|
||||||
if not city_name:
|
if not city_name:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -131,64 +114,52 @@ async def resolve_airport_code(city_name: str, request: Request) -> Optional[str
|
||||||
extra_headers = {}
|
extra_headers = {}
|
||||||
inject(extra_headers, context=ctx)
|
inject(extra_headers, context=ctx)
|
||||||
|
|
||||||
response = await openai_client_via_plano.chat.completions.create(
|
response = await openai_client.chat.completions.create(
|
||||||
model=EXTRACTION_MODEL,
|
model=EXTRACTION_MODEL,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "Convert city names to primary airport IATA codes. Return only the 3-letter code. Examples: Seattle→SEA, Atlanta→ATL, New York→JFK, London→LHR",
|
"content": "Convert city names to primary airport IATA codes. Return only the 3-letter code. Examples: Seattle→SEA, Atlanta→ATL, New York→JFK, Dubai→DXB, Lahore→LHE",
|
||||||
},
|
},
|
||||||
{"role": "user", "content": city_name},
|
{"role": "user", "content": city_name},
|
||||||
],
|
],
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
extra_headers=extra_headers if extra_headers else None,
|
extra_headers=extra_headers or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
code = response.choices[0].message.content.strip().upper()
|
code = response.choices[0].message.content.strip().upper()
|
||||||
code = code.strip("\"'`.,!? \n\t")
|
code = code.strip("\"'`.,!? \n\t")
|
||||||
return code if len(code) == 3 else None
|
return code if len(code) == 3 else None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error resolving airport code for {city_name}: {e}")
|
logger.error(f"Error resolving airport code for {city_name}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_flights(
|
async def fetch_flights(
|
||||||
origin_code: str, dest_code: str, travel_date: Optional[str] = None
|
origin_code: str, dest_code: str, travel_date: Optional[str] = None
|
||||||
) -> Optional[dict]:
|
) -> dict:
|
||||||
"""Get flights between two airports using FlightAware API.
|
"""Fetch flights between two airports. Note: FlightAware limits to 2 days ahead."""
|
||||||
|
search_date = travel_date or datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
Args:
|
search_date_obj = datetime.strptime(search_date, "%Y-%m-%d")
|
||||||
origin_code: Origin airport IATA code
|
today = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
dest_code: Destination airport IATA code
|
days_ahead = (search_date_obj - today).days
|
||||||
travel_date: Travel date in YYYY-MM-DD format, defaults to today
|
|
||||||
|
if days_ahead > 2:
|
||||||
|
logger.warning(
|
||||||
|
f"Date {search_date} is {days_ahead} days ahead, exceeds FlightAware limit"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"origin_code": origin_code,
|
||||||
|
"destination_code": dest_code,
|
||||||
|
"flights": [],
|
||||||
|
"count": 0,
|
||||||
|
"error": f"FlightAware API only provides data up to 2 days ahead. Requested date ({search_date}) is {days_ahead} days away.",
|
||||||
|
}
|
||||||
|
|
||||||
Note: FlightAware API limits searches to 2 days in the future.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Use provided date or default to today
|
|
||||||
if travel_date:
|
|
||||||
search_date = travel_date
|
|
||||||
else:
|
|
||||||
search_date = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
# Validate date is not too far in the future (FlightAware limit: 2 days)
|
|
||||||
search_date_obj = datetime.strptime(search_date, "%Y-%m-%d")
|
|
||||||
today = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
|
||||||
days_ahead = (search_date_obj - today).days
|
|
||||||
|
|
||||||
if days_ahead > 2:
|
|
||||||
logger.warning(
|
|
||||||
f"Requested date {search_date} is {days_ahead} days ahead, exceeds FlightAware 2-day limit"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"origin_code": origin_code,
|
|
||||||
"destination_code": dest_code,
|
|
||||||
"flights": [],
|
|
||||||
"count": 0,
|
|
||||||
"error": f"FlightAware API only provides flight data up to 2 days in the future. The requested date ({search_date}) is {days_ahead} days ahead. Please search for today, tomorrow, or the day after.",
|
|
||||||
}
|
|
||||||
|
|
||||||
url = f"{AEROAPI_BASE_URL}/airports/{origin_code}/flights/to/{dest_code}"
|
url = f"{AEROAPI_BASE_URL}/airports/{origin_code}/flights/to/{dest_code}"
|
||||||
headers = {"x-apikey": AEROAPI_KEY}
|
headers = {"x-apikey": AEROAPI_KEY}
|
||||||
params = {
|
params = {
|
||||||
|
|
@ -204,43 +175,34 @@ async def get_flights(
|
||||||
logger.error(
|
logger.error(
|
||||||
f"FlightAware API error {response.status_code}: {response.text}"
|
f"FlightAware API error {response.status_code}: {response.text}"
|
||||||
)
|
)
|
||||||
return None
|
return {
|
||||||
|
"origin_code": origin_code,
|
||||||
|
"destination_code": dest_code,
|
||||||
|
"flights": [],
|
||||||
|
"count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
data = response.json()
|
data = response.json()
|
||||||
flights = []
|
flights = []
|
||||||
|
|
||||||
# Log raw API response for debugging
|
for flight_group in data.get("flights", [])[:5]:
|
||||||
logger.info(f"FlightAware API returned {len(data.get('flights', []))} flights")
|
|
||||||
|
|
||||||
for idx, flight_group in enumerate(
|
|
||||||
data.get("flights", [])[:5]
|
|
||||||
): # Limit to 5 flights
|
|
||||||
# FlightAware API nests data in segments array
|
|
||||||
segments = flight_group.get("segments", [])
|
segments = flight_group.get("segments", [])
|
||||||
if not segments:
|
if not segments:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
flight = segments[0] # Get first segment (direct flights only have one)
|
flight = segments[0]
|
||||||
|
|
||||||
# Extract airport codes from nested objects
|
|
||||||
flight_origin = None
|
|
||||||
flight_dest = None
|
|
||||||
|
|
||||||
if isinstance(flight.get("origin"), dict):
|
|
||||||
flight_origin = flight["origin"].get("code_iata")
|
|
||||||
|
|
||||||
if isinstance(flight.get("destination"), dict):
|
|
||||||
flight_dest = flight["destination"].get("code_iata")
|
|
||||||
|
|
||||||
# Build flight object
|
|
||||||
flights.append(
|
flights.append(
|
||||||
{
|
{
|
||||||
"airline": flight.get("operator"),
|
"airline": flight.get("operator"),
|
||||||
"flight_number": flight.get("ident_iata") or flight.get("ident"),
|
"flight_number": flight.get("ident_iata") or flight.get("ident"),
|
||||||
"departure_time": flight.get("scheduled_out"),
|
"departure_time": flight.get("scheduled_out"),
|
||||||
"arrival_time": flight.get("scheduled_in"),
|
"arrival_time": flight.get("scheduled_in"),
|
||||||
"origin": flight_origin,
|
"origin": flight["origin"].get("code_iata")
|
||||||
"destination": flight_dest,
|
if isinstance(flight.get("origin"), dict)
|
||||||
|
else None,
|
||||||
|
"destination": flight["destination"].get("code_iata")
|
||||||
|
if isinstance(flight.get("destination"), dict)
|
||||||
|
else None,
|
||||||
"aircraft_type": flight.get("aircraft_type"),
|
"aircraft_type": flight.get("aircraft_type"),
|
||||||
"status": flight.get("status"),
|
"status": flight.get("status"),
|
||||||
"terminal_origin": flight.get("terminal_origin"),
|
"terminal_origin": flight.get("terminal_origin"),
|
||||||
|
|
@ -248,15 +210,67 @@ async def get_flights(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"Found {len(flights)} flights from {origin_code} to {dest_code}")
|
||||||
return {
|
return {
|
||||||
"origin_code": origin_code,
|
"origin_code": origin_code,
|
||||||
"destination_code": dest_code,
|
"destination_code": dest_code,
|
||||||
"flights": flights,
|
"flights": flights,
|
||||||
"count": len(flights),
|
"count": len(flights),
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching flights: {e}")
|
logger.error(f"Error fetching flights: {e}")
|
||||||
return None
|
return {
|
||||||
|
"origin_code": origin_code,
|
||||||
|
"destination_code": dest_code,
|
||||||
|
"flights": [],
|
||||||
|
"count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_flight_context(cities: list, airport_codes: list, legs_data: list) -> str:
|
||||||
|
if len(cities) == 2:
|
||||||
|
leg = legs_data[0]
|
||||||
|
flight_data = {
|
||||||
|
"flights": leg["flights"],
|
||||||
|
"count": len(leg["flights"]),
|
||||||
|
"origin_code": leg["origin_code"],
|
||||||
|
"destination_code": leg["dest_code"],
|
||||||
|
}
|
||||||
|
if leg["flights"]:
|
||||||
|
return f"""
|
||||||
|
Flight search results from {leg['origin']} ({leg['origin_code']}) to {leg['destination']} ({leg['dest_code']}):
|
||||||
|
|
||||||
|
Flight data in JSON format:
|
||||||
|
{json.dumps(flight_data, indent=2)}
|
||||||
|
|
||||||
|
Present these {len(leg['flights'])} flight(s) to the user clearly."""
|
||||||
|
else:
|
||||||
|
error = leg.get("error") or "No direct flights found"
|
||||||
|
return f"""
|
||||||
|
Flight search from {leg['origin']} ({leg['origin_code']}) to {leg['destination']} ({leg['dest_code']}):
|
||||||
|
|
||||||
|
Result: {error}
|
||||||
|
|
||||||
|
Let the user know and suggest alternatives if appropriate."""
|
||||||
|
|
||||||
|
route_str = " → ".join(
|
||||||
|
[f"{city} ({code})" for city, code in zip(cities, airport_codes)]
|
||||||
|
)
|
||||||
|
context = f"\nMulti-leg flight search: {route_str}\n\n"
|
||||||
|
|
||||||
|
for leg in legs_data:
|
||||||
|
context += f"**Leg {leg['leg']}: {leg['origin']} ({leg['origin_code']}) → {leg['destination']} ({leg['dest_code']})**\n"
|
||||||
|
if leg["flights"]:
|
||||||
|
leg_data = {"flights": leg["flights"], "count": len(leg["flights"])}
|
||||||
|
context += f"Flight data:\n{json.dumps(leg_data, indent=2)}\n\n"
|
||||||
|
elif leg.get("error"):
|
||||||
|
context += f"Error: {leg['error']}\n\n"
|
||||||
|
else:
|
||||||
|
context += "No direct flights found for this leg.\n\n"
|
||||||
|
|
||||||
|
context += "Present this itinerary clearly. For each leg, show available flights by departure time. Note connection timing between legs."
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Flight Information Agent", version="1.0.0")
|
app = FastAPI(title="Flight Information Agent", version="1.0.0")
|
||||||
|
|
@ -264,143 +278,80 @@ app = FastAPI(title="Flight Information Agent", version="1.0.0")
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def handle_request(request: Request):
|
async def handle_request(request: Request):
|
||||||
"""HTTP endpoint for chat completions with streaming support."""
|
|
||||||
request_body = await request.json()
|
request_body = await request.json()
|
||||||
messages = request_body.get("messages", [])
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
invoke_flight_agent(request, request_body),
|
invoke_flight_agent(request, request_body),
|
||||||
media_type="text/plain",
|
media_type="text/event-stream",
|
||||||
headers={"content-type": "text/event-stream"},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def invoke_flight_agent(request: Request, request_body: dict):
|
async def invoke_flight_agent(request: Request, request_body: dict):
|
||||||
"""Generate streaming chat completions."""
|
|
||||||
messages = request_body.get("messages", [])
|
messages = request_body.get("messages", [])
|
||||||
|
|
||||||
# Step 1: Extract origin, destination, and date
|
|
||||||
route = await extract_flight_route(messages, request)
|
route = await extract_flight_route(messages, request)
|
||||||
origin = route.get("origin")
|
cities = route.get("cities", [])
|
||||||
destination = route.get("destination")
|
|
||||||
travel_date = route.get("date")
|
travel_date = route.get("date")
|
||||||
|
|
||||||
# Step 2: Short circuit if missing origin or destination
|
# Build context based on what we could extract
|
||||||
if not origin or not destination:
|
if len(cities) < 2:
|
||||||
missing = []
|
flight_context = """
|
||||||
if not origin:
|
Could not extract a complete flight route from the user's request.
|
||||||
missing.append("origin city")
|
|
||||||
if not destination:
|
|
||||||
missing.append("destination city")
|
|
||||||
|
|
||||||
error_message = f"I need both origin and destination cities to search for flights. Please provide the {' and '.join(missing)}. For example: 'Flights from Seattle to Atlanta'"
|
Ask the user to provide both origin and destination cities.
|
||||||
|
Example: 'Flights from Seattle to Atlanta' or 'Seattle to Dubai to Lahore'"""
|
||||||
|
airport_codes = []
|
||||||
|
legs_data = []
|
||||||
|
else:
|
||||||
|
airport_codes = []
|
||||||
|
failed_city = None
|
||||||
|
for city in cities:
|
||||||
|
code = await resolve_airport_code(city, request)
|
||||||
|
if not code:
|
||||||
|
failed_city = city
|
||||||
|
break
|
||||||
|
airport_codes.append(code)
|
||||||
|
|
||||||
error_chunk = {
|
if failed_city:
|
||||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
flight_context = f"""
|
||||||
"object": "chat.completion.chunk",
|
Could not find airport code for "{failed_city}".
|
||||||
"created": int(time.time()),
|
|
||||||
"model": request_body.get("model", FLIGHT_MODEL),
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"content": error_message},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(error_chunk)}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
# Step 3: Resolve airport codes
|
Ask the user to check the city name or provide a different city."""
|
||||||
origin_code = await resolve_airport_code(origin, request)
|
legs_data = []
|
||||||
dest_code = await resolve_airport_code(destination, request)
|
|
||||||
|
|
||||||
if not origin_code or not dest_code:
|
|
||||||
error_chunk = {
|
|
||||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": request_body.get("model", FLIGHT_MODEL),
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"content": f"I couldn't find airport codes for {origin if not origin_code else destination}. Please check the city name."
|
|
||||||
},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(error_chunk)}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
# Step 4: Get live flight data
|
|
||||||
flight_data = await get_flights(origin_code, dest_code, travel_date)
|
|
||||||
|
|
||||||
# Determine date display for messages
|
|
||||||
date_display = travel_date if travel_date else "today"
|
|
||||||
|
|
||||||
if not flight_data or not flight_data.get("flights"):
|
|
||||||
# Check if there's a specific error message (e.g., date too far in future)
|
|
||||||
error_detail = flight_data.get("error") if flight_data else None
|
|
||||||
if error_detail:
|
|
||||||
no_flights_message = error_detail
|
|
||||||
else:
|
else:
|
||||||
no_flights_message = f"No direct flights found from {origin} ({origin_code}) to {destination} ({dest_code}) for {date_display}."
|
legs_data = []
|
||||||
|
for i in range(len(cities) - 1):
|
||||||
|
flight_data = await fetch_flights(
|
||||||
|
airport_codes[i], airport_codes[i + 1], travel_date
|
||||||
|
)
|
||||||
|
legs_data.append(
|
||||||
|
{
|
||||||
|
"leg": i + 1,
|
||||||
|
"origin": cities[i],
|
||||||
|
"origin_code": airport_codes[i],
|
||||||
|
"destination": cities[i + 1],
|
||||||
|
"dest_code": airport_codes[i + 1],
|
||||||
|
"flights": flight_data.get("flights", []),
|
||||||
|
"error": flight_data.get("error"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
error_chunk = {
|
flight_context = build_flight_context(cities, airport_codes, legs_data)
|
||||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": request_body.get("model", FLIGHT_MODEL),
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"content": no_flights_message},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(error_chunk)}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
# Step 5: Prepare context for LLM - append flight data to last user message
|
|
||||||
flight_context = f"""
|
|
||||||
|
|
||||||
Flight search results from {origin} ({origin_code}) to {destination} ({dest_code}):
|
|
||||||
|
|
||||||
Flight data in JSON format:
|
|
||||||
{json.dumps(flight_data, indent=2)}
|
|
||||||
|
|
||||||
Present these {len(flight_data.get('flights', []))} flight(s) to the user in a clear, readable format."""
|
|
||||||
|
|
||||||
# Build message history with flight data appended to the last user message
|
|
||||||
response_messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
response_messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||||
|
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
# Append flight data to the last user message
|
content = msg.get("content", "")
|
||||||
if i == len(messages) - 1 and msg.get("role") == "user":
|
if i == len(messages) - 1 and msg.get("role") == "user":
|
||||||
response_messages.append(
|
content += flight_context
|
||||||
{"role": "user", "content": msg.get("content") + flight_context}
|
response_messages.append({"role": msg.get("role"), "content": content})
|
||||||
)
|
|
||||||
else:
|
|
||||||
response_messages.append(
|
|
||||||
{"role": msg.get("role"), "content": msg.get("content")}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log what we're sending to the LLM for debugging
|
logger.info(f"Sending {len(response_messages)} messages to LLM")
|
||||||
logger.info(f"Sending messages to LLM: {json.dumps(response_messages, indent=2)}")
|
|
||||||
|
|
||||||
# Step 6: Stream response
|
|
||||||
try:
|
try:
|
||||||
ctx = extract(request.headers)
|
ctx = extract(request.headers)
|
||||||
extra_headers = {"x-envoy-max-retries": "3"}
|
extra_headers = {"x-envoy-max-retries": "3"}
|
||||||
inject(extra_headers, context=ctx)
|
inject(extra_headers, context=ctx)
|
||||||
|
|
||||||
stream = await openai_client_via_plano.chat.completions.create(
|
stream = await openai_client.chat.completions.create(
|
||||||
model=FLIGHT_MODEL,
|
model=FLIGHT_MODEL,
|
||||||
messages=response_messages,
|
messages=response_messages,
|
||||||
temperature=request_body.get("temperature", 0.7),
|
temperature=request_body.get("temperature", 0.7),
|
||||||
|
|
@ -416,34 +367,16 @@ Present these {len(flight_data.get('flights', []))} flight(s) to the user in a c
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating flight response: {e}")
|
logger.error(f"Error generating response: {e}")
|
||||||
error_chunk = {
|
|
||||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": request_body.get("model", FLIGHT_MODEL),
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"content": "I apologize, but I'm having trouble retrieving flight information right now. Please try again."
|
|
||||||
},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(error_chunk)}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""Health check endpoint."""
|
|
||||||
return {"status": "healthy", "agent": "flight_information"}
|
return {"status": "healthy", "agent": "flight_information"}
|
||||||
|
|
||||||
|
|
||||||
def start_server(host: str = "localhost", port: int = 10520):
|
def start_server(host: str = "0.0.0.0", port: int = 10520):
|
||||||
"""Start the REST server."""
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
host=host,
|
host=host,
|
||||||
|
|
@ -453,23 +386,20 @@ def start_server(host: str = "localhost", port: int = 10520):
|
||||||
"disable_existing_loggers": False,
|
"disable_existing_loggers": False,
|
||||||
"formatters": {
|
"formatters": {
|
||||||
"default": {
|
"default": {
|
||||||
"format": "%(asctime)s - [FLIGHT_AGENT] - %(levelname)s - %(message)s",
|
"format": "%(asctime)s - [FLIGHT_AGENT] - %(levelname)s - %(message)s"
|
||||||
},
|
}
|
||||||
},
|
},
|
||||||
"handlers": {
|
"handlers": {
|
||||||
"default": {
|
"default": {
|
||||||
"formatter": "default",
|
"formatter": "default",
|
||||||
"class": "logging.StreamHandler",
|
"class": "logging.StreamHandler",
|
||||||
"stream": "ext://sys.stdout",
|
"stream": "ext://sys.stdout",
|
||||||
},
|
}
|
||||||
},
|
|
||||||
"root": {
|
|
||||||
"level": "INFO",
|
|
||||||
"handlers": ["default"],
|
|
||||||
},
|
},
|
||||||
|
"root": {"level": "INFO", "handlers": ["default"]},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
start_server(host="0.0.0.0", port=10520)
|
start_server()
|
||||||
|
|
|
||||||
|
|
@ -70,26 +70,22 @@ async def get_weather_data(request: Request, messages: list, days: int = 1):
|
||||||
|
|
||||||
Currently returns only current day weather. Want to add multi-day forecasts?
|
Currently returns only current day weather. Want to add multi-day forecasts?
|
||||||
"""
|
"""
|
||||||
|
instructions = """You are a city name extractor. Look at the FINAL user message ONLY and extract the city name.
|
||||||
|
|
||||||
instructions = """Extract the location for WEATHER queries. Return just the city name.
|
The FINAL user message will be the LAST message with role "user" in the conversation.
|
||||||
|
|
||||||
Rules:
|
IMPORTANT: Ignore all previous messages. Focus ONLY on the FINAL user message.
|
||||||
1. For multi-part queries, extract ONLY the location mentioned with weather keywords ("weather in [location]")
|
|
||||||
2. If user says "there" or "that city", it typically refers to the DESTINATION city in travel contexts (not the origin)
|
|
||||||
3. For flight queries with weather, "there" means the destination city where they're traveling TO
|
|
||||||
4. Return plain text (e.g., "London", "New York", "Paris, France")
|
|
||||||
5. If no weather location found, return "NOT_FOUND"
|
|
||||||
|
|
||||||
Examples:
|
Examples of what to extract from the FINAL user message:
|
||||||
- "What's the weather in London?" → "London"
|
- "What's the weather in Seattle?" → Seattle
|
||||||
- "Flights from Seattle to Atlanta, and show me the weather there" → "Atlanta"
|
- "What's the weather in San Francisco?" → San Francisco
|
||||||
- "Can you get me flights from Seattle to Atlanta tomorrow, and also please show me the weather there" → "Atlanta"
|
- "What about Dubai?" → Dubai
|
||||||
- "What's the weather in Seattle, and what is one flight that goes direct to Atlanta?" → "Seattle"
|
- "How's the weather in Tokyo today?" → Tokyo
|
||||||
- User asked about flights to Atlanta, then "what's the weather like there?" → "Atlanta"
|
- "Tell me about Lahore" → Lahore
|
||||||
- "I'm going to Seattle" → "Seattle"
|
- "What about there?" → Look at conversation for the last mentioned city
|
||||||
- "What's happening?" → "NOT_FOUND"
|
|
||||||
|
|
||||||
Extract location:"""
|
Output ONLY the city name. Nothing else. One word or city name only.
|
||||||
|
If no city can be found, output: NOT_FOUND"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user_messages = [
|
user_messages = [
|
||||||
|
|
@ -114,7 +110,7 @@ async def get_weather_data(request: Request, messages: list, days: int = 1):
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
max_tokens=50,
|
max_tokens=10,
|
||||||
extra_headers=extra_headers if extra_headers else None,
|
extra_headers=extra_headers if extra_headers else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -265,12 +261,16 @@ async def handle_request(request: Request):
|
||||||
|
|
||||||
request_body = await request.json()
|
request_body = await request.json()
|
||||||
messages = request_body.get("messages", [])
|
messages = request_body.get("messages", [])
|
||||||
|
# Respect the stream parameter - orchestrator controls this based on agent position in chain
|
||||||
|
is_streaming = request_body.get("stream", True)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"messages detail json dumps: %s",
|
"messages detail json dumps: %s",
|
||||||
json.dumps(messages, indent=2),
|
json.dumps(messages, indent=2),
|
||||||
)
|
)
|
||||||
|
|
||||||
traceparent_header = request.headers.get("traceparent")
|
traceparent_header = request.headers.get("traceparent")
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
invoke_weather_agent(request, request_body, traceparent_header),
|
invoke_weather_agent(request, request_body, traceparent_header),
|
||||||
media_type="text/plain",
|
media_type="text/plain",
|
||||||
|
|
@ -311,7 +311,9 @@ async def invoke_weather_agent(
|
||||||
weather_context = f"""
|
weather_context = f"""
|
||||||
|
|
||||||
Weather data for {weather_data['location']} ({forecast_type}):
|
Weather data for {weather_data['location']} ({forecast_type}):
|
||||||
{json.dumps(weather_data, indent=2)}"""
|
{json.dumps(weather_data, indent=2)}
|
||||||
|
|
||||||
|
Present the weather information to the user in a clear, readable format. If there is information from other agents, start your response with a summary of that information."""
|
||||||
|
|
||||||
# System prompt for weather agent
|
# System prompt for weather agent
|
||||||
instructions = """You are a weather assistant in a multi-agent system. You will receive weather data in JSON format with these fields:
|
instructions = """You are a weather assistant in a multi-agent system. You will receive weather data in JSON format with these fields:
|
||||||
|
|
@ -328,7 +330,7 @@ Weather data for {weather_data['location']} ({forecast_type}):
|
||||||
5. Describe conditions naturally based on weather_code
|
5. Describe conditions naturally based on weather_code
|
||||||
6. Use conversational language
|
6. Use conversational language
|
||||||
|
|
||||||
Important: If the conversation includes information from other agents (like flight details), acknowledge and build upon that context naturally. Your primary focus is weather, but maintain awareness of the full conversation.
|
Multi-agent context: You are part of a larger system. If the conversation includes additional context or information from other sources, acknowledge and incorporate it naturally into your response. Your primary focus is weather, but be aware of the full conversation context.
|
||||||
|
|
||||||
Remember: Only use the provided data. If fields are null, mention data is unavailable."""
|
Remember: Only use the provided data. If fields are null, mention data is unavailable."""
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue