feat: implement background podcast generation with Celery and task polling in UI

This commit is contained in:
Anish Sarkar 2025-12-21 19:35:00 +05:30
parent 4c4e4b3c4c
commit e79e1187b2
4 changed files with 335 additions and 134 deletions

View file

@ -2,8 +2,8 @@
Podcast generation tool for the new chat agent.
This module provides a factory function for creating the generate_podcast tool
that integrates with the existing podcaster agent. Podcasts are saved to the
database like the old system, providing authentication and persistence.
that submits a Celery task for background podcast generation. The frontend
polls for completion and auto-updates when the podcast is ready.
"""
from typing import Any
@ -11,10 +11,6 @@ from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.db import Podcast
def create_generate_podcast_tool(
search_space_id: int,
@ -26,7 +22,7 @@ def create_generate_podcast_tool(
Args:
search_space_id: The user's search space ID
db_session: Database session
db_session: Database session (not used - Celery creates its own)
user_id: The user's ID (as string)
Returns:
@ -50,8 +46,8 @@ def create_generate_podcast_tool(
- "Make a podcast about..."
- "Turn this into a podcast"
The tool will generate a complete audio podcast with two speakers
discussing the provided content in an engaging conversational format.
The tool will start generating a podcast in the background.
The podcast will be available once generation completes.
Args:
source_content: The text content to convert into a podcast.
@ -63,108 +59,43 @@ def create_generate_podcast_tool(
Returns:
A dictionary containing:
- status: "success" or "error"
- podcast_id: The database ID of the saved podcast (for API access)
- status: "processing" (task submitted) or "error"
- task_id: The Celery task ID for polling status
- title: The podcast title
- transcript: Full podcast transcript with all dialogue entries
- duration_ms: Estimated podcast duration in milliseconds
- transcript_entries: Number of dialogue entries
"""
try:
# Configure the podcaster graph
config = {
"configurable": {
"podcast_title": podcast_title,
"user_id": str(user_id),
"search_space_id": search_space_id,
"user_prompt": user_prompt,
}
}
# Import Celery task here to avoid circular imports
from app.tasks.celery_tasks.podcast_tasks import (
generate_content_podcast_task,
)
# Initialize the podcaster state with the source content
initial_state = PodcasterState(
# Submit Celery task for background processing
task = generate_content_podcast_task.delay(
source_content=source_content,
db_session=db_session,
)
# Run the podcaster graph
result = await podcaster_graph.ainvoke(initial_state, config=config)
# Extract results
podcast_transcript = result.get("podcast_transcript", [])
file_path = result.get("final_podcast_file_path", "")
# Calculate estimated duration (rough estimate: ~150 words per minute)
total_words = sum(
len(entry.dialog.split()) if hasattr(entry, "dialog") else len(entry.get("dialog", "").split())
for entry in podcast_transcript
)
estimated_duration_ms = int((total_words / 150) * 60 * 1000)
# Create full transcript for display (all entries, complete dialog)
full_transcript = []
for entry in podcast_transcript:
if hasattr(entry, "speaker_id"):
speaker = f"Speaker {entry.speaker_id + 1}"
dialog = entry.dialog
else:
speaker = f"Speaker {entry.get('speaker_id', 0) + 1}"
dialog = entry.get("dialog", "")
full_transcript.append(f"{speaker}: {dialog}")
# Convert podcast transcript entries to serializable format (like old system)
serializable_transcript = []
for entry in podcast_transcript:
if hasattr(entry, "speaker_id"):
serializable_transcript.append({
"speaker_id": entry.speaker_id,
"dialog": entry.dialog
})
else:
serializable_transcript.append({
"speaker_id": entry.get("speaker_id", 0),
"dialog": entry.get("dialog", "")
})
# Save podcast to database (like old system)
# This provides authentication and persistence
podcast = Podcast(
title=podcast_title,
podcast_transcript=serializable_transcript,
file_location=file_path,
search_space_id=search_space_id,
# chat_id is None since new-chat uses LangGraph threads, not DB chats
chat_id=None,
chat_state_version=None,
user_id=str(user_id),
podcast_title=podcast_title,
user_prompt=user_prompt,
)
db_session.add(podcast)
await db_session.commit()
await db_session.refresh(podcast)
# Return podcast_id - frontend will use it to call the API endpoint
# GET /api/v1/podcasts/{podcast_id}/stream (like the old system)
print(f"[generate_podcast] Submitted Celery task: {task.id}")
# Return immediately with task_id for polling
return {
"status": "success",
"podcast_id": podcast.id,
"status": "processing",
"task_id": task.id,
"title": podcast_title,
"transcript": "\n\n".join(full_transcript),
"duration_ms": estimated_duration_ms,
"transcript_entries": len(podcast_transcript),
"message": "Podcast generation started. This may take a few minutes.",
}
except Exception as e:
error_message = str(e)
print(f"[generate_podcast] Error: {error_message}")
# Rollback on error
await db_session.rollback()
print(f"[generate_podcast] Error submitting task: {error_message}")
return {
"status": "error",
"error": error_message,
"title": podcast_title,
"podcast_id": None,
"duration_ms": 0,
"transcript_entries": 0,
"task_id": None,
}
return generate_podcast