refactor(backend): two-phase synchronous cloning

This commit is contained in:
CREDO23 2026-01-28 00:17:29 +02:00
parent 0fbf5d5bdd
commit 0c8d1f3fef
8 changed files with 178 additions and 228 deletions

View file

@ -8,6 +8,7 @@ Adds columns for:
1. Public sharing via tokenized URLs (public_share_token, public_share_enabled) 1. Public sharing via tokenized URLs (public_share_token, public_share_enabled)
2. Clone tracking for audit (cloned_from_thread_id, cloned_at) 2. Clone tracking for audit (cloned_from_thread_id, cloned_at)
3. History bootstrap flag for cloned chats (needs_history_bootstrap) 3. History bootstrap flag for cloned chats (needs_history_bootstrap)
4. Clone pending flag for two-phase clone (clone_pending)
""" """
from collections.abc import Sequence from collections.abc import Sequence
@ -76,6 +77,13 @@ def upgrade() -> None:
""" """
) )
op.execute(
"""
ALTER TABLE new_chat_threads
ADD COLUMN IF NOT EXISTS clone_pending BOOLEAN NOT NULL DEFAULT FALSE;
"""
)
op.execute( op.execute(
""" """
CREATE INDEX IF NOT EXISTS ix_new_chat_threads_cloned_from_thread_id CREATE INDEX IF NOT EXISTS ix_new_chat_threads_cloned_from_thread_id
@ -89,6 +97,7 @@ def downgrade() -> None:
"""Remove public sharing and cloning columns from new_chat_threads.""" """Remove public sharing and cloning columns from new_chat_threads."""
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_cloned_from_thread_id") op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_cloned_from_thread_id")
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS clone_pending")
op.execute( op.execute(
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS needs_history_bootstrap" "ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS needs_history_bootstrap"
) )

View file

@ -65,7 +65,6 @@ celery_app = Celery(
"app.tasks.celery_tasks.schedule_checker_task", "app.tasks.celery_tasks.schedule_checker_task",
"app.tasks.celery_tasks.blocknote_migration_tasks", "app.tasks.celery_tasks.blocknote_migration_tasks",
"app.tasks.celery_tasks.document_reindex_tasks", "app.tasks.celery_tasks.document_reindex_tasks",
"app.tasks.celery_tasks.clone_chat_tasks",
], ],
) )

View file

@ -437,6 +437,13 @@ class NewChatThread(BaseModel, TimestampMixin):
default=False, default=False,
server_default="false", server_default="false",
) )
# Flag indicating content clone is pending (two-phase clone)
clone_pending = Column(
Boolean,
nullable=False,
default=False,
server_default="false",
)
# Relationships # Relationships
search_space = relationship("SearchSpace", back_populates="new_chat_threads") search_space = relationship("SearchSpace", back_populates="new_chat_threads")

View file

@ -37,6 +37,7 @@ from app.db import (
get_async_session, get_async_session,
) )
from app.schemas.new_chat import ( from app.schemas.new_chat import (
CompleteCloneResponse,
NewChatMessageAppend, NewChatMessageAppend,
NewChatMessageRead, NewChatMessageRead,
NewChatRequest, NewChatRequest,
@ -669,6 +670,62 @@ async def delete_thread(
) from None ) from None
@router.post("/threads/{thread_id}/complete-clone", response_model=CompleteCloneResponse)
async def complete_clone(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Complete the cloning process for a thread.
Copies messages and podcasts from the source thread.
Sets clone_pending=False and needs_history_bootstrap=True when done.
Requires authentication and ownership of the thread.
"""
from app.services.public_chat_service import complete_clone_content
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
if thread.created_by_id != user.id:
raise HTTPException(status_code=403, detail="Not authorized")
if not thread.clone_pending:
raise HTTPException(status_code=400, detail="Clone already completed")
if not thread.cloned_from_thread_id:
raise HTTPException(status_code=400, detail="No source thread to clone from")
message_count = await complete_clone_content(
session=session,
target_thread=thread,
source_thread_id=thread.cloned_from_thread_id,
target_search_space_id=thread.search_space_id,
)
return CompleteCloneResponse(
status="success",
message_count=message_count,
)
except HTTPException:
raise
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while completing clone: {e!s}",
) from None
@router.patch("/threads/{thread_id}/visibility", response_model=NewChatThreadRead) @router.patch("/threads/{thread_id}/visibility", response_model=NewChatThreadRead)
async def update_thread_visibility( async def update_thread_visibility(
thread_id: int, thread_id: int,

View file

@ -2,17 +2,20 @@
Routes for public chat access (unauthenticated and mixed-auth endpoints). Routes for public chat access (unauthenticated and mixed-auth endpoints).
""" """
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db import User, get_async_session from app.db import ChatVisibility, NewChatThread, User, get_async_session
from app.schemas.new_chat import ( from app.schemas.new_chat import (
CloneInitiatedResponse, CloneInitResponse,
PublicChatResponse, PublicChatResponse,
) )
from app.services.public_chat_service import ( from app.services.public_chat_service import (
get_public_chat, get_public_chat,
get_thread_by_share_token, get_thread_by_share_token,
get_user_default_search_space,
) )
from app.users import current_active_user from app.users import current_active_user
@ -33,32 +36,47 @@ async def read_public_chat(
return await get_public_chat(session, share_token) return await get_public_chat(session, share_token)
@router.post("/{share_token}/clone", response_model=CloneInitiatedResponse) @router.post("/{share_token}/clone", response_model=CloneInitResponse)
async def clone_public_chat_endpoint( async def clone_public_chat_endpoint(
share_token: str, share_token: str,
session: AsyncSession = Depends(get_async_session), session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user), user: User = Depends(current_active_user),
): ):
""" """
Clone a public chat to the user's account. Initialize cloning a public chat to the user's account.
Creates an empty thread with clone_pending=True.
Frontend should redirect to the new thread and call /complete-clone.
Requires authentication. Requires authentication.
Initiates a background job to copy the chat.
""" """
from app.tasks.celery_tasks.clone_chat_tasks import clone_public_chat_task source_thread = await get_thread_by_share_token(session, share_token)
thread = await get_thread_by_share_token(session, share_token) if not source_thread:
raise HTTPException(status_code=404, detail="Chat not found or no longer public")
if not thread: target_search_space_id = await get_user_default_search_space(session, user.id)
raise HTTPException(status_code=404, detail="Not found")
task_result = clone_public_chat_task.delay( if target_search_space_id is None:
raise HTTPException(status_code=400, detail="No search space found for user")
new_thread = NewChatThread(
title=source_thread.title,
archived=False,
visibility=ChatVisibility.PRIVATE,
search_space_id=target_search_space_id,
created_by_id=user.id,
public_share_enabled=False,
cloned_from_thread_id=source_thread.id,
cloned_at=datetime.now(UTC),
clone_pending=True,
)
session.add(new_thread)
await session.commit()
await session.refresh(new_thread)
return CloneInitResponse(
thread_id=new_thread.id,
search_space_id=target_search_space_id,
share_token=share_token, share_token=share_token,
user_id=str(user.id),
)
return CloneInitiatedResponse(
status="processing",
task_id=task_result.id,
message="Copying chat to your account...",
) )

View file

@ -97,6 +97,7 @@ class NewChatThreadRead(NewChatThreadBase, IDModel):
created_by_id: UUID | None = None created_by_id: UUID | None = None
public_share_enabled: bool = False public_share_enabled: bool = False
public_share_token: str | None = None public_share_token: str | None = None
clone_pending: bool = False
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
@ -255,7 +256,15 @@ class PublicChatResponse(BaseModel):
messages: list[PublicChatMessage] messages: list[PublicChatMessage]
class CloneInitiatedResponse(BaseModel): class CloneInitResponse(BaseModel):
status: str = "processing"
task_id: str
message: str = "Copying chat to your account..." thread_id: int
search_space_id: int
share_token: str
class CompleteCloneResponse(BaseModel):
status: str
message_count: int

View file

@ -4,7 +4,6 @@ Service layer for public chat sharing and cloning.
import re import re
import secrets import secrets
from datetime import UTC, datetime
from uuid import UUID from uuid import UUID
from fastapi import HTTPException from fastapi import HTTPException
@ -241,108 +240,74 @@ async def get_user_default_search_space(
return None return None
async def clone_public_chat( async def complete_clone_content(
session: AsyncSession, session: AsyncSession,
share_token: str, target_thread: NewChatThread,
user_id: UUID, source_thread_id: int,
) -> dict: target_search_space_id: int,
) -> int:
""" """
Clone a public chat to user's account. Copy messages and podcasts from source thread to target thread.
Creates a new private thread with all messages and podcasts. Sets clone_pending=False and needs_history_bootstrap=True when done.
Citations are stripped since they reference the original user's documents. Returns the number of messages copied.
""" """
from app.db import ( from app.db import NewChatMessage
ChatVisibility,
NewChatMessage, result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == source_thread_id)
) )
source_thread = result.scalars().first()
source_thread = await get_thread_by_share_token(session, share_token)
if not source_thread: if not source_thread:
await _create_clone_failure_notification( raise ValueError("Source thread not found")
session, user_id, share_token, "Chat not found or no longer public"
podcast_id_map: dict[int, int] = {}
message_count = 0
for msg in sorted(source_thread.messages, key=lambda m: m.created_at):
new_content = sanitize_content_for_public(msg.content)
if isinstance(new_content, list):
for part in new_content:
if (
isinstance(part, dict)
and part.get("type") == "tool-call"
and part.get("toolName") == "generate_podcast"
):
result_data = part.get("result", {})
old_podcast_id = result_data.get("podcast_id")
if old_podcast_id and old_podcast_id not in podcast_id_map:
new_podcast_id = await _clone_podcast(
session,
old_podcast_id,
target_search_space_id,
target_thread.id,
)
if new_podcast_id:
podcast_id_map[old_podcast_id] = new_podcast_id
if old_podcast_id and old_podcast_id in podcast_id_map:
result_data["podcast_id"] = podcast_id_map[old_podcast_id]
new_message = NewChatMessage(
thread_id=target_thread.id,
role=msg.role,
content=new_content,
author_id=msg.author_id,
created_at=msg.created_at,
) )
return {"status": "error", "error": "Chat not found or no longer public"} session.add(new_message)
message_count += 1
try: target_thread.clone_pending = False
target_search_space_id = await get_user_default_search_space(session, user_id) target_thread.needs_history_bootstrap = True
if target_search_space_id is None: await session.commit()
await _create_clone_failure_notification(
session, user_id, share_token, "No search space found"
)
return {"status": "error", "error": "No search space found"}
new_thread = NewChatThread( return message_count
title=source_thread.title,
archived=False,
visibility=ChatVisibility.PRIVATE,
search_space_id=target_search_space_id,
created_by_id=user_id,
public_share_enabled=False,
cloned_from_thread_id=source_thread.id,
cloned_at=datetime.now(UTC),
needs_history_bootstrap=True,
)
session.add(new_thread)
await session.flush()
podcast_id_map: dict[int, int] = {}
for msg in sorted(source_thread.messages, key=lambda m: m.created_at):
new_content = sanitize_content_for_public(msg.content)
if isinstance(new_content, list):
for part in new_content:
if (
isinstance(part, dict)
and part.get("type") == "tool-call"
and part.get("toolName") == "generate_podcast"
):
result = part.get("result", {})
old_podcast_id = result.get("podcast_id")
if old_podcast_id and old_podcast_id not in podcast_id_map:
new_podcast_id = await _clone_podcast(
session,
old_podcast_id,
target_search_space_id,
new_thread.id,
)
if new_podcast_id:
podcast_id_map[old_podcast_id] = new_podcast_id
if old_podcast_id and old_podcast_id in podcast_id_map:
result["podcast_id"] = podcast_id_map[old_podcast_id]
new_message = NewChatMessage(
thread_id=new_thread.id,
role=msg.role,
content=new_content,
author_id=msg.author_id,
created_at=msg.created_at,
)
session.add(new_message)
await session.commit()
await _create_clone_success_notification(
session,
user_id,
new_thread.id,
target_search_space_id,
source_thread.title,
)
return {
"status": "success",
"thread_id": new_thread.id,
"search_space_id": target_search_space_id,
}
except Exception as e:
await session.rollback()
await _create_clone_failure_notification(session, user_id, share_token, str(e))
return {"status": "error", "error": str(e)}
async def _clone_podcast( async def _clone_podcast(
@ -387,54 +352,6 @@ async def _clone_podcast(
return new_podcast.id return new_podcast.id
async def _create_clone_success_notification(
session: AsyncSession,
user_id: UUID,
thread_id: int,
search_space_id: int,
original_title: str,
) -> None:
"""Create success notification for clone operation."""
from app.db import Notification
notification = Notification(
user_id=user_id,
search_space_id=search_space_id,
type="chat_cloned",
title="Chat copied successfully",
message=f"Your copy of '{original_title}' is ready",
notification_metadata={
"thread_id": thread_id,
"search_space_id": search_space_id,
},
)
session.add(notification)
await session.commit()
async def _create_clone_failure_notification(
session: AsyncSession,
user_id: UUID,
share_token: str,
error: str,
) -> None:
"""Create failure notification for clone operation."""
from app.db import Notification
notification = Notification(
user_id=user_id,
type="chat_clone_failed",
title="Failed to copy chat",
message="Could not copy the chat. Please try again.",
notification_metadata={
"share_token": share_token,
"error": error,
},
)
session.add(notification)
await session.commit()
async def is_podcast_publicly_accessible( async def is_podcast_publicly_accessible(
session: AsyncSession, session: AsyncSession,
podcast_id: int, podcast_id: int,

View file

@ -1,66 +0,0 @@
"""Celery tasks for cloning public chats."""
import asyncio
import logging
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.pool import NullPool
from app.celery_app import celery_app
from app.config import config
logger = logging.getLogger(__name__)
def get_celery_session_maker():
"""Create a new async session maker for Celery tasks."""
engine = create_async_engine(
config.DATABASE_URL,
poolclass=NullPool,
echo=False,
)
return async_sessionmaker(engine, expire_on_commit=False)
@celery_app.task(name="clone_public_chat", bind=True)
def clone_public_chat_task(
self,
share_token: str,
user_id: str,
) -> dict:
"""
Celery task to clone a public chat to user's account.
Args:
share_token: Public share token of the chat to clone
user_id: UUID string of the user cloning the chat
Returns:
dict with status and thread_id on success, or error info on failure
"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run_clone(share_token, user_id))
return result
except Exception as e:
logger.error(f"Error cloning public chat: {e!s}")
return {"status": "error", "error": str(e)}
finally:
asyncio.set_event_loop(None)
loop.close()
async def _run_clone(share_token: str, user_id: str) -> dict:
"""Run the clone operation with a fresh database session."""
from uuid import UUID
from app.services.public_chat_service import clone_public_chat
async with get_celery_session_maker()() as session:
return await clone_public_chat(
session=session,
share_token=share_token,
user_id=UUID(user_id),
)