SurfSense/surfsense_backend/app/agents/new_chat/tools/podcast.py
DESKTOP-RTLN3BA\$punk 94e834134f chore: linting
2026-05-28 19:21:29 -07:00

160 lines
6.1 KiB
Python

"""
Podcast generation tool for the SurfSense agent.
This module provides a factory function for creating the generate_podcast tool
that submits a Celery task for background podcast generation. The tool then
polls the podcast row until it reaches a terminal status (READY/FAILED) and
returns that status. The wait is bounded by the chat's HTTP / process
lifetime; see app.agents.shared.deliverable_wait for details.
"""
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.deliverable_wait import wait_for_deliverable
from app.db import Podcast, PodcastStatus, shielded_async_session
logger = logging.getLogger(__name__)
def create_generate_podcast_tool(
search_space_id: int,
db_session: AsyncSession,
thread_id: int | None = None,
):
"""
Factory function to create the generate_podcast tool with injected dependencies.
Pre-creates podcast record with pending status so podcast_id is available
immediately for frontend polling.
Args:
search_space_id: The user's search space ID
db_session: Reserved for future read-side use; the row is written via a
fresh, tool-local session so parallel tool calls (e.g. podcast +
video presentation in the same agent step) don't share an
``AsyncSession`` (which is not concurrency-safe).
thread_id: The chat thread ID for associating the podcast
Returns:
A configured tool function for generating podcasts
"""
del db_session # writes use a fresh tool-local session, see below
@tool
async def generate_podcast(
source_content: str,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
) -> dict[str, Any]:
"""
Generate a podcast from the provided content.
Use this tool when the user asks to create, generate, or make a podcast.
Common triggers include phrases like:
- "Give me a podcast about this"
- "Create a podcast from this conversation"
- "Generate a podcast summary"
- "Make a podcast about..."
- "Turn this into a podcast"
Args:
source_content: The text content to convert into a podcast.
podcast_title: Title for the podcast (default: "SurfSense Podcast")
user_prompt: Optional instructions for podcast style, tone, or format.
Returns:
A dictionary containing:
- status: PodcastStatus value (pending, generating, or failed)
- podcast_id: The podcast ID for polling (when status is pending or generating)
- title: The podcast title
- message: Status message (or "error" field if status is failed)
"""
try:
# Open a fresh session per call. The streaming task's session is
# shared between every tool, and ``AsyncSession`` is NOT safe for
# concurrent use: when the LLM emits parallel tool calls, two
# concurrent ``add()`` / ``commit()`` paths interleave and the
# second one hits "Session.add() during flush" → the transaction
# is poisoned for both tools.
async with shielded_async_session() as session:
podcast = Podcast(
title=podcast_title,
status=PodcastStatus.PENDING,
search_space_id=search_space_id,
thread_id=thread_id,
)
session.add(podcast)
await session.commit()
await session.refresh(podcast)
podcast_id = podcast.id
from app.tasks.celery_tasks.podcast_tasks import (
generate_content_podcast_task,
)
task = generate_content_podcast_task.delay(
podcast_id=podcast_id,
source_content=source_content,
search_space_id=search_space_id,
user_prompt=user_prompt,
)
logger.info(
"[generate_podcast] Created podcast %s, task: %s",
podcast_id,
task.id,
)
# Wait until the Celery worker flips the row to a terminal
# state. No internal budget — see deliverable_wait module.
terminal_status, columns, elapsed = await wait_for_deliverable(
model=Podcast,
row_id=podcast_id,
columns=[Podcast.status, Podcast.file_location],
terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED},
)
if terminal_status == PodcastStatus.READY:
file_location = columns[1] if columns else None
logger.info(
"[generate_podcast] Podcast %s READY in %.2fs (file=%s)",
podcast_id,
elapsed,
file_location,
)
return {
"status": PodcastStatus.READY.value,
"podcast_id": podcast_id,
"title": podcast_title,
"file_location": file_location,
"message": ("Podcast generated and saved to your podcast panel."),
}
# Only other terminal state is FAILED.
logger.warning(
"[generate_podcast] Podcast %s FAILED in %.2fs",
podcast_id,
elapsed,
)
return {
"status": PodcastStatus.FAILED.value,
"podcast_id": podcast_id,
"title": podcast_title,
"error": ("Background worker reported FAILED status for this podcast."),
}
except Exception as e:
error_message = str(e)
logger.exception("[generate_podcast] Error: %s", error_message)
return {
"status": PodcastStatus.FAILED.value,
"error": error_message,
"title": podcast_title,
"podcast_id": None,
}
return generate_podcast