diff --git a/surfsense_backend/alembic/versions/81_add_public_chat_features.py b/surfsense_backend/alembic/versions/81_add_public_chat_features.py index ab73b06bb..8d7e95df7 100644 --- a/surfsense_backend/alembic/versions/81_add_public_chat_features.py +++ b/surfsense_backend/alembic/versions/81_add_public_chat_features.py @@ -8,6 +8,7 @@ Adds columns for: 1. Public sharing via tokenized URLs (public_share_token, public_share_enabled) 2. Clone tracking for audit (cloned_from_thread_id, cloned_at) 3. History bootstrap flag for cloned chats (needs_history_bootstrap) +4. Clone pending flag for two-phase clone (clone_pending) """ from collections.abc import Sequence @@ -76,6 +77,13 @@ def upgrade() -> None: """ ) + op.execute( + """ + ALTER TABLE new_chat_threads + ADD COLUMN IF NOT EXISTS clone_pending BOOLEAN NOT NULL DEFAULT FALSE; + """ + ) + op.execute( """ CREATE INDEX IF NOT EXISTS ix_new_chat_threads_cloned_from_thread_id @@ -89,6 +97,7 @@ def downgrade() -> None: """Remove public sharing and cloning columns from new_chat_threads.""" op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_cloned_from_thread_id") + op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS clone_pending") op.execute( "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS needs_history_bootstrap" ) diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index b4869d23f..f7bea8cc3 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -65,7 +65,6 @@ celery_app = Celery( "app.tasks.celery_tasks.schedule_checker_task", "app.tasks.celery_tasks.blocknote_migration_tasks", "app.tasks.celery_tasks.document_reindex_tasks", - "app.tasks.celery_tasks.clone_chat_tasks", ], ) diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 5a74cddeb..8c6942e44 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -437,6 +437,13 @@ class NewChatThread(BaseModel, TimestampMixin): default=False, server_default="false", ) + # Flag indicating content clone is pending (two-phase clone) + clone_pending = Column( + Boolean, + nullable=False, + default=False, + server_default="false", + ) # Relationships search_space = relationship("SearchSpace", back_populates="new_chat_threads") diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index db371a81c..541e25a75 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -37,6 +37,7 @@ from app.db import ( get_async_session, ) from app.schemas.new_chat import ( + CompleteCloneResponse, NewChatMessageAppend, NewChatMessageRead, NewChatRequest, @@ -669,6 +670,62 @@ async def delete_thread( ) from None +@router.post("/threads/{thread_id}/complete-clone", response_model=CompleteCloneResponse) +async def complete_clone( + thread_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """ + Complete the cloning process for a thread. + + Copies messages and podcasts from the source thread. + Sets clone_pending=False and needs_history_bootstrap=True when done. + + Requires authentication and ownership of the thread. + """ + from app.services.public_chat_service import complete_clone_content + + try: + result = await session.execute( + select(NewChatThread).filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + if thread.created_by_id != user.id: + raise HTTPException(status_code=403, detail="Not authorized") + + if not thread.clone_pending: + raise HTTPException(status_code=400, detail="Clone already completed") + + if not thread.cloned_from_thread_id: + raise HTTPException(status_code=400, detail="No source thread to clone from") + + message_count = await complete_clone_content( + session=session, + target_thread=thread, + source_thread_id=thread.cloned_from_thread_id, + target_search_space_id=thread.search_space_id, + ) + + return CompleteCloneResponse( + status="success", + message_count=message_count, + ) + + except HTTPException: + raise + except Exception as e: + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"An unexpected error occurred while completing clone: {e!s}", + ) from None + + @router.patch("/threads/{thread_id}/visibility", response_model=NewChatThreadRead) async def update_thread_visibility( thread_id: int, diff --git a/surfsense_backend/app/routes/public_chat_routes.py b/surfsense_backend/app/routes/public_chat_routes.py index ca70e911a..8b4f42559 100644 --- a/surfsense_backend/app/routes/public_chat_routes.py +++ b/surfsense_backend/app/routes/public_chat_routes.py @@ -2,17 +2,20 @@ Routes for public chat access (unauthenticated and mixed-auth endpoints). """ +from datetime import UTC, datetime + from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from app.db import User, get_async_session +from app.db import ChatVisibility, NewChatThread, User, get_async_session from app.schemas.new_chat import ( - CloneInitiatedResponse, + CloneInitResponse, PublicChatResponse, ) from app.services.public_chat_service import ( get_public_chat, get_thread_by_share_token, + get_user_default_search_space, ) from app.users import current_active_user @@ -33,32 +36,47 @@ async def read_public_chat( return await get_public_chat(session, share_token) -@router.post("/{share_token}/clone", response_model=CloneInitiatedResponse) +@router.post("/{share_token}/clone", response_model=CloneInitResponse) async def clone_public_chat_endpoint( share_token: str, session: AsyncSession = Depends(get_async_session), user: User = Depends(current_active_user), ): """ - Clone a public chat to the user's account. + Initialize cloning a public chat to the user's account. + + Creates an empty thread with clone_pending=True. + Frontend should redirect to the new thread and call /complete-clone. Requires authentication. - Initiates a background job to copy the chat. """ - from app.tasks.celery_tasks.clone_chat_tasks import clone_public_chat_task + source_thread = await get_thread_by_share_token(session, share_token) - thread = await get_thread_by_share_token(session, share_token) + if not source_thread: + raise HTTPException(status_code=404, detail="Chat not found or no longer public") - if not thread: - raise HTTPException(status_code=404, detail="Not found") + target_search_space_id = await get_user_default_search_space(session, user.id) - task_result = clone_public_chat_task.delay( + if target_search_space_id is None: + raise HTTPException(status_code=400, detail="No search space found for user") + + new_thread = NewChatThread( + title=source_thread.title, + archived=False, + visibility=ChatVisibility.PRIVATE, + search_space_id=target_search_space_id, + created_by_id=user.id, + public_share_enabled=False, + cloned_from_thread_id=source_thread.id, + cloned_at=datetime.now(UTC), + clone_pending=True, + ) + session.add(new_thread) + await session.commit() + await session.refresh(new_thread) + + return CloneInitResponse( + thread_id=new_thread.id, + search_space_id=target_search_space_id, share_token=share_token, - user_id=str(user.id), - ) - - return CloneInitiatedResponse( - status="processing", - task_id=task_result.id, - message="Copying chat to your account...", ) diff --git a/surfsense_backend/app/schemas/new_chat.py b/surfsense_backend/app/schemas/new_chat.py index 5e9d44beb..b420b1b91 100644 --- a/surfsense_backend/app/schemas/new_chat.py +++ b/surfsense_backend/app/schemas/new_chat.py @@ -97,6 +97,7 @@ class NewChatThreadRead(NewChatThreadBase, IDModel): created_by_id: UUID | None = None public_share_enabled: bool = False public_share_token: str | None = None + clone_pending: bool = False created_at: datetime updated_at: datetime @@ -255,7 +256,15 @@ class PublicChatResponse(BaseModel): messages: list[PublicChatMessage] -class CloneInitiatedResponse(BaseModel): - status: str = "processing" - task_id: str - message: str = "Copying chat to your account..." +class CloneInitResponse(BaseModel): + + + thread_id: int + search_space_id: int + share_token: str + + +class CompleteCloneResponse(BaseModel): + + status: str + message_count: int diff --git a/surfsense_backend/app/services/public_chat_service.py b/surfsense_backend/app/services/public_chat_service.py index 1dcc97a11..79618974f 100644 --- a/surfsense_backend/app/services/public_chat_service.py +++ b/surfsense_backend/app/services/public_chat_service.py @@ -4,7 +4,6 @@ Service layer for public chat sharing and cloning. import re import secrets -from datetime import UTC, datetime from uuid import UUID from fastapi import HTTPException @@ -241,108 +240,74 @@ async def get_user_default_search_space( return None -async def clone_public_chat( +async def complete_clone_content( session: AsyncSession, - share_token: str, - user_id: UUID, -) -> dict: + target_thread: NewChatThread, + source_thread_id: int, + target_search_space_id: int, +) -> int: """ - Clone a public chat to user's account. + Copy messages and podcasts from source thread to target thread. - Creates a new private thread with all messages and podcasts. - Citations are stripped since they reference the original user's documents. + Sets clone_pending=False and needs_history_bootstrap=True when done. + Returns the number of messages copied. """ - from app.db import ( - ChatVisibility, - NewChatMessage, + from app.db import NewChatMessage + + result = await session.execute( + select(NewChatThread) + .options(selectinload(NewChatThread.messages)) + .filter(NewChatThread.id == source_thread_id) ) + source_thread = result.scalars().first() - source_thread = await get_thread_by_share_token(session, share_token) if not source_thread: - await _create_clone_failure_notification( - session, user_id, share_token, "Chat not found or no longer public" + raise ValueError("Source thread not found") + + podcast_id_map: dict[int, int] = {} + message_count = 0 + + for msg in sorted(source_thread.messages, key=lambda m: m.created_at): + new_content = sanitize_content_for_public(msg.content) + + if isinstance(new_content, list): + for part in new_content: + if ( + isinstance(part, dict) + and part.get("type") == "tool-call" + and part.get("toolName") == "generate_podcast" + ): + result_data = part.get("result", {}) + old_podcast_id = result_data.get("podcast_id") + if old_podcast_id and old_podcast_id not in podcast_id_map: + new_podcast_id = await _clone_podcast( + session, + old_podcast_id, + target_search_space_id, + target_thread.id, + ) + if new_podcast_id: + podcast_id_map[old_podcast_id] = new_podcast_id + + if old_podcast_id and old_podcast_id in podcast_id_map: + result_data["podcast_id"] = podcast_id_map[old_podcast_id] + + new_message = NewChatMessage( + thread_id=target_thread.id, + role=msg.role, + content=new_content, + author_id=msg.author_id, + created_at=msg.created_at, ) - return {"status": "error", "error": "Chat not found or no longer public"} + session.add(new_message) + message_count += 1 - try: - target_search_space_id = await get_user_default_search_space(session, user_id) + target_thread.clone_pending = False + target_thread.needs_history_bootstrap = True - if target_search_space_id is None: - await _create_clone_failure_notification( - session, user_id, share_token, "No search space found" - ) - return {"status": "error", "error": "No search space found"} + await session.commit() - new_thread = NewChatThread( - title=source_thread.title, - archived=False, - visibility=ChatVisibility.PRIVATE, - search_space_id=target_search_space_id, - created_by_id=user_id, - public_share_enabled=False, - cloned_from_thread_id=source_thread.id, - cloned_at=datetime.now(UTC), - needs_history_bootstrap=True, - ) - session.add(new_thread) - await session.flush() - - podcast_id_map: dict[int, int] = {} - - for msg in sorted(source_thread.messages, key=lambda m: m.created_at): - new_content = sanitize_content_for_public(msg.content) - - if isinstance(new_content, list): - for part in new_content: - if ( - isinstance(part, dict) - and part.get("type") == "tool-call" - and part.get("toolName") == "generate_podcast" - ): - result = part.get("result", {}) - old_podcast_id = result.get("podcast_id") - if old_podcast_id and old_podcast_id not in podcast_id_map: - new_podcast_id = await _clone_podcast( - session, - old_podcast_id, - target_search_space_id, - new_thread.id, - ) - if new_podcast_id: - podcast_id_map[old_podcast_id] = new_podcast_id - - if old_podcast_id and old_podcast_id in podcast_id_map: - result["podcast_id"] = podcast_id_map[old_podcast_id] - - new_message = NewChatMessage( - thread_id=new_thread.id, - role=msg.role, - content=new_content, - author_id=msg.author_id, - created_at=msg.created_at, - ) - session.add(new_message) - - await session.commit() - - await _create_clone_success_notification( - session, - user_id, - new_thread.id, - target_search_space_id, - source_thread.title, - ) - - return { - "status": "success", - "thread_id": new_thread.id, - "search_space_id": target_search_space_id, - } - - except Exception as e: - await session.rollback() - await _create_clone_failure_notification(session, user_id, share_token, str(e)) - return {"status": "error", "error": str(e)} + return message_count async def _clone_podcast( @@ -387,54 +352,6 @@ async def _clone_podcast( return new_podcast.id -async def _create_clone_success_notification( - session: AsyncSession, - user_id: UUID, - thread_id: int, - search_space_id: int, - original_title: str, -) -> None: - """Create success notification for clone operation.""" - from app.db import Notification - - notification = Notification( - user_id=user_id, - search_space_id=search_space_id, - type="chat_cloned", - title="Chat copied successfully", - message=f"Your copy of '{original_title}' is ready", - notification_metadata={ - "thread_id": thread_id, - "search_space_id": search_space_id, - }, - ) - session.add(notification) - await session.commit() - - -async def _create_clone_failure_notification( - session: AsyncSession, - user_id: UUID, - share_token: str, - error: str, -) -> None: - """Create failure notification for clone operation.""" - from app.db import Notification - - notification = Notification( - user_id=user_id, - type="chat_clone_failed", - title="Failed to copy chat", - message="Could not copy the chat. Please try again.", - notification_metadata={ - "share_token": share_token, - "error": error, - }, - ) - session.add(notification) - await session.commit() - - async def is_podcast_publicly_accessible( session: AsyncSession, podcast_id: int, diff --git a/surfsense_backend/app/tasks/celery_tasks/clone_chat_tasks.py b/surfsense_backend/app/tasks/celery_tasks/clone_chat_tasks.py deleted file mode 100644 index b846ee555..000000000 --- a/surfsense_backend/app/tasks/celery_tasks/clone_chat_tasks.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Celery tasks for cloning public chats.""" - -import asyncio -import logging - -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine -from sqlalchemy.pool import NullPool - -from app.celery_app import celery_app -from app.config import config - -logger = logging.getLogger(__name__) - - -def get_celery_session_maker(): - """Create a new async session maker for Celery tasks.""" - engine = create_async_engine( - config.DATABASE_URL, - poolclass=NullPool, - echo=False, - ) - return async_sessionmaker(engine, expire_on_commit=False) - - -@celery_app.task(name="clone_public_chat", bind=True) -def clone_public_chat_task( - self, - share_token: str, - user_id: str, -) -> dict: - """ - Celery task to clone a public chat to user's account. - - Args: - share_token: Public share token of the chat to clone - user_id: UUID string of the user cloning the chat - - Returns: - dict with status and thread_id on success, or error info on failure - """ - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - result = loop.run_until_complete(_run_clone(share_token, user_id)) - return result - except Exception as e: - logger.error(f"Error cloning public chat: {e!s}") - return {"status": "error", "error": str(e)} - finally: - asyncio.set_event_loop(None) - loop.close() - - -async def _run_clone(share_token: str, user_id: str) -> dict: - """Run the clone operation with a fresh database session.""" - from uuid import UUID - - from app.services.public_chat_service import clone_public_chat - - async with get_celery_session_maker()() as session: - return await clone_public_chat( - session=session, - share_token=share_token, - user_id=UUID(user_id), - )