""" Podcast generation tool for the SurfSense agent. This module provides a factory function for creating the generate_podcast tool that submits a Celery task for background podcast generation. The tool then polls the podcast row until it reaches a terminal status (READY/FAILED) and returns that status. The wait is bounded by the chat's HTTP / process lifetime; see app.agents.shared.deliverable_wait for details. """ import logging from typing import Any from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession from app.agents.shared.deliverable_wait import wait_for_deliverable from app.db import Podcast, PodcastStatus, shielded_async_session logger = logging.getLogger(__name__) def create_generate_podcast_tool( search_space_id: int, db_session: AsyncSession, thread_id: int | None = None, ): """ Factory function to create the generate_podcast tool with injected dependencies. Pre-creates podcast record with pending status so podcast_id is available immediately for frontend polling. Args: search_space_id: The user's search space ID db_session: Reserved for future read-side use; the row is written via a fresh, tool-local session so parallel tool calls (e.g. podcast + video presentation in the same agent step) don't share an ``AsyncSession`` (which is not concurrency-safe). thread_id: The chat thread ID for associating the podcast Returns: A configured tool function for generating podcasts """ del db_session # writes use a fresh tool-local session, see below @tool async def generate_podcast( source_content: str, podcast_title: str = "SurfSense Podcast", user_prompt: str | None = None, ) -> dict[str, Any]: """ Generate a podcast from the provided content. Use this tool when the user asks to create, generate, or make a podcast. Common triggers include phrases like: - "Give me a podcast about this" - "Create a podcast from this conversation" - "Generate a podcast summary" - "Make a podcast about..." - "Turn this into a podcast" Args: source_content: The text content to convert into a podcast. podcast_title: Title for the podcast (default: "SurfSense Podcast") user_prompt: Optional instructions for podcast style, tone, or format. Returns: A dictionary containing: - status: PodcastStatus value (pending, generating, or failed) - podcast_id: The podcast ID for polling (when status is pending or generating) - title: The podcast title - message: Status message (or "error" field if status is failed) """ try: # Open a fresh session per call. The streaming task's session is # shared between every tool, and ``AsyncSession`` is NOT safe for # concurrent use: when the LLM emits parallel tool calls, two # concurrent ``add()`` / ``commit()`` paths interleave and the # second one hits "Session.add() during flush" → the transaction # is poisoned for both tools. async with shielded_async_session() as session: podcast = Podcast( title=podcast_title, status=PodcastStatus.PENDING, search_space_id=search_space_id, thread_id=thread_id, ) session.add(podcast) await session.commit() await session.refresh(podcast) podcast_id = podcast.id from app.tasks.celery_tasks.podcast_tasks import ( generate_content_podcast_task, ) task = generate_content_podcast_task.delay( podcast_id=podcast_id, source_content=source_content, search_space_id=search_space_id, user_prompt=user_prompt, ) logger.info( "[generate_podcast] Created podcast %s, task: %s", podcast_id, task.id, ) # Wait until the Celery worker flips the row to a terminal # state. No internal budget — see deliverable_wait module. terminal_status, columns, elapsed = await wait_for_deliverable( model=Podcast, row_id=podcast_id, columns=[Podcast.status, Podcast.file_location], terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED}, ) if terminal_status == PodcastStatus.READY: file_location = columns[1] if columns else None logger.info( "[generate_podcast] Podcast %s READY in %.2fs (file=%s)", podcast_id, elapsed, file_location, ) return { "status": PodcastStatus.READY.value, "podcast_id": podcast_id, "title": podcast_title, "file_location": file_location, "message": ( "Podcast generated and saved to your podcast panel." ), } # Only other terminal state is FAILED. logger.warning( "[generate_podcast] Podcast %s FAILED in %.2fs", podcast_id, elapsed, ) return { "status": PodcastStatus.FAILED.value, "podcast_id": podcast_id, "title": podcast_title, "error": ( "Background worker reported FAILED status for this podcast." ), } except Exception as e: error_message = str(e) logger.exception("[generate_podcast] Error: %s", error_message) return { "status": PodcastStatus.FAILED.value, "error": error_message, "title": podcast_title, "podcast_id": None, } return generate_podcast