mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-11 00:32:38 +02:00
feat: implement background podcast generation with Celery and task polling in UI
This commit is contained in:
parent
4c4e4b3c4c
commit
e79e1187b2
4 changed files with 335 additions and 134 deletions
|
|
@ -2,8 +2,8 @@
|
|||
Podcast generation tool for the new chat agent.
|
||||
|
||||
This module provides a factory function for creating the generate_podcast tool
|
||||
that integrates with the existing podcaster agent. Podcasts are saved to the
|
||||
database like the old system, providing authentication and persistence.
|
||||
that submits a Celery task for background podcast generation. The frontend
|
||||
polls for completion and auto-updates when the podcast is ready.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
|
@ -11,10 +11,6 @@ from typing import Any
|
|||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.db import Podcast
|
||||
|
||||
|
||||
def create_generate_podcast_tool(
|
||||
search_space_id: int,
|
||||
|
|
@ -26,7 +22,7 @@ def create_generate_podcast_tool(
|
|||
|
||||
Args:
|
||||
search_space_id: The user's search space ID
|
||||
db_session: Database session
|
||||
db_session: Database session (not used - Celery creates its own)
|
||||
user_id: The user's ID (as string)
|
||||
|
||||
Returns:
|
||||
|
|
@ -50,8 +46,8 @@ def create_generate_podcast_tool(
|
|||
- "Make a podcast about..."
|
||||
- "Turn this into a podcast"
|
||||
|
||||
The tool will generate a complete audio podcast with two speakers
|
||||
discussing the provided content in an engaging conversational format.
|
||||
The tool will start generating a podcast in the background.
|
||||
The podcast will be available once generation completes.
|
||||
|
||||
Args:
|
||||
source_content: The text content to convert into a podcast.
|
||||
|
|
@ -63,108 +59,43 @@ def create_generate_podcast_tool(
|
|||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- status: "success" or "error"
|
||||
- podcast_id: The database ID of the saved podcast (for API access)
|
||||
- status: "processing" (task submitted) or "error"
|
||||
- task_id: The Celery task ID for polling status
|
||||
- title: The podcast title
|
||||
- transcript: Full podcast transcript with all dialogue entries
|
||||
- duration_ms: Estimated podcast duration in milliseconds
|
||||
- transcript_entries: Number of dialogue entries
|
||||
"""
|
||||
try:
|
||||
# Configure the podcaster graph
|
||||
config = {
|
||||
"configurable": {
|
||||
"podcast_title": podcast_title,
|
||||
"user_id": str(user_id),
|
||||
"search_space_id": search_space_id,
|
||||
"user_prompt": user_prompt,
|
||||
}
|
||||
}
|
||||
# Import Celery task here to avoid circular imports
|
||||
from app.tasks.celery_tasks.podcast_tasks import (
|
||||
generate_content_podcast_task,
|
||||
)
|
||||
|
||||
# Initialize the podcaster state with the source content
|
||||
initial_state = PodcasterState(
|
||||
# Submit Celery task for background processing
|
||||
task = generate_content_podcast_task.delay(
|
||||
source_content=source_content,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Run the podcaster graph
|
||||
result = await podcaster_graph.ainvoke(initial_state, config=config)
|
||||
|
||||
# Extract results
|
||||
podcast_transcript = result.get("podcast_transcript", [])
|
||||
file_path = result.get("final_podcast_file_path", "")
|
||||
|
||||
# Calculate estimated duration (rough estimate: ~150 words per minute)
|
||||
total_words = sum(
|
||||
len(entry.dialog.split()) if hasattr(entry, "dialog") else len(entry.get("dialog", "").split())
|
||||
for entry in podcast_transcript
|
||||
)
|
||||
estimated_duration_ms = int((total_words / 150) * 60 * 1000)
|
||||
|
||||
# Create full transcript for display (all entries, complete dialog)
|
||||
full_transcript = []
|
||||
for entry in podcast_transcript:
|
||||
if hasattr(entry, "speaker_id"):
|
||||
speaker = f"Speaker {entry.speaker_id + 1}"
|
||||
dialog = entry.dialog
|
||||
else:
|
||||
speaker = f"Speaker {entry.get('speaker_id', 0) + 1}"
|
||||
dialog = entry.get("dialog", "")
|
||||
full_transcript.append(f"{speaker}: {dialog}")
|
||||
|
||||
# Convert podcast transcript entries to serializable format (like old system)
|
||||
serializable_transcript = []
|
||||
for entry in podcast_transcript:
|
||||
if hasattr(entry, "speaker_id"):
|
||||
serializable_transcript.append({
|
||||
"speaker_id": entry.speaker_id,
|
||||
"dialog": entry.dialog
|
||||
})
|
||||
else:
|
||||
serializable_transcript.append({
|
||||
"speaker_id": entry.get("speaker_id", 0),
|
||||
"dialog": entry.get("dialog", "")
|
||||
})
|
||||
|
||||
# Save podcast to database (like old system)
|
||||
# This provides authentication and persistence
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
podcast_transcript=serializable_transcript,
|
||||
file_location=file_path,
|
||||
search_space_id=search_space_id,
|
||||
# chat_id is None since new-chat uses LangGraph threads, not DB chats
|
||||
chat_id=None,
|
||||
chat_state_version=None,
|
||||
user_id=str(user_id),
|
||||
podcast_title=podcast_title,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
db_session.add(podcast)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(podcast)
|
||||
|
||||
# Return podcast_id - frontend will use it to call the API endpoint
|
||||
# GET /api/v1/podcasts/{podcast_id}/stream (like the old system)
|
||||
print(f"[generate_podcast] Submitted Celery task: {task.id}")
|
||||
|
||||
# Return immediately with task_id for polling
|
||||
return {
|
||||
"status": "success",
|
||||
"podcast_id": podcast.id,
|
||||
"status": "processing",
|
||||
"task_id": task.id,
|
||||
"title": podcast_title,
|
||||
"transcript": "\n\n".join(full_transcript),
|
||||
"duration_ms": estimated_duration_ms,
|
||||
"transcript_entries": len(podcast_transcript),
|
||||
"message": "Podcast generation started. This may take a few minutes.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
print(f"[generate_podcast] Error: {error_message}")
|
||||
# Rollback on error
|
||||
await db_session.rollback()
|
||||
print(f"[generate_podcast] Error submitting task: {error_message}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": error_message,
|
||||
"title": podcast_title,
|
||||
"podcast_id": None,
|
||||
"duration_ms": 0,
|
||||
"transcript_entries": 0,
|
||||
"task_id": None,
|
||||
}
|
||||
|
||||
return generate_podcast
|
||||
|
||||
|
|
|
|||
|
|
@ -444,3 +444,66 @@ async def get_podcast_by_chat_id(
|
|||
raise HTTPException(
|
||||
status_code=500, detail=f"Error fetching podcast: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/podcasts/task/{task_id}/status")
|
||||
async def get_podcast_task_status(
|
||||
task_id: str,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get the status of a podcast generation task.
|
||||
Used by new-chat frontend to poll for completion.
|
||||
|
||||
Returns:
|
||||
- status: "processing" | "success" | "error"
|
||||
- podcast_id: (only if status == "success")
|
||||
- title: (only if status == "success")
|
||||
- error: (only if status == "error")
|
||||
"""
|
||||
try:
|
||||
from celery.result import AsyncResult
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
result = AsyncResult(task_id, app=celery_app)
|
||||
|
||||
if result.ready():
|
||||
# Task completed
|
||||
if result.successful():
|
||||
task_result = result.result
|
||||
if isinstance(task_result, dict):
|
||||
if task_result.get("status") == "success":
|
||||
return {
|
||||
"status": "success",
|
||||
"podcast_id": task_result.get("podcast_id"),
|
||||
"title": task_result.get("title"),
|
||||
"transcript_entries": task_result.get("transcript_entries"),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": task_result.get("error", "Unknown error"),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": "Unexpected task result format",
|
||||
}
|
||||
else:
|
||||
# Task failed
|
||||
return {
|
||||
"status": "error",
|
||||
"error": str(result.result) if result.result else "Task failed",
|
||||
}
|
||||
else:
|
||||
# Task still processing
|
||||
return {
|
||||
"status": "processing",
|
||||
"state": result.state,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error checking task status: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -11,6 +11,11 @@ from app.celery_app import celery_app
|
|||
from app.config import config
|
||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||
|
||||
# Import for content-based podcast (new-chat)
|
||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.db import Podcast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
|
|
@ -86,3 +91,131 @@ async def _generate_chat_podcast(
|
|||
except Exception as e:
|
||||
logger.error(f"Error generating podcast from chat: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Content-based podcast generation (for new-chat)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@celery_app.task(name="generate_content_podcast", bind=True)
|
||||
def generate_content_podcast_task(
|
||||
self,
|
||||
source_content: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
podcast_title: str = "SurfSense Podcast",
|
||||
user_prompt: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Celery task to generate podcast from source content (for new-chat).
|
||||
|
||||
Unlike generate_chat_podcast which requires a chat_id, this task
|
||||
generates a podcast directly from provided content.
|
||||
|
||||
Args:
|
||||
source_content: The text content to convert into a podcast
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user (as string)
|
||||
podcast_title: Title for the podcast
|
||||
user_prompt: Optional instructions for podcast style/tone
|
||||
|
||||
Returns:
|
||||
dict with podcast_id on success, or error info on failure
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
result = loop.run_until_complete(
|
||||
_generate_content_podcast(
|
||||
source_content,
|
||||
search_space_id,
|
||||
user_id,
|
||||
podcast_title,
|
||||
user_prompt,
|
||||
)
|
||||
)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating content podcast: {e!s}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _generate_content_podcast(
|
||||
source_content: str,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
podcast_title: str = "SurfSense Podcast",
|
||||
user_prompt: str | None = None,
|
||||
) -> dict:
|
||||
"""Generate content-based podcast with new session."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
try:
|
||||
# Configure the podcaster graph
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"podcast_title": podcast_title,
|
||||
"user_id": str(user_id),
|
||||
"search_space_id": search_space_id,
|
||||
"user_prompt": user_prompt,
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize the podcaster state with the source content
|
||||
initial_state = PodcasterState(
|
||||
source_content=source_content,
|
||||
db_session=session,
|
||||
)
|
||||
|
||||
# Run the podcaster graph
|
||||
result = await podcaster_graph.ainvoke(initial_state, config=graph_config)
|
||||
|
||||
# Extract results
|
||||
podcast_transcript = result.get("podcast_transcript", [])
|
||||
file_path = result.get("final_podcast_file_path", "")
|
||||
|
||||
# Convert transcript to serializable format
|
||||
serializable_transcript = []
|
||||
for entry in podcast_transcript:
|
||||
if hasattr(entry, "speaker_id"):
|
||||
serializable_transcript.append({
|
||||
"speaker_id": entry.speaker_id,
|
||||
"dialog": entry.dialog
|
||||
})
|
||||
else:
|
||||
serializable_transcript.append({
|
||||
"speaker_id": entry.get("speaker_id", 0),
|
||||
"dialog": entry.get("dialog", "")
|
||||
})
|
||||
|
||||
# Save podcast to database
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
podcast_transcript=serializable_transcript,
|
||||
file_location=file_path,
|
||||
search_space_id=search_space_id,
|
||||
chat_id=None, # No chat_id for new-chat podcasts
|
||||
chat_state_version=None,
|
||||
)
|
||||
session.add(podcast)
|
||||
await session.commit()
|
||||
await session.refresh(podcast)
|
||||
|
||||
logger.info(f"Successfully generated content podcast: {podcast.id}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"podcast_id": podcast.id,
|
||||
"title": podcast_title,
|
||||
"transcript_entries": len(serializable_transcript),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _generate_content_podcast: {e!s}")
|
||||
await session.rollback()
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import { makeAssistantToolUI } from "@assistant-ui/react";
|
|||
import { AlertCircleIcon, Loader2Icon, MicIcon } from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { Audio } from "@/components/tool-ui/audio";
|
||||
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||
import { podcastsApiService } from "@/lib/apis/podcasts-api.service";
|
||||
|
||||
/**
|
||||
|
|
@ -16,12 +17,21 @@ interface GeneratePodcastArgs {
|
|||
}
|
||||
|
||||
interface GeneratePodcastResult {
|
||||
status: "success" | "error";
|
||||
status: "processing" | "success" | "error";
|
||||
task_id?: string;
|
||||
podcast_id?: number;
|
||||
title?: string;
|
||||
transcript?: string;
|
||||
duration_ms?: number;
|
||||
transcript_entries?: number;
|
||||
message?: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface TaskStatusResponse {
|
||||
status: "processing" | "success" | "error";
|
||||
podcast_id?: number;
|
||||
title?: string;
|
||||
transcript_entries?: number;
|
||||
state?: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
|
|
@ -106,15 +116,11 @@ function PodcastPlayer({
|
|||
title,
|
||||
description,
|
||||
durationMs,
|
||||
transcript,
|
||||
transcriptEntries,
|
||||
}: {
|
||||
podcastId: number;
|
||||
title: string;
|
||||
description: string;
|
||||
durationMs?: number;
|
||||
transcript?: string;
|
||||
transcriptEntries?: number;
|
||||
}) {
|
||||
const [audioSrc, setAudioSrc] = useState<string | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
|
|
@ -194,29 +200,96 @@ function PodcastPlayer({
|
|||
durationMs={durationMs}
|
||||
className="w-full"
|
||||
/>
|
||||
{/* Full transcript */}
|
||||
{transcript && (
|
||||
<details className="mt-3 rounded-lg border bg-muted/30 p-3">
|
||||
<summary className="cursor-pointer font-medium text-muted-foreground text-sm hover:text-foreground">
|
||||
View full transcript{transcriptEntries ? ` (${transcriptEntries} entries)` : ""}
|
||||
</summary>
|
||||
<pre className="mt-2 max-h-96 overflow-y-auto whitespace-pre-wrap text-muted-foreground text-xs">
|
||||
{transcript}
|
||||
</pre>
|
||||
</details>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Polling component that checks task status and shows player when complete
|
||||
*/
|
||||
function PodcastTaskPoller({
|
||||
taskId,
|
||||
title,
|
||||
}: {
|
||||
taskId: string;
|
||||
title: string;
|
||||
}) {
|
||||
const [taskStatus, setTaskStatus] = useState<TaskStatusResponse>({ status: "processing" });
|
||||
const [pollCount, setPollCount] = useState(0);
|
||||
const pollingRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
// Poll for task status
|
||||
useEffect(() => {
|
||||
const pollStatus = async () => {
|
||||
try {
|
||||
const response = await baseApiService.get<TaskStatusResponse>(
|
||||
`/api/v1/podcasts/task/${taskId}/status`
|
||||
);
|
||||
setTaskStatus(response);
|
||||
|
||||
// Stop polling if task is complete or errored
|
||||
if (response.status !== "processing") {
|
||||
if (pollingRef.current) {
|
||||
clearInterval(pollingRef.current);
|
||||
pollingRef.current = null;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Error polling task status:", err);
|
||||
// Don't stop polling on network errors, just increment count
|
||||
}
|
||||
setPollCount((prev) => prev + 1);
|
||||
};
|
||||
|
||||
// Initial poll
|
||||
pollStatus();
|
||||
|
||||
// Poll every 5 seconds
|
||||
pollingRef.current = setInterval(pollStatus, 5000);
|
||||
|
||||
return () => {
|
||||
if (pollingRef.current) {
|
||||
clearInterval(pollingRef.current);
|
||||
}
|
||||
};
|
||||
}, [taskId]);
|
||||
|
||||
// Show loading state while processing
|
||||
if (taskStatus.status === "processing") {
|
||||
return <PodcastGeneratingState title={title} />;
|
||||
}
|
||||
|
||||
// Show error state
|
||||
if (taskStatus.status === "error") {
|
||||
return <PodcastErrorState title={title} error={taskStatus.error || "Generation failed"} />;
|
||||
}
|
||||
|
||||
// Show player when complete
|
||||
if (taskStatus.status === "success" && taskStatus.podcast_id) {
|
||||
return (
|
||||
<PodcastPlayer
|
||||
podcastId={taskStatus.podcast_id}
|
||||
title={taskStatus.title || title}
|
||||
description={
|
||||
taskStatus.transcript_entries
|
||||
? `${taskStatus.transcript_entries} dialogue entries`
|
||||
: "SurfSense AI-generated podcast"
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback
|
||||
return <PodcastErrorState title={title} error="Unexpected state" />;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate Podcast Tool UI Component
|
||||
*
|
||||
* This component is registered with assistant-ui to render custom UI
|
||||
* when the generate_podcast tool is called by the agent.
|
||||
*
|
||||
* It fetches the podcast audio with authentication (like the old system)
|
||||
* and displays it using the Audio component.
|
||||
* It polls for task completion and auto-updates when the podcast is ready.
|
||||
*/
|
||||
export const GeneratePodcastToolUI = makeAssistantToolUI<
|
||||
GeneratePodcastArgs,
|
||||
|
|
@ -226,7 +299,7 @@ export const GeneratePodcastToolUI = makeAssistantToolUI<
|
|||
render: function GeneratePodcastUI({ args, result, status }) {
|
||||
const title = args.podcast_title || "SurfSense Podcast";
|
||||
|
||||
// Loading state - podcast is being generated
|
||||
// Loading state - tool is still running (agent processing)
|
||||
if (status.type === "running" || status.type === "requires-action") {
|
||||
return <PodcastGeneratingState title={title} />;
|
||||
}
|
||||
|
|
@ -263,26 +336,27 @@ export const GeneratePodcastToolUI = makeAssistantToolUI<
|
|||
return <PodcastErrorState title={title} error={result.error || "Unknown error"} />;
|
||||
}
|
||||
|
||||
// Success - need podcast_id to fetch with auth
|
||||
if (!result.podcast_id) {
|
||||
return <PodcastErrorState title={title} error="Missing podcast ID" />;
|
||||
// Processing - poll for completion
|
||||
if (result.status === "processing" && result.task_id) {
|
||||
return <PodcastTaskPoller taskId={result.task_id} title={result.title || title} />;
|
||||
}
|
||||
|
||||
// Render the podcast player (handles auth fetch internally)
|
||||
return (
|
||||
<PodcastPlayer
|
||||
podcastId={result.podcast_id}
|
||||
title={result.title || title}
|
||||
description={
|
||||
result.transcript_entries
|
||||
? `${result.transcript_entries} dialogue entries`
|
||||
: "SurfSense AI-generated podcast"
|
||||
}
|
||||
durationMs={result.duration_ms}
|
||||
transcript={result.transcript}
|
||||
transcriptEntries={result.transcript_entries}
|
||||
/>
|
||||
);
|
||||
// Success with podcast_id (direct result, not via polling)
|
||||
if (result.status === "success" && result.podcast_id) {
|
||||
return (
|
||||
<PodcastPlayer
|
||||
podcastId={result.podcast_id}
|
||||
title={result.title || title}
|
||||
description={
|
||||
result.transcript_entries
|
||||
? `${result.transcript_entries} dialogue entries`
|
||||
: "SurfSense AI-generated podcast"
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback - missing required data
|
||||
return <PodcastErrorState title={title} error="Missing task ID or podcast ID" />;
|
||||
},
|
||||
});
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue