SurfSense/surfsense_backend/app/tasks/celery_tasks/podcast_tasks.py
DESKTOP-RTLN3BA\$punk b14283e300 feat: refactor new chat agent to support configurable tools and remove deprecated components
- Enhanced the new chat agent module to allow for configurable tools, enabling users to customize their experience with various functionalities.
- Removed outdated tools including display image, knowledge base search, link preview, podcast generation, and web scraping, streamlining the codebase.
- Updated the system prompt and agent factory to reflect these changes, ensuring a more cohesive and efficient architecture.
2025-12-22 20:17:08 -08:00

178 lines
6 KiB
Python

"""Celery tasks for podcast generation."""
import asyncio
import logging
import sys
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
# Import for content-based podcast (new-chat)
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
from app.config import config
from app.db import Podcast
logger = logging.getLogger(__name__)
if sys.platform.startswith("win"):
try:
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
except AttributeError:
logger.warning(
"WindowsProactorEventLoopPolicy is unavailable; async subprocess support may fail."
)
def get_celery_session_maker():
"""
Create a new async session maker for Celery tasks.
This is necessary because Celery tasks run in a new event loop,
and the default session maker is bound to the main app's event loop.
"""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool, # Don't use connection pooling for Celery tasks
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
# =============================================================================
# Content-based podcast generation (for new-chat)
# =============================================================================
def _clear_active_podcast_redis_key(search_space_id: int) -> None:
"""Clear the active podcast task key from Redis when task completes."""
import os
import redis
try:
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
client = redis.from_url(redis_url, decode_responses=True)
key = f"podcast:active:{search_space_id}"
client.delete(key)
logger.info(f"Cleared active podcast key for search_space_id={search_space_id}")
except Exception as e:
logger.warning(f"Could not clear active podcast key: {e}")
@celery_app.task(name="generate_content_podcast", bind=True)
def generate_content_podcast_task(
self,
source_content: str,
search_space_id: int,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
) -> dict:
"""
Celery task to generate podcast from source content (for new-chat).
This task generates a podcast directly from provided content.
Args:
source_content: The text content to convert into a podcast
search_space_id: ID of the search space
podcast_title: Title for the podcast
user_prompt: Optional instructions for podcast style/tone
Returns:
dict with podcast_id on success, or error info on failure
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(
_generate_content_podcast(
source_content,
search_space_id,
podcast_title,
user_prompt,
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
return result
except Exception as e:
logger.error(f"Error generating content podcast: {e!s}")
return {"status": "error", "error": str(e)}
finally:
# Always clear the active podcast key when task completes (success or failure)
_clear_active_podcast_redis_key(search_space_id)
asyncio.set_event_loop(None)
loop.close()
async def _generate_content_podcast(
source_content: str,
search_space_id: int,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
) -> dict:
"""Generate content-based podcast with new session."""
async with get_celery_session_maker()() as session:
try:
# Configure the podcaster graph
graph_config = {
"configurable": {
"podcast_title": podcast_title,
"search_space_id": search_space_id,
"user_prompt": user_prompt,
}
}
# Initialize the podcaster state with the source content
initial_state = PodcasterState(
source_content=source_content,
db_session=session,
)
# Run the podcaster graph
result = await podcaster_graph.ainvoke(initial_state, config=graph_config)
# Extract results
podcast_transcript = result.get("podcast_transcript", [])
file_path = result.get("final_podcast_file_path", "")
# Convert transcript to serializable format
serializable_transcript = []
for entry in podcast_transcript:
if hasattr(entry, "speaker_id"):
serializable_transcript.append(
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
)
else:
serializable_transcript.append(
{
"speaker_id": entry.get("speaker_id", 0),
"dialog": entry.get("dialog", ""),
}
)
# Save podcast to database
podcast = Podcast(
title=podcast_title,
podcast_transcript=serializable_transcript,
file_location=file_path,
search_space_id=search_space_id,
)
session.add(podcast)
await session.commit()
await session.refresh(podcast)
logger.info(f"Successfully generated content podcast: {podcast.id}")
return {
"status": "success",
"podcast_id": podcast.id,
"title": podcast_title,
"transcript_entries": len(serializable_transcript),
}
except Exception as e:
logger.error(f"Error in _generate_content_podcast: {e!s}")
await session.rollback()
raise