feat: added attachment support

This commit is contained in:
DESKTOP-RTLN3BA\$punk 2025-12-21 22:26:33 -08:00
parent bb971460fc
commit c2dcb2045d
62 changed files with 1166 additions and 9012 deletions

View file

@ -9,7 +9,6 @@ from sqlalchemy import (
ARRAY,
JSON,
TIMESTAMP,
BigInteger,
Boolean,
Column,
Enum as SQLAlchemyEnum,
@ -423,6 +422,25 @@ class Chunk(BaseModel, TimestampMixin):
document = relationship("Document", back_populates="chunks")
class Podcast(BaseModel, TimestampMixin):
"""Podcast model for storing generated podcasts."""
__tablename__ = "podcasts"
title = Column(String(500), nullable=False)
podcast_transcript = Column(JSONB, nullable=True) # List of transcript entries
file_location = Column(Text, nullable=True) # Path to the audio file
search_space_id = Column(
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
)
search_space = relationship("SearchSpace", back_populates="podcasts")
# Optional: link to chat thread (null for content-based podcasts from new-chat)
chat_id = Column(Integer, nullable=True)
chat_state_version = Column(String(100), nullable=True)
class SearchSpace(BaseModel, TimestampMixin):
__tablename__ = "searchspaces"
@ -459,6 +477,12 @@ class SearchSpace(BaseModel, TimestampMixin):
order_by="NewChatThread.updated_at.desc()",
cascade="all, delete-orphan",
)
podcasts = relationship(
"Podcast",
back_populates="search_space",
order_by="Podcast.id.desc()",
cascade="all, delete-orphan",
)
logs = relationship(
"Log",
back_populates="search_space",

View file

@ -16,6 +16,7 @@ from .logs_routes import router as logs_router
from .luma_add_connector_route import router as luma_add_connector_router
from .new_chat_routes import router as new_chat_router
from .notes_routes import router as notes_router
from .podcasts_routes import router as podcasts_router
from .rbac_routes import router as rbac_router
from .search_source_connectors_routes import router as search_source_connectors_router
from .search_spaces_routes import router as search_spaces_router
@ -28,6 +29,7 @@ router.include_router(editor_router)
router.include_router(documents_router)
router.include_router(notes_router)
router.include_router(new_chat_router) # Chat with assistant-ui persistence
router.include_router(podcasts_router) # Podcast task status and audio
router.include_router(search_source_connectors_router)
router.include_router(google_calendar_add_connector_router)
router.include_router(google_gmail_add_connector_router)

View file

@ -8,11 +8,16 @@ These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
- PUT /threads/{thread_id} - Update thread (rename, archive)
- DELETE /threads/{thread_id} - Delete thread
- POST /threads/{thread_id}/messages - Append message
- POST /attachments/process - Process attachments for chat context
"""
import contextlib
import os
import tempfile
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from fastapi.responses import StreamingResponse
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
@ -650,10 +655,10 @@ async def handle_new_chat(
):
"""
Stream chat responses from the deep agent.
This endpoint handles the new chat functionality with streaming responses
using Server-Sent Events (SSE) format compatible with Vercel AI SDK.
Requires CHATS_CREATE permission.
"""
try:
@ -695,6 +700,7 @@ async def handle_new_chat(
session=session,
llm_config_id=llm_config_id,
messages=request.messages,
attachments=request.attachments,
),
media_type="text/event-stream",
headers={
@ -711,3 +717,185 @@ async def handle_new_chat(
status_code=500,
detail=f"An unexpected error occurred: {e!s}",
) from None
# =============================================================================
# Attachment Processing Endpoint
# =============================================================================
@router.post("/attachments/process")
async def process_attachment(
file: UploadFile = File(...),
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Process an attachment file and extract its content as markdown.
This endpoint uses the configured ETL service to parse files and return
the extracted content that can be used as context in chat messages.
Supported file types depend on the configured ETL_SERVICE:
- Markdown/Text files: .md, .markdown, .txt (always supported)
- Audio files: .mp3, .mp4, .mpeg, .mpga, .m4a, .wav, .webm (if STT configured)
- Documents: .pdf, .docx, .doc, .pptx, .xlsx (depends on ETL service)
Returns:
JSON with attachment id, name, type, and extracted content
"""
from app.config import config as app_config
if not file.filename:
raise HTTPException(status_code=400, detail="No filename provided")
filename = file.filename
attachment_id = str(uuid.uuid4())
try:
# Save file to a temporary location
file_ext = os.path.splitext(filename)[1].lower()
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
temp_path = temp_file.name
content = await file.read()
temp_file.write(content)
extracted_content = ""
# Process based on file type
if file_ext in (".md", ".markdown", ".txt"):
# For text/markdown files, read content directly
with open(temp_path, encoding="utf-8") as f:
extracted_content = f.read()
elif file_ext in (".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm"):
# Audio files - transcribe if STT service is configured
if not app_config.STT_SERVICE:
raise HTTPException(
status_code=422,
detail="Audio transcription is not configured. Please set STT_SERVICE.",
)
stt_service_type = (
"local" if app_config.STT_SERVICE.startswith("local/") else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
result = stt_service.transcribe_file(temp_path)
extracted_content = result.get("text", "")
else:
from litellm import atranscription
with open(temp_path, "rb") as audio_file:
transcription_kwargs = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
transcription_kwargs["api_base"] = (
app_config.STT_SERVICE_API_BASE
)
transcription_response = await atranscription(
**transcription_kwargs
)
extracted_content = transcription_response.get("text", "")
if extracted_content:
extracted_content = (
f"# Transcription of {filename}\n\n{extracted_content}"
)
else:
# Document files - use configured ETL service
if app_config.ETL_SERVICE == "UNSTRUCTURED":
from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown
loader = UnstructuredLoader(
temp_path,
mode="elements",
post_processors=[],
languages=["eng"],
include_orig_elements=False,
include_metadata=False,
strategy="auto",
)
docs = await loader.aload()
extracted_content = await convert_document_to_markdown(docs)
elif app_config.ETL_SERVICE == "LLAMACLOUD":
from llama_cloud_services import LlamaParse
from llama_cloud_services.parse.utils import ResultType
parser = LlamaParse(
api_key=app_config.LLAMA_CLOUD_API_KEY,
num_workers=1,
verbose=False,
language="en",
result_type=ResultType.MD,
)
result = await parser.aparse(temp_path)
markdown_documents = await result.aget_markdown_documents(
split_by_page=False
)
if markdown_documents:
extracted_content = "\n\n".join(
doc.text for doc in markdown_documents
)
elif app_config.ETL_SERVICE == "DOCLING":
from app.services.docling_service import create_docling_service
docling_service = create_docling_service()
result = await docling_service.process_document(temp_path, filename)
extracted_content = result.get("content", "")
else:
raise HTTPException(
status_code=422,
detail=f"ETL service not configured or unsupported file type: {file_ext}",
)
# Clean up temp file
with contextlib.suppress(Exception):
os.unlink(temp_path)
if not extracted_content:
raise HTTPException(
status_code=422,
detail=f"Could not extract content from file: {filename}",
)
# Determine attachment type (must be one of: "image", "document", "file")
# assistant-ui only supports these three types
if file_ext in (".png", ".jpg", ".jpeg", ".gif", ".webp"):
attachment_type = "image"
else:
# All other files (including audio, documents, text) are treated as "document"
attachment_type = "document"
return {
"id": attachment_id,
"name": filename,
"type": attachment_type,
"content": extracted_content,
"contentLength": len(extracted_content),
}
except HTTPException:
raise
except Exception as e:
# Clean up temp file on error
with contextlib.suppress(Exception):
os.unlink(temp_path)
raise HTTPException(
status_code=500,
detail=f"Failed to process attachment: {e!s}",
) from e

View file

@ -0,0 +1,277 @@
"""
Podcast routes for task status polling and audio retrieval.
These routes support the podcast generation feature in new-chat.
Note: The old Chat-based podcast generation has been removed.
"""
import os
from pathlib import Path
from celery.result import AsyncResult
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession
from app.celery_app import celery_app
from app.db import (
Permission,
Podcast,
SearchSpace,
SearchSpaceMembership,
User,
get_async_session,
)
from app.schemas import PodcastRead
from app.users import current_active_user
from app.utils.rbac import check_permission
router = APIRouter()
@router.get("/podcasts", response_model=list[PodcastRead])
async def read_podcasts(
skip: int = 0,
limit: int = 100,
search_space_id: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List podcasts the user has access to.
Requires PODCASTS_READ permission for the search space(s).
"""
if skip < 0 or limit < 1:
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
try:
if search_space_id is not None:
# Check permission for specific search space
await check_permission(
session,
user,
search_space_id,
Permission.PODCASTS_READ.value,
"You don't have permission to read podcasts in this search space",
)
result = await session.execute(
select(Podcast)
.filter(Podcast.search_space_id == search_space_id)
.offset(skip)
.limit(limit)
)
else:
# Get podcasts from all search spaces user has membership in
result = await session.execute(
select(Podcast)
.join(SearchSpace)
.join(SearchSpaceMembership)
.filter(SearchSpaceMembership.user_id == user.id)
.offset(skip)
.limit(limit)
)
return result.scalars().all()
except HTTPException:
raise
except SQLAlchemyError:
raise HTTPException(
status_code=500, detail="Database error occurred while fetching podcasts"
) from None
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
async def read_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get a specific podcast by ID.
Requires PODCASTS_READ permission for the search space.
"""
try:
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
podcast = result.scalars().first()
if not podcast:
raise HTTPException(
status_code=404,
detail="Podcast not found",
)
# Check permission for the search space
await check_permission(
session,
user,
podcast.search_space_id,
Permission.PODCASTS_READ.value,
"You don't have permission to read podcasts in this search space",
)
return podcast
except HTTPException as he:
raise he
except SQLAlchemyError:
raise HTTPException(
status_code=500, detail="Database error occurred while fetching podcast"
) from None
@router.delete("/podcasts/{podcast_id}", response_model=dict)
async def delete_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete a podcast.
Requires PODCASTS_DELETE permission for the search space.
"""
try:
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
db_podcast = result.scalars().first()
if not db_podcast:
raise HTTPException(status_code=404, detail="Podcast not found")
# Check permission for the search space
await check_permission(
session,
user,
db_podcast.search_space_id,
Permission.PODCASTS_DELETE.value,
"You don't have permission to delete podcasts in this search space",
)
await session.delete(db_podcast)
await session.commit()
return {"message": "Podcast deleted successfully"}
except HTTPException as he:
raise he
except SQLAlchemyError:
await session.rollback()
raise HTTPException(
status_code=500, detail="Database error occurred while deleting podcast"
) from None
@router.get("/podcasts/{podcast_id}/stream")
@router.get("/podcasts/{podcast_id}/audio")
async def stream_podcast(
podcast_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Stream a podcast audio file.
Requires PODCASTS_READ permission for the search space.
Note: Both /stream and /audio endpoints are supported for compatibility.
"""
try:
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
podcast = result.scalars().first()
if not podcast:
raise HTTPException(
status_code=404,
detail="Podcast not found",
)
# Check permission for the search space
await check_permission(
session,
user,
podcast.search_space_id,
Permission.PODCASTS_READ.value,
"You don't have permission to access podcasts in this search space",
)
# Get the file path
file_path = podcast.file_location
# Check if the file exists
if not file_path or not os.path.isfile(file_path):
raise HTTPException(status_code=404, detail="Podcast audio file not found")
# Define a generator function to stream the file
def iterfile():
with open(file_path, mode="rb") as file_like:
yield from file_like
# Return a streaming response with appropriate headers
return StreamingResponse(
iterfile(),
media_type="audio/mpeg",
headers={
"Accept-Ranges": "bytes",
"Content-Disposition": f"inline; filename={Path(file_path).name}",
},
)
except HTTPException as he:
raise he
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error streaming podcast: {e!s}"
) from e
@router.get("/podcasts/task/{task_id}/status")
async def get_podcast_task_status(
task_id: str,
user: User = Depends(current_active_user),
):
"""
Get the status of a podcast generation task.
Used by new-chat frontend to poll for completion.
Returns:
- status: "processing" | "success" | "error"
- podcast_id: (only if status == "success")
- title: (only if status == "success")
- error: (only if status == "error")
"""
try:
result = AsyncResult(task_id, app=celery_app)
if result.ready():
# Task completed
if result.successful():
task_result = result.result
if isinstance(task_result, dict):
if task_result.get("status") == "success":
return {
"status": "success",
"podcast_id": task_result.get("podcast_id"),
"title": task_result.get("title"),
"transcript_entries": task_result.get("transcript_entries"),
}
else:
return {
"status": "error",
"error": task_result.get("error", "Unknown error"),
}
else:
return {
"status": "error",
"error": "Unexpected task result format",
}
else:
# Task failed
return {
"status": "error",
"error": str(result.result) if result.result else "Task failed",
}
else:
# Task still processing
return {
"status": "processing",
"state": result.state,
}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error checking task status: {e!s}"
) from e

View file

@ -26,6 +26,7 @@ from .new_chat import (
ThreadListItem,
ThreadListResponse,
)
from .podcasts import PodcastBase, PodcastCreate, PodcastRead, PodcastUpdate
from .rbac_schemas import (
InviteAcceptRequest,
InviteAcceptResponse,
@ -61,17 +62,6 @@ from .users import UserCreate, UserRead, UserUpdate
__all__ = [
# Chat schemas (assistant-ui integration)
"ChatMessage",
"NewChatMessageAppend",
"NewChatMessageCreate",
"NewChatMessageRead",
"NewChatRequest",
"NewChatThreadCreate",
"NewChatThreadRead",
"NewChatThreadUpdate",
"NewChatThreadWithMessages",
"ThreadHistoryLoadResponse",
"ThreadListItem",
"ThreadListResponse",
# Chunk schemas
"ChunkBase",
"ChunkCreate",
@ -85,10 +75,15 @@ __all__ = [
"DocumentsCreate",
"ExtensionDocumentContent",
"ExtensionDocumentMetadata",
"PaginatedResponse",
# Base schemas
"IDModel",
"TimestampModel",
# RBAC schemas
"InviteAcceptRequest",
"InviteAcceptResponse",
"InviteCreate",
"InviteInfoResponse",
"InviteRead",
"InviteUpdate",
# LLM Config schemas
"LLMConfigBase",
"LLMConfigCreate",
@ -100,18 +95,25 @@ __all__ = [
"LogFilter",
"LogRead",
"LogUpdate",
# RBAC schemas
"InviteAcceptRequest",
"InviteAcceptResponse",
"InviteCreate",
"InviteInfoResponse",
"InviteRead",
"InviteUpdate",
"MembershipRead",
"MembershipReadWithUser",
"MembershipUpdate",
"NewChatMessageAppend",
"NewChatMessageCreate",
"NewChatMessageRead",
"NewChatRequest",
"NewChatThreadCreate",
"NewChatThreadRead",
"NewChatThreadUpdate",
"NewChatThreadWithMessages",
"PaginatedResponse",
"PermissionInfo",
"PermissionsListResponse",
# Podcast schemas
"PodcastBase",
"PodcastCreate",
"PodcastRead",
"PodcastUpdate",
"RoleCreate",
"RoleRead",
"RoleUpdate",
@ -126,6 +128,10 @@ __all__ = [
"SearchSpaceRead",
"SearchSpaceUpdate",
"SearchSpaceWithStats",
"ThreadHistoryLoadResponse",
"ThreadListItem",
"ThreadListResponse",
"TimestampModel",
# User schemas
"UserCreate",
"UserRead",

View file

@ -141,6 +141,15 @@ class ChatMessage(BaseModel):
content: str
class ChatAttachment(BaseModel):
"""An attachment with its extracted content for chat context."""
id: str # Unique attachment ID
name: str # Original filename
type: str # Attachment type: document, image, audio
content: str # Extracted markdown content from the file
class NewChatRequest(BaseModel):
"""Request schema for the deep agent chat endpoint."""
@ -148,3 +157,6 @@ class NewChatRequest(BaseModel):
user_query: str
search_space_id: int
messages: list[ChatMessage] | None = None # Optional chat history from frontend
attachments: list[ChatAttachment] | None = (
None # Optional attachments with extracted content
)

View file

@ -0,0 +1,41 @@
"""Podcast schemas for API responses."""
from datetime import datetime
from typing import Any
from pydantic import BaseModel
class PodcastBase(BaseModel):
"""Base podcast schema."""
title: str
podcast_transcript: list[dict[str, Any]] | None = None
file_location: str | None = None
search_space_id: int
chat_id: int | None = None
chat_state_version: str | None = None
class PodcastCreate(PodcastBase):
"""Schema for creating a podcast."""
pass
class PodcastUpdate(BaseModel):
"""Schema for updating a podcast."""
title: str | None = None
podcast_transcript: list[dict[str, Any]] | None = None
file_location: str | None = None
class PodcastRead(PodcastBase):
"""Schema for reading a podcast."""
id: int
created_at: datetime
class Config:
from_attributes = True

View file

@ -13,7 +13,6 @@ from app.agents.podcaster.state import State as PodcasterState
from app.celery_app import celery_app
from app.config import config
from app.db import Podcast
from app.tasks.podcast_tasks import generate_chat_podcast
logger = logging.getLogger(__name__)
@ -40,58 +39,6 @@ def get_celery_session_maker():
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="generate_chat_podcast", bind=True)
def generate_chat_podcast_task(
self,
chat_id: int,
search_space_id: int,
user_id: int,
podcast_title: str | None = None,
user_prompt: str | None = None,
):
"""
Celery task to generate podcast from chat.
Args:
chat_id: ID of the chat to generate podcast from
search_space_id: ID of the search space
user_id: ID of the user,
podcast_title: Title for the podcast
user_prompt: Optional prompt from the user to guide the podcast generation
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_generate_chat_podcast(
chat_id, search_space_id, user_id, podcast_title, user_prompt
)
)
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
asyncio.set_event_loop(None)
loop.close()
async def _generate_chat_podcast(
chat_id: int,
search_space_id: int,
user_id: int,
podcast_title: str | None = None,
user_prompt: str | None = None,
):
"""Generate chat podcast with new session."""
async with get_celery_session_maker()() as session:
try:
await generate_chat_podcast(
session, chat_id, search_space_id, user_id, podcast_title, user_prompt
)
except Exception as e:
logger.error(f"Error generating podcast from chat: {e!s}")
raise
# =============================================================================
# Content-based podcast generation (for new-chat)
# =============================================================================

View file

@ -18,11 +18,28 @@ from app.agents.new_chat.llm_config import (
create_chat_litellm_from_config,
load_llm_config_from_yaml,
)
from app.schemas.new_chat import ChatMessage
from app.schemas.new_chat import ChatAttachment, ChatMessage
from app.services.connector_service import ConnectorService
from app.services.new_streaming_service import VercelStreamingService
def format_attachments_as_context(attachments: list[ChatAttachment]) -> str:
"""Format attachments as context for the agent."""
if not attachments:
return ""
context_parts = ["<user_attachments>"]
for i, attachment in enumerate(attachments, 1):
context_parts.append(
f"<attachment index='{i}' name='{attachment.name}' type='{attachment.type}'>"
)
context_parts.append(f"<![CDATA[{attachment.content}]]>")
context_parts.append("</attachment>")
context_parts.append("</user_attachments>")
return "\n".join(context_parts)
async def stream_new_chat(
user_query: str,
user_id: str | UUID,
@ -31,6 +48,7 @@ async def stream_new_chat(
session: AsyncSession,
llm_config_id: int = -1,
messages: list[ChatMessage] | None = None,
attachments: list[ChatAttachment] | None = None,
) -> AsyncGenerator[str, None]:
"""
Stream chat responses from the new SurfSense deep agent.
@ -96,6 +114,14 @@ async def stream_new_chat(
# Build input with message history from frontend
langchain_messages = []
# Format the user query with attachment context if any
final_query = user_query
if attachments:
attachment_context = format_attachments_as_context(attachments)
final_query = (
f"{attachment_context}\n\n<user_query>{user_query}</user_query>"
)
# if messages:
# # Convert frontend messages to LangChain format
# for msg in messages:
@ -104,8 +130,8 @@ async def stream_new_chat(
# elif msg.role == "assistant":
# langchain_messages.append(AIMessage(content=msg.content))
# else:
# Fallback: just use the current user query
langchain_messages.append(HumanMessage(content=user_query))
# Fallback: just use the current user query with attachment context
langchain_messages.append(HumanMessage(content=final_query))
input_state = {
# Lets not pass this message atm because we are using the checkpointer to manage the conversation history

View file

@ -1,11 +1,15 @@
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
"""
Legacy podcast task for old chat system.
NOTE: The old Chat model has been removed. This module is kept for backwards
compatibility but the generate_chat_podcast function will raise an error
if called. Use generate_content_podcast_task in celery_tasks/podcast_tasks.py
for new-chat podcast generation instead.
"""
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.podcaster.graph import graph as podcaster_graph
from app.agents.podcaster.state import State
from app.db import Chat, Podcast
from app.services.task_logging_service import TaskLoggingService
from app.db import Podcast # noqa: F401 - imported for backwards compatibility
async def generate_chat_podcast(
@ -16,196 +20,13 @@ async def generate_chat_podcast(
podcast_title: str | None = None,
user_prompt: str | None = None,
):
task_logger = TaskLoggingService(session, search_space_id)
"""
Legacy function for generating podcasts from old chat system.
# Log task start
log_entry = await task_logger.log_task_start(
task_name="generate_chat_podcast",
source="podcast_task",
message=f"Starting podcast generation for chat {chat_id}",
metadata={
"chat_id": chat_id,
"search_space_id": search_space_id,
"podcast_title": podcast_title,
"user_id": str(user_id),
"user_prompt": user_prompt,
},
This function is deprecated as the old Chat model has been removed.
Use generate_content_podcast_task for new-chat podcast generation.
"""
raise NotImplementedError(
"generate_chat_podcast is deprecated. The old Chat model has been removed. "
"Use generate_content_podcast_task for podcast generation from new-chat."
)
try:
# Fetch the chat with the specified ID
await task_logger.log_task_progress(
log_entry, f"Fetching chat {chat_id} from database", {"stage": "fetch_chat"}
)
query = select(Chat).filter(
Chat.id == chat_id, Chat.search_space_id == search_space_id
)
result = await session.execute(query)
chat = result.scalars().first()
if not chat:
await task_logger.log_task_failure(
log_entry,
f"Chat with id {chat_id} not found in search space {search_space_id}",
"Chat not found",
{"error_type": "ChatNotFound"},
)
raise ValueError(
f"Chat with id {chat_id} not found in search space {search_space_id}"
)
# Create chat history structure
await task_logger.log_task_progress(
log_entry,
f"Processing chat history for chat {chat_id}",
{"stage": "process_chat_history", "message_count": len(chat.messages)},
)
chat_history_str = "<chat_history>"
processed_messages = 0
for message in chat.messages:
if message["role"] == "user":
chat_history_str += f"<user_message>{message['content']}</user_message>"
processed_messages += 1
elif message["role"] == "assistant":
chat_history_str += (
f"<assistant_message>{message['content']}</assistant_message>"
)
processed_messages += 1
chat_history_str += "</chat_history>"
# Pass it to the SurfSense Podcaster
await task_logger.log_task_progress(
log_entry,
f"Initializing podcast generation for chat {chat_id}",
{
"stage": "initialize_podcast_generation",
"processed_messages": processed_messages,
"content_length": len(chat_history_str),
},
)
config = {
"configurable": {
"podcast_title": podcast_title or "SurfSense Podcast",
"user_id": str(user_id),
"search_space_id": search_space_id,
"user_prompt": user_prompt,
}
}
# Initialize state with database session and streaming service
initial_state = State(source_content=chat_history_str, db_session=session)
# Run the graph directly
await task_logger.log_task_progress(
log_entry,
f"Running podcast generation graph for chat {chat_id}",
{"stage": "run_podcast_graph"},
)
result = await podcaster_graph.ainvoke(initial_state, config=config)
# Convert podcast transcript entries to serializable format
await task_logger.log_task_progress(
log_entry,
f"Processing podcast transcript for chat {chat_id}",
{
"stage": "process_transcript",
"transcript_entries": len(result["podcast_transcript"]),
},
)
serializable_transcript = []
for entry in result["podcast_transcript"]:
serializable_transcript.append(
{"speaker_id": entry.speaker_id, "dialog": entry.dialog}
)
# Create a new podcast entry
await task_logger.log_task_progress(
log_entry,
f"Creating podcast database entry for chat {chat_id}",
{
"stage": "create_podcast_entry",
"file_location": result.get("final_podcast_file_path"),
},
)
# check if podcast already exists for this chat (re-generation)
existing_podcast = await session.execute(
select(Podcast).filter(Podcast.chat_id == chat_id)
)
existing_podcast = existing_podcast.scalars().first()
if existing_podcast:
existing_podcast.podcast_transcript = serializable_transcript
existing_podcast.file_location = result["final_podcast_file_path"]
existing_podcast.chat_state_version = chat.state_version
await session.commit()
await session.refresh(existing_podcast)
return existing_podcast
else:
podcast = Podcast(
title=f"{podcast_title}",
podcast_transcript=serializable_transcript,
file_location=result["final_podcast_file_path"],
search_space_id=search_space_id,
chat_state_version=chat.state_version,
chat_id=chat.id,
)
# Add to session and commit
session.add(podcast)
await session.commit()
await session.refresh(podcast)
# Log success
await task_logger.log_task_success(
log_entry,
f"Successfully generated podcast for chat {chat_id}",
{
"podcast_id": podcast.id,
"podcast_title": podcast_title,
"transcript_entries": len(serializable_transcript),
"file_location": result.get("final_podcast_file_path"),
"processed_messages": processed_messages,
"content_length": len(chat_history_str),
},
)
return podcast
except ValueError as ve:
# ValueError is already logged above for chat not found
if "not found" not in str(ve):
await task_logger.log_task_failure(
log_entry,
f"Value error during podcast generation for chat {chat_id}",
str(ve),
{"error_type": "ValueError"},
)
raise ve
except SQLAlchemyError as db_error:
await session.rollback()
await task_logger.log_task_failure(
log_entry,
f"Database error during podcast generation for chat {chat_id}",
str(db_error),
{"error_type": "SQLAlchemyError"},
)
raise db_error
except Exception as e:
await session.rollback()
await task_logger.log_task_failure(
log_entry,
f"Unexpected error during podcast generation for chat {chat_id}",
str(e),
{"error_type": type(e).__name__},
)
raise RuntimeError(
f"Failed to generate podcast for chat {chat_id}: {e!s}"
) from e