feat: enhance video presentation agent with parallel theme assignment and watermarking

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-03-21 23:02:09 -07:00
parent 0fe5e034fe
commit d90b6d35ce
9 changed files with 123 additions and 197 deletions

View file

@ -4,60 +4,15 @@ Podcast generation tool for the SurfSense agent.
This module provides a factory function for creating the generate_podcast tool This module provides a factory function for creating the generate_podcast tool
that submits a Celery task for background podcast generation. The frontend that submits a Celery task for background podcast generation. The frontend
polls for completion and auto-updates when the podcast is ready. polls for completion and auto-updates when the podcast is ready.
Duplicate request prevention:
- Only one podcast can be generated at a time per search space
- Uses Redis to track active podcast tasks
- Returns a friendly message if a podcast is already being generated
""" """
from typing import Any from typing import Any
import redis
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import Podcast, PodcastStatus from app.db import Podcast, PodcastStatus
# Redis connection for tracking active podcast tasks
# Defaults to the Celery broker when REDIS_APP_URL is not set
REDIS_URL = config.REDIS_APP_URL
_redis_client: redis.Redis | None = None
def get_redis_client() -> redis.Redis:
"""Get or create Redis client for podcast task tracking."""
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
return _redis_client
def _redis_key(search_space_id: int) -> str:
return f"podcast:generating:{search_space_id}"
def get_generating_podcast_id(search_space_id: int) -> int | None:
"""Get the podcast ID currently being generated for this search space."""
try:
client = get_redis_client()
value = client.get(_redis_key(search_space_id))
return int(value) if value else None
except Exception:
return None
def set_generating_podcast(search_space_id: int, podcast_id: int) -> None:
"""Mark a podcast as currently generating for this search space."""
try:
client = get_redis_client()
client.setex(_redis_key(search_space_id), 1800, str(podcast_id))
except Exception as e:
print(
f"[generate_podcast] Warning: Could not set generating podcast in Redis: {e}"
)
def create_generate_podcast_tool( def create_generate_podcast_tool(
search_space_id: int, search_space_id: int,
@ -109,18 +64,6 @@ def create_generate_podcast_tool(
- message: Status message (or "error" field if status is failed) - message: Status message (or "error" field if status is failed)
""" """
try: try:
generating_podcast_id = get_generating_podcast_id(search_space_id)
if generating_podcast_id:
print(
f"[generate_podcast] Blocked duplicate request. Generating podcast: {generating_podcast_id}"
)
return {
"status": PodcastStatus.GENERATING.value,
"podcast_id": generating_podcast_id,
"title": podcast_title,
"message": "A podcast is already being generated. Please wait for it to complete.",
}
podcast = Podcast( podcast = Podcast(
title=podcast_title, title=podcast_title,
status=PodcastStatus.PENDING, status=PodcastStatus.PENDING,
@ -142,8 +85,6 @@ def create_generate_podcast_tool(
user_prompt=user_prompt, user_prompt=user_prompt,
) )
set_generating_podcast(search_space_id, podcast.id)
print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}") print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}")
return { return {

View file

@ -4,70 +4,15 @@ Video presentation generation tool for the SurfSense agent.
This module provides a factory function for creating the generate_video_presentation This module provides a factory function for creating the generate_video_presentation
tool that submits a Celery task for background video presentation generation. tool that submits a Celery task for background video presentation generation.
The frontend polls for completion and auto-updates when the presentation is ready. The frontend polls for completion and auto-updates when the presentation is ready.
Duplicate request prevention:
- Only one video presentation can be generated at a time per search space
- Uses Redis to track active video presentation tasks
- Validates the Redis marker against actual DB status to avoid stale locks
""" """
from typing import Any from typing import Any
import redis
from langchain_core.tools import tool from langchain_core.tools import tool
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import VideoPresentation, VideoPresentationStatus from app.db import VideoPresentation, VideoPresentationStatus
REDIS_URL = config.REDIS_APP_URL
_redis_client: redis.Redis | None = None
def get_redis_client() -> redis.Redis:
"""Get or create Redis client for video presentation task tracking."""
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
return _redis_client
def _redis_key(search_space_id: int) -> str:
return f"video_presentation:generating:{search_space_id}"
def get_generating_video_presentation_id(search_space_id: int) -> int | None:
"""Get the video presentation ID currently being generated for this search space."""
try:
client = get_redis_client()
value = client.get(_redis_key(search_space_id))
return int(value) if value else None
except Exception:
return None
def clear_generating_video_presentation(search_space_id: int) -> None:
"""Clear the generating marker (used when we detect a stale lock)."""
try:
client = get_redis_client()
client.delete(_redis_key(search_space_id))
except Exception:
pass
def set_generating_video_presentation(
search_space_id: int, video_presentation_id: int
) -> None:
"""Mark a video presentation as currently generating for this search space."""
try:
client = get_redis_client()
client.setex(_redis_key(search_space_id), 1800, str(video_presentation_id))
except Exception as e:
print(
f"[generate_video_presentation] Warning: Could not set generating video presentation in Redis: {e}"
)
def create_generate_video_presentation_tool( def create_generate_video_presentation_tool(
search_space_id: int, search_space_id: int,
@ -97,33 +42,6 @@ def create_generate_video_presentation_tool(
user_prompt: Optional style/tone instructions. user_prompt: Optional style/tone instructions.
""" """
try: try:
generating_id = get_generating_video_presentation_id(search_space_id)
if generating_id:
result = await db_session.execute(
select(VideoPresentation).filter(
VideoPresentation.id == generating_id
)
)
existing = result.scalars().first()
if existing and existing.status == VideoPresentationStatus.GENERATING:
print(
f"[generate_video_presentation] Blocked duplicate — "
f"presentation {generating_id} is actively generating"
)
return {
"status": VideoPresentationStatus.GENERATING.value,
"video_presentation_id": generating_id,
"title": video_title,
"message": "A video presentation is already being generated. Please wait for it to complete.",
}
print(
f"[generate_video_presentation] Stale Redis lock for presentation {generating_id} "
f"(status={existing.status if existing else 'not found'}). Clearing and proceeding."
)
clear_generating_video_presentation(search_space_id)
video_pres = VideoPresentation( video_pres = VideoPresentation(
title=video_title, title=video_title,
status=VideoPresentationStatus.PENDING, status=VideoPresentationStatus.PENDING,
@ -145,8 +63,6 @@ def create_generate_video_presentation_tool(
user_prompt=user_prompt, user_prompt=user_prompt,
) )
set_generating_video_presentation(search_space_id, video_pres.id)
print( print(
f"[generate_video_presentation] Created video presentation {video_pres.id}, task: {task.id}" f"[generate_video_presentation] Created video presentation {video_pres.id}, task: {task.id}"
) )

View file

@ -2,6 +2,7 @@ from langgraph.graph import StateGraph
from .configuration import Configuration from .configuration import Configuration
from .nodes import ( from .nodes import (
assign_slide_themes,
create_presentation_slides, create_presentation_slides,
create_slide_audio, create_slide_audio,
generate_slide_scene_codes, generate_slide_scene_codes,
@ -14,11 +15,19 @@ def build_graph():
workflow.add_node("create_presentation_slides", create_presentation_slides) workflow.add_node("create_presentation_slides", create_presentation_slides)
workflow.add_node("create_slide_audio", create_slide_audio) workflow.add_node("create_slide_audio", create_slide_audio)
workflow.add_node("assign_slide_themes", assign_slide_themes)
workflow.add_node("generate_slide_scene_codes", generate_slide_scene_codes) workflow.add_node("generate_slide_scene_codes", generate_slide_scene_codes)
# Fan-out: after slides are parsed, run audio generation and theme
# assignment in parallel (themes only need slide metadata, not audio).
workflow.add_edge("__start__", "create_presentation_slides") workflow.add_edge("__start__", "create_presentation_slides")
workflow.add_edge("create_presentation_slides", "create_slide_audio") workflow.add_edge("create_presentation_slides", "create_slide_audio")
workflow.add_edge("create_presentation_slides", "assign_slide_themes")
# Fan-in: scene code generation waits for both audio and themes.
workflow.add_edge("create_slide_audio", "generate_slide_scene_codes") workflow.add_edge("create_slide_audio", "generate_slide_scene_codes")
workflow.add_edge("assign_slide_themes", "generate_slide_scene_codes")
workflow.add_edge("generate_slide_scene_codes", "__end__") workflow.add_edge("generate_slide_scene_codes", "__end__")
graph = workflow.compile() graph = workflow.compile()

View file

@ -178,18 +178,29 @@ async def create_slide_audio(state: State, config: RunnableConfig) -> dict[str,
chunk_paths: list[str] = [] chunk_paths: list[str] = []
try: try:
for i, text in enumerate(slide.speaker_transcripts): chunk_paths = [
chunk_path = str( str(
temp_dir temp_dir
/ f"{session_id}_slide_{slide.slide_number}_chunk_{i}.{ext}" / f"{session_id}_slide_{slide.slide_number}_chunk_{i}.{ext}"
) )
for i in range(len(slide.speaker_transcripts))
]
for i, text in enumerate(slide.speaker_transcripts):
print( print(
f" Slide {slide.slide_number} chunk {i + 1}/" f" Slide {slide.slide_number} chunk {i + 1}/"
f"{len(slide.speaker_transcripts)}: " f"{len(slide.speaker_transcripts)}: "
f'"{text[:60]}..."' f'"{text[:60]}..."'
) )
await _generate_tts_chunk(text, chunk_path)
chunk_paths.append(chunk_path) await asyncio.gather(
*[
_generate_tts_chunk(text, path)
for text, path in zip(
slide.speaker_transcripts, chunk_paths, strict=False
)
]
)
if len(chunk_paths) == 1: if len(chunk_paths) == 1:
shutil.move(chunk_paths[0], output_file) shutil.move(chunk_paths[0], output_file)
@ -340,13 +351,30 @@ async def _assign_themes_with_llm(
} }
async def assign_slide_themes(state: State, config: RunnableConfig) -> dict[str, Any]:
"""Assign a theme preset + dark/light mode to every slide via a single LLM call.
Runs in parallel with audio generation since it only needs slide metadata.
"""
configuration = Configuration.from_runnable_config(config)
search_space_id = configuration.search_space_id
llm = await get_agent_llm(state.db_session, search_space_id)
if not llm:
raise RuntimeError(f"No LLM configured for search space {search_space_id}")
slides = state.slides or []
assignments = await _assign_themes_with_llm(llm, slides)
return {"slide_theme_assignments": assignments}
async def generate_slide_scene_codes( async def generate_slide_scene_codes(
state: State, config: RunnableConfig state: State, config: RunnableConfig
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Generate Remotion component code for each slide using LLM. """Generate Remotion component code for each slide using LLM.
First assigns a theme+mode to every slide via a single LLM call, Reads pre-assigned themes from state (produced by the parallel
then generates scene code per slide with the assigned theme. assign_slide_themes node) and generates scene code concurrently.
""" """
configuration = Configuration.from_runnable_config(config) configuration = Configuration.from_runnable_config(config)
@ -362,11 +390,9 @@ async def generate_slide_scene_codes(
audio_map: dict[int, SlideAudioResult] = {r.slide_number: r for r in audio_results} audio_map: dict[int, SlideAudioResult] = {r.slide_number: r for r in audio_results}
total_slides = len(slides) total_slides = len(slides)
theme_assignments = await _assign_themes_with_llm(llm, slides) theme_assignments = state.slide_theme_assignments or {}
scene_codes: list[SlideSceneCode] = [] async def _generate_scene_for_slide(slide: SlideContent) -> SlideSceneCode:
for slide in slides:
audio = audio_map.get(slide.slide_number) audio = audio_map.get(slide.slide_number)
duration = audio.duration_in_frames if audio else DEFAULT_DURATION_IN_FRAMES duration = audio.duration_in_frames if audio else DEFAULT_DURATION_IN_FRAMES
@ -402,15 +428,17 @@ async def generate_slide_scene_codes(
code = await _refine_if_needed(llm, code, slide.slide_number) code = await _refine_if_needed(llm, code, slide.slide_number)
scene_codes.append( print(f"Scene code ready for slide {slide.slide_number} ({len(code)} chars)")
SlideSceneCode(
slide_number=slide.slide_number, return SlideSceneCode(
code=code, slide_number=slide.slide_number,
title=scene_title or slide.title, code=code,
) title=scene_title or slide.title,
) )
print(f"Scene code ready for slide {slide.slide_number} ({len(code)} chars)") scene_codes = list(
await asyncio.gather(*[_generate_scene_for_slide(s) for s in slides])
)
return {"slide_scene_codes": scene_codes} return {"slide_scene_codes": scene_codes}

View file

@ -60,7 +60,7 @@ class SlideSceneCode(BaseModel):
class State: class State:
"""State for the video presentation agent graph. """State for the video presentation agent graph.
Pipeline: parse slides generate per-slide TTS audio generate per-slide Remotion code Pipeline: parse slides (TTS audio theme assignment) generate Remotion code
The frontend receives the slides + code + audio and handles compilation/rendering. The frontend receives the slides + code + audio and handles compilation/rendering.
""" """
@ -69,4 +69,5 @@ class State:
slides: list[SlideContent] | None = None slides: list[SlideContent] | None = None
slide_audio_results: list[SlideAudioResult] | None = None slide_audio_results: list[SlideAudioResult] | None = None
slide_theme_assignments: dict[int, tuple[str, str]] | None = None
slide_scene_codes: list[SlideSceneCode] | None = None slide_scene_codes: list[SlideSceneCode] | None = None

View file

@ -9,7 +9,6 @@ from sqlalchemy import select
from app.agents.podcaster.graph import graph as podcaster_graph from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State as PodcasterState from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app from app.celery_app import celery_app
from app.config import config
from app.db import Podcast, PodcastStatus from app.db import Podcast, PodcastStatus
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker
@ -29,21 +28,6 @@ if sys.platform.startswith("win"):
# ============================================================================= # =============================================================================
def _clear_generating_podcast(search_space_id: int) -> None:
"""Clear the generating podcast marker from Redis when task completes."""
import redis
try:
client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
key = f"podcast:generating:{search_space_id}"
client.delete(key)
logger.info(
f"Cleared generating podcast key for search_space_id={search_space_id}"
)
except Exception as e:
logger.warning(f"Could not clear generating podcast key: {e}")
@celery_app.task(name="generate_content_podcast", bind=True) @celery_app.task(name="generate_content_podcast", bind=True)
def generate_content_podcast_task( def generate_content_podcast_task(
self, self,
@ -75,7 +59,6 @@ def generate_content_podcast_task(
loop.run_until_complete(_mark_podcast_failed(podcast_id)) loop.run_until_complete(_mark_podcast_failed(podcast_id))
return {"status": "failed", "podcast_id": podcast_id} return {"status": "failed", "podcast_id": podcast_id}
finally: finally:
_clear_generating_podcast(search_space_id)
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
loop.close() loop.close()

View file

@ -9,7 +9,6 @@ from sqlalchemy import select
from app.agents.video_presentation.graph import graph as video_presentation_graph from app.agents.video_presentation.graph import graph as video_presentation_graph
from app.agents.video_presentation.state import State as VideoPresentationState from app.agents.video_presentation.state import State as VideoPresentationState
from app.celery_app import celery_app from app.celery_app import celery_app
from app.config import config
from app.db import VideoPresentation, VideoPresentationStatus from app.db import VideoPresentation, VideoPresentationStatus
from app.tasks.celery_tasks import get_celery_session_maker from app.tasks.celery_tasks import get_celery_session_maker
@ -24,21 +23,6 @@ if sys.platform.startswith("win"):
) )
def _clear_generating_video_presentation(search_space_id: int) -> None:
"""Clear the generating video presentation marker from Redis when task completes."""
import redis
try:
client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
key = f"video_presentation:generating:{search_space_id}"
client.delete(key)
logger.info(
f"Cleared generating video presentation key for search_space_id={search_space_id}"
)
except Exception as e:
logger.warning(f"Could not clear generating video presentation key: {e}")
@celery_app.task(name="generate_video_presentation", bind=True) @celery_app.task(name="generate_video_presentation", bind=True)
def generate_video_presentation_task( def generate_video_presentation_task(
self, self,
@ -70,7 +54,6 @@ def generate_video_presentation_task(
loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id)) loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id))
return {"status": "failed", "video_presentation_id": video_presentation_id} return {"status": "failed", "video_presentation_id": video_presentation_id}
finally: finally:
_clear_generating_video_presentation(search_space_id)
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
loop.close() loop.close()

View file

@ -2,7 +2,7 @@
import React, { useMemo } from "react"; import React, { useMemo } from "react";
import { Player } from "@remotion/player"; import { Player } from "@remotion/player";
import { Sequence, AbsoluteFill } from "remotion"; import { Sequence, AbsoluteFill, useCurrentFrame, useVideoConfig, interpolate } from "remotion";
import { Audio } from "@remotion/media"; import { Audio } from "@remotion/media";
import { FPS } from "@/lib/remotion/constants"; import { FPS } from "@/lib/remotion/constants";
@ -14,6 +14,68 @@ export interface CompiledSlide {
audioUrl?: string; audioUrl?: string;
} }
const WATERMARK_STYLES = {
container: {
position: "absolute" as const,
bottom: 28,
right: 36,
display: "flex",
alignItems: "center",
gap: 8,
padding: "6px 14px 6px 10px",
borderRadius: 9999,
background: "rgba(0, 0, 0, 0.35)",
backdropFilter: "blur(12px)",
WebkitBackdropFilter: "blur(12px)",
border: "1px solid rgba(255, 255, 255, 0.12)",
boxShadow: "0 2px 8px rgba(0, 0, 0, 0.15)",
pointerEvents: "none" as const,
zIndex: 9999,
},
logo: {
width: 22,
height: 22,
filter: "brightness(0) invert(1)",
},
text: {
fontFamily: "Inter, system-ui, -apple-system, sans-serif",
fontSize: 15,
fontWeight: 600,
color: "rgba(255, 255, 255, 0.95)",
letterSpacing: "0.01em",
lineHeight: 1,
},
};
function Watermark() {
const frame = useCurrentFrame();
const { fps } = useVideoConfig();
const opacity = interpolate(frame, [0, fps * 0.5], [0, 1], {
extrapolateRight: "clamp",
});
return (
<div style={{ ...WATERMARK_STYLES.container, opacity }}>
{/* eslint-disable-next-line @next/next/no-img-element */}
<img src="/icon-128.svg" alt="" style={WATERMARK_STYLES.logo} />
<span style={WATERMARK_STYLES.text}>SurfSense</span>
</div>
);
}
export function buildSlideWithWatermark(
SlideComponent: React.ComponentType,
): React.FC {
const Wrapped: React.FC = () => (
<AbsoluteFill>
<SlideComponent />
<Watermark />
</AbsoluteFill>
);
return Wrapped;
}
function CombinedComposition({ scenes }: { scenes: CompiledSlide[] }) { function CombinedComposition({ scenes }: { scenes: CompiledSlide[] }) {
let offset = 0; let offset = 0;
@ -29,6 +91,7 @@ function CombinedComposition({ scenes }: { scenes: CompiledSlide[] }) {
</Sequence> </Sequence>
); );
})} })}
<Watermark />
</AbsoluteFill> </AbsoluteFill>
); );
} }

View file

@ -19,6 +19,7 @@ import { FPS } from "@/lib/remotion/constants";
import { import {
CombinedPlayer, CombinedPlayer,
buildCompositionComponent, buildCompositionComponent,
buildSlideWithWatermark,
type CompiledSlide, type CompiledSlide,
} from "./combined-player"; } from "./combined-player";
@ -397,11 +398,12 @@ function VideoPresentationPlayer({
const holdFrame = Math.floor(slide.durationInFrames * 0.3); const holdFrame = Math.floor(slide.durationInFrames * 0.3);
const root = createRoot(wrapper); const root = createRoot(wrapper);
const SlideWithWatermark = buildSlideWithWatermark(slide.component);
flushSync(() => { flushSync(() => {
root.render( root.render(
React.createElement(Thumbnail, { React.createElement(Thumbnail, {
component: slide.component, component: SlideWithWatermark,
compositionWidth: 1920, compositionWidth: 1920,
compositionHeight: 1080, compositionHeight: 1080,
frameToDisplay: holdFrame, frameToDisplay: holdFrame,