From e79e1187b2593b1225200b90cfb1dc9662700386 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sun, 21 Dec 2025 19:35:00 +0530 Subject: [PATCH] feat: implement background podcast generation with Celery and task polling in UI --- .../app/agents/new_chat/podcast.py | 117 +++---------- .../app/routes/podcasts_routes.py | 63 +++++++ .../app/tasks/celery_tasks/podcast_tasks.py | 133 +++++++++++++++ .../components/tool-ui/generate-podcast.tsx | 156 +++++++++++++----- 4 files changed, 335 insertions(+), 134 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/podcast.py b/surfsense_backend/app/agents/new_chat/podcast.py index ed4116bfb..2205227b1 100644 --- a/surfsense_backend/app/agents/new_chat/podcast.py +++ b/surfsense_backend/app/agents/new_chat/podcast.py @@ -2,8 +2,8 @@ Podcast generation tool for the new chat agent. This module provides a factory function for creating the generate_podcast tool -that integrates with the existing podcaster agent. Podcasts are saved to the -database like the old system, providing authentication and persistence. +that submits a Celery task for background podcast generation. The frontend +polls for completion and auto-updates when the podcast is ready. """ from typing import Any @@ -11,10 +11,6 @@ from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession -from app.agents.podcaster.graph import graph as podcaster_graph -from app.agents.podcaster.state import State as PodcasterState -from app.db import Podcast - def create_generate_podcast_tool( search_space_id: int, @@ -26,7 +22,7 @@ def create_generate_podcast_tool( Args: search_space_id: The user's search space ID - db_session: Database session + db_session: Database session (not used - Celery creates its own) user_id: The user's ID (as string) Returns: @@ -50,8 +46,8 @@ def create_generate_podcast_tool( - "Make a podcast about..." - "Turn this into a podcast" - The tool will generate a complete audio podcast with two speakers - discussing the provided content in an engaging conversational format. + The tool will start generating a podcast in the background. + The podcast will be available once generation completes. Args: source_content: The text content to convert into a podcast. @@ -63,108 +59,43 @@ def create_generate_podcast_tool( Returns: A dictionary containing: - - status: "success" or "error" - - podcast_id: The database ID of the saved podcast (for API access) + - status: "processing" (task submitted) or "error" + - task_id: The Celery task ID for polling status - title: The podcast title - - transcript: Full podcast transcript with all dialogue entries - - duration_ms: Estimated podcast duration in milliseconds - - transcript_entries: Number of dialogue entries """ try: - # Configure the podcaster graph - config = { - "configurable": { - "podcast_title": podcast_title, - "user_id": str(user_id), - "search_space_id": search_space_id, - "user_prompt": user_prompt, - } - } + # Import Celery task here to avoid circular imports + from app.tasks.celery_tasks.podcast_tasks import ( + generate_content_podcast_task, + ) - # Initialize the podcaster state with the source content - initial_state = PodcasterState( + # Submit Celery task for background processing + task = generate_content_podcast_task.delay( source_content=source_content, - db_session=db_session, - ) - - # Run the podcaster graph - result = await podcaster_graph.ainvoke(initial_state, config=config) - - # Extract results - podcast_transcript = result.get("podcast_transcript", []) - file_path = result.get("final_podcast_file_path", "") - - # Calculate estimated duration (rough estimate: ~150 words per minute) - total_words = sum( - len(entry.dialog.split()) if hasattr(entry, "dialog") else len(entry.get("dialog", "").split()) - for entry in podcast_transcript - ) - estimated_duration_ms = int((total_words / 150) * 60 * 1000) - - # Create full transcript for display (all entries, complete dialog) - full_transcript = [] - for entry in podcast_transcript: - if hasattr(entry, "speaker_id"): - speaker = f"Speaker {entry.speaker_id + 1}" - dialog = entry.dialog - else: - speaker = f"Speaker {entry.get('speaker_id', 0) + 1}" - dialog = entry.get("dialog", "") - full_transcript.append(f"{speaker}: {dialog}") - - # Convert podcast transcript entries to serializable format (like old system) - serializable_transcript = [] - for entry in podcast_transcript: - if hasattr(entry, "speaker_id"): - serializable_transcript.append({ - "speaker_id": entry.speaker_id, - "dialog": entry.dialog - }) - else: - serializable_transcript.append({ - "speaker_id": entry.get("speaker_id", 0), - "dialog": entry.get("dialog", "") - }) - - # Save podcast to database (like old system) - # This provides authentication and persistence - podcast = Podcast( - title=podcast_title, - podcast_transcript=serializable_transcript, - file_location=file_path, search_space_id=search_space_id, - # chat_id is None since new-chat uses LangGraph threads, not DB chats - chat_id=None, - chat_state_version=None, + user_id=str(user_id), + podcast_title=podcast_title, + user_prompt=user_prompt, ) - db_session.add(podcast) - await db_session.commit() - await db_session.refresh(podcast) - # Return podcast_id - frontend will use it to call the API endpoint - # GET /api/v1/podcasts/{podcast_id}/stream (like the old system) + print(f"[generate_podcast] Submitted Celery task: {task.id}") + + # Return immediately with task_id for polling return { - "status": "success", - "podcast_id": podcast.id, + "status": "processing", + "task_id": task.id, "title": podcast_title, - "transcript": "\n\n".join(full_transcript), - "duration_ms": estimated_duration_ms, - "transcript_entries": len(podcast_transcript), + "message": "Podcast generation started. This may take a few minutes.", } except Exception as e: error_message = str(e) - print(f"[generate_podcast] Error: {error_message}") - # Rollback on error - await db_session.rollback() + print(f"[generate_podcast] Error submitting task: {error_message}") return { "status": "error", "error": error_message, "title": podcast_title, - "podcast_id": None, - "duration_ms": 0, - "transcript_entries": 0, + "task_id": None, } return generate_podcast - diff --git a/surfsense_backend/app/routes/podcasts_routes.py b/surfsense_backend/app/routes/podcasts_routes.py index deb9d9744..904de20a3 100644 --- a/surfsense_backend/app/routes/podcasts_routes.py +++ b/surfsense_backend/app/routes/podcasts_routes.py @@ -444,3 +444,66 @@ async def get_podcast_by_chat_id( raise HTTPException( status_code=500, detail=f"Error fetching podcast: {e!s}" ) from e + + +@router.get("/podcasts/task/{task_id}/status") +async def get_podcast_task_status( + task_id: str, + user: User = Depends(current_active_user), +): + """ + Get the status of a podcast generation task. + Used by new-chat frontend to poll for completion. + + Returns: + - status: "processing" | "success" | "error" + - podcast_id: (only if status == "success") + - title: (only if status == "success") + - error: (only if status == "error") + """ + try: + from celery.result import AsyncResult + + from app.celery_app import celery_app + + result = AsyncResult(task_id, app=celery_app) + + if result.ready(): + # Task completed + if result.successful(): + task_result = result.result + if isinstance(task_result, dict): + if task_result.get("status") == "success": + return { + "status": "success", + "podcast_id": task_result.get("podcast_id"), + "title": task_result.get("title"), + "transcript_entries": task_result.get("transcript_entries"), + } + else: + return { + "status": "error", + "error": task_result.get("error", "Unknown error"), + } + else: + return { + "status": "error", + "error": "Unexpected task result format", + } + else: + # Task failed + return { + "status": "error", + "error": str(result.result) if result.result else "Task failed", + } + else: + # Task still processing + return { + "status": "processing", + "state": result.state, + } + + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Error checking task status: {e!s}" + ) from e diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py index 65cdb886b..994f67be7 100644 --- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -11,6 +11,11 @@ from app.celery_app import celery_app from app.config import config from app.tasks.podcast_tasks import generate_chat_podcast +# Import for content-based podcast (new-chat) +from app.agents.podcaster.graph import graph as podcaster_graph +from app.agents.podcaster.state import State as PodcasterState +from app.db import Podcast + logger = logging.getLogger(__name__) if sys.platform.startswith("win"): @@ -86,3 +91,131 @@ async def _generate_chat_podcast( except Exception as e: logger.error(f"Error generating podcast from chat: {e!s}") raise + + +# ============================================================================= +# Content-based podcast generation (for new-chat) +# ============================================================================= + + +@celery_app.task(name="generate_content_podcast", bind=True) +def generate_content_podcast_task( + self, + source_content: str, + search_space_id: int, + user_id: str, + podcast_title: str = "SurfSense Podcast", + user_prompt: str | None = None, +) -> dict: + """ + Celery task to generate podcast from source content (for new-chat). + + Unlike generate_chat_podcast which requires a chat_id, this task + generates a podcast directly from provided content. + + Args: + source_content: The text content to convert into a podcast + search_space_id: ID of the search space + user_id: ID of the user (as string) + podcast_title: Title for the podcast + user_prompt: Optional instructions for podcast style/tone + + Returns: + dict with podcast_id on success, or error info on failure + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete( + _generate_content_podcast( + source_content, + search_space_id, + user_id, + podcast_title, + user_prompt, + ) + ) + loop.run_until_complete(loop.shutdown_asyncgens()) + return result + except Exception as e: + logger.error(f"Error generating content podcast: {e!s}") + return {"status": "error", "error": str(e)} + finally: + asyncio.set_event_loop(None) + loop.close() + + +async def _generate_content_podcast( + source_content: str, + search_space_id: int, + user_id: str, + podcast_title: str = "SurfSense Podcast", + user_prompt: str | None = None, +) -> dict: + """Generate content-based podcast with new session.""" + async with get_celery_session_maker()() as session: + try: + # Configure the podcaster graph + graph_config = { + "configurable": { + "podcast_title": podcast_title, + "user_id": str(user_id), + "search_space_id": search_space_id, + "user_prompt": user_prompt, + } + } + + # Initialize the podcaster state with the source content + initial_state = PodcasterState( + source_content=source_content, + db_session=session, + ) + + # Run the podcaster graph + result = await podcaster_graph.ainvoke(initial_state, config=graph_config) + + # Extract results + podcast_transcript = result.get("podcast_transcript", []) + file_path = result.get("final_podcast_file_path", "") + + # Convert transcript to serializable format + serializable_transcript = [] + for entry in podcast_transcript: + if hasattr(entry, "speaker_id"): + serializable_transcript.append({ + "speaker_id": entry.speaker_id, + "dialog": entry.dialog + }) + else: + serializable_transcript.append({ + "speaker_id": entry.get("speaker_id", 0), + "dialog": entry.get("dialog", "") + }) + + # Save podcast to database + podcast = Podcast( + title=podcast_title, + podcast_transcript=serializable_transcript, + file_location=file_path, + search_space_id=search_space_id, + chat_id=None, # No chat_id for new-chat podcasts + chat_state_version=None, + ) + session.add(podcast) + await session.commit() + await session.refresh(podcast) + + logger.info(f"Successfully generated content podcast: {podcast.id}") + + return { + "status": "success", + "podcast_id": podcast.id, + "title": podcast_title, + "transcript_entries": len(serializable_transcript), + } + + except Exception as e: + logger.error(f"Error in _generate_content_podcast: {e!s}") + await session.rollback() + raise diff --git a/surfsense_web/components/tool-ui/generate-podcast.tsx b/surfsense_web/components/tool-ui/generate-podcast.tsx index 0aa50aea2..fc40d85b9 100644 --- a/surfsense_web/components/tool-ui/generate-podcast.tsx +++ b/surfsense_web/components/tool-ui/generate-podcast.tsx @@ -4,6 +4,7 @@ import { makeAssistantToolUI } from "@assistant-ui/react"; import { AlertCircleIcon, Loader2Icon, MicIcon } from "lucide-react"; import { useCallback, useEffect, useRef, useState } from "react"; import { Audio } from "@/components/tool-ui/audio"; +import { baseApiService } from "@/lib/apis/base-api.service"; import { podcastsApiService } from "@/lib/apis/podcasts-api.service"; /** @@ -16,12 +17,21 @@ interface GeneratePodcastArgs { } interface GeneratePodcastResult { - status: "success" | "error"; + status: "processing" | "success" | "error"; + task_id?: string; podcast_id?: number; title?: string; - transcript?: string; - duration_ms?: number; transcript_entries?: number; + message?: string; + error?: string; +} + +interface TaskStatusResponse { + status: "processing" | "success" | "error"; + podcast_id?: number; + title?: string; + transcript_entries?: number; + state?: string; error?: string; } @@ -106,15 +116,11 @@ function PodcastPlayer({ title, description, durationMs, - transcript, - transcriptEntries, }: { podcastId: number; title: string; description: string; durationMs?: number; - transcript?: string; - transcriptEntries?: number; }) { const [audioSrc, setAudioSrc] = useState(null); const [isLoading, setIsLoading] = useState(true); @@ -194,29 +200,96 @@ function PodcastPlayer({ durationMs={durationMs} className="w-full" /> - {/* Full transcript */} - {transcript && ( -
- - View full transcript{transcriptEntries ? ` (${transcriptEntries} entries)` : ""} - -
-						{transcript}
-					
-
- )} ); } +/** + * Polling component that checks task status and shows player when complete + */ +function PodcastTaskPoller({ + taskId, + title, +}: { + taskId: string; + title: string; +}) { + const [taskStatus, setTaskStatus] = useState({ status: "processing" }); + const [pollCount, setPollCount] = useState(0); + const pollingRef = useRef(null); + + // Poll for task status + useEffect(() => { + const pollStatus = async () => { + try { + const response = await baseApiService.get( + `/api/v1/podcasts/task/${taskId}/status` + ); + setTaskStatus(response); + + // Stop polling if task is complete or errored + if (response.status !== "processing") { + if (pollingRef.current) { + clearInterval(pollingRef.current); + pollingRef.current = null; + } + } + } catch (err) { + console.error("Error polling task status:", err); + // Don't stop polling on network errors, just increment count + } + setPollCount((prev) => prev + 1); + }; + + // Initial poll + pollStatus(); + + // Poll every 5 seconds + pollingRef.current = setInterval(pollStatus, 5000); + + return () => { + if (pollingRef.current) { + clearInterval(pollingRef.current); + } + }; + }, [taskId]); + + // Show loading state while processing + if (taskStatus.status === "processing") { + return ; + } + + // Show error state + if (taskStatus.status === "error") { + return ; + } + + // Show player when complete + if (taskStatus.status === "success" && taskStatus.podcast_id) { + return ( + + ); + } + + // Fallback + return ; +} + /** * Generate Podcast Tool UI Component * * This component is registered with assistant-ui to render custom UI * when the generate_podcast tool is called by the agent. - * - * It fetches the podcast audio with authentication (like the old system) - * and displays it using the Audio component. + * + * It polls for task completion and auto-updates when the podcast is ready. */ export const GeneratePodcastToolUI = makeAssistantToolUI< GeneratePodcastArgs, @@ -226,7 +299,7 @@ export const GeneratePodcastToolUI = makeAssistantToolUI< render: function GeneratePodcastUI({ args, result, status }) { const title = args.podcast_title || "SurfSense Podcast"; - // Loading state - podcast is being generated + // Loading state - tool is still running (agent processing) if (status.type === "running" || status.type === "requires-action") { return ; } @@ -263,26 +336,27 @@ export const GeneratePodcastToolUI = makeAssistantToolUI< return ; } - // Success - need podcast_id to fetch with auth - if (!result.podcast_id) { - return ; + // Processing - poll for completion + if (result.status === "processing" && result.task_id) { + return ; } - // Render the podcast player (handles auth fetch internally) - return ( - - ); + // Success with podcast_id (direct result, not via polling) + if (result.status === "success" && result.podcast_id) { + return ( + + ); + } + + // Fallback - missing required data + return ; }, }); -