mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-11 00:32:38 +02:00
feat: enhance video presentation agent with parallel theme assignment and watermarking
This commit is contained in:
parent
0fe5e034fe
commit
d90b6d35ce
9 changed files with 123 additions and 197 deletions
|
|
@ -4,60 +4,15 @@ Podcast generation tool for the SurfSense agent.
|
|||
This module provides a factory function for creating the generate_podcast tool
|
||||
that submits a Celery task for background podcast generation. The frontend
|
||||
polls for completion and auto-updates when the podcast is ready.
|
||||
|
||||
Duplicate request prevention:
|
||||
- Only one podcast can be generated at a time per search space
|
||||
- Uses Redis to track active podcast tasks
|
||||
- Returns a friendly message if a podcast is already being generated
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
|
||||
# Redis connection for tracking active podcast tasks
|
||||
# Defaults to the Celery broker when REDIS_APP_URL is not set
|
||||
REDIS_URL = config.REDIS_APP_URL
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
|
||||
def get_redis_client() -> redis.Redis:
|
||||
"""Get or create Redis client for podcast task tracking."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def _redis_key(search_space_id: int) -> str:
|
||||
return f"podcast:generating:{search_space_id}"
|
||||
|
||||
|
||||
def get_generating_podcast_id(search_space_id: int) -> int | None:
|
||||
"""Get the podcast ID currently being generated for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
value = client.get(_redis_key(search_space_id))
|
||||
return int(value) if value else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def set_generating_podcast(search_space_id: int, podcast_id: int) -> None:
|
||||
"""Mark a podcast as currently generating for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
client.setex(_redis_key(search_space_id), 1800, str(podcast_id))
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[generate_podcast] Warning: Could not set generating podcast in Redis: {e}"
|
||||
)
|
||||
|
||||
|
||||
def create_generate_podcast_tool(
|
||||
search_space_id: int,
|
||||
|
|
@ -109,18 +64,6 @@ def create_generate_podcast_tool(
|
|||
- message: Status message (or "error" field if status is failed)
|
||||
"""
|
||||
try:
|
||||
generating_podcast_id = get_generating_podcast_id(search_space_id)
|
||||
if generating_podcast_id:
|
||||
print(
|
||||
f"[generate_podcast] Blocked duplicate request. Generating podcast: {generating_podcast_id}"
|
||||
)
|
||||
return {
|
||||
"status": PodcastStatus.GENERATING.value,
|
||||
"podcast_id": generating_podcast_id,
|
||||
"title": podcast_title,
|
||||
"message": "A podcast is already being generated. Please wait for it to complete.",
|
||||
}
|
||||
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
status=PodcastStatus.PENDING,
|
||||
|
|
@ -142,8 +85,6 @@ def create_generate_podcast_tool(
|
|||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
set_generating_podcast(search_space_id, podcast.id)
|
||||
|
||||
print(f"[generate_podcast] Created podcast {podcast.id}, task: {task.id}")
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -4,70 +4,15 @@ Video presentation generation tool for the SurfSense agent.
|
|||
This module provides a factory function for creating the generate_video_presentation
|
||||
tool that submits a Celery task for background video presentation generation.
|
||||
The frontend polls for completion and auto-updates when the presentation is ready.
|
||||
|
||||
Duplicate request prevention:
|
||||
- Only one video presentation can be generated at a time per search space
|
||||
- Uses Redis to track active video presentation tasks
|
||||
- Validates the Redis marker against actual DB status to avoid stale locks
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import VideoPresentation, VideoPresentationStatus
|
||||
|
||||
REDIS_URL = config.REDIS_APP_URL
|
||||
_redis_client: redis.Redis | None = None
|
||||
|
||||
|
||||
def get_redis_client() -> redis.Redis:
|
||||
"""Get or create Redis client for video presentation task tracking."""
|
||||
global _redis_client
|
||||
if _redis_client is None:
|
||||
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
return _redis_client
|
||||
|
||||
|
||||
def _redis_key(search_space_id: int) -> str:
|
||||
return f"video_presentation:generating:{search_space_id}"
|
||||
|
||||
|
||||
def get_generating_video_presentation_id(search_space_id: int) -> int | None:
|
||||
"""Get the video presentation ID currently being generated for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
value = client.get(_redis_key(search_space_id))
|
||||
return int(value) if value else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def clear_generating_video_presentation(search_space_id: int) -> None:
|
||||
"""Clear the generating marker (used when we detect a stale lock)."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
client.delete(_redis_key(search_space_id))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def set_generating_video_presentation(
|
||||
search_space_id: int, video_presentation_id: int
|
||||
) -> None:
|
||||
"""Mark a video presentation as currently generating for this search space."""
|
||||
try:
|
||||
client = get_redis_client()
|
||||
client.setex(_redis_key(search_space_id), 1800, str(video_presentation_id))
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[generate_video_presentation] Warning: Could not set generating video presentation in Redis: {e}"
|
||||
)
|
||||
|
||||
|
||||
def create_generate_video_presentation_tool(
|
||||
search_space_id: int,
|
||||
|
|
@ -97,33 +42,6 @@ def create_generate_video_presentation_tool(
|
|||
user_prompt: Optional style/tone instructions.
|
||||
"""
|
||||
try:
|
||||
generating_id = get_generating_video_presentation_id(search_space_id)
|
||||
if generating_id:
|
||||
result = await db_session.execute(
|
||||
select(VideoPresentation).filter(
|
||||
VideoPresentation.id == generating_id
|
||||
)
|
||||
)
|
||||
existing = result.scalars().first()
|
||||
|
||||
if existing and existing.status == VideoPresentationStatus.GENERATING:
|
||||
print(
|
||||
f"[generate_video_presentation] Blocked duplicate — "
|
||||
f"presentation {generating_id} is actively generating"
|
||||
)
|
||||
return {
|
||||
"status": VideoPresentationStatus.GENERATING.value,
|
||||
"video_presentation_id": generating_id,
|
||||
"title": video_title,
|
||||
"message": "A video presentation is already being generated. Please wait for it to complete.",
|
||||
}
|
||||
|
||||
print(
|
||||
f"[generate_video_presentation] Stale Redis lock for presentation {generating_id} "
|
||||
f"(status={existing.status if existing else 'not found'}). Clearing and proceeding."
|
||||
)
|
||||
clear_generating_video_presentation(search_space_id)
|
||||
|
||||
video_pres = VideoPresentation(
|
||||
title=video_title,
|
||||
status=VideoPresentationStatus.PENDING,
|
||||
|
|
@ -145,8 +63,6 @@ def create_generate_video_presentation_tool(
|
|||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
set_generating_video_presentation(search_space_id, video_pres.id)
|
||||
|
||||
print(
|
||||
f"[generate_video_presentation] Created video presentation {video_pres.id}, task: {task.id}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from langgraph.graph import StateGraph
|
|||
|
||||
from .configuration import Configuration
|
||||
from .nodes import (
|
||||
assign_slide_themes,
|
||||
create_presentation_slides,
|
||||
create_slide_audio,
|
||||
generate_slide_scene_codes,
|
||||
|
|
@ -14,11 +15,19 @@ def build_graph():
|
|||
|
||||
workflow.add_node("create_presentation_slides", create_presentation_slides)
|
||||
workflow.add_node("create_slide_audio", create_slide_audio)
|
||||
workflow.add_node("assign_slide_themes", assign_slide_themes)
|
||||
workflow.add_node("generate_slide_scene_codes", generate_slide_scene_codes)
|
||||
|
||||
# Fan-out: after slides are parsed, run audio generation and theme
|
||||
# assignment in parallel (themes only need slide metadata, not audio).
|
||||
workflow.add_edge("__start__", "create_presentation_slides")
|
||||
workflow.add_edge("create_presentation_slides", "create_slide_audio")
|
||||
workflow.add_edge("create_presentation_slides", "assign_slide_themes")
|
||||
|
||||
# Fan-in: scene code generation waits for both audio and themes.
|
||||
workflow.add_edge("create_slide_audio", "generate_slide_scene_codes")
|
||||
workflow.add_edge("assign_slide_themes", "generate_slide_scene_codes")
|
||||
|
||||
workflow.add_edge("generate_slide_scene_codes", "__end__")
|
||||
|
||||
graph = workflow.compile()
|
||||
|
|
|
|||
|
|
@ -178,18 +178,29 @@ async def create_slide_audio(state: State, config: RunnableConfig) -> dict[str,
|
|||
|
||||
chunk_paths: list[str] = []
|
||||
try:
|
||||
for i, text in enumerate(slide.speaker_transcripts):
|
||||
chunk_path = str(
|
||||
chunk_paths = [
|
||||
str(
|
||||
temp_dir
|
||||
/ f"{session_id}_slide_{slide.slide_number}_chunk_{i}.{ext}"
|
||||
)
|
||||
for i in range(len(slide.speaker_transcripts))
|
||||
]
|
||||
|
||||
for i, text in enumerate(slide.speaker_transcripts):
|
||||
print(
|
||||
f" Slide {slide.slide_number} chunk {i + 1}/"
|
||||
f"{len(slide.speaker_transcripts)}: "
|
||||
f'"{text[:60]}..."'
|
||||
)
|
||||
await _generate_tts_chunk(text, chunk_path)
|
||||
chunk_paths.append(chunk_path)
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
_generate_tts_chunk(text, path)
|
||||
for text, path in zip(
|
||||
slide.speaker_transcripts, chunk_paths, strict=False
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if len(chunk_paths) == 1:
|
||||
shutil.move(chunk_paths[0], output_file)
|
||||
|
|
@ -340,13 +351,30 @@ async def _assign_themes_with_llm(
|
|||
}
|
||||
|
||||
|
||||
async def assign_slide_themes(state: State, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""Assign a theme preset + dark/light mode to every slide via a single LLM call.
|
||||
|
||||
Runs in parallel with audio generation since it only needs slide metadata.
|
||||
"""
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
search_space_id = configuration.search_space_id
|
||||
|
||||
llm = await get_agent_llm(state.db_session, search_space_id)
|
||||
if not llm:
|
||||
raise RuntimeError(f"No LLM configured for search space {search_space_id}")
|
||||
|
||||
slides = state.slides or []
|
||||
assignments = await _assign_themes_with_llm(llm, slides)
|
||||
return {"slide_theme_assignments": assignments}
|
||||
|
||||
|
||||
async def generate_slide_scene_codes(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Generate Remotion component code for each slide using LLM.
|
||||
|
||||
First assigns a theme+mode to every slide via a single LLM call,
|
||||
then generates scene code per slide with the assigned theme.
|
||||
Reads pre-assigned themes from state (produced by the parallel
|
||||
assign_slide_themes node) and generates scene code concurrently.
|
||||
"""
|
||||
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
|
|
@ -362,11 +390,9 @@ async def generate_slide_scene_codes(
|
|||
audio_map: dict[int, SlideAudioResult] = {r.slide_number: r for r in audio_results}
|
||||
total_slides = len(slides)
|
||||
|
||||
theme_assignments = await _assign_themes_with_llm(llm, slides)
|
||||
theme_assignments = state.slide_theme_assignments or {}
|
||||
|
||||
scene_codes: list[SlideSceneCode] = []
|
||||
|
||||
for slide in slides:
|
||||
async def _generate_scene_for_slide(slide: SlideContent) -> SlideSceneCode:
|
||||
audio = audio_map.get(slide.slide_number)
|
||||
duration = audio.duration_in_frames if audio else DEFAULT_DURATION_IN_FRAMES
|
||||
|
||||
|
|
@ -402,15 +428,17 @@ async def generate_slide_scene_codes(
|
|||
|
||||
code = await _refine_if_needed(llm, code, slide.slide_number)
|
||||
|
||||
scene_codes.append(
|
||||
SlideSceneCode(
|
||||
slide_number=slide.slide_number,
|
||||
code=code,
|
||||
title=scene_title or slide.title,
|
||||
)
|
||||
print(f"Scene code ready for slide {slide.slide_number} ({len(code)} chars)")
|
||||
|
||||
return SlideSceneCode(
|
||||
slide_number=slide.slide_number,
|
||||
code=code,
|
||||
title=scene_title or slide.title,
|
||||
)
|
||||
|
||||
print(f"Scene code ready for slide {slide.slide_number} ({len(code)} chars)")
|
||||
scene_codes = list(
|
||||
await asyncio.gather(*[_generate_scene_for_slide(s) for s in slides])
|
||||
)
|
||||
|
||||
return {"slide_scene_codes": scene_codes}
|
||||
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class SlideSceneCode(BaseModel):
|
|||
class State:
|
||||
"""State for the video presentation agent graph.
|
||||
|
||||
Pipeline: parse slides → generate per-slide TTS audio → generate per-slide Remotion code
|
||||
Pipeline: parse slides → (TTS audio ∥ theme assignment) → generate Remotion code
|
||||
The frontend receives the slides + code + audio and handles compilation/rendering.
|
||||
"""
|
||||
|
||||
|
|
@ -69,4 +69,5 @@ class State:
|
|||
|
||||
slides: list[SlideContent] | None = None
|
||||
slide_audio_results: list[SlideAudioResult] | None = None
|
||||
slide_theme_assignments: dict[int, tuple[str, str]] | None = None
|
||||
slide_scene_codes: list[SlideSceneCode] | None = None
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from sqlalchemy import select
|
|||
from app.agents.podcaster.graph import graph as podcaster_graph
|
||||
from app.agents.podcaster.state import State as PodcasterState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import Podcast, PodcastStatus
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
|
|
@ -29,21 +28,6 @@ if sys.platform.startswith("win"):
|
|||
# =============================================================================
|
||||
|
||||
|
||||
def _clear_generating_podcast(search_space_id: int) -> None:
|
||||
"""Clear the generating podcast marker from Redis when task completes."""
|
||||
import redis
|
||||
|
||||
try:
|
||||
client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
key = f"podcast:generating:{search_space_id}"
|
||||
client.delete(key)
|
||||
logger.info(
|
||||
f"Cleared generating podcast key for search_space_id={search_space_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not clear generating podcast key: {e}")
|
||||
|
||||
|
||||
@celery_app.task(name="generate_content_podcast", bind=True)
|
||||
def generate_content_podcast_task(
|
||||
self,
|
||||
|
|
@ -75,7 +59,6 @@ def generate_content_podcast_task(
|
|||
loop.run_until_complete(_mark_podcast_failed(podcast_id))
|
||||
return {"status": "failed", "podcast_id": podcast_id}
|
||||
finally:
|
||||
_clear_generating_podcast(search_space_id)
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from sqlalchemy import select
|
|||
from app.agents.video_presentation.graph import graph as video_presentation_graph
|
||||
from app.agents.video_presentation.state import State as VideoPresentationState
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config
|
||||
from app.db import VideoPresentation, VideoPresentationStatus
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
|
|
@ -24,21 +23,6 @@ if sys.platform.startswith("win"):
|
|||
)
|
||||
|
||||
|
||||
def _clear_generating_video_presentation(search_space_id: int) -> None:
|
||||
"""Clear the generating video presentation marker from Redis when task completes."""
|
||||
import redis
|
||||
|
||||
try:
|
||||
client = redis.from_url(config.REDIS_APP_URL, decode_responses=True)
|
||||
key = f"video_presentation:generating:{search_space_id}"
|
||||
client.delete(key)
|
||||
logger.info(
|
||||
f"Cleared generating video presentation key for search_space_id={search_space_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not clear generating video presentation key: {e}")
|
||||
|
||||
|
||||
@celery_app.task(name="generate_video_presentation", bind=True)
|
||||
def generate_video_presentation_task(
|
||||
self,
|
||||
|
|
@ -70,7 +54,6 @@ def generate_video_presentation_task(
|
|||
loop.run_until_complete(_mark_video_presentation_failed(video_presentation_id))
|
||||
return {"status": "failed", "video_presentation_id": video_presentation_id}
|
||||
finally:
|
||||
_clear_generating_video_presentation(search_space_id)
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import React, { useMemo } from "react";
|
||||
import { Player } from "@remotion/player";
|
||||
import { Sequence, AbsoluteFill } from "remotion";
|
||||
import { Sequence, AbsoluteFill, useCurrentFrame, useVideoConfig, interpolate } from "remotion";
|
||||
import { Audio } from "@remotion/media";
|
||||
import { FPS } from "@/lib/remotion/constants";
|
||||
|
||||
|
|
@ -14,6 +14,68 @@ export interface CompiledSlide {
|
|||
audioUrl?: string;
|
||||
}
|
||||
|
||||
const WATERMARK_STYLES = {
|
||||
container: {
|
||||
position: "absolute" as const,
|
||||
bottom: 28,
|
||||
right: 36,
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
gap: 8,
|
||||
padding: "6px 14px 6px 10px",
|
||||
borderRadius: 9999,
|
||||
background: "rgba(0, 0, 0, 0.35)",
|
||||
backdropFilter: "blur(12px)",
|
||||
WebkitBackdropFilter: "blur(12px)",
|
||||
border: "1px solid rgba(255, 255, 255, 0.12)",
|
||||
boxShadow: "0 2px 8px rgba(0, 0, 0, 0.15)",
|
||||
pointerEvents: "none" as const,
|
||||
zIndex: 9999,
|
||||
},
|
||||
logo: {
|
||||
width: 22,
|
||||
height: 22,
|
||||
filter: "brightness(0) invert(1)",
|
||||
},
|
||||
text: {
|
||||
fontFamily: "Inter, system-ui, -apple-system, sans-serif",
|
||||
fontSize: 15,
|
||||
fontWeight: 600,
|
||||
color: "rgba(255, 255, 255, 0.95)",
|
||||
letterSpacing: "0.01em",
|
||||
lineHeight: 1,
|
||||
},
|
||||
};
|
||||
|
||||
function Watermark() {
|
||||
const frame = useCurrentFrame();
|
||||
const { fps } = useVideoConfig();
|
||||
|
||||
const opacity = interpolate(frame, [0, fps * 0.5], [0, 1], {
|
||||
extrapolateRight: "clamp",
|
||||
});
|
||||
|
||||
return (
|
||||
<div style={{ ...WATERMARK_STYLES.container, opacity }}>
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img src="/icon-128.svg" alt="" style={WATERMARK_STYLES.logo} />
|
||||
<span style={WATERMARK_STYLES.text}>SurfSense</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function buildSlideWithWatermark(
|
||||
SlideComponent: React.ComponentType,
|
||||
): React.FC {
|
||||
const Wrapped: React.FC = () => (
|
||||
<AbsoluteFill>
|
||||
<SlideComponent />
|
||||
<Watermark />
|
||||
</AbsoluteFill>
|
||||
);
|
||||
return Wrapped;
|
||||
}
|
||||
|
||||
function CombinedComposition({ scenes }: { scenes: CompiledSlide[] }) {
|
||||
let offset = 0;
|
||||
|
||||
|
|
@ -29,6 +91,7 @@ function CombinedComposition({ scenes }: { scenes: CompiledSlide[] }) {
|
|||
</Sequence>
|
||||
);
|
||||
})}
|
||||
<Watermark />
|
||||
</AbsoluteFill>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import { FPS } from "@/lib/remotion/constants";
|
|||
import {
|
||||
CombinedPlayer,
|
||||
buildCompositionComponent,
|
||||
buildSlideWithWatermark,
|
||||
type CompiledSlide,
|
||||
} from "./combined-player";
|
||||
|
||||
|
|
@ -397,11 +398,12 @@ function VideoPresentationPlayer({
|
|||
|
||||
const holdFrame = Math.floor(slide.durationInFrames * 0.3);
|
||||
const root = createRoot(wrapper);
|
||||
const SlideWithWatermark = buildSlideWithWatermark(slide.component);
|
||||
|
||||
flushSync(() => {
|
||||
root.render(
|
||||
React.createElement(Thumbnail, {
|
||||
component: slide.component,
|
||||
component: SlideWithWatermark,
|
||||
compositionWidth: 1920,
|
||||
compositionHeight: 1080,
|
||||
frameToDisplay: holdFrame,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue