mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 16:56:22 +02:00
183 lines
6.1 KiB
Python
183 lines
6.1 KiB
Python
"""Celery tasks for podcast generation."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import sys
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
from sqlalchemy.pool import NullPool
|
|
|
|
from app.agents.podcaster.graph import graph as podcaster_graph
|
|
from app.agents.podcaster.state import State as PodcasterState
|
|
from app.celery_app import celery_app
|
|
from app.config import config
|
|
from app.db import Podcast, PodcastStatus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if sys.platform.startswith("win"):
|
|
try:
|
|
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
|
except AttributeError:
|
|
logger.warning(
|
|
"WindowsProactorEventLoopPolicy is unavailable; async subprocess support may fail."
|
|
)
|
|
|
|
|
|
def get_celery_session_maker():
|
|
"""
|
|
Create a new async session maker for Celery tasks.
|
|
This is necessary because Celery tasks run in a new event loop,
|
|
and the default session maker is bound to the main app's event loop.
|
|
"""
|
|
engine = create_async_engine(
|
|
config.DATABASE_URL,
|
|
poolclass=NullPool, # Don't use connection pooling for Celery tasks
|
|
echo=False,
|
|
)
|
|
return async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
|
|
# =============================================================================
|
|
# Content-based podcast generation (for new-chat)
|
|
# =============================================================================
|
|
|
|
|
|
def _clear_generating_podcast(search_space_id: int) -> None:
|
|
"""Clear the generating podcast marker from Redis when task completes."""
|
|
import redis
|
|
|
|
try:
|
|
client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
|
key = f"podcast:generating:{search_space_id}"
|
|
client.delete(key)
|
|
logger.info(
|
|
f"Cleared generating podcast key for search_space_id={search_space_id}"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Could not clear generating podcast key: {e}")
|
|
|
|
|
|
@celery_app.task(name="generate_content_podcast", bind=True)
|
|
def generate_content_podcast_task(
|
|
self,
|
|
podcast_id: int,
|
|
source_content: str,
|
|
search_space_id: int,
|
|
user_prompt: str | None = None,
|
|
) -> dict:
|
|
"""
|
|
Celery task to generate podcast from source content.
|
|
Updates existing podcast record created by the tool.
|
|
"""
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
try:
|
|
result = loop.run_until_complete(
|
|
_generate_content_podcast(
|
|
podcast_id,
|
|
source_content,
|
|
search_space_id,
|
|
user_prompt,
|
|
)
|
|
)
|
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error generating content podcast: {e!s}")
|
|
loop.run_until_complete(_mark_podcast_failed(podcast_id))
|
|
return {"status": "failed", "podcast_id": podcast_id}
|
|
finally:
|
|
_clear_generating_podcast(search_space_id)
|
|
asyncio.set_event_loop(None)
|
|
loop.close()
|
|
|
|
|
|
async def _mark_podcast_failed(podcast_id: int) -> None:
|
|
"""Mark a podcast as failed in the database."""
|
|
async with get_celery_session_maker()() as session:
|
|
try:
|
|
result = await session.execute(
|
|
select(Podcast).filter(Podcast.id == podcast_id)
|
|
)
|
|
podcast = result.scalars().first()
|
|
if podcast:
|
|
podcast.status = PodcastStatus.FAILED
|
|
await session.commit()
|
|
except Exception as e:
|
|
logger.error(f"Failed to mark podcast as failed: {e}")
|
|
|
|
|
|
async def _generate_content_podcast(
|
|
podcast_id: int,
|
|
source_content: str,
|
|
search_space_id: int,
|
|
user_prompt: str | None = None,
|
|
) -> dict:
|
|
"""Generate content-based podcast and update existing record."""
|
|
async with get_celery_session_maker()() as session:
|
|
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
|
podcast = result.scalars().first()
|
|
|
|
if not podcast:
|
|
raise ValueError(f"Podcast {podcast_id} not found")
|
|
|
|
try:
|
|
podcast.status = PodcastStatus.GENERATING
|
|
await session.commit()
|
|
|
|
graph_config = {
|
|
"configurable": {
|
|
"podcast_title": podcast.title,
|
|
"search_space_id": search_space_id,
|
|
"user_prompt": user_prompt,
|
|
}
|
|
}
|
|
|
|
initial_state = PodcasterState(
|
|
source_content=source_content,
|
|
db_session=session,
|
|
)
|
|
|
|
graph_result = await podcaster_graph.ainvoke(
|
|
initial_state, config=graph_config
|
|
)
|
|
|
|
podcast_transcript = graph_result.get("podcast_transcript", [])
|
|
file_path = graph_result.get("final_podcast_file_path", "")
|
|
|
|
serializable_transcript = []
|
|
for entry in podcast_transcript:
|
|
if hasattr(entry, "speaker_id"):
|
|
serializable_transcript.append(
|
|
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
|
|
)
|
|
else:
|
|
serializable_transcript.append(
|
|
{
|
|
"speaker_id": entry.get("speaker_id", 0),
|
|
"dialog": entry.get("dialog", ""),
|
|
}
|
|
)
|
|
|
|
podcast.podcast_transcript = serializable_transcript
|
|
podcast.file_location = file_path
|
|
podcast.status = PodcastStatus.READY
|
|
await session.commit()
|
|
|
|
logger.info(f"Successfully generated podcast: {podcast.id}")
|
|
|
|
return {
|
|
"status": "ready",
|
|
"podcast_id": podcast.id,
|
|
"title": podcast.title,
|
|
"transcript_entries": len(serializable_transcript),
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in _generate_content_podcast: {e!s}")
|
|
podcast.status = PodcastStatus.FAILED
|
|
await session.commit()
|
|
raise
|