refactor(backend): two-phase synchronous cloning

This commit is contained in:
CREDO23 2026-01-28 00:17:29 +02:00
parent 0fbf5d5bdd
commit 0c8d1f3fef
8 changed files with 178 additions and 228 deletions

View file

@ -4,7 +4,6 @@ Service layer for public chat sharing and cloning.
import re
import secrets
from datetime import UTC, datetime
from uuid import UUID
from fastapi import HTTPException
@ -241,108 +240,74 @@ async def get_user_default_search_space(
return None
async def clone_public_chat(
async def complete_clone_content(
session: AsyncSession,
share_token: str,
user_id: UUID,
) -> dict:
target_thread: NewChatThread,
source_thread_id: int,
target_search_space_id: int,
) -> int:
"""
Clone a public chat to user's account.
Copy messages and podcasts from source thread to target thread.
Creates a new private thread with all messages and podcasts.
Citations are stripped since they reference the original user's documents.
Sets clone_pending=False and needs_history_bootstrap=True when done.
Returns the number of messages copied.
"""
from app.db import (
ChatVisibility,
NewChatMessage,
from app.db import NewChatMessage
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == source_thread_id)
)
source_thread = result.scalars().first()
source_thread = await get_thread_by_share_token(session, share_token)
if not source_thread:
await _create_clone_failure_notification(
session, user_id, share_token, "Chat not found or no longer public"
raise ValueError("Source thread not found")
podcast_id_map: dict[int, int] = {}
message_count = 0
for msg in sorted(source_thread.messages, key=lambda m: m.created_at):
new_content = sanitize_content_for_public(msg.content)
if isinstance(new_content, list):
for part in new_content:
if (
isinstance(part, dict)
and part.get("type") == "tool-call"
and part.get("toolName") == "generate_podcast"
):
result_data = part.get("result", {})
old_podcast_id = result_data.get("podcast_id")
if old_podcast_id and old_podcast_id not in podcast_id_map:
new_podcast_id = await _clone_podcast(
session,
old_podcast_id,
target_search_space_id,
target_thread.id,
)
if new_podcast_id:
podcast_id_map[old_podcast_id] = new_podcast_id
if old_podcast_id and old_podcast_id in podcast_id_map:
result_data["podcast_id"] = podcast_id_map[old_podcast_id]
new_message = NewChatMessage(
thread_id=target_thread.id,
role=msg.role,
content=new_content,
author_id=msg.author_id,
created_at=msg.created_at,
)
return {"status": "error", "error": "Chat not found or no longer public"}
session.add(new_message)
message_count += 1
try:
target_search_space_id = await get_user_default_search_space(session, user_id)
target_thread.clone_pending = False
target_thread.needs_history_bootstrap = True
if target_search_space_id is None:
await _create_clone_failure_notification(
session, user_id, share_token, "No search space found"
)
return {"status": "error", "error": "No search space found"}
await session.commit()
new_thread = NewChatThread(
title=source_thread.title,
archived=False,
visibility=ChatVisibility.PRIVATE,
search_space_id=target_search_space_id,
created_by_id=user_id,
public_share_enabled=False,
cloned_from_thread_id=source_thread.id,
cloned_at=datetime.now(UTC),
needs_history_bootstrap=True,
)
session.add(new_thread)
await session.flush()
podcast_id_map: dict[int, int] = {}
for msg in sorted(source_thread.messages, key=lambda m: m.created_at):
new_content = sanitize_content_for_public(msg.content)
if isinstance(new_content, list):
for part in new_content:
if (
isinstance(part, dict)
and part.get("type") == "tool-call"
and part.get("toolName") == "generate_podcast"
):
result = part.get("result", {})
old_podcast_id = result.get("podcast_id")
if old_podcast_id and old_podcast_id not in podcast_id_map:
new_podcast_id = await _clone_podcast(
session,
old_podcast_id,
target_search_space_id,
new_thread.id,
)
if new_podcast_id:
podcast_id_map[old_podcast_id] = new_podcast_id
if old_podcast_id and old_podcast_id in podcast_id_map:
result["podcast_id"] = podcast_id_map[old_podcast_id]
new_message = NewChatMessage(
thread_id=new_thread.id,
role=msg.role,
content=new_content,
author_id=msg.author_id,
created_at=msg.created_at,
)
session.add(new_message)
await session.commit()
await _create_clone_success_notification(
session,
user_id,
new_thread.id,
target_search_space_id,
source_thread.title,
)
return {
"status": "success",
"thread_id": new_thread.id,
"search_space_id": target_search_space_id,
}
except Exception as e:
await session.rollback()
await _create_clone_failure_notification(session, user_id, share_token, str(e))
return {"status": "error", "error": str(e)}
return message_count
async def _clone_podcast(
@ -387,54 +352,6 @@ async def _clone_podcast(
return new_podcast.id
async def _create_clone_success_notification(
session: AsyncSession,
user_id: UUID,
thread_id: int,
search_space_id: int,
original_title: str,
) -> None:
"""Create success notification for clone operation."""
from app.db import Notification
notification = Notification(
user_id=user_id,
search_space_id=search_space_id,
type="chat_cloned",
title="Chat copied successfully",
message=f"Your copy of '{original_title}' is ready",
notification_metadata={
"thread_id": thread_id,
"search_space_id": search_space_id,
},
)
session.add(notification)
await session.commit()
async def _create_clone_failure_notification(
session: AsyncSession,
user_id: UUID,
share_token: str,
error: str,
) -> None:
"""Create failure notification for clone operation."""
from app.db import Notification
notification = Notification(
user_id=user_id,
type="chat_clone_failed",
title="Failed to copy chat",
message="Could not copy the chat. Please try again.",
notification_metadata={
"share_token": share_token,
"error": error,
},
)
session.add(notification)
await session.commit()
async def is_podcast_publicly_accessible(
session: AsyncSession,
podcast_id: int,