feat: add podcast status tracking

This commit is contained in:
CREDO23 2026-01-27 17:51:36 +02:00
parent c65cda24d7
commit 87c7d92672
7 changed files with 165 additions and 193 deletions

View file

@ -1,9 +1,10 @@
"""Add thread_id to podcasts
"""Add status and thread_id to podcasts
Revision ID: 82
Revises: 81
Create Date: 2026-01-23
Create Date: 2026-01-27
Adds status enum and thread_id FK to podcasts.
"""
from collections.abc import Sequence
@ -17,7 +18,19 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add thread_id column to podcasts."""
op.execute(
"""
CREATE TYPE podcast_status AS ENUM ('pending', 'generating', 'ready', 'failed');
"""
)
op.execute(
"""
ALTER TABLE podcasts
ADD COLUMN IF NOT EXISTS status podcast_status NOT NULL DEFAULT 'ready';
"""
)
op.execute(
"""
ALTER TABLE podcasts
@ -33,8 +46,17 @@ def upgrade() -> None:
"""
)
op.execute(
"""
CREATE INDEX IF NOT EXISTS ix_podcasts_status
ON podcasts(status);
"""
)
def downgrade() -> None:
"""Remove thread_id column from podcasts."""
op.execute("DROP INDEX IF EXISTS ix_podcasts_status")
op.execute("DROP INDEX IF EXISTS ix_podcasts_thread_id")
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS thread_id")
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS status")
op.execute("DROP TYPE IF EXISTS podcast_status")

View file

@ -18,6 +18,8 @@ import redis
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Podcast, PodcastStatus
# Redis connection for tracking active podcast tasks
# Uses the same Redis instance as Celery
REDIS_URL = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
@ -32,38 +34,27 @@ def get_redis_client() -> redis.Redis:
return _redis_client
def get_active_podcast_key(search_space_id: int) -> str:
"""Generate Redis key for tracking active podcast task."""
return f"podcast:active:{search_space_id}"
def _redis_key(search_space_id: int) -> str:
return f"podcast:generating:{search_space_id}"
def get_active_podcast_task(search_space_id: int) -> str | None:
"""Check if there's an active podcast task for this search space."""
def get_generating_podcast_id(search_space_id: int) -> int | None:
"""Get the podcast ID currently being generated for this search space."""
try:
client = get_redis_client()
return client.get(get_active_podcast_key(search_space_id))
value = client.get(_redis_key(search_space_id))
return int(value) if value else None
except Exception:
# If Redis is unavailable, allow the request (fail open)
return None
def set_active_podcast_task(search_space_id: int, task_id: str) -> None:
"""Mark a podcast task as active for this search space."""
def set_generating_podcast(search_space_id: int, podcast_id: int) -> None:
"""Mark a podcast as currently generating for this search space."""
try:
client = get_redis_client()
# Set with 30-minute expiry as safety net (podcast should complete before this)
client.setex(get_active_podcast_key(search_space_id), 1800, task_id)
client.setex(_redis_key(search_space_id), 1800, str(podcast_id))
except Exception as e:
print(f"[generate_podcast] Warning: Could not set active task in Redis: {e}")
def clear_active_podcast_task(search_space_id: int) -> None:
"""Clear the active podcast task for this search space."""
try:
client = get_redis_client()
client.delete(get_active_podcast_key(search_space_id))
except Exception as e:
print(f"[generate_podcast] Warning: Could not clear active task in Redis: {e}")
print(f"[generate_podcast] Warning: Could not set generating podcast in Redis: {e}")
def create_generate_podcast_tool(
@ -74,9 +65,12 @@ def create_generate_podcast_tool(
"""
Factory function to create the generate_podcast tool with injected dependencies.
Pre-creates podcast record with pending status so podcast_id is available
immediately for frontend polling.
Args:
search_space_id: The user's search space ID
db_session: Database session (not used - Celery creates its own)
db_session: Database session for creating the podcast record
thread_id: The chat thread ID for associating the podcast
Returns:
@ -100,77 +94,71 @@ def create_generate_podcast_tool(
- "Make a podcast about..."
- "Turn this into a podcast"
The tool will start generating a podcast in the background.
The podcast will be available once generation completes.
IMPORTANT: Only one podcast can be generated at a time. If a podcast
is already being generated, this tool will return a message asking
the user to wait.
Args:
source_content: The text content to convert into a podcast.
This can be a summary, research findings, or any text
the user wants transformed into an audio podcast.
podcast_title: Title for the podcast (default: "SurfSense Podcast")
user_prompt: Optional instructions for podcast style, tone, or format.
For example: "Make it casual and fun" or "Focus on the key insights"
Returns:
A dictionary containing:
- status: "processing" (task submitted), "already_generating", or "error"
- task_id: The Celery task ID for polling status (if processing)
- status: PodcastStatus value (pending, generating, or failed)
- podcast_id: The podcast ID for polling (when status is pending or generating)
- title: The podcast title
- message: Status message for the user
- message: Status message (or "error" field if status is failed)
"""
try:
# Check if a podcast is already being generated for this search space
active_task_id = get_active_podcast_task(search_space_id)
if active_task_id:
generating_podcast_id = get_generating_podcast_id(search_space_id)
if generating_podcast_id:
print(
f"[generate_podcast] Blocked duplicate request. Active task: {active_task_id}"
f"[generate_podcast] Blocked duplicate request. Generating podcast: {generating_podcast_id}"
)
return {
"status": "already_generating",
"task_id": active_task_id,
"status": PodcastStatus.GENERATING.value,
"podcast_id": generating_podcast_id,
"title": podcast_title,
"message": "A podcast is already being generated. Please wait for it to complete before requesting another one.",
"message": "A podcast is already being generated. Please wait for it to complete.",
}
# Import Celery task here to avoid circular imports
podcast = Podcast(
title=podcast_title,
status=PodcastStatus.PENDING,
search_space_id=search_space_id,
thread_id=thread_id,
)
db_session.add(podcast)
await db_session.commit()
await db_session.refresh(podcast)
from app.tasks.celery_tasks.podcast_tasks import (
generate_content_podcast_task,
)
# Submit Celery task for background processing
task = generate_content_podcast_task.delay(
podcast_id=podcast.id,
source_content=source_content,
search_space_id=search_space_id,
podcast_title=podcast_title,
user_prompt=user_prompt,
thread_id=thread_id,
)
# Mark this task as active
set_active_podcast_task(search_space_id, task.id)
set_generating_podcast(search_space_id, podcast.id)
print(f"[generate_podcast] Submitted Celery task: {task.id}")
print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}")
# Return immediately with task_id for polling
return {
"status": "processing",
"task_id": task.id,
"status": PodcastStatus.PENDING.value,
"podcast_id": podcast.id,
"title": podcast_title,
"message": "Podcast generation started. This may take a few minutes.",
}
except Exception as e:
error_message = str(e)
print(f"[generate_podcast] Error submitting task: {error_message}")
print(f"[generate_podcast] Error: {error_message}")
return {
"status": "error",
"status": PodcastStatus.FAILED.value,
"error": error_message,
"title": podcast_title,
"task_id": None,
"podcast_id": None,
}
return generate_podcast

View file

@ -93,6 +93,13 @@ class SearchSourceConnectorType(str, Enum):
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
class PodcastStatus(str, Enum):
PENDING = "pending"
GENERATING = "generating"
READY = "ready"
FAILED = "failed"
class LiteLLMProvider(str, Enum):
"""
Enum for LLM providers supported by LiteLLM.
@ -743,8 +750,15 @@ class Podcast(BaseModel, TimestampMixin):
__tablename__ = "podcasts"
title = Column(String(500), nullable=False)
podcast_transcript = Column(JSONB, nullable=True) # List of transcript entries
file_location = Column(Text, nullable=True) # Path to the audio file
podcast_transcript = Column(JSONB, nullable=True)
file_location = Column(Text, nullable=True)
status = Column(
SQLAlchemyEnum(PodcastStatus, name="podcast_status", create_type=False),
nullable=False,
default=PodcastStatus.READY,
server_default="ready",
index=True,
)
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False

View file

@ -1,21 +1,19 @@
"""
Podcast routes for task status polling and audio retrieval.
Podcast routes for CRUD operations and audio streaming.
These routes support the podcast generation feature in new-chat.
Note: The old Chat-based podcast generation has been removed.
Frontend polls GET /podcasts/{podcast_id} to check status field.
"""
import os
from pathlib import Path
from celery.result import AsyncResult
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.celery_app import celery_app
from app.db import (
Permission,
Podcast,
@ -228,62 +226,3 @@ async def stream_podcast(
raise HTTPException(
status_code=500, detail=f"Error streaming podcast: {e!s}"
) from e
@router.get("/podcasts/task/{task_id}/status")
async def get_podcast_task_status(
task_id: str,
user: User = Depends(current_active_user),
):
"""
Get the status of a podcast generation task.
Used by new-chat frontend to poll for completion.
Returns:
- status: "processing" | "success" | "error"
- podcast_id: (only if status == "success")
- title: (only if status == "success")
- error: (only if status == "error")
"""
try:
result = AsyncResult(task_id, app=celery_app)
if result.ready():
# Task completed
if result.successful():
task_result = result.result
if isinstance(task_result, dict):
if task_result.get("status") == "success":
return {
"status": "success",
"podcast_id": task_result.get("podcast_id"),
"title": task_result.get("title"),
"transcript_entries": task_result.get("transcript_entries"),
}
else:
return {
"status": "error",
"error": task_result.get("error", "Unknown error"),
}
else:
return {
"status": "error",
"error": "Unexpected task result format",
}
else:
# Task failed
return {
"status": "error",
"error": str(result.result) if result.result else "Task failed",
}
else:
# Task still processing
return {
"status": "processing",
"state": result.state,
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error checking task status: {e!s}"
) from e

View file

@ -1,11 +1,19 @@
"""Podcast schemas for API responses."""
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
class PodcastStatusEnum(str, Enum):
PENDING = "pending"
GENERATING = "generating"
READY = "ready"
FAILED = "failed"
class PodcastBase(BaseModel):
"""Base podcast schema."""
@ -33,6 +41,7 @@ class PodcastRead(PodcastBase):
"""Schema for reading a podcast."""
id: int
status: PodcastStatusEnum = PodcastStatusEnum.READY
created_at: datetime
class Config:

View file

@ -40,7 +40,10 @@ def strip_citations(text: str) -> str:
def sanitize_content_for_public(content: list | str | None) -> list:
"""Filter message content for public view."""
"""
Filter message content for public view.
Strips citations and filters to UI-relevant tools.
"""
if content is None:
return []
@ -67,13 +70,6 @@ def sanitize_content_for_public(content: list | str | None) -> list:
tool_name = part.get("toolName")
if tool_name not in UI_TOOLS:
continue
# Skip podcasts that are still processing (would cause auth errors)
if tool_name == "generate_podcast":
result = part.get("result", {})
if result.get("status") in ("processing", "already_generating"):
continue
sanitized.append(part)
return sanitized
@ -355,16 +351,16 @@ async def _clone_podcast(
target_search_space_id: int,
target_thread_id: int,
) -> int | None:
"""Clone a podcast record and its audio file."""
"""Clone a podcast record and its audio file. Only clones ready podcasts."""
import shutil
import uuid
from pathlib import Path
from app.db import Podcast
from app.db import Podcast, PodcastStatus
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
original = result.scalars().first()
if not original:
if not original or original.status != PodcastStatus.READY:
return None
new_file_path = None
@ -381,6 +377,7 @@ async def _clone_podcast(
title=original.title,
podcast_transcript=original.podcast_transcript,
file_location=new_file_path,
status=PodcastStatus.READY,
search_space_id=target_search_space_id,
thread_id=target_thread_id,
)

View file

@ -4,15 +4,15 @@ import asyncio
import logging
import sys
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
# Import for content-based podcast (new-chat)
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
from app.config import config
from app.db import Podcast
from app.db import Podcast, PodcastStatus
logger = logging.getLogger(__name__)
@ -44,8 +44,8 @@ def get_celery_session_maker():
# =============================================================================
def _clear_active_podcast_redis_key(search_space_id: int) -> None:
"""Clear the active podcast task key from Redis when task completes."""
def _clear_generating_podcast(search_space_id: int) -> None:
"""Clear the generating podcast marker from Redis when task completes."""
import os
import redis
@ -53,36 +53,24 @@ def _clear_active_podcast_redis_key(search_space_id: int) -> None:
try:
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
client = redis.from_url(redis_url, decode_responses=True)
key = f"podcast:active:{search_space_id}"
key = f"podcast:generating:{search_space_id}"
client.delete(key)
logger.info(f"Cleared active podcast key for search_space_id={search_space_id}")
logger.info(f"Cleared generating podcast key for search_space_id={search_space_id}")
except Exception as e:
logger.warning(f"Could not clear active podcast key: {e}")
logger.warning(f"Could not clear generating podcast key: {e}")
@celery_app.task(name="generate_content_podcast", bind=True)
def generate_content_podcast_task(
self,
podcast_id: int,
source_content: str,
search_space_id: int,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
thread_id: int | None = None,
) -> dict:
"""
Celery task to generate podcast from source content (for new-chat).
This task generates a podcast directly from provided content.
Args:
source_content: The text content to convert into a podcast
search_space_id: ID of the search space
podcast_title: Title for the podcast
user_prompt: Optional instructions for podcast style/tone
thread_id: Optional ID of the chat thread that generated this podcast
Returns:
dict with podcast_id on success, or error info on failure
Celery task to generate podcast from source content.
Updates existing podcast record created by the tool.
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@ -90,58 +78,79 @@ def generate_content_podcast_task(
try:
result = loop.run_until_complete(
_generate_content_podcast(
podcast_id,
source_content,
search_space_id,
podcast_title,
user_prompt,
thread_id,
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
return result
except Exception as e:
logger.error(f"Error generating content podcast: {e!s}")
return {"status": "error", "error": str(e)}
loop.run_until_complete(_mark_podcast_failed(podcast_id))
return {"status": "failed", "podcast_id": podcast_id}
finally:
# Always clear the active podcast key when task completes (success or failure)
_clear_active_podcast_redis_key(search_space_id)
_clear_generating_podcast(search_space_id)
asyncio.set_event_loop(None)
loop.close()
async def _generate_content_podcast(
source_content: str,
search_space_id: int,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
thread_id: int | None = None,
) -> dict:
"""Generate content-based podcast with new session."""
async def _mark_podcast_failed(podcast_id: int) -> None:
"""Mark a podcast as failed in the database."""
async with get_celery_session_maker()() as session:
try:
# Configure the podcaster graph
result = await session.execute(
select(Podcast).filter(Podcast.id == podcast_id)
)
podcast = result.scalars().first()
if podcast:
podcast.status = PodcastStatus.FAILED
await session.commit()
except Exception as e:
logger.error(f"Failed to mark podcast as failed: {e}")
async def _generate_content_podcast(
podcast_id: int,
source_content: str,
search_space_id: int,
user_prompt: str | None = None,
) -> dict:
"""Generate content-based podcast and update existing record."""
async with get_celery_session_maker()() as session:
result = await session.execute(
select(Podcast).filter(Podcast.id == podcast_id)
)
podcast = result.scalars().first()
if not podcast:
raise ValueError(f"Podcast {podcast_id} not found")
try:
podcast.status = PodcastStatus.GENERATING
await session.commit()
graph_config = {
"configurable": {
"podcast_title": podcast_title,
"podcast_title": podcast.title,
"search_space_id": search_space_id,
"user_prompt": user_prompt,
}
}
# Initialize the podcaster state with the source content
initial_state = PodcasterState(
source_content=source_content,
db_session=session,
)
# Run the podcaster graph
result = await podcaster_graph.ainvoke(initial_state, config=graph_config)
graph_result = await podcaster_graph.ainvoke(
initial_state, config=graph_config
)
# Extract results
podcast_transcript = result.get("podcast_transcript", [])
file_path = result.get("final_podcast_file_path", "")
podcast_transcript = graph_result.get("podcast_transcript", [])
file_path = graph_result.get("final_podcast_file_path", "")
# Convert transcript to serializable format
serializable_transcript = []
for entry in podcast_transcript:
if hasattr(entry, "speaker_id"):
@ -156,28 +165,22 @@ async def _generate_content_podcast(
}
)
# Save podcast to database
podcast = Podcast(
title=podcast_title,
podcast_transcript=serializable_transcript,
file_location=file_path,
search_space_id=search_space_id,
thread_id=thread_id,
)
session.add(podcast)
podcast.podcast_transcript = serializable_transcript
podcast.file_location = file_path
podcast.status = PodcastStatus.READY
await session.commit()
await session.refresh(podcast)
logger.info(f"Successfully generated content podcast: {podcast.id}")
logger.info(f"Successfully generated podcast: {podcast.id}")
return {
"status": "success",
"status": "ready",
"podcast_id": podcast.id,
"title": podcast_title,
"title": podcast.title,
"transcript_entries": len(serializable_transcript),
}
except Exception as e:
logger.error(f"Error in _generate_content_podcast: {e!s}")
await session.rollback()
podcast.status = PodcastStatus.FAILED
await session.commit()
raise