SurfSense/surfsense_backend/app/agents/podcaster/nodes.py

196 lines
6.9 KiB
Python
Raw Normal View History

import asyncio
import json
import os
import uuid
from pathlib import Path
from typing import Any
from ffmpeg.asyncio import FFmpeg
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from litellm import aspeech
from app.config import config as app_config
from app.services.kokoro_tts_service import get_kokoro_tts_service
from app.services.llm_service import get_agent_llm
from app.utils.content_utils import extract_text_content, strip_markdown_fences
from .configuration import Configuration
from .prompts import get_podcast_generation_prompt
from .state import PodcastTranscriptEntry, PodcastTranscripts, State
from .utils import get_voice_for_provider
async def create_podcast_transcript(
state: State, config: RunnableConfig
) -> dict[str, Any]:
"""Generate the podcast transcript from the source content."""
2025-06-09 15:50:15 -07:00
configuration = Configuration.from_runnable_config(config)
search_space_id = configuration.search_space_id
user_prompt = configuration.user_prompt
llm = await get_agent_llm(state.db_session, search_space_id)
2025-06-09 15:50:15 -07:00
if not llm:
error_message = f"No agent LLM configured for search space {search_space_id}"
2025-06-09 15:50:15 -07:00
print(error_message)
raise RuntimeError(error_message)
prompt = get_podcast_generation_prompt(user_prompt)
messages = [
SystemMessage(content=prompt),
HumanMessage(
content=f"<source_content>{state.source_content}</source_content>"
),
]
llm_response = await llm.ainvoke(messages)
# Reasoning models may return content as blocks; normalise to a string.
content = strip_markdown_fences(extract_text_content(llm_response.content))
try:
podcast_transcript = PodcastTranscripts.model_validate(json.loads(content))
except (json.JSONDecodeError, TypeError, ValueError) as e:
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
try:
json_start = content.find("{")
json_end = content.rfind("}") + 1
if json_start >= 0 and json_end > json_start:
json_str = content[json_start:json_end]
parsed_data = json.loads(json_str)
podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
print("Successfully parsed podcast transcript using fallback approach")
else:
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
print(error_message)
raise ValueError(error_message)
except (json.JSONDecodeError, TypeError, ValueError) as e2:
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
print(f"Error parsing LLM response: {e2!s}")
print(f"Raw response: {content}")
raise
return {"podcast_transcript": podcast_transcript.podcast_transcripts}
async def create_merged_podcast_audio(
state: State, config: RunnableConfig
) -> dict[str, Any]:
"""Generate audio for each transcript and merge them into a single podcast file."""
starting_transcript = PodcastTranscriptEntry(
speaker_id=1, dialog="Welcome to Surfsense Podcast."
)
transcript = state.podcast_transcript
# transcript may be a PodcastTranscripts object or already a list.
if hasattr(transcript, "podcast_transcripts"):
transcript_entries = transcript.podcast_transcripts
else:
transcript_entries = transcript
merged_transcript = [starting_transcript, *transcript_entries]
temp_dir = Path("temp_audio")
temp_dir.mkdir(exist_ok=True)
session_id = str(uuid.uuid4())
output_path = f"podcasts/{session_id}_podcast.mp3"
os.makedirs("podcasts", exist_ok=True)
audio_files = []
async def generate_speech_for_segment(segment, index):
if hasattr(segment, "speaker_id"):
speaker_id = segment.speaker_id
dialog = segment.dialog
else:
speaker_id = segment.get("speaker_id", 0)
dialog = segment.get("dialog", "")
voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id)
if app_config.TTS_SERVICE == "local/kokoro":
filename = f"{temp_dir}/{session_id}_{index}.wav"
else:
filename = f"{temp_dir}/{session_id}_{index}.mp3"
try:
if app_config.TTS_SERVICE == "local/kokoro":
kokoro_service = await get_kokoro_tts_service(
lang_code="a"
) # American English
audio_path = await kokoro_service.generate_speech(
text=dialog, voice=voice, speed=1.0, output_path=filename
)
return audio_path
else:
if app_config.TTS_SERVICE_API_BASE:
response = await aspeech(
model=app_config.TTS_SERVICE,
api_base=app_config.TTS_SERVICE_API_BASE,
api_key=app_config.TTS_SERVICE_API_KEY,
voice=voice,
input=dialog,
max_retries=2,
timeout=600,
)
else:
response = await aspeech(
model=app_config.TTS_SERVICE,
api_key=app_config.TTS_SERVICE_API_KEY,
voice=voice,
input=dialog,
max_retries=2,
timeout=600,
)
with open(filename, "wb") as f:
f.write(response.content)
return filename
except Exception as e:
print(f"Error generating speech for segment {index}: {e!s}")
raise
tasks = [
generate_speech_for_segment(segment, i)
for i, segment in enumerate(merged_transcript)
]
audio_files = await asyncio.gather(*tasks)
try:
ffmpeg = FFmpeg().option("y")
for audio_file in audio_files:
ffmpeg = ffmpeg.input(audio_file)
filter_complex = []
for i in range(len(audio_files)):
filter_complex.append(f"[{i}:0]")
filter_complex_str = (
"".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
)
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
ffmpeg = ffmpeg.output(output_path, map="[outa]")
await ffmpeg.execute()
print(f"Successfully created podcast audio: {output_path}")
except Exception as e:
print(f"Error merging audio files: {e!s}")
raise
finally:
for audio_file in audio_files:
try:
os.remove(audio_file)
except Exception as e:
print(f"Error removing audio file {audio_file}: {e!s}")
pass
return {
"podcast_transcript": merged_transcript,
"final_podcast_file_path": output_path,
}