mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 16:56:22 +02:00
- Enhanced the new chat agent module to allow for configurable tools, enabling users to customize their experience with various functionalities. - Removed outdated tools including display image, knowledge base search, link preview, podcast generation, and web scraping, streamlining the codebase. - Updated the system prompt and agent factory to reflect these changes, ensuring a more cohesive and efficient architecture.
178 lines
6 KiB
Python
178 lines
6 KiB
Python
"""Celery tasks for podcast generation."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import sys
|
|
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|
from sqlalchemy.pool import NullPool
|
|
|
|
# Import for content-based podcast (new-chat)
|
|
from app.agents.podcaster.graph import graph as podcaster_graph
|
|
from app.agents.podcaster.state import State as PodcasterState
|
|
from app.celery_app import celery_app
|
|
from app.config import config
|
|
from app.db import Podcast
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if sys.platform.startswith("win"):
|
|
try:
|
|
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
|
except AttributeError:
|
|
logger.warning(
|
|
"WindowsProactorEventLoopPolicy is unavailable; async subprocess support may fail."
|
|
)
|
|
|
|
|
|
def get_celery_session_maker():
|
|
"""
|
|
Create a new async session maker for Celery tasks.
|
|
This is necessary because Celery tasks run in a new event loop,
|
|
and the default session maker is bound to the main app's event loop.
|
|
"""
|
|
engine = create_async_engine(
|
|
config.DATABASE_URL,
|
|
poolclass=NullPool, # Don't use connection pooling for Celery tasks
|
|
echo=False,
|
|
)
|
|
return async_sessionmaker(engine, expire_on_commit=False)
|
|
|
|
|
|
# =============================================================================
|
|
# Content-based podcast generation (for new-chat)
|
|
# =============================================================================
|
|
|
|
|
|
def _clear_active_podcast_redis_key(search_space_id: int) -> None:
|
|
"""Clear the active podcast task key from Redis when task completes."""
|
|
import os
|
|
|
|
import redis
|
|
|
|
try:
|
|
redis_url = os.getenv("CELERY_BROKER_URL", "redis://localhost:6379/0")
|
|
client = redis.from_url(redis_url, decode_responses=True)
|
|
key = f"podcast:active:{search_space_id}"
|
|
client.delete(key)
|
|
logger.info(f"Cleared active podcast key for search_space_id={search_space_id}")
|
|
except Exception as e:
|
|
logger.warning(f"Could not clear active podcast key: {e}")
|
|
|
|
|
|
@celery_app.task(name="generate_content_podcast", bind=True)
|
|
def generate_content_podcast_task(
|
|
self,
|
|
source_content: str,
|
|
search_space_id: int,
|
|
podcast_title: str = "SurfSense Podcast",
|
|
user_prompt: str | None = None,
|
|
) -> dict:
|
|
"""
|
|
Celery task to generate podcast from source content (for new-chat).
|
|
|
|
This task generates a podcast directly from provided content.
|
|
|
|
Args:
|
|
source_content: The text content to convert into a podcast
|
|
search_space_id: ID of the search space
|
|
podcast_title: Title for the podcast
|
|
user_prompt: Optional instructions for podcast style/tone
|
|
|
|
Returns:
|
|
dict with podcast_id on success, or error info on failure
|
|
"""
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
try:
|
|
result = loop.run_until_complete(
|
|
_generate_content_podcast(
|
|
source_content,
|
|
search_space_id,
|
|
podcast_title,
|
|
user_prompt,
|
|
)
|
|
)
|
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error generating content podcast: {e!s}")
|
|
return {"status": "error", "error": str(e)}
|
|
finally:
|
|
# Always clear the active podcast key when task completes (success or failure)
|
|
_clear_active_podcast_redis_key(search_space_id)
|
|
asyncio.set_event_loop(None)
|
|
loop.close()
|
|
|
|
|
|
async def _generate_content_podcast(
|
|
source_content: str,
|
|
search_space_id: int,
|
|
podcast_title: str = "SurfSense Podcast",
|
|
user_prompt: str | None = None,
|
|
) -> dict:
|
|
"""Generate content-based podcast with new session."""
|
|
async with get_celery_session_maker()() as session:
|
|
try:
|
|
# Configure the podcaster graph
|
|
graph_config = {
|
|
"configurable": {
|
|
"podcast_title": podcast_title,
|
|
"search_space_id": search_space_id,
|
|
"user_prompt": user_prompt,
|
|
}
|
|
}
|
|
|
|
# Initialize the podcaster state with the source content
|
|
initial_state = PodcasterState(
|
|
source_content=source_content,
|
|
db_session=session,
|
|
)
|
|
|
|
# Run the podcaster graph
|
|
result = await podcaster_graph.ainvoke(initial_state, config=graph_config)
|
|
|
|
# Extract results
|
|
podcast_transcript = result.get("podcast_transcript", [])
|
|
file_path = result.get("final_podcast_file_path", "")
|
|
|
|
# Convert transcript to serializable format
|
|
serializable_transcript = []
|
|
for entry in podcast_transcript:
|
|
if hasattr(entry, "speaker_id"):
|
|
serializable_transcript.append(
|
|
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
|
|
)
|
|
else:
|
|
serializable_transcript.append(
|
|
{
|
|
"speaker_id": entry.get("speaker_id", 0),
|
|
"dialog": entry.get("dialog", ""),
|
|
}
|
|
)
|
|
|
|
# Save podcast to database
|
|
podcast = Podcast(
|
|
title=podcast_title,
|
|
podcast_transcript=serializable_transcript,
|
|
file_location=file_path,
|
|
search_space_id=search_space_id,
|
|
)
|
|
session.add(podcast)
|
|
await session.commit()
|
|
await session.refresh(podcast)
|
|
|
|
logger.info(f"Successfully generated content podcast: {podcast.id}")
|
|
|
|
return {
|
|
"status": "success",
|
|
"podcast_id": podcast.id,
|
|
"title": podcast_title,
|
|
"transcript_entries": len(serializable_transcript),
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in _generate_content_podcast: {e!s}")
|
|
await session.rollback()
|
|
raise
|