mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-11 16:52:38 +02:00
feat: implement background podcast generation with Celery and task polling in UI
This commit is contained in:
parent
4c4e4b3c4c
commit
e79e1187b2
4 changed files with 335 additions and 134 deletions
|
|
@ -2,8 +2,8 @@
|
||||||
Podcast generation tool for the new chat agent.
|
Podcast generation tool for the new chat agent.
|
||||||
|
|
||||||
This module provides a factory function for creating the generate_podcast tool
|
This module provides a factory function for creating the generate_podcast tool
|
||||||
that integrates with the existing podcaster agent. Podcasts are saved to the
|
that submits a Celery task for background podcast generation. The frontend
|
||||||
database like the old system, providing authentication and persistence.
|
polls for completion and auto-updates when the podcast is ready.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
@ -11,10 +11,6 @@ from typing import Any
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
|
||||||
from app.agents.podcaster.state import State as PodcasterState
|
|
||||||
from app.db import Podcast
|
|
||||||
|
|
||||||
|
|
||||||
def create_generate_podcast_tool(
|
def create_generate_podcast_tool(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
|
|
@ -26,7 +22,7 @@ def create_generate_podcast_tool(
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
search_space_id: The user's search space ID
|
search_space_id: The user's search space ID
|
||||||
db_session: Database session
|
db_session: Database session (not used - Celery creates its own)
|
||||||
user_id: The user's ID (as string)
|
user_id: The user's ID (as string)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -50,8 +46,8 @@ def create_generate_podcast_tool(
|
||||||
- "Make a podcast about..."
|
- "Make a podcast about..."
|
||||||
- "Turn this into a podcast"
|
- "Turn this into a podcast"
|
||||||
|
|
||||||
The tool will generate a complete audio podcast with two speakers
|
The tool will start generating a podcast in the background.
|
||||||
discussing the provided content in an engaging conversational format.
|
The podcast will be available once generation completes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_content: The text content to convert into a podcast.
|
source_content: The text content to convert into a podcast.
|
||||||
|
|
@ -63,108 +59,43 @@ def create_generate_podcast_tool(
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing:
|
A dictionary containing:
|
||||||
- status: "success" or "error"
|
- status: "processing" (task submitted) or "error"
|
||||||
- podcast_id: The database ID of the saved podcast (for API access)
|
- task_id: The Celery task ID for polling status
|
||||||
- title: The podcast title
|
- title: The podcast title
|
||||||
- transcript: Full podcast transcript with all dialogue entries
|
|
||||||
- duration_ms: Estimated podcast duration in milliseconds
|
|
||||||
- transcript_entries: Number of dialogue entries
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Configure the podcaster graph
|
# Import Celery task here to avoid circular imports
|
||||||
config = {
|
from app.tasks.celery_tasks.podcast_tasks import (
|
||||||
"configurable": {
|
generate_content_podcast_task,
|
||||||
"podcast_title": podcast_title,
|
)
|
||||||
"user_id": str(user_id),
|
|
||||||
"search_space_id": search_space_id,
|
|
||||||
"user_prompt": user_prompt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize the podcaster state with the source content
|
# Submit Celery task for background processing
|
||||||
initial_state = PodcasterState(
|
task = generate_content_podcast_task.delay(
|
||||||
source_content=source_content,
|
source_content=source_content,
|
||||||
db_session=db_session,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run the podcaster graph
|
|
||||||
result = await podcaster_graph.ainvoke(initial_state, config=config)
|
|
||||||
|
|
||||||
# Extract results
|
|
||||||
podcast_transcript = result.get("podcast_transcript", [])
|
|
||||||
file_path = result.get("final_podcast_file_path", "")
|
|
||||||
|
|
||||||
# Calculate estimated duration (rough estimate: ~150 words per minute)
|
|
||||||
total_words = sum(
|
|
||||||
len(entry.dialog.split()) if hasattr(entry, "dialog") else len(entry.get("dialog", "").split())
|
|
||||||
for entry in podcast_transcript
|
|
||||||
)
|
|
||||||
estimated_duration_ms = int((total_words / 150) * 60 * 1000)
|
|
||||||
|
|
||||||
# Create full transcript for display (all entries, complete dialog)
|
|
||||||
full_transcript = []
|
|
||||||
for entry in podcast_transcript:
|
|
||||||
if hasattr(entry, "speaker_id"):
|
|
||||||
speaker = f"Speaker {entry.speaker_id + 1}"
|
|
||||||
dialog = entry.dialog
|
|
||||||
else:
|
|
||||||
speaker = f"Speaker {entry.get('speaker_id', 0) + 1}"
|
|
||||||
dialog = entry.get("dialog", "")
|
|
||||||
full_transcript.append(f"{speaker}: {dialog}")
|
|
||||||
|
|
||||||
# Convert podcast transcript entries to serializable format (like old system)
|
|
||||||
serializable_transcript = []
|
|
||||||
for entry in podcast_transcript:
|
|
||||||
if hasattr(entry, "speaker_id"):
|
|
||||||
serializable_transcript.append({
|
|
||||||
"speaker_id": entry.speaker_id,
|
|
||||||
"dialog": entry.dialog
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
serializable_transcript.append({
|
|
||||||
"speaker_id": entry.get("speaker_id", 0),
|
|
||||||
"dialog": entry.get("dialog", "")
|
|
||||||
})
|
|
||||||
|
|
||||||
# Save podcast to database (like old system)
|
|
||||||
# This provides authentication and persistence
|
|
||||||
podcast = Podcast(
|
|
||||||
title=podcast_title,
|
|
||||||
podcast_transcript=serializable_transcript,
|
|
||||||
file_location=file_path,
|
|
||||||
search_space_id=search_space_id,
|
search_space_id=search_space_id,
|
||||||
# chat_id is None since new-chat uses LangGraph threads, not DB chats
|
user_id=str(user_id),
|
||||||
chat_id=None,
|
podcast_title=podcast_title,
|
||||||
chat_state_version=None,
|
user_prompt=user_prompt,
|
||||||
)
|
)
|
||||||
db_session.add(podcast)
|
|
||||||
await db_session.commit()
|
|
||||||
await db_session.refresh(podcast)
|
|
||||||
|
|
||||||
# Return podcast_id - frontend will use it to call the API endpoint
|
print(f"[generate_podcast] Submitted Celery task: {task.id}")
|
||||||
# GET /api/v1/podcasts/{podcast_id}/stream (like the old system)
|
|
||||||
|
# Return immediately with task_id for polling
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "processing",
|
||||||
"podcast_id": podcast.id,
|
"task_id": task.id,
|
||||||
"title": podcast_title,
|
"title": podcast_title,
|
||||||
"transcript": "\n\n".join(full_transcript),
|
"message": "Podcast generation started. This may take a few minutes.",
|
||||||
"duration_ms": estimated_duration_ms,
|
|
||||||
"transcript_entries": len(podcast_transcript),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
print(f"[generate_podcast] Error: {error_message}")
|
print(f"[generate_podcast] Error submitting task: {error_message}")
|
||||||
# Rollback on error
|
|
||||||
await db_session.rollback()
|
|
||||||
return {
|
return {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": error_message,
|
"error": error_message,
|
||||||
"title": podcast_title,
|
"title": podcast_title,
|
||||||
"podcast_id": None,
|
"task_id": None,
|
||||||
"duration_ms": 0,
|
|
||||||
"transcript_entries": 0,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return generate_podcast
|
return generate_podcast
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -444,3 +444,66 @@ async def get_podcast_by_chat_id(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"Error fetching podcast: {e!s}"
|
status_code=500, detail=f"Error fetching podcast: {e!s}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/podcasts/task/{task_id}/status")
|
||||||
|
async def get_podcast_task_status(
|
||||||
|
task_id: str,
|
||||||
|
user: User = Depends(current_active_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the status of a podcast generation task.
|
||||||
|
Used by new-chat frontend to poll for completion.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- status: "processing" | "success" | "error"
|
||||||
|
- podcast_id: (only if status == "success")
|
||||||
|
- title: (only if status == "success")
|
||||||
|
- error: (only if status == "error")
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from celery.result import AsyncResult
|
||||||
|
|
||||||
|
from app.celery_app import celery_app
|
||||||
|
|
||||||
|
result = AsyncResult(task_id, app=celery_app)
|
||||||
|
|
||||||
|
if result.ready():
|
||||||
|
# Task completed
|
||||||
|
if result.successful():
|
||||||
|
task_result = result.result
|
||||||
|
if isinstance(task_result, dict):
|
||||||
|
if task_result.get("status") == "success":
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"podcast_id": task_result.get("podcast_id"),
|
||||||
|
"title": task_result.get("title"),
|
||||||
|
"transcript_entries": task_result.get("transcript_entries"),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": task_result.get("error", "Unknown error"),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": "Unexpected task result format",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Task failed
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"error": str(result.result) if result.result else "Task failed",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Task still processing
|
||||||
|
return {
|
||||||
|
"status": "processing",
|
||||||
|
"state": result.state,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Error checking task status: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,11 @@ from app.celery_app import celery_app
|
||||||
from app.config import config
|
from app.config import config
|
||||||
from app.tasks.podcast_tasks import generate_chat_podcast
|
from app.tasks.podcast_tasks import generate_chat_podcast
|
||||||
|
|
||||||
|
# Import for content-based podcast (new-chat)
|
||||||
|
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||||
|
from app.agents.podcaster.state import State as PodcasterState
|
||||||
|
from app.db import Podcast
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
|
|
@ -86,3 +91,131 @@ async def _generate_chat_podcast(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating podcast from chat: {e!s}")
|
logger.error(f"Error generating podcast from chat: {e!s}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Content-based podcast generation (for new-chat)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(name="generate_content_podcast", bind=True)
|
||||||
|
def generate_content_podcast_task(
|
||||||
|
self,
|
||||||
|
source_content: str,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
podcast_title: str = "SurfSense Podcast",
|
||||||
|
user_prompt: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Celery task to generate podcast from source content (for new-chat).
|
||||||
|
|
||||||
|
Unlike generate_chat_podcast which requires a chat_id, this task
|
||||||
|
generates a podcast directly from provided content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_content: The text content to convert into a podcast
|
||||||
|
search_space_id: ID of the search space
|
||||||
|
user_id: ID of the user (as string)
|
||||||
|
podcast_title: Title for the podcast
|
||||||
|
user_prompt: Optional instructions for podcast style/tone
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with podcast_id on success, or error info on failure
|
||||||
|
"""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = loop.run_until_complete(
|
||||||
|
_generate_content_podcast(
|
||||||
|
source_content,
|
||||||
|
search_space_id,
|
||||||
|
user_id,
|
||||||
|
podcast_title,
|
||||||
|
user_prompt,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating content podcast: {e!s}")
|
||||||
|
return {"status": "error", "error": str(e)}
|
||||||
|
finally:
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def _generate_content_podcast(
|
||||||
|
source_content: str,
|
||||||
|
search_space_id: int,
|
||||||
|
user_id: str,
|
||||||
|
podcast_title: str = "SurfSense Podcast",
|
||||||
|
user_prompt: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Generate content-based podcast with new session."""
|
||||||
|
async with get_celery_session_maker()() as session:
|
||||||
|
try:
|
||||||
|
# Configure the podcaster graph
|
||||||
|
graph_config = {
|
||||||
|
"configurable": {
|
||||||
|
"podcast_title": podcast_title,
|
||||||
|
"user_id": str(user_id),
|
||||||
|
"search_space_id": search_space_id,
|
||||||
|
"user_prompt": user_prompt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize the podcaster state with the source content
|
||||||
|
initial_state = PodcasterState(
|
||||||
|
source_content=source_content,
|
||||||
|
db_session=session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the podcaster graph
|
||||||
|
result = await podcaster_graph.ainvoke(initial_state, config=graph_config)
|
||||||
|
|
||||||
|
# Extract results
|
||||||
|
podcast_transcript = result.get("podcast_transcript", [])
|
||||||
|
file_path = result.get("final_podcast_file_path", "")
|
||||||
|
|
||||||
|
# Convert transcript to serializable format
|
||||||
|
serializable_transcript = []
|
||||||
|
for entry in podcast_transcript:
|
||||||
|
if hasattr(entry, "speaker_id"):
|
||||||
|
serializable_transcript.append({
|
||||||
|
"speaker_id": entry.speaker_id,
|
||||||
|
"dialog": entry.dialog
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
serializable_transcript.append({
|
||||||
|
"speaker_id": entry.get("speaker_id", 0),
|
||||||
|
"dialog": entry.get("dialog", "")
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save podcast to database
|
||||||
|
podcast = Podcast(
|
||||||
|
title=podcast_title,
|
||||||
|
podcast_transcript=serializable_transcript,
|
||||||
|
file_location=file_path,
|
||||||
|
search_space_id=search_space_id,
|
||||||
|
chat_id=None, # No chat_id for new-chat podcasts
|
||||||
|
chat_state_version=None,
|
||||||
|
)
|
||||||
|
session.add(podcast)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(podcast)
|
||||||
|
|
||||||
|
logger.info(f"Successfully generated content podcast: {podcast.id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"podcast_id": podcast.id,
|
||||||
|
"title": podcast_title,
|
||||||
|
"transcript_entries": len(serializable_transcript),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in _generate_content_podcast: {e!s}")
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import { makeAssistantToolUI } from "@assistant-ui/react";
|
||||||
import { AlertCircleIcon, Loader2Icon, MicIcon } from "lucide-react";
|
import { AlertCircleIcon, Loader2Icon, MicIcon } from "lucide-react";
|
||||||
import { useCallback, useEffect, useRef, useState } from "react";
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import { Audio } from "@/components/tool-ui/audio";
|
import { Audio } from "@/components/tool-ui/audio";
|
||||||
|
import { baseApiService } from "@/lib/apis/base-api.service";
|
||||||
import { podcastsApiService } from "@/lib/apis/podcasts-api.service";
|
import { podcastsApiService } from "@/lib/apis/podcasts-api.service";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -16,12 +17,21 @@ interface GeneratePodcastArgs {
|
||||||
}
|
}
|
||||||
|
|
||||||
interface GeneratePodcastResult {
|
interface GeneratePodcastResult {
|
||||||
status: "success" | "error";
|
status: "processing" | "success" | "error";
|
||||||
|
task_id?: string;
|
||||||
podcast_id?: number;
|
podcast_id?: number;
|
||||||
title?: string;
|
title?: string;
|
||||||
transcript?: string;
|
|
||||||
duration_ms?: number;
|
|
||||||
transcript_entries?: number;
|
transcript_entries?: number;
|
||||||
|
message?: string;
|
||||||
|
error?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface TaskStatusResponse {
|
||||||
|
status: "processing" | "success" | "error";
|
||||||
|
podcast_id?: number;
|
||||||
|
title?: string;
|
||||||
|
transcript_entries?: number;
|
||||||
|
state?: string;
|
||||||
error?: string;
|
error?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -106,15 +116,11 @@ function PodcastPlayer({
|
||||||
title,
|
title,
|
||||||
description,
|
description,
|
||||||
durationMs,
|
durationMs,
|
||||||
transcript,
|
|
||||||
transcriptEntries,
|
|
||||||
}: {
|
}: {
|
||||||
podcastId: number;
|
podcastId: number;
|
||||||
title: string;
|
title: string;
|
||||||
description: string;
|
description: string;
|
||||||
durationMs?: number;
|
durationMs?: number;
|
||||||
transcript?: string;
|
|
||||||
transcriptEntries?: number;
|
|
||||||
}) {
|
}) {
|
||||||
const [audioSrc, setAudioSrc] = useState<string | null>(null);
|
const [audioSrc, setAudioSrc] = useState<string | null>(null);
|
||||||
const [isLoading, setIsLoading] = useState(true);
|
const [isLoading, setIsLoading] = useState(true);
|
||||||
|
|
@ -194,29 +200,96 @@ function PodcastPlayer({
|
||||||
durationMs={durationMs}
|
durationMs={durationMs}
|
||||||
className="w-full"
|
className="w-full"
|
||||||
/>
|
/>
|
||||||
{/* Full transcript */}
|
|
||||||
{transcript && (
|
|
||||||
<details className="mt-3 rounded-lg border bg-muted/30 p-3">
|
|
||||||
<summary className="cursor-pointer font-medium text-muted-foreground text-sm hover:text-foreground">
|
|
||||||
View full transcript{transcriptEntries ? ` (${transcriptEntries} entries)` : ""}
|
|
||||||
</summary>
|
|
||||||
<pre className="mt-2 max-h-96 overflow-y-auto whitespace-pre-wrap text-muted-foreground text-xs">
|
|
||||||
{transcript}
|
|
||||||
</pre>
|
|
||||||
</details>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Polling component that checks task status and shows player when complete
|
||||||
|
*/
|
||||||
|
function PodcastTaskPoller({
|
||||||
|
taskId,
|
||||||
|
title,
|
||||||
|
}: {
|
||||||
|
taskId: string;
|
||||||
|
title: string;
|
||||||
|
}) {
|
||||||
|
const [taskStatus, setTaskStatus] = useState<TaskStatusResponse>({ status: "processing" });
|
||||||
|
const [pollCount, setPollCount] = useState(0);
|
||||||
|
const pollingRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
|
|
||||||
|
// Poll for task status
|
||||||
|
useEffect(() => {
|
||||||
|
const pollStatus = async () => {
|
||||||
|
try {
|
||||||
|
const response = await baseApiService.get<TaskStatusResponse>(
|
||||||
|
`/api/v1/podcasts/task/${taskId}/status`
|
||||||
|
);
|
||||||
|
setTaskStatus(response);
|
||||||
|
|
||||||
|
// Stop polling if task is complete or errored
|
||||||
|
if (response.status !== "processing") {
|
||||||
|
if (pollingRef.current) {
|
||||||
|
clearInterval(pollingRef.current);
|
||||||
|
pollingRef.current = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Error polling task status:", err);
|
||||||
|
// Don't stop polling on network errors, just increment count
|
||||||
|
}
|
||||||
|
setPollCount((prev) => prev + 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Initial poll
|
||||||
|
pollStatus();
|
||||||
|
|
||||||
|
// Poll every 5 seconds
|
||||||
|
pollingRef.current = setInterval(pollStatus, 5000);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
if (pollingRef.current) {
|
||||||
|
clearInterval(pollingRef.current);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}, [taskId]);
|
||||||
|
|
||||||
|
// Show loading state while processing
|
||||||
|
if (taskStatus.status === "processing") {
|
||||||
|
return <PodcastGeneratingState title={title} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show error state
|
||||||
|
if (taskStatus.status === "error") {
|
||||||
|
return <PodcastErrorState title={title} error={taskStatus.error || "Generation failed"} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show player when complete
|
||||||
|
if (taskStatus.status === "success" && taskStatus.podcast_id) {
|
||||||
|
return (
|
||||||
|
<PodcastPlayer
|
||||||
|
podcastId={taskStatus.podcast_id}
|
||||||
|
title={taskStatus.title || title}
|
||||||
|
description={
|
||||||
|
taskStatus.transcript_entries
|
||||||
|
? `${taskStatus.transcript_entries} dialogue entries`
|
||||||
|
: "SurfSense AI-generated podcast"
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback
|
||||||
|
return <PodcastErrorState title={title} error="Unexpected state" />;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate Podcast Tool UI Component
|
* Generate Podcast Tool UI Component
|
||||||
*
|
*
|
||||||
* This component is registered with assistant-ui to render custom UI
|
* This component is registered with assistant-ui to render custom UI
|
||||||
* when the generate_podcast tool is called by the agent.
|
* when the generate_podcast tool is called by the agent.
|
||||||
*
|
*
|
||||||
* It fetches the podcast audio with authentication (like the old system)
|
* It polls for task completion and auto-updates when the podcast is ready.
|
||||||
* and displays it using the Audio component.
|
|
||||||
*/
|
*/
|
||||||
export const GeneratePodcastToolUI = makeAssistantToolUI<
|
export const GeneratePodcastToolUI = makeAssistantToolUI<
|
||||||
GeneratePodcastArgs,
|
GeneratePodcastArgs,
|
||||||
|
|
@ -226,7 +299,7 @@ export const GeneratePodcastToolUI = makeAssistantToolUI<
|
||||||
render: function GeneratePodcastUI({ args, result, status }) {
|
render: function GeneratePodcastUI({ args, result, status }) {
|
||||||
const title = args.podcast_title || "SurfSense Podcast";
|
const title = args.podcast_title || "SurfSense Podcast";
|
||||||
|
|
||||||
// Loading state - podcast is being generated
|
// Loading state - tool is still running (agent processing)
|
||||||
if (status.type === "running" || status.type === "requires-action") {
|
if (status.type === "running" || status.type === "requires-action") {
|
||||||
return <PodcastGeneratingState title={title} />;
|
return <PodcastGeneratingState title={title} />;
|
||||||
}
|
}
|
||||||
|
|
@ -263,12 +336,13 @@ export const GeneratePodcastToolUI = makeAssistantToolUI<
|
||||||
return <PodcastErrorState title={title} error={result.error || "Unknown error"} />;
|
return <PodcastErrorState title={title} error={result.error || "Unknown error"} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Success - need podcast_id to fetch with auth
|
// Processing - poll for completion
|
||||||
if (!result.podcast_id) {
|
if (result.status === "processing" && result.task_id) {
|
||||||
return <PodcastErrorState title={title} error="Missing podcast ID" />;
|
return <PodcastTaskPoller taskId={result.task_id} title={result.title || title} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Render the podcast player (handles auth fetch internally)
|
// Success with podcast_id (direct result, not via polling)
|
||||||
|
if (result.status === "success" && result.podcast_id) {
|
||||||
return (
|
return (
|
||||||
<PodcastPlayer
|
<PodcastPlayer
|
||||||
podcastId={result.podcast_id}
|
podcastId={result.podcast_id}
|
||||||
|
|
@ -278,11 +352,11 @@ export const GeneratePodcastToolUI = makeAssistantToolUI<
|
||||||
? `${result.transcript_entries} dialogue entries`
|
? `${result.transcript_entries} dialogue entries`
|
||||||
: "SurfSense AI-generated podcast"
|
: "SurfSense AI-generated podcast"
|
||||||
}
|
}
|
||||||
durationMs={result.duration_ms}
|
|
||||||
transcript={result.transcript}
|
|
||||||
transcriptEntries={result.transcript_entries}
|
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback - missing required data
|
||||||
|
return <PodcastErrorState title={title} error="Missing task ID or podcast ID" />;
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue