diff --git a/surfsense_backend/alembic/versions/82_add_thread_id_to_podcasts.py b/surfsense_backend/alembic/versions/82_add_podcast_status_and_thread.py similarity index 51% rename from surfsense_backend/alembic/versions/82_add_thread_id_to_podcasts.py rename to surfsense_backend/alembic/versions/82_add_podcast_status_and_thread.py index f08fe32d8..fd4eed89f 100644 --- a/surfsense_backend/alembic/versions/82_add_thread_id_to_podcasts.py +++ b/surfsense_backend/alembic/versions/82_add_podcast_status_and_thread.py @@ -1,9 +1,10 @@ -"""Add thread_id to podcasts +"""Add status and thread_id to podcasts Revision ID: 82 Revises: 81 -Create Date: 2026-01-23 +Create Date: 2026-01-27 +Adds status enum and thread_id FK to podcasts. """ from collections.abc import Sequence @@ -17,7 +18,19 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: - """Add thread_id column to podcasts.""" + op.execute( + """ + CREATE TYPE podcast_status AS ENUM ('pending', 'generating', 'ready', 'failed'); + """ + ) + + op.execute( + """ + ALTER TABLE podcasts + ADD COLUMN IF NOT EXISTS status podcast_status NOT NULL DEFAULT 'ready'; + """ + ) + op.execute( """ ALTER TABLE podcasts @@ -33,8 +46,17 @@ def upgrade() -> None: """ ) + op.execute( + """ + CREATE INDEX IF NOT EXISTS ix_podcasts_status + ON podcasts(status); + """ + ) + def downgrade() -> None: - """Remove thread_id column from podcasts.""" + op.execute("DROP INDEX IF EXISTS ix_podcasts_status") op.execute("DROP INDEX IF EXISTS ix_podcasts_thread_id") op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS thread_id") + op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS status") + op.execute("DROP TYPE IF EXISTS podcast_status") diff --git a/surfsense_backend/app/agents/new_chat/tools/podcast.py b/surfsense_backend/app/agents/new_chat/tools/podcast.py index d4e023f6f..424b04f77 100644 --- a/surfsense_backend/app/agents/new_chat/tools/podcast.py +++ b/surfsense_backend/app/agents/new_chat/tools/podcast.py @@ -18,6 +18,8 @@ import redis from langchain_core.tools import tool from sqlalchemy.ext.asyncio import AsyncSession +from app.db import Podcast, PodcastStatus + # Redis connection for tracking active podcast tasks # Uses the same Redis instance as Celery REDIS_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") @@ -32,38 +34,27 @@ def get_redis_client() -> redis.Redis: return _redis_client -def get_active_podcast_key(search_space_id: int) -> str: - """Generate Redis key for tracking active podcast task.""" - return f"podcast:active:{search_space_id}" +def _redis_key(search_space_id: int) -> str: + return f"podcast:generating:{search_space_id}" -def get_active_podcast_task(search_space_id: int) -> str | None: - """Check if there's an active podcast task for this search space.""" +def get_generating_podcast_id(search_space_id: int) -> int | None: + """Get the podcast ID currently being generated for this search space.""" try: client = get_redis_client() - return client.get(get_active_podcast_key(search_space_id)) + value = client.get(_redis_key(search_space_id)) + return int(value) if value else None except Exception: - # If Redis is unavailable, allow the request (fail open) return None -def set_active_podcast_task(search_space_id: int, task_id: str) -> None: - """Mark a podcast task as active for this search space.""" +def set_generating_podcast(search_space_id: int, podcast_id: int) -> None: + """Mark a podcast as currently generating for this search space.""" try: client = get_redis_client() - # Set with 30-minute expiry as safety net (podcast should complete before this) - client.setex(get_active_podcast_key(search_space_id), 1800, task_id) + client.setex(_redis_key(search_space_id), 1800, str(podcast_id)) except Exception as e: - print(f"[generate_podcast] Warning: Could not set active task in Redis: {e}") - - -def clear_active_podcast_task(search_space_id: int) -> None: - """Clear the active podcast task for this search space.""" - try: - client = get_redis_client() - client.delete(get_active_podcast_key(search_space_id)) - except Exception as e: - print(f"[generate_podcast] Warning: Could not clear active task in Redis: {e}") + print(f"[generate_podcast] Warning: Could not set generating podcast in Redis: {e}") def create_generate_podcast_tool( @@ -74,9 +65,12 @@ def create_generate_podcast_tool( """ Factory function to create the generate_podcast tool with injected dependencies. + Pre-creates podcast record with pending status so podcast_id is available + immediately for frontend polling. + Args: search_space_id: The user's search space ID - db_session: Database session (not used - Celery creates its own) + db_session: Database session for creating the podcast record thread_id: The chat thread ID for associating the podcast Returns: @@ -100,77 +94,71 @@ def create_generate_podcast_tool( - "Make a podcast about..." - "Turn this into a podcast" - The tool will start generating a podcast in the background. - The podcast will be available once generation completes. - - IMPORTANT: Only one podcast can be generated at a time. If a podcast - is already being generated, this tool will return a message asking - the user to wait. - Args: source_content: The text content to convert into a podcast. - This can be a summary, research findings, or any text - the user wants transformed into an audio podcast. podcast_title: Title for the podcast (default: "SurfSense Podcast") user_prompt: Optional instructions for podcast style, tone, or format. - For example: "Make it casual and fun" or "Focus on the key insights" Returns: A dictionary containing: - - status: "processing" (task submitted), "already_generating", or "error" - - task_id: The Celery task ID for polling status (if processing) + - status: PodcastStatus value (pending, generating, or failed) + - podcast_id: The podcast ID for polling (when status is pending or generating) - title: The podcast title - - message: Status message for the user + - message: Status message (or "error" field if status is failed) """ try: - # Check if a podcast is already being generated for this search space - active_task_id = get_active_podcast_task(search_space_id) - if active_task_id: + generating_podcast_id = get_generating_podcast_id(search_space_id) + if generating_podcast_id: print( - f"[generate_podcast] Blocked duplicate request. Active task: {active_task_id}" + f"[generate_podcast] Blocked duplicate request. Generating podcast: {generating_podcast_id}" ) return { - "status": "already_generating", - "task_id": active_task_id, + "status": PodcastStatus.GENERATING.value, + "podcast_id": generating_podcast_id, "title": podcast_title, - "message": "A podcast is already being generated. Please wait for it to complete before requesting another one.", + "message": "A podcast is already being generated. Please wait for it to complete.", } - # Import Celery task here to avoid circular imports + podcast = Podcast( + title=podcast_title, + status=PodcastStatus.PENDING, + search_space_id=search_space_id, + thread_id=thread_id, + ) + db_session.add(podcast) + await db_session.commit() + await db_session.refresh(podcast) + from app.tasks.celery_tasks.podcast_tasks import ( generate_content_podcast_task, ) - # Submit Celery task for background processing task = generate_content_podcast_task.delay( + podcast_id=podcast.id, source_content=source_content, search_space_id=search_space_id, - podcast_title=podcast_title, user_prompt=user_prompt, - thread_id=thread_id, ) - # Mark this task as active - set_active_podcast_task(search_space_id, task.id) + set_generating_podcast(search_space_id, podcast.id) - print(f"[generate_podcast] Submitted Celery task: {task.id}") + print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}") - # Return immediately with task_id for polling return { - "status": "processing", - "task_id": task.id, + "status": PodcastStatus.PENDING.value, + "podcast_id": podcast.id, "title": podcast_title, "message": "Podcast generation started. This may take a few minutes.", } except Exception as e: error_message = str(e) - print(f"[generate_podcast] Error submitting task: {error_message}") + print(f"[generate_podcast] Error: {error_message}") return { - "status": "error", + "status": PodcastStatus.FAILED.value, "error": error_message, "title": podcast_title, - "task_id": None, + "podcast_id": None, } return generate_podcast diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 0182d2c53..41962b769 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -93,6 +93,13 @@ class SearchSourceConnectorType(str, Enum): COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR" +class PodcastStatus(str, Enum): + PENDING = "pending" + GENERATING = "generating" + READY = "ready" + FAILED = "failed" + + class LiteLLMProvider(str, Enum): """ Enum for LLM providers supported by LiteLLM. @@ -743,8 +750,15 @@ class Podcast(BaseModel, TimestampMixin): __tablename__ = "podcasts" title = Column(String(500), nullable=False) - podcast_transcript = Column(JSONB, nullable=True) # List of transcript entries - file_location = Column(Text, nullable=True) # Path to the audio file + podcast_transcript = Column(JSONB, nullable=True) + file_location = Column(Text, nullable=True) + status = Column( + SQLAlchemyEnum(PodcastStatus, name="podcast_status", create_type=False), + nullable=False, + default=PodcastStatus.READY, + server_default="ready", + index=True, + ) search_space_id = Column( Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False diff --git a/surfsense_backend/app/routes/podcasts_routes.py b/surfsense_backend/app/routes/podcasts_routes.py index 27970b707..041dd80ee 100644 --- a/surfsense_backend/app/routes/podcasts_routes.py +++ b/surfsense_backend/app/routes/podcasts_routes.py @@ -1,21 +1,19 @@ """ -Podcast routes for task status polling and audio retrieval. +Podcast routes for CRUD operations and audio streaming. These routes support the podcast generation feature in new-chat. -Note: The old Chat-based podcast generation has been removed. +Frontend polls GET /podcasts/{podcast_id} to check status field. """ import os from pathlib import Path -from celery.result import AsyncResult from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from app.celery_app import celery_app from app.db import ( Permission, Podcast, @@ -228,62 +226,3 @@ async def stream_podcast( raise HTTPException( status_code=500, detail=f"Error streaming podcast: {e!s}" ) from e - - -@router.get("/podcasts/task/{task_id}/status") -async def get_podcast_task_status( - task_id: str, - user: User = Depends(current_active_user), -): - """ - Get the status of a podcast generation task. - Used by new-chat frontend to poll for completion. - - Returns: - - status: "processing" | "success" | "error" - - podcast_id: (only if status == "success") - - title: (only if status == "success") - - error: (only if status == "error") - """ - try: - result = AsyncResult(task_id, app=celery_app) - - if result.ready(): - # Task completed - if result.successful(): - task_result = result.result - if isinstance(task_result, dict): - if task_result.get("status") == "success": - return { - "status": "success", - "podcast_id": task_result.get("podcast_id"), - "title": task_result.get("title"), - "transcript_entries": task_result.get("transcript_entries"), - } - else: - return { - "status": "error", - "error": task_result.get("error", "Unknown error"), - } - else: - return { - "status": "error", - "error": "Unexpected task result format", - } - else: - # Task failed - return { - "status": "error", - "error": str(result.result) if result.result else "Task failed", - } - else: - # Task still processing - return { - "status": "processing", - "state": result.state, - } - - except Exception as e: - raise HTTPException( - status_code=500, detail=f"Error checking task status: {e!s}" - ) from e diff --git a/surfsense_backend/app/schemas/podcasts.py b/surfsense_backend/app/schemas/podcasts.py index 72c915d88..ad77c27f8 100644 --- a/surfsense_backend/app/schemas/podcasts.py +++ b/surfsense_backend/app/schemas/podcasts.py @@ -1,11 +1,19 @@ """Podcast schemas for API responses.""" from datetime import datetime +from enum import Enum from typing import Any from pydantic import BaseModel +class PodcastStatusEnum(str, Enum): + PENDING = "pending" + GENERATING = "generating" + READY = "ready" + FAILED = "failed" + + class PodcastBase(BaseModel): """Base podcast schema.""" @@ -33,6 +41,7 @@ class PodcastRead(PodcastBase): """Schema for reading a podcast.""" id: int + status: PodcastStatusEnum = PodcastStatusEnum.READY created_at: datetime class Config: diff --git a/surfsense_backend/app/services/public_chat_service.py b/surfsense_backend/app/services/public_chat_service.py index 7c3b89af9..1dcc97a11 100644 --- a/surfsense_backend/app/services/public_chat_service.py +++ b/surfsense_backend/app/services/public_chat_service.py @@ -40,7 +40,10 @@ def strip_citations(text: str) -> str: def sanitize_content_for_public(content: list | str | None) -> list: - """Filter message content for public view.""" + """ + Filter message content for public view. + Strips citations and filters to UI-relevant tools. + """ if content is None: return [] @@ -67,13 +70,6 @@ def sanitize_content_for_public(content: list | str | None) -> list: tool_name = part.get("toolName") if tool_name not in UI_TOOLS: continue - - # Skip podcasts that are still processing (would cause auth errors) - if tool_name == "generate_podcast": - result = part.get("result", {}) - if result.get("status") in ("processing", "already_generating"): - continue - sanitized.append(part) return sanitized @@ -355,16 +351,16 @@ async def _clone_podcast( target_search_space_id: int, target_thread_id: int, ) -> int | None: - """Clone a podcast record and its audio file.""" + """Clone a podcast record and its audio file. Only clones ready podcasts.""" import shutil import uuid from pathlib import Path - from app.db import Podcast + from app.db import Podcast, PodcastStatus result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) original = result.scalars().first() - if not original: + if not original or original.status != PodcastStatus.READY: return None new_file_path = None @@ -381,6 +377,7 @@ async def _clone_podcast( title=original.title, podcast_transcript=original.podcast_transcript, file_location=new_file_path, + status=PodcastStatus.READY, search_space_id=target_search_space_id, thread_id=target_thread_id, ) diff --git a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py index 862234b46..0ce714cdc 100644 --- a/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py @@ -4,15 +4,15 @@ import asyncio import logging import sys +from sqlalchemy import select from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine from sqlalchemy.pool import NullPool -# Import for content-based podcast (new-chat) from app.agents.podcaster.graph import graph as podcaster_graph from app.agents.podcaster.state import State as PodcasterState from app.celery_app import celery_app from app.config import config -from app.db import Podcast +from app.db import Podcast, PodcastStatus logger = logging.getLogger(__name__) @@ -44,8 +44,8 @@ def get_celery_session_maker(): # ============================================================================= -def _clear_active_podcast_redis_key(search_space_id: int) -> None: - """Clear the active podcast task key from Redis when task completes.""" +def _clear_generating_podcast(search_space_id: int) -> None: + """Clear the generating podcast marker from Redis when task completes.""" import os import redis @@ -53,36 +53,24 @@ def _clear_active_podcast_redis_key(search_space_id: int) -> None: try: redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0") client = redis.from_url(redis_url, decode_responses=True) - key = f"podcast:active:{search_space_id}" + key = f"podcast:generating:{search_space_id}" client.delete(key) - logger.info(f"Cleared active podcast key for search_space_id={search_space_id}") + logger.info(f"Cleared generating podcast key for search_space_id={search_space_id}") except Exception as e: - logger.warning(f"Could not clear active podcast key: {e}") + logger.warning(f"Could not clear generating podcast key: {e}") @celery_app.task(name="generate_content_podcast", bind=True) def generate_content_podcast_task( self, + podcast_id: int, source_content: str, search_space_id: int, - podcast_title: str = "SurfSense Podcast", user_prompt: str | None = None, - thread_id: int | None = None, ) -> dict: """ - Celery task to generate podcast from source content (for new-chat). - - This task generates a podcast directly from provided content. - - Args: - source_content: The text content to convert into a podcast - search_space_id: ID of the search space - podcast_title: Title for the podcast - user_prompt: Optional instructions for podcast style/tone - thread_id: Optional ID of the chat thread that generated this podcast - - Returns: - dict with podcast_id on success, or error info on failure + Celery task to generate podcast from source content. + Updates existing podcast record created by the tool. """ loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -90,58 +78,79 @@ def generate_content_podcast_task( try: result = loop.run_until_complete( _generate_content_podcast( + podcast_id, source_content, search_space_id, - podcast_title, user_prompt, - thread_id, ) ) loop.run_until_complete(loop.shutdown_asyncgens()) return result except Exception as e: logger.error(f"Error generating content podcast: {e!s}") - return {"status": "error", "error": str(e)} + loop.run_until_complete(_mark_podcast_failed(podcast_id)) + return {"status": "failed", "podcast_id": podcast_id} finally: - # Always clear the active podcast key when task completes (success or failure) - _clear_active_podcast_redis_key(search_space_id) + _clear_generating_podcast(search_space_id) asyncio.set_event_loop(None) loop.close() -async def _generate_content_podcast( - source_content: str, - search_space_id: int, - podcast_title: str = "SurfSense Podcast", - user_prompt: str | None = None, - thread_id: int | None = None, -) -> dict: - """Generate content-based podcast with new session.""" +async def _mark_podcast_failed(podcast_id: int) -> None: + """Mark a podcast as failed in the database.""" async with get_celery_session_maker()() as session: try: - # Configure the podcaster graph + result = await session.execute( + select(Podcast).filter(Podcast.id == podcast_id) + ) + podcast = result.scalars().first() + if podcast: + podcast.status = PodcastStatus.FAILED + await session.commit() + except Exception as e: + logger.error(f"Failed to mark podcast as failed: {e}") + + +async def _generate_content_podcast( + podcast_id: int, + source_content: str, + search_space_id: int, + user_prompt: str | None = None, +) -> dict: + """Generate content-based podcast and update existing record.""" + async with get_celery_session_maker()() as session: + result = await session.execute( + select(Podcast).filter(Podcast.id == podcast_id) + ) + podcast = result.scalars().first() + + if not podcast: + raise ValueError(f"Podcast {podcast_id} not found") + + try: + podcast.status = PodcastStatus.GENERATING + await session.commit() + graph_config = { "configurable": { - "podcast_title": podcast_title, + "podcast_title": podcast.title, "search_space_id": search_space_id, "user_prompt": user_prompt, } } - # Initialize the podcaster state with the source content initial_state = PodcasterState( source_content=source_content, db_session=session, ) - # Run the podcaster graph - result = await podcaster_graph.ainvoke(initial_state, config=graph_config) + graph_result = await podcaster_graph.ainvoke( + initial_state, config=graph_config + ) - # Extract results - podcast_transcript = result.get("podcast_transcript", []) - file_path = result.get("final_podcast_file_path", "") + podcast_transcript = graph_result.get("podcast_transcript", []) + file_path = graph_result.get("final_podcast_file_path", "") - # Convert transcript to serializable format serializable_transcript = [] for entry in podcast_transcript: if hasattr(entry, "speaker_id"): @@ -156,28 +165,22 @@ async def _generate_content_podcast( } ) - # Save podcast to database - podcast = Podcast( - title=podcast_title, - podcast_transcript=serializable_transcript, - file_location=file_path, - search_space_id=search_space_id, - thread_id=thread_id, - ) - session.add(podcast) + podcast.podcast_transcript = serializable_transcript + podcast.file_location = file_path + podcast.status = PodcastStatus.READY await session.commit() - await session.refresh(podcast) - logger.info(f"Successfully generated content podcast: {podcast.id}") + logger.info(f"Successfully generated podcast: {podcast.id}") return { - "status": "success", + "status": "ready", "podcast_id": podcast.id, - "title": podcast_title, + "title": podcast.title, "transcript_entries": len(serializable_transcript), } except Exception as e: logger.error(f"Error in _generate_content_podcast: {e!s}") - await session.rollback() + podcast.status = PodcastStatus.FAILED + await session.commit() raise