diff --git a/demos/use_cases/travel_agents/src/travel_agents/flight_agent.py b/demos/use_cases/travel_agents/src/travel_agents/flight_agent.py index 50b603f9..7c962481 100644 --- a/demos/use_cases/travel_agents/src/travel_agents/flight_agent.py +++ b/demos/use_cases/travel_agents/src/travel_agents/flight_agent.py @@ -4,8 +4,6 @@ from fastapi.responses import StreamingResponse from openai import AsyncOpenAI import os import logging -import time -import uuid import uvicorn from datetime import datetime import httpx @@ -68,19 +66,6 @@ Examples: Today is January 6, 2026. Extract flight route:""" -def create_sse_chunk(message: str, model: str) -> str: - chunk = { - "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [ - {"index": 0, "delta": {"content": message}, "finish_reason": "stop"} - ], - } - return f"data: {json.dumps(chunk)}\n\n" - - async def extract_flight_route(messages: list, request: Request) -> dict: try: ctx = extract(request.headers) @@ -252,13 +237,22 @@ def build_flight_context(cities: list, airport_codes: list, legs_data: list) -> "origin_code": leg["origin_code"], "destination_code": leg["dest_code"], } - return f""" + if leg["flights"]: + return f""" Flight search results from {leg['origin']} ({leg['origin_code']}) to {leg['destination']} ({leg['dest_code']}): Flight data in JSON format: {json.dumps(flight_data, indent=2)} Present these {len(leg['flights'])} flight(s) to the user clearly.""" + else: + error = leg.get("error") or "No direct flights found" + return f""" +Flight search from {leg['origin']} ({leg['origin_code']}) to {leg['destination']} ({leg['dest_code']}): + +Result: {error} + +Let the user know and suggest alternatives if appropriate.""" route_str = " → ".join( [f"{city} ({code})" for city, code in zip(cities, airport_codes)] @@ -293,72 +287,55 @@ async def handle_request(request: Request): async def invoke_flight_agent(request: Request, request_body: dict): messages = request_body.get("messages", []) - model = request_body.get("model", FLIGHT_MODEL) route = await extract_flight_route(messages, request) cities = route.get("cities", []) travel_date = route.get("date") - date_display = travel_date or "today" + # Build context based on what we could extract if len(cities) < 2: - yield create_sse_chunk( - "I need at least an origin and destination to search for flights. " - "Example: 'Flights from Seattle to Atlanta' or 'Seattle to Dubai to Lahore'", - model, - ) - yield "data: [DONE]\n\n" - return + flight_context = """ +Could not extract a complete flight route from the user's request. - airport_codes = [] - for city in cities: - code = await resolve_airport_code(city, request) - if not code: - yield create_sse_chunk( - f"Couldn't find airport code for {city}. Please check the city name.", - model, - ) - yield "data: [DONE]\n\n" - return - airport_codes.append(code) +Ask the user to provide both origin and destination cities. +Example: 'Flights from Seattle to Atlanta' or 'Seattle to Dubai to Lahore'""" + airport_codes = [] + legs_data = [] + else: + airport_codes = [] + failed_city = None + for city in cities: + code = await resolve_airport_code(city, request) + if not code: + failed_city = city + break + airport_codes.append(code) - legs_data = [] - for i in range(len(cities) - 1): - flight_data = await fetch_flights( - airport_codes[i], airport_codes[i + 1], travel_date - ) - legs_data.append( - { - "leg": i + 1, - "origin": cities[i], - "origin_code": airport_codes[i], - "destination": cities[i + 1], - "dest_code": airport_codes[i + 1], - "flights": flight_data.get("flights", []), - "error": flight_data.get("error"), - } - ) + if failed_city: + flight_context = f""" +Could not find airport code for "{failed_city}". - has_any_flights = any(leg["flights"] for leg in legs_data) - - if not has_any_flights: - if len(legs_data) > 1: - parts = [ - f"Leg {leg['leg']} ({leg['origin']} → {leg['destination']}): {leg.get('error') or 'No direct flights'}" - for leg in legs_data - ] - message = "Multi-leg flight search results:\n\n" + "\n\n".join(parts) +Ask the user to check the city name or provide a different city.""" + legs_data = [] else: - leg = legs_data[0] - message = ( - leg.get("error") - or f"No direct flights from {leg['origin']} ({leg['origin_code']}) to {leg['destination']} ({leg['dest_code']}) for {date_display}." - ) + legs_data = [] + for i in range(len(cities) - 1): + flight_data = await fetch_flights( + airport_codes[i], airport_codes[i + 1], travel_date + ) + legs_data.append( + { + "leg": i + 1, + "origin": cities[i], + "origin_code": airport_codes[i], + "destination": cities[i + 1], + "dest_code": airport_codes[i + 1], + "flights": flight_data.get("flights", []), + "error": flight_data.get("error"), + } + ) - yield create_sse_chunk(message, model) - yield "data: [DONE]\n\n" - return - - flight_context = build_flight_context(cities, airport_codes, legs_data) + flight_context = build_flight_context(cities, airport_codes, legs_data) response_messages = [{"role": "system", "content": SYSTEM_PROMPT}] for i, msg in enumerate(messages): @@ -391,9 +368,6 @@ async def invoke_flight_agent(request: Request, request_body: dict): except Exception as e: logger.error(f"Error generating response: {e}") - yield create_sse_chunk( - "I'm having trouble retrieving flight information. Please try again.", model - ) yield "data: [DONE]\n\n"