mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-17 18:35:19 +02:00
feat: enhance video presentation agent with parallel theme assignment and watermarking
This commit is contained in:
parent
0fe5e034fe
commit
d90b6d35ce
9 changed files with 123 additions and 197 deletions
|
|
@ -4,60 +4,15 @@ Podcast generation tool for the SurfSense agent.
|
|||
This module provides a factory function for creating the generate_podcast tool
|
||||
that submits a Celery task for background podcast generation. The frontend
|
||||
polls for completion and auto-updates when the podcast is ready.
|
||||
|
||||
Duplicate request prevention:
|
||||
- Only one podcast can be generated at a time per search space
|
||||
- Uses Redis to track active podcast tasks
|
||||
- Returns a friendly message if a podcast is already being generated
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
|
||||
# Redis connection for tracking active podcast tasks
|
||||
# Defaults to the Celery broker when REDIS_APP_URL is not set
|
||||
REDIS_URL = config.REDIS_APP_URL
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
|
||||
def get_redis_client() -> redis.Redis:
|
||||
"""Get or create Redis client for podcast task tracking."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def _redis_key(search_space_id: int) -> str:
|
||||
return f"podcast:generating:{search_space_id}"
|
||||
|
||||
|
||||
def get_generating_podcast_id(search_space_id: int) -> int | None:
|
||||
"""Get the podcast ID currently being generated for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
value = client.get(_redis_key(search_space_id))
|
||||
return int(value) if value else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def set_generating_podcast(search_space_id: int, podcast_id: int) -> None:
|
||||
"""Mark a podcast as currently generating for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
client.setex(_redis_key(search_space_id), 1800, str(podcast_id))
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[generate_podcast] Warning: Could not set generating podcast in Redis: {e}"
|
||||
)
|
||||
|
||||
|
||||
def create_generate_podcast_tool(
|
||||
search_space_id: int,
|
||||
|
|
@ -109,18 +64,6 @@ def create_generate_podcast_tool(
|
|||
- message: Status message (or "error" field if status is failed)
|
||||
"""
|
||||
try:
|
||||
generating_podcast_id = get_generating_podcast_id(search_space_id)
|
||||
if generating_podcast_id:
|
||||
print(
|
||||
f"[generate_podcast] Blocked duplicate request. Generating podcast: {generating_podcast_id}"
|
||||
)
|
||||
return {
|
||||
"status": PodcastStatus.GENERATING.value,
|
||||
"podcast_id": generating_podcast_id,
|
||||
"title": podcast_title,
|
||||
"message": "A podcast is already being generated. Please wait for it to complete.",
|
||||
}
|
||||
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
status=PodcastStatus.PENDING,
|
||||
|
|
@ -142,8 +85,6 @@ def create_generate_podcast_tool(
|
|||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
set_generating_podcast(search_space_id, podcast.id)
|
||||
|
||||
print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}")
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -4,70 +4,15 @@ Video presentation generation tool for the SurfSense agent.
|
|||
This module provides a factory function for creating the generate_video_presentation
|
||||
tool that submits a Celery task for background video presentation generation.
|
||||
The frontend polls for completion and auto-updates when the presentation is ready.
|
||||
|
||||
Duplicate request prevention:
|
||||
- Only one video presentation can be generated at a time per search space
|
||||
- Uses Redis to track active video presentation tasks
|
||||
- Validates the Redis marker against actual DB status to avoid stale locks
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import VideoPresentation, VideoPresentationStatus
|
||||
|
||||
REDIS_URL = config.REDIS_APP_URL
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
|
||||
def get_redis_client() -> redis.Redis:
|
||||
"""Get or create Redis client for video presentation task tracking."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def _redis_key(search_space_id: int) -> str:
|
||||
return f"video_presentation:generating:{search_space_id}"
|
||||
|
||||
|
||||
def get_generating_video_presentation_id(search_space_id: int) -> int | None:
|
||||
"""Get the video presentation ID currently being generated for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
value = client.get(_redis_key(search_space_id))
|
||||
return int(value) if value else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def clear_generating_video_presentation(search_space_id: int) -> None:
|
||||
"""Clear the generating marker (used when we detect a stale lock)."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
client.delete(_redis_key(search_space_id))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def set_generating_video_presentation(
|
||||
search_space_id: int, video_presentation_id: int
|
||||
) -> None:
|
||||
"""Mark a video presentation as currently generating for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
client.setex(_redis_key(search_space_id), 1800, str(video_presentation_id))
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[generate_video_presentation] Warning: Could not set generating video presentation in Redis: {e}"
|
||||
)
|
||||
|
||||
|
||||
def create_generate_video_presentation_tool(
|
||||
search_space_id: int,
|
||||
|
|
@ -97,33 +42,6 @@ def create_generate_video_presentation_tool(
|
|||
user_prompt: Optional style/tone instructions.
|
||||
"""
|
||||
try:
|
||||
generating_id = get_generating_video_presentation_id(search_space_id)
|
||||
if generating_id:
|
||||
result = await db_session.execute(
|
||||
select(VideoPresentation).filter(
|
||||
VideoPresentation.id == generating_id
|
||||
)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
|
||||
if existing and existing.status == VideoPresentationStatus.GENERATING:
|
||||
print(
|
||||
f"[generate_video_presentation] Blocked duplicate — "
|
||||
f"presentation {generating_id} is actively generating"
|
||||
)
|
||||
return {
|
||||
"status": VideoPresentationStatus.GENERATING.value,
|
||||
"video_presentation_id": generating_id,
|
||||
"title": video_title,
|
||||
"message": "A video presentation is already being generated. Please wait for it to complete.",
|
||||
}
|
||||
|
||||
print(
|
||||
f"[generate_video_presentation] Stale Redis lock for presentation {generating_id} "
|
||||
f"(status={existing.status if existing else 'not found'}). Clearing and proceeding."
|
||||
)
|
||||
clear_generating_video_presentation(search_space_id)
|
||||
|
||||
video_pres = VideoPresentation(
|
||||
title=video_title,
|
||||
status=VideoPresentationStatus.PENDING,
|
||||
|
|
@ -145,8 +63,6 @@ def create_generate_video_presentation_tool(
|
|||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
set_generating_video_presentation(search_space_id, video_pres.id)
|
||||
|
||||
print(
|
||||
f"[generate_video_presentation] Created video presentation {video_pres.id}, task: {task.id}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from langgraph.graph import StateGraph
|
|||
|
||||
from .configuration import Configuration
|
||||
from .nodes import (
|
||||
assign_slide_themes,
|
||||
create_presentation_slides,
|
||||
create_slide_audio,
|
||||
generate_slide_scene_codes,
|
||||
|
|
@ -14,11 +15,19 @@ def build_graph():
|
|||
|
||||
workflow.add_node("create_presentation_slides", create_presentation_slides)
|
||||
workflow.add_node("create_slide_audio", create_slide_audio)
|
||||
workflow.add_node("assign_slide_themes", assign_slide_themes)
|
||||
workflow.add_node("generate_slide_scene_codes", generate_slide_scene_codes)
|
||||
|
||||
# Fan-out: after slides are parsed, run audio generation and theme
|
||||
# assignment in parallel (themes only need slide metadata, not audio).
|
||||
workflow.add_edge("__start__", "create_presentation_slides")
|
||||
workflow.add_edge("create_presentation_slides", "create_slide_audio")
|
||||
workflow.add_edge("create_presentation_slides", "assign_slide_themes")
|
||||
|
||||
# Fan-in: scene code generation waits for both audio and themes.
|
||||
workflow.add_edge("create_slide_audio", "generate_slide_scene_codes")
|
||||
workflow.add_edge("assign_slide_themes", "generate_slide_scene_codes")
|
||||
|
||||
workflow.add_edge("generate_slide_scene_codes", "__end__")
|
||||
|
||||
graph = workflow.compile()
|
||||
|
|
|
|||
|
|
@ -178,18 +178,29 @@ async def create_slide_audio(state: State, config: RunnableConfig) -> dict[str,
|
|||
|
||||
chunk_paths: list[str] = []
|
||||
try:
|
||||
for i, text in enumerate(slide.speaker_transcripts):
|
||||
chunk_path = str(
|
||||
chunk_paths = [
|
||||
str(
|
||||
temp_dir
|
||||
/ f"{session_id}_slide_{slide.slide_number}_chunk_{i}.{ext}"
|
||||
)
|
||||
for i in range(len(slide.speaker_transcripts))
|
||||
]
|
||||
|
||||
for i, text in enumerate(slide.speaker_transcripts):
|
||||
print(
|
||||
f" Slide {slide.slide_number} chunk {i + 1}/"
|
||||
f"{len(slide.speaker_transcripts)}: "
|
||||
f'"{text[:60]}..."'
|
||||
)
|
||||
await _generate_tts_chunk(text, chunk_path)
|
||||
chunk_paths.append(chunk_path)
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
_generate_tts_chunk(text, path)
|
||||
for text, path in zip(
|
||||
slide.speaker_transcripts, chunk_paths, strict=False
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if len(chunk_paths) == 1:
|
||||
shutil.move(chunk_paths[0], output_file)
|
||||
|
|
@ -340,13 +351,30 @@ async def _assign_themes_with_llm(
|
|||
}
|
||||
|
||||
|
||||
async def assign_slide_themes(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""Assign a theme preset + dark/light mode to every slide via a single LLM call.
|
||||
|
||||
Runs in parallel with audio generation since it only needs slide metadata.
|
||||
"""
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
search_space_id = configuration.search_space_id
|
||||
|
||||
llm = await get_agent_llm(state.db_session, search_space_id)
|
||||
if not llm:
|
||||
raise RuntimeError(f"No LLM configured for search space {search_space_id}")
|
||||
|
||||
slides = state.slides or []
|
||||
assignments = await _assign_themes_with_llm(llm, slides)
|
||||
return {"slide_theme_assignments": assignments}
|
||||
|
||||
|
||||
async def generate_slide_scene_codes(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Generate Remotion component code for each slide using LLM.
|
||||
|
||||
First assigns a theme+mode to every slide via a single LLM call,
|
||||
then generates scene code per slide with the assigned theme.
|
||||
Reads pre-assigned themes from state (produced by the parallel
|
||||
assign_slide_themes node) and generates scene code concurrently.
|
||||
"""
|
||||
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
|
|
@ -362,11 +390,9 @@ async def generate_slide_scene_codes(
|
|||
audio_map: dict[int, SlideAudioResult] = {r.slide_number: r for r in audio_results}
|
||||
total_slides = len(slides)
|
||||
|
||||
theme_assignments = await _assign_themes_with_llm(llm, slides)
|
||||
theme_assignments = state.slide_theme_assignments or {}
|
||||
|
||||
scene_codes: list[SlideSceneCode] = []
|
||||
|
||||
for slide in slides:
|
||||
async def _generate_scene_for_slide(slide: SlideContent) -> SlideSceneCode:
|
||||
audio = audio_map.get(slide.slide_number)
|
||||
duration = audio.duration_in_frames if audio else DEFAULT_DURATION_IN_FRAMES
|
||||
|
||||
|
|
@ -402,15 +428,17 @@ async def generate_slide_scene_codes(
|
|||
|
||||
code = await _refine_if_needed(llm, code, slide.slide_number)
|
||||
|
||||
scene_codes.append(
|
||||
SlideSceneCode(
|
||||
slide_number=slide.slide_number,
|
||||
code=code,
|
||||
title=scene_title or slide.title,
|
||||
)
|
||||
print(f"Scene code ready for slide {slide.slide_number} ({len(code)} chars)")
|
||||
|
||||
return SlideSceneCode(
|
||||
slide_number=slide.slide_number,
|
||||
code=code,
|
||||
title=scene_title or slide.title,
|
||||
)
|
||||
|
||||
print(f"Scene code ready for slide {slide.slide_number} ({len(code)} chars)")
|
||||
scene_codes = list(
|
||||
await asyncio.gather(*[_generate_scene_for_slide(s) for s in slides])
|
||||
)
|
||||
|
||||
return {"slide_scene_codes": scene_codes}
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class SlideSceneCode(BaseModel):
|
|||
class State:
|
||||
"""State for the video presentation agent graph.
|
||||
|
||||
Pipeline: parse slides → generate per-slide TTS audio → generate per-slide Remotion code
|
||||
Pipeline: parse slides → (TTS audio ∥ theme assignment) → generate Remotion code
|
||||
The frontend receives the slides + code + audio and handles compilation/rendering.
|
||||
"""
|
||||
|
||||
|
|
@ -69,4 +69,5 @@ class State:
|
|||
|
||||
slides: list[SlideContent] | None = None
|
||||
slide_audio_results: list[SlideAudioResult] | None = None
|
||||
slide_theme_assignments: dict[int, tuple[str, str]] | None = None
|
||||
slide_scene_codes: list[SlideSceneCode] | None = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue