feat: Added Podcast Feature and its actually fast.

- Fully Async
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-05-05 23:18:12 -07:00
parent 10d56acaa8
commit b4bee887bd
19 changed files with 1676 additions and 75 deletions

View file

@ -0,0 +1,44 @@
"""Change podcast_content to podcast_transcript with JSON type
Revision ID: 6
Revises: 5
Create Date: 2023-08-15 00:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSON
# revision identifiers, used by Alembic.
revision: str = '6'
down_revision: Union[str, None] = '5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Drop the old column and create a new one with the new name and type
# We need to do this because PostgreSQL doesn't support direct column renames with type changes
op.add_column('podcasts', sa.Column('podcast_transcript', JSON, nullable=False, server_default='{}'))
# Copy data from old column to new column
# Convert text to JSON by storing it as a JSON string value
op.execute("UPDATE podcasts SET podcast_transcript = jsonb_build_object('text', podcast_content) WHERE podcast_content != ''")
# Drop the old column
op.drop_column('podcasts', 'podcast_content')
def downgrade() -> None:
# Add back the original column
op.add_column('podcasts', sa.Column('podcast_content', sa.Text(), nullable=False, server_default=''))
# Copy data from JSON column back to text column
# Extract the 'text' field if it exists, otherwise use empty string
op.execute("UPDATE podcasts SET podcast_content = COALESCE((podcast_transcript->>'text'), '')")
# Drop the new column
op.drop_column('podcasts', 'podcast_transcript')

View file

@ -0,0 +1,28 @@
"""Remove is_generated column from podcasts table
Revision ID: 7
Revises: 6
Create Date: 2023-08-15 01:00:00.000000
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '7'
down_revision: Union[str, None] = '6'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Drop the is_generated column
op.drop_column('podcasts', 'is_generated')
def downgrade() -> None:
# Add back the is_generated column with its original constraints
op.add_column('podcasts', sa.Column('is_generated', sa.Boolean(), nullable=False, server_default='false'))

View file

@ -6,18 +6,26 @@ from .state import State
from .nodes import create_merged_podcast_audio, create_podcast_transcript
# Define a new graph
workflow = StateGraph(State, config_schema=Configuration)
# Add the node to the graph
workflow.add_node("create_podcast_transcript", create_podcast_transcript)
workflow.add_node("create_merged_podcast_audio", create_merged_podcast_audio)
def build_graph():
# Define a new graph
workflow = StateGraph(State, config_schema=Configuration)
# Set the entrypoint as `call_model`
workflow.add_edge("__start__", "create_podcast_transcript")
workflow.add_edge("create_podcast_transcript", "create_merged_podcast_audio")
workflow.add_edge("create_merged_podcast_audio", "__end__")
# Add the node to the graph
workflow.add_node("create_podcast_transcript", create_podcast_transcript)
workflow.add_node("create_merged_podcast_audio", create_merged_podcast_audio)
# Compile the workflow into an executable graph
graph = workflow.compile()
graph.name = "Surfsense Podcaster" # This defines the custom name in LangSmith
# Set the entrypoint as `call_model`
workflow.add_edge("__start__", "create_podcast_transcript")
workflow.add_edge("create_podcast_transcript", "create_merged_podcast_audio")
workflow.add_edge("create_merged_podcast_audio", "__end__")
# Compile the workflow into an executable graph
graph = workflow.compile()
graph.name = "Surfsense Podcaster" # This defines the custom name in LangSmith
return graph
# Compile the graph once when the module is loaded
graph = build_graph()

View file

@ -28,7 +28,7 @@ async def create_podcast_transcript(state: State, config: RunnableConfig) -> Dic
# Create the messages
messages = [
SystemMessage(content=prompt),
HumanMessage(content=state.source_content)
HumanMessage(content=f"<source_content>{state.source_content}</source_content>")
]
# Generate the podcast transcript

View file

@ -106,6 +106,6 @@ Output:
}}
</examples>
Transform the source material into a lively and engaging podcast conversation. Craft dialogue that showcases authentic host chemistry and natural interaction (including occasional disagreement, building on points, or asking follow-up questions). Use varied speech patterns reflecting real human conversation, ensuring the final script effectively educates *and* entertains the listener while keeping within a 3-minute audio duration.
Transform the source material into a lively and engaging podcast conversation. Craft dialogue that showcases authentic host chemistry and natural interaction (including occasional disagreement, building on points, or asking follow-up questions). Use varied speech patterns reflecting real human conversation, ensuring the final script effectively educates *and* entertains the listener while keeping within a 5-minute audio duration.
</podcast_generation_system>
"""

View file

@ -110,8 +110,7 @@ class Podcast(BaseModel, TimestampMixin):
__tablename__ = "podcasts"
title = Column(String, nullable=False, index=True)
is_generated = Column(Boolean, nullable=False, default=False)
podcast_content = Column(Text, nullable=False, default="")
podcast_transcript = Column(JSON, nullable=False, default={})
file_location = Column(String(500), nullable=False, default="")
search_space_id = Column(Integer, ForeignKey("searchspaces.id", ondelete='CASCADE'), nullable=False)

View file

@ -1,12 +1,16 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from typing import List
from app.db import get_async_session, User, SearchSpace, Podcast
from app.schemas import PodcastCreate, PodcastUpdate, PodcastRead
from app.db import get_async_session, User, SearchSpace, Podcast, Chat
from app.schemas import PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
from app.users import current_active_user
from app.utils.check_ownership import check_ownership
from app.tasks.podcast_tasks import generate_chat_podcast
from fastapi.responses import StreamingResponse
import os
from pathlib import Path
router = APIRouter()
@ -119,4 +123,121 @@ async def delete_podcast(
raise he
except SQLAlchemyError:
await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while deleting podcast")
raise HTTPException(status_code=500, detail="Database error occurred while deleting podcast")
async def generate_chat_podcast_with_new_session(
chat_id: int,
search_space_id: int,
podcast_title: str = "SurfSense Podcast"
):
"""Create a new session and process chat podcast generation."""
from app.db import async_session_maker
async with async_session_maker() as session:
try:
await generate_chat_podcast(session, chat_id, search_space_id, podcast_title)
except Exception as e:
import logging
logging.error(f"Error generating podcast from chat: {str(e)}")
@router.post("/podcasts/generate/")
async def generate_podcast(
request: PodcastGenerateRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
fastapi_background_tasks: BackgroundTasks = BackgroundTasks()
):
try:
# Check if the user owns the search space
await check_ownership(session, SearchSpace, request.search_space_id, user)
if request.type == "CHAT":
# Verify that all chat IDs belong to this user and search space
query = select(Chat).filter(
Chat.id.in_(request.ids),
Chat.search_space_id == request.search_space_id
).join(SearchSpace).filter(SearchSpace.user_id == user.id)
result = await session.execute(query)
valid_chats = result.scalars().all()
valid_chat_ids = [chat.id for chat in valid_chats]
# If any requested ID is not in valid IDs, raise error immediately
if len(valid_chat_ids) != len(request.ids):
raise HTTPException(
status_code=403,
detail="One or more chat IDs do not belong to this user or search space"
)
# Only add a single task with the first chat ID
for chat_id in valid_chat_ids:
fastapi_background_tasks.add_task(
generate_chat_podcast_with_new_session,
chat_id,
request.search_space_id,
request.podcast_title
)
return {
"message": "Podcast generation started",
}
except HTTPException as he:
raise he
except IntegrityError as e:
await session.rollback()
raise HTTPException(status_code=400, detail="Podcast generation failed due to constraint violation")
except SQLAlchemyError as e:
await session.rollback()
raise HTTPException(status_code=500, detail="Database error occurred while generating podcast")
except Exception as e:
await session.rollback()
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
@router.get("/podcasts/{podcast_id}/stream")
async def stream_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user)
):
"""Stream a podcast audio file."""
try:
# Get the podcast and check if user has access
result = await session.execute(
select(Podcast)
.join(SearchSpace)
.filter(Podcast.id == podcast_id, SearchSpace.user_id == user.id)
)
podcast = result.scalars().first()
if not podcast:
raise HTTPException(
status_code=404,
detail="Podcast not found or you don't have permission to access it"
)
# Get the file path
file_path = podcast.file_location
# Check if the file exists
if not os.path.isfile(file_path):
raise HTTPException(status_code=404, detail="Podcast audio file not found")
# Define a generator function to stream the file
def iterfile():
with open(file_path, mode="rb") as file_like:
yield from file_like
# Return a streaming response with appropriate headers
return StreamingResponse(
iterfile(),
media_type="audio/mpeg",
headers={
"Accept-Ranges": "bytes",
"Content-Disposition": f"inline; filename={Path(file_path).name}"
}
)
except HTTPException as he:
raise he
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error streaming podcast: {str(e)}")

View file

@ -10,7 +10,7 @@ from .documents import (
DocumentRead,
)
from .chunks import ChunkBase, ChunkCreate, ChunkUpdate, ChunkRead
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead
from .podcasts import PodcastBase, PodcastCreate, PodcastUpdate, PodcastRead, PodcastGenerateRequest
from .chats import ChatBase, ChatCreate, ChatUpdate, ChatRead, AISDKChatRequest
from .search_source_connector import SearchSourceConnectorBase, SearchSourceConnectorCreate, SearchSourceConnectorUpdate, SearchSourceConnectorRead
@ -39,6 +39,7 @@ __all__ = [
"PodcastCreate",
"PodcastUpdate",
"PodcastRead",
"PodcastGenerateRequest",
"ChatBase",
"ChatCreate",
"ChatUpdate",

View file

@ -1,8 +1,10 @@
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from sqlalchemy import JSON
from .base import IDModel, TimestampModel
from app.db import ChatType
from pydantic import BaseModel
from .base import IDModel, TimestampModel
class ChatBase(BaseModel):
type: ChatType

View file

@ -1,10 +1,10 @@
from pydantic import BaseModel
from typing import Any, List, Literal
from .base import IDModel, TimestampModel
class PodcastBase(BaseModel):
title: str
is_generated: bool = False
podcast_content: str = ""
podcast_transcript: List[Any]
file_location: str = ""
search_space_id: int
@ -16,4 +16,10 @@ class PodcastUpdate(PodcastBase):
class PodcastRead(PodcastBase, IDModel, TimestampModel):
class Config:
from_attributes = True
from_attributes = True
class PodcastGenerateRequest(BaseModel):
type: Literal["DOCUMENT", "CHAT"]
ids: List[int]
search_space_id: int
podcast_title: str = "SurfSense Podcast"

View file

@ -0,0 +1,94 @@
from sqlalchemy.ext.asyncio import AsyncSession
from app.schemas import PodcastGenerateRequest
from typing import List
from sqlalchemy import select
from app.db import Chat, Podcast
from app.agents.podcaster.graph import graph as podcaster_graph
from surfsense_backend.app.agents.podcaster.state import State
async def generate_document_podcast(
session: AsyncSession,
document_id: int,
search_space_id: int,
user_id: int
):
# TODO: Need to fetch the document chunks, then concatenate them and pass them to the podcast generation model
pass
async def generate_chat_podcast(
session: AsyncSession,
chat_id: int,
search_space_id: int,
podcast_title: str
):
# Fetch the chat with the specified ID
query = select(Chat).filter(
Chat.id == chat_id,
Chat.search_space_id == search_space_id
)
result = await session.execute(query)
chat = result.scalars().first()
if not chat:
raise ValueError(f"Chat with id {chat_id} not found in search space {search_space_id}")
# Create chat history structure
chat_history_str = "<chat_history>"
for message in chat.messages:
if message["role"] == "user":
chat_history_str += f"<user_message>{message['content']}</user_message>"
elif message["role"] == "assistant":
# Last annotation type will always be "ANSWER" here
answer_annotation = message["annotations"][-1]
answer_text = ""
if answer_annotation["type"] == "ANSWER":
answer_text = answer_annotation["content"]
# If content is a list, join it into a single string
if isinstance(answer_text, list):
answer_text = "\n".join(answer_text)
chat_history_str += f"<assistant_message>{answer_text}</assistant_message>"
chat_history_str += "</chat_history>"
# Pass it to the SurfSense Podcaster
config = {
"configurable": {
"podcast_title" : "Surfsense",
}
}
# Initialize state with database session and streaming service
initial_state = State(
source_content=chat_history_str,
)
# Run the graph directly
result = await podcaster_graph.ainvoke(initial_state, config=config)
# Convert podcast transcript entries to serializable format
serializable_transcript = []
for entry in result["podcast_transcript"]:
serializable_transcript.append({
"speaker_id": entry.speaker_id,
"dialog": entry.dialog
})
# Create a new podcast entry
podcast = Podcast(
title=f"{podcast_title}",
podcast_transcript=serializable_transcript,
file_location=result["final_podcast_file_path"],
search_space_id=search_space_id
)
# Add to session and commit
session.add(podcast)
await session.commit()
await session.refresh(podcast)
return podcast