feat: implement background podcast generation with Celery and task polling in UI

This commit is contained in:
Anish Sarkar 2025-12-21 19:35:00 +05:30
parent 4c4e4b3c4c
commit e79e1187b2
4 changed files with 335 additions and 134 deletions

View file

@ -2,8 +2,8 @@
Podcast generation tool for the new chat agent. Podcast generation tool for the new chat agent.
This module provides a factory function for creating the generate_podcast tool This module provides a factory function for creating the generate_podcast tool
that integrates with the existing podcaster agent. Podcasts are saved to the that submits a Celery task for background podcast generation. The frontend
database like the old system, providing authentication and persistence. polls for completion and auto-updates when the podcast is ready.
""" """
from typing import Any from typing import Any
@ -11,10 +11,6 @@ from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.db import Podcast
def create_generate_podcast_tool( def create_generate_podcast_tool(
search_space_id: int, search_space_id: int,
@ -26,7 +22,7 @@ def create_generate_podcast_tool(
Args: Args:
search_space_id: The user's search space ID search_space_id: The user's search space ID
db_session: Database session db_session: Database session (not used - Celery creates its own)
user_id: The user's ID (as string) user_id: The user's ID (as string)
Returns: Returns:
@ -50,8 +46,8 @@ def create_generate_podcast_tool(
- "Make a podcast about..." - "Make a podcast about..."
- "Turn this into a podcast" - "Turn this into a podcast"
The tool will generate a complete audio podcast with two speakers The tool will start generating a podcast in the background.
discussing the provided content in an engaging conversational format. The podcast will be available once generation completes.
Args: Args:
source_content: The text content to convert into a podcast. source_content: The text content to convert into a podcast.
@ -63,108 +59,43 @@ def create_generate_podcast_tool(
Returns: Returns:
A dictionary containing: A dictionary containing:
- status: "success" or "error" - status: "processing" (task submitted) or "error"
- podcast_id: The database ID of the saved podcast (for API access) - task_id: The Celery task ID for polling status
- title: The podcast title - title: The podcast title
- transcript: Full podcast transcript with all dialogue entries
- duration_ms: Estimated podcast duration in milliseconds
- transcript_entries: Number of dialogue entries
""" """
try: try:
# Configure the podcaster graph # Import Celery task here to avoid circular imports
config = { from app.tasks.celery_tasks.podcast_tasks import (
"configurable": { generate_content_podcast_task,
"podcast_title": podcast_title, )
"user_id": str(user_id),
"search_space_id": search_space_id,
"user_prompt": user_prompt,
}
}
# Initialize the podcaster state with the source content # Submit Celery task for background processing
initial_state = PodcasterState( task = generate_content_podcast_task.delay(
source_content=source_content, source_content=source_content,
db_session=db_session,
)
# Run the podcaster graph
result = await podcaster_graph.ainvoke(initial_state, config=config)
# Extract results
podcast_transcript = result.get("podcast_transcript", [])
file_path = result.get("final_podcast_file_path", "")
# Calculate estimated duration (rough estimate: ~150 words per minute)
total_words = sum(
len(entry.dialog.split()) if hasattr(entry, "dialog") else len(entry.get("dialog", "").split())
for entry in podcast_transcript
)
estimated_duration_ms = int((total_words / 150) * 60 * 1000)
# Create full transcript for display (all entries, complete dialog)
full_transcript = []
for entry in podcast_transcript:
if hasattr(entry, "speaker_id"):
speaker = f"Speaker {entry.speaker_id + 1}"
dialog = entry.dialog
else:
speaker = f"Speaker {entry.get('speaker_id', 0) + 1}"
dialog = entry.get("dialog", "")
full_transcript.append(f"{speaker}: {dialog}")
# Convert podcast transcript entries to serializable format (like old system)
serializable_transcript = []
for entry in podcast_transcript:
if hasattr(entry, "speaker_id"):
serializable_transcript.append({
"speaker_id": entry.speaker_id,
"dialog": entry.dialog
})
else:
serializable_transcript.append({
"speaker_id": entry.get("speaker_id", 0),
"dialog": entry.get("dialog", "")
})
# Save podcast to database (like old system)
# This provides authentication and persistence
podcast = Podcast(
title=podcast_title,
podcast_transcript=serializable_transcript,
file_location=file_path,
search_space_id=search_space_id, search_space_id=search_space_id,
# chat_id is None since new-chat uses LangGraph threads, not DB chats user_id=str(user_id),
chat_id=None, podcast_title=podcast_title,
chat_state_version=None, user_prompt=user_prompt,
) )
db_session.add(podcast)
await db_session.commit()
await db_session.refresh(podcast)
# Return podcast_id - frontend will use it to call the API endpoint print(f"[generate_podcast] Submitted Celery task: {task.id}")
# GET /api/v1/podcasts/{podcast_id}/stream (like the old system)
# Return immediately with task_id for polling
return { return {
"status": "success", "status": "processing",
"podcast_id": podcast.id, "task_id": task.id,
"title": podcast_title, "title": podcast_title,
"transcript": "\n\n".join(full_transcript), "message": "Podcast generation started. This may take a few minutes.",
"duration_ms": estimated_duration_ms,
"transcript_entries": len(podcast_transcript),
} }
except Exception as e: except Exception as e:
error_message = str(e) error_message = str(e)
print(f"[generate_podcast] Error: {error_message}") print(f"[generate_podcast] Error submitting task: {error_message}")
# Rollback on error
await db_session.rollback()
return { return {
"status": "error", "status": "error",
"error": error_message, "error": error_message,
"title": podcast_title, "title": podcast_title,
"podcast_id": None, "task_id": None,
"duration_ms": 0,
"transcript_entries": 0,
} }
return generate_podcast return generate_podcast

View file

@ -444,3 +444,66 @@ async def get_podcast_by_chat_id(
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Error fetching podcast: {e!s}" status_code=500, detail=f"Error fetching podcast: {e!s}"
) from e ) from e
@router.get("/podcasts/task/{task_id}/status")
async def get_podcast_task_status(
task_id: str,
user: User = Depends(current_active_user),
):
"""
Get the status of a podcast generation task.
Used by new-chat frontend to poll for completion.
Returns:
- status: "processing" | "success" | "error"
- podcast_id: (only if status == "success")
- title: (only if status == "success")
- error: (only if status == "error")
"""
try:
from celery.result import AsyncResult
from app.celery_app import celery_app
result = AsyncResult(task_id, app=celery_app)
if result.ready():
# Task completed
if result.successful():
task_result = result.result
if isinstance(task_result, dict):
if task_result.get("status") == "success":
return {
"status": "success",
"podcast_id": task_result.get("podcast_id"),
"title": task_result.get("title"),
"transcript_entries": task_result.get("transcript_entries"),
}
else:
return {
"status": "error",
"error": task_result.get("error", "Unknown error"),
}
else:
return {
"status": "error",
"error": "Unexpected task result format",
}
else:
# Task failed
return {
"status": "error",
"error": str(result.result) if result.result else "Task failed",
}
else:
# Task still processing
return {
"status": "processing",
"state": result.state,
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error checking task status: {e!s}"
) from e

View file

@ -11,6 +11,11 @@ from app.celery_app import celery_app
from app.config import config from app.config import config
from app.tasks.podcast_tasks import generate_chat_podcast from app.tasks.podcast_tasks import generate_chat_podcast
# Import for content-based podcast (new-chat)
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.db import Podcast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
@ -86,3 +91,131 @@ async def _generate_chat_podcast(
except Exception as e: except Exception as e:
logger.error(f"Error generating podcast from chat: {e!s}") logger.error(f"Error generating podcast from chat: {e!s}")
raise raise
# =============================================================================
# Content-based podcast generation (for new-chat)
# =============================================================================
@celery_app.task(name="generate_content_podcast", bind=True)
def generate_content_podcast_task(
self,
source_content: str,
search_space_id: int,
user_id: str,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
) -> dict:
"""
Celery task to generate podcast from source content (for new-chat).
Unlike generate_chat_podcast which requires a chat_id, this task
generates a podcast directly from provided content.
Args:
source_content: The text content to convert into a podcast
search_space_id: ID of the search space
user_id: ID of the user (as string)
podcast_title: Title for the podcast
user_prompt: Optional instructions for podcast style/tone
Returns:
dict with podcast_id on success, or error info on failure
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
_generate_content_podcast(
source_content,
search_space_id,
user_id,
podcast_title,
user_prompt,
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
return result
except Exception as e:
logger.error(f"Error generating content podcast: {e!s}")
return {"status": "error", "error": str(e)}
finally:
asyncio.set_event_loop(None)
loop.close()
async def _generate_content_podcast(
source_content: str,
search_space_id: int,
user_id: str,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
) -> dict:
"""Generate content-based podcast with new session."""
async with get_celery_session_maker()() as session:
try:
# Configure the podcaster graph
graph_config = {
"configurable": {
"podcast_title": podcast_title,
"user_id": str(user_id),
"search_space_id": search_space_id,
"user_prompt": user_prompt,
}
}
# Initialize the podcaster state with the source content
initial_state = PodcasterState(
source_content=source_content,
db_session=session,
)
# Run the podcaster graph
result = await podcaster_graph.ainvoke(initial_state, config=graph_config)
# Extract results
podcast_transcript = result.get("podcast_transcript", [])
file_path = result.get("final_podcast_file_path", "")
# Convert transcript to serializable format
serializable_transcript = []
for entry in podcast_transcript:
if hasattr(entry, "speaker_id"):
serializable_transcript.append({
"speaker_id": entry.speaker_id,
"dialog": entry.dialog
})
else:
serializable_transcript.append({
"speaker_id": entry.get("speaker_id", 0),
"dialog": entry.get("dialog", "")
})
# Save podcast to database
podcast = Podcast(
title=podcast_title,
podcast_transcript=serializable_transcript,
file_location=file_path,
search_space_id=search_space_id,
chat_id=None, # No chat_id for new-chat podcasts
chat_state_version=None,
)
session.add(podcast)
await session.commit()
await session.refresh(podcast)
logger.info(f"Successfully generated content podcast: {podcast.id}")
return {
"status": "success",
"podcast_id": podcast.id,
"title": podcast_title,
"transcript_entries": len(serializable_transcript),
}
except Exception as e:
logger.error(f"Error in _generate_content_podcast: {e!s}")
await session.rollback()
raise

View file

@ -4,6 +4,7 @@ import { makeAssistantToolUI } from "@assistant-ui/react";
import { AlertCircleIcon, Loader2Icon, MicIcon } from "lucide-react"; import { AlertCircleIcon, Loader2Icon, MicIcon } from "lucide-react";
import { useCallback, useEffect, useRef, useState } from "react"; import { useCallback, useEffect, useRef, useState } from "react";
import { Audio } from "@/components/tool-ui/audio"; import { Audio } from "@/components/tool-ui/audio";
import { baseApiService } from "@/lib/apis/base-api.service";
import { podcastsApiService } from "@/lib/apis/podcasts-api.service"; import { podcastsApiService } from "@/lib/apis/podcasts-api.service";
/** /**
@ -16,12 +17,21 @@ interface GeneratePodcastArgs {
} }
interface GeneratePodcastResult { interface GeneratePodcastResult {
status: "success" | "error"; status: "processing" | "success" | "error";
task_id?: string;
podcast_id?: number; podcast_id?: number;
title?: string; title?: string;
transcript?: string;
duration_ms?: number;
transcript_entries?: number; transcript_entries?: number;
message?: string;
error?: string;
}
interface TaskStatusResponse {
status: "processing" | "success" | "error";
podcast_id?: number;
title?: string;
transcript_entries?: number;
state?: string;
error?: string; error?: string;
} }
@ -106,15 +116,11 @@ function PodcastPlayer({
title, title,
description, description,
durationMs, durationMs,
transcript,
transcriptEntries,
}: { }: {
podcastId: number; podcastId: number;
title: string; title: string;
description: string; description: string;
durationMs?: number; durationMs?: number;
transcript?: string;
transcriptEntries?: number;
}) { }) {
const [audioSrc, setAudioSrc] = useState<string | null>(null); const [audioSrc, setAudioSrc] = useState<string | null>(null);
const [isLoading, setIsLoading] = useState(true); const [isLoading, setIsLoading] = useState(true);
@ -194,29 +200,96 @@ function PodcastPlayer({
durationMs={durationMs} durationMs={durationMs}
className="w-full" className="w-full"
/> />
{/* Full transcript */}
{transcript && (
<details className="mt-3 rounded-lg border bg-muted/30 p-3">
<summary className="cursor-pointer font-medium text-muted-foreground text-sm hover:text-foreground">
View full transcript{transcriptEntries ? ` (${transcriptEntries} entries)` : ""}
</summary>
<pre className="mt-2 max-h-96 overflow-y-auto whitespace-pre-wrap text-muted-foreground text-xs">
{transcript}
</pre>
</details>
)}
</div> </div>
); );
} }
/**
* Polling component that checks task status and shows player when complete
*/
function PodcastTaskPoller({
taskId,
title,
}: {
taskId: string;
title: string;
}) {
const [taskStatus, setTaskStatus] = useState<TaskStatusResponse>({ status: "processing" });
const [pollCount, setPollCount] = useState(0);
const pollingRef = useRef<NodeJS.Timeout | null>(null);
// Poll for task status
useEffect(() => {
const pollStatus = async () => {
try {
const response = await baseApiService.get<TaskStatusResponse>(
`/api/v1/podcasts/task/${taskId}/status`
);
setTaskStatus(response);
// Stop polling if task is complete or errored
if (response.status !== "processing") {
if (pollingRef.current) {
clearInterval(pollingRef.current);
pollingRef.current = null;
}
}
} catch (err) {
console.error("Error polling task status:", err);
// Don't stop polling on network errors, just increment count
}
setPollCount((prev) => prev + 1);
};
// Initial poll
pollStatus();
// Poll every 5 seconds
pollingRef.current = setInterval(pollStatus, 5000);
return () => {
if (pollingRef.current) {
clearInterval(pollingRef.current);
}
};
}, [taskId]);
// Show loading state while processing
if (taskStatus.status === "processing") {
return <PodcastGeneratingState title={title} />;
}
// Show error state
if (taskStatus.status === "error") {
return <PodcastErrorState title={title} error={taskStatus.error || "Generation failed"} />;
}
// Show player when complete
if (taskStatus.status === "success" && taskStatus.podcast_id) {
return (
<PodcastPlayer
podcastId={taskStatus.podcast_id}
title={taskStatus.title || title}
description={
taskStatus.transcript_entries
? `${taskStatus.transcript_entries} dialogue entries`
: "SurfSense AI-generated podcast"
}
/>
);
}
// Fallback
return <PodcastErrorState title={title} error="Unexpected state" />;
}
/** /**
* Generate Podcast Tool UI Component * Generate Podcast Tool UI Component
* *
* This component is registered with assistant-ui to render custom UI * This component is registered with assistant-ui to render custom UI
* when the generate_podcast tool is called by the agent. * when the generate_podcast tool is called by the agent.
* *
* It fetches the podcast audio with authentication (like the old system) * It polls for task completion and auto-updates when the podcast is ready.
* and displays it using the Audio component.
*/ */
export const GeneratePodcastToolUI = makeAssistantToolUI< export const GeneratePodcastToolUI = makeAssistantToolUI<
GeneratePodcastArgs, GeneratePodcastArgs,
@ -226,7 +299,7 @@ export const GeneratePodcastToolUI = makeAssistantToolUI<
render: function GeneratePodcastUI({ args, result, status }) { render: function GeneratePodcastUI({ args, result, status }) {
const title = args.podcast_title || "SurfSense Podcast"; const title = args.podcast_title || "SurfSense Podcast";
// Loading state - podcast is being generated // Loading state - tool is still running (agent processing)
if (status.type === "running" || status.type === "requires-action") { if (status.type === "running" || status.type === "requires-action") {
return <PodcastGeneratingState title={title} />; return <PodcastGeneratingState title={title} />;
} }
@ -263,26 +336,27 @@ export const GeneratePodcastToolUI = makeAssistantToolUI<
return <PodcastErrorState title={title} error={result.error || "Unknown error"} />; return <PodcastErrorState title={title} error={result.error || "Unknown error"} />;
} }
// Success - need podcast_id to fetch with auth // Processing - poll for completion
if (!result.podcast_id) { if (result.status === "processing" && result.task_id) {
return <PodcastErrorState title={title} error="Missing podcast ID" />; return <PodcastTaskPoller taskId={result.task_id} title={result.title || title} />;
} }
// Render the podcast player (handles auth fetch internally) // Success with podcast_id (direct result, not via polling)
return ( if (result.status === "success" && result.podcast_id) {
<PodcastPlayer return (
podcastId={result.podcast_id} <PodcastPlayer
title={result.title || title} podcastId={result.podcast_id}
description={ title={result.title || title}
result.transcript_entries description={
? `${result.transcript_entries} dialogue entries` result.transcript_entries
: "SurfSense AI-generated podcast" ? `${result.transcript_entries} dialogue entries`
} : "SurfSense AI-generated podcast"
durationMs={result.duration_ms} }
transcript={result.transcript} />
transcriptEntries={result.transcript_entries} );
/> }
);
// Fallback - missing required data
return <PodcastErrorState title={title} error="Missing task ID or podcast ID" />;
}, },
}); });