mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-03 21:02:40 +02:00
feat: add podcast status tracking
This commit is contained in:
parent
c65cda24d7
commit
87c7d92672
7 changed files with 165 additions and 193 deletions
|
|
@ -4,15 +4,15 @@ import asyncio
|
|||
import logging
|
||||
import sys
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import NullPool
|
||||
|
||||
# Import for content-based podcast (new-chat)
|
||||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Podcast
|
||||
from app.db import Podcast, PodcastStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -44,8 +44,8 @@ def get_celery_session_maker():
|
|||
# =============================================================================
|
||||
|
||||
|
||||
def _clear_active_podcast_redis_key(search_space_id: int) -> None:
|
||||
"""Clear the active podcast task key from Redis when task completes."""
|
||||
def _clear_generating_podcast(search_space_id: int) -> None:
|
||||
"""Clear the generating podcast marker from Redis when task completes."""
|
||||
import os
|
||||
|
||||
import redis
|
||||
|
|
@ -53,36 +53,24 @@ def _clear_active_podcast_redis_key(search_space_id: int) -> None:
|
|||
try:
|
||||
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
||||
client = redis.from_url(redis_url, decode_responses=True)
|
||||
key = f"podcast:active:{search_space_id}"
|
||||
key = f"podcast:generating:{search_space_id}"
|
||||
client.delete(key)
|
||||
logger.info(f"Cleared active podcast key for search_space_id={search_space_id}")
|
||||
logger.info(f"Cleared generating podcast key for search_space_id={search_space_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not clear active podcast key: {e}")
|
||||
logger.warning(f"Could not clear generating podcast key: {e}")
|
||||
|
||||
|
||||
@celery_app.task(name="generate_content_podcast", bind=True)
|
||||
def generate_content_podcast_task(
|
||||
self,
|
||||
podcast_id: int,
|
||||
source_content: str,
|
||||
search_space_id: int,
|
||||
podcast_title: str = "SurfSense Podcast",
|
||||
user_prompt: str | None = None,
|
||||
thread_id: int | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Celery task to generate podcast from source content (for new-chat).
|
||||
|
||||
This task generates a podcast directly from provided content.
|
||||
|
||||
Args:
|
||||
source_content: The text content to convert into a podcast
|
||||
search_space_id: ID of the search space
|
||||
podcast_title: Title for the podcast
|
||||
user_prompt: Optional instructions for podcast style/tone
|
||||
thread_id: Optional ID of the chat thread that generated this podcast
|
||||
|
||||
Returns:
|
||||
dict with podcast_id on success, or error info on failure
|
||||
Celery task to generate podcast from source content.
|
||||
Updates existing podcast record created by the tool.
|
||||
"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
|
@ -90,58 +78,79 @@ def generate_content_podcast_task(
|
|||
try:
|
||||
result = loop.run_until_complete(
|
||||
_generate_content_podcast(
|
||||
podcast_id,
|
||||
source_content,
|
||||
search_space_id,
|
||||
podcast_title,
|
||||
user_prompt,
|
||||
thread_id,
|
||||
)
|
||||
)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating content podcast: {e!s}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
loop.run_until_complete(_mark_podcast_failed(podcast_id))
|
||||
return {"status": "failed", "podcast_id": podcast_id}
|
||||
finally:
|
||||
# Always clear the active podcast key when task completes (success or failure)
|
||||
_clear_active_podcast_redis_key(search_space_id)
|
||||
_clear_generating_podcast(search_space_id)
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _generate_content_podcast(
|
||||
source_content: str,
|
||||
search_space_id: int,
|
||||
podcast_title: str = "SurfSense Podcast",
|
||||
user_prompt: str | None = None,
|
||||
thread_id: int | None = None,
|
||||
) -> dict:
|
||||
"""Generate content-based podcast with new session."""
|
||||
async def _mark_podcast_failed(podcast_id: int) -> None:
|
||||
"""Mark a podcast as failed in the database."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
try:
|
||||
# Configure the podcaster graph
|
||||
result = await session.execute(
|
||||
select(Podcast).filter(Podcast.id == podcast_id)
|
||||
)
|
||||
podcast = result.scalars().first()
|
||||
if podcast:
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark podcast as failed: {e}")
|
||||
|
||||
|
||||
async def _generate_content_podcast(
|
||||
podcast_id: int,
|
||||
source_content: str,
|
||||
search_space_id: int,
|
||||
user_prompt: str | None = None,
|
||||
) -> dict:
|
||||
"""Generate content-based podcast and update existing record."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
result = await session.execute(
|
||||
select(Podcast).filter(Podcast.id == podcast_id)
|
||||
)
|
||||
podcast = result.scalars().first()
|
||||
|
||||
if not podcast:
|
||||
raise ValueError(f"Podcast {podcast_id} not found")
|
||||
|
||||
try:
|
||||
podcast.status = PodcastStatus.GENERATING
|
||||
await session.commit()
|
||||
|
||||
graph_config = {
|
||||
"configurable": {
|
||||
"podcast_title": podcast_title,
|
||||
"podcast_title": podcast.title,
|
||||
"search_space_id": search_space_id,
|
||||
"user_prompt": user_prompt,
|
||||
}
|
||||
}
|
||||
|
||||
# Initialize the podcaster state with the source content
|
||||
initial_state = PodcasterState(
|
||||
source_content=source_content,
|
||||
db_session=session,
|
||||
)
|
||||
|
||||
# Run the podcaster graph
|
||||
result = await podcaster_graph.ainvoke(initial_state, config=graph_config)
|
||||
graph_result = await podcaster_graph.ainvoke(
|
||||
initial_state, config=graph_config
|
||||
)
|
||||
|
||||
# Extract results
|
||||
podcast_transcript = result.get("podcast_transcript", [])
|
||||
file_path = result.get("final_podcast_file_path", "")
|
||||
podcast_transcript = graph_result.get("podcast_transcript", [])
|
||||
file_path = graph_result.get("final_podcast_file_path", "")
|
||||
|
||||
# Convert transcript to serializable format
|
||||
serializable_transcript = []
|
||||
for entry in podcast_transcript:
|
||||
if hasattr(entry, "speaker_id"):
|
||||
|
|
@ -156,28 +165,22 @@ async def _generate_content_podcast(
|
|||
}
|
||||
)
|
||||
|
||||
# Save podcast to database
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
podcast_transcript=serializable_transcript,
|
||||
file_location=file_path,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
session.add(podcast)
|
||||
podcast.podcast_transcript = serializable_transcript
|
||||
podcast.file_location = file_path
|
||||
podcast.status = PodcastStatus.READY
|
||||
await session.commit()
|
||||
await session.refresh(podcast)
|
||||
|
||||
logger.info(f"Successfully generated content podcast: {podcast.id}")
|
||||
logger.info(f"Successfully generated podcast: {podcast.id}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"status": "ready",
|
||||
"podcast_id": podcast.id,
|
||||
"title": podcast_title,
|
||||
"title": podcast.title,
|
||||
"transcript_entries": len(serializable_transcript),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _generate_content_podcast: {e!s}")
|
||||
await session.rollback()
|
||||
podcast.status = PodcastStatus.FAILED
|
||||
await session.commit()
|
||||
raise
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue