From a45412abad6bbfe52465382d114fd1576a5dbb37 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 29 Jan 2026 20:24:50 +0200 Subject: [PATCH] refactor: rewrite public_chat_service for immutable snapshots --- .../app/services/public_chat_service.py | 579 ++++++++++++------ 1 file changed, 390 insertions(+), 189 deletions(-) diff --git a/surfsense_backend/app/services/public_chat_service.py b/surfsense_backend/app/services/public_chat_service.py index a5b8c9ffe..e58329cf4 100644 --- a/surfsense_backend/app/services/public_chat_service.py +++ b/surfsense_backend/app/services/public_chat_service.py @@ -1,17 +1,36 @@ """ -Service layer for public chat sharing and cloning. +Service layer for public chat sharing via immutable snapshots. + +Key concepts: +- Snapshots are frozen copies of a chat at a specific point in time +- Content hash enables deduplication (same content = same URL) +- Podcasts are embedded in snapshot_data for self-contained public views +- Single-phase clone reads directly from snapshot_data """ +import contextlib +import hashlib +import json import re import secrets +from datetime import UTC, datetime from uuid import UUID from fastapi import HTTPException -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from app.db import NewChatThread, User +from app.db import ( + ChatVisibility, + NewChatMessage, + NewChatThread, + Podcast, + PodcastStatus, + PublicChatSnapshot, + SearchSpaceMembership, + User, +) UI_TOOLS = { "display_image", @@ -100,20 +119,241 @@ async def get_author_display( return user_cache[author_id] -async def toggle_public_share( +# ============================================================================= +# Content Hashing +# ============================================================================= + + +def compute_content_hash(messages: list[dict]) -> str: + """ + Compute SHA-256 hash of message content for deduplication. + + The hash is based on message IDs and content, ensuring that: + - Same messages = same hash = same URL (deduplication) + - Any change = different hash = new URL + """ + # Sort by message ID to ensure consistent ordering + sorted_messages = sorted(messages, key=lambda m: m.get("id", 0)) + + # Create normalized representation + normalized = [] + for msg in sorted_messages: + normalized.append( + { + "id": msg.get("id"), + "role": msg.get("role"), + "content": msg.get("content"), + } + ) + + content_str = json.dumps(normalized, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(content_str.encode()).hexdigest() + + +# ============================================================================= +# Snapshot Creation +# ============================================================================= + + +async def create_snapshot( session: AsyncSession, thread_id: int, - enabled: bool, user: User, base_url: str, ) -> dict: """ - Enable or disable public sharing for a thread. + Create a public snapshot of a chat thread. - Only the thread owner can toggle public sharing. - When enabling, generates a new token if one doesn't exist. - When disabling, keeps the token for potential re-enable. + Returns existing snapshot if content unchanged (same hash). + Returns new snapshot with unique URL if content changed. """ + result = await session.execute( + select(NewChatThread) + .options(selectinload(NewChatThread.messages)) + .filter(NewChatThread.id == thread_id) + ) + thread = result.scalars().first() + + if not thread: + raise HTTPException(status_code=404, detail="Thread not found") + + if thread.created_by_id != user.id: + raise HTTPException( + status_code=403, + detail="Only the creator of this chat can create public snapshots", + ) + + # Build snapshot data + user_cache: dict[UUID, dict] = {} + messages_data = [] + message_ids = [] + podcasts_data = [] + podcast_ids_seen: set[int] = set() + + for msg in sorted(thread.messages, key=lambda m: m.created_at): + author = await get_author_display(session, msg.author_id, user_cache) + sanitized_content = sanitize_content_for_public(msg.content) + + # Extract podcast references (keep original podcast_id unchanged) + if isinstance(sanitized_content, list): + for part in sanitized_content: + if ( + isinstance(part, dict) + and part.get("type") == "tool-call" + and part.get("toolName") == "generate_podcast" + ): + result_data = part.get("result", {}) + podcast_id = result_data.get("podcast_id") + if podcast_id and podcast_id not in podcast_ids_seen: + + podcast_info = await _get_podcast_for_snapshot( + session, podcast_id + ) + if podcast_info: + podcasts_data.append(podcast_info) + podcast_ids_seen.add(podcast_id) + + + messages_data.append( + { + "id": msg.id, + "role": msg.role.value if hasattr(msg.role, "value") else str(msg.role), + "content": sanitized_content, + "author": author, + "author_id": str(msg.author_id) if msg.author_id else None, + "created_at": msg.created_at.isoformat() if msg.created_at else None, + } + ) + message_ids.append(msg.id) + + if not messages_data: + raise HTTPException(status_code=400, detail="Cannot share an empty chat") + + # Compute content hash for deduplication + content_hash = compute_content_hash(messages_data) + + # Check if identical snapshot already exists + existing_result = await session.execute( + select(PublicChatSnapshot).filter( + PublicChatSnapshot.thread_id == thread_id, + PublicChatSnapshot.content_hash == content_hash, + ) + ) + existing = existing_result.scalars().first() + + if existing: + # Return existing snapshot URL + return { + "snapshot_id": existing.id, + "share_token": existing.share_token, + "public_url": f"{base_url}/public/{existing.share_token}", + "is_new": False, + } + + # Get thread author info + thread_author = await get_author_display(session, thread.created_by_id, user_cache) + + # Create snapshot data + snapshot_data = { + "title": thread.title, + "snapshot_at": datetime.now(UTC).isoformat(), + "author": thread_author, + "messages": messages_data, + "podcasts": podcasts_data, + } + + # Create new snapshot + share_token = secrets.token_urlsafe(48) + snapshot = PublicChatSnapshot( + thread_id=thread_id, + share_token=share_token, + content_hash=content_hash, + snapshot_data=snapshot_data, + message_ids=message_ids, + created_by_user_id=user.id, + ) + session.add(snapshot) + await session.commit() + await session.refresh(snapshot) + + return { + "snapshot_id": snapshot.id, + "share_token": snapshot.share_token, + "public_url": f"{base_url}/public/{snapshot.share_token}", + "is_new": True, + } + + +async def _get_podcast_for_snapshot( + session: AsyncSession, + podcast_id: int, +) -> dict | None: + """Get podcast info for embedding in snapshot_data.""" + result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) + podcast = result.scalars().first() + + if not podcast or podcast.status != PodcastStatus.READY: + return None + + return { + "original_id": podcast.id, + "title": podcast.title, + "transcript": podcast.podcast_transcript, + "file_path": podcast.file_location, + } + + +# ============================================================================= +# Snapshot Retrieval +# ============================================================================= + + +async def get_snapshot_by_token( + session: AsyncSession, + share_token: str, +) -> PublicChatSnapshot | None: + """Get a snapshot by its share token.""" + result = await session.execute( + select(PublicChatSnapshot).filter( + PublicChatSnapshot.share_token == share_token + ) + ) + return result.scalars().first() + + +async def get_public_chat( + session: AsyncSession, + share_token: str, +) -> dict: + """ + Get public chat data from a snapshot. + + Returns sanitized content suitable for public viewing. + """ + snapshot = await get_snapshot_by_token(session, share_token) + + if not snapshot: + raise HTTPException(status_code=404, detail="Not found") + + data = snapshot.snapshot_data + + return { + "thread": { + "title": data.get("title", "Untitled"), + "created_at": data.get("snapshot_at"), + }, + "messages": data.get("messages", []), + } + + +async def list_snapshots_for_thread( + session: AsyncSession, + thread_id: int, + user: User, + base_url: str, +) -> list[dict]: + """List all public snapshots for a thread.""" + # Verify ownership result = await session.execute( select(NewChatThread).filter(NewChatThread.id == thread_id) ) @@ -125,92 +365,99 @@ async def toggle_public_share( if thread.created_by_id != user.id: raise HTTPException( status_code=403, - detail="Only the creator of this chat can manage public sharing", + detail="Only the creator can view snapshots", ) - if enabled and not thread.public_share_token: - thread.public_share_token = secrets.token_urlsafe(48) + # Get snapshots + result = await session.execute( + select(PublicChatSnapshot) + .filter(PublicChatSnapshot.thread_id == thread_id) + .order_by(PublicChatSnapshot.created_at.desc()) + ) + snapshots = result.scalars().all() - thread.public_share_enabled = enabled - - await session.commit() - await session.refresh(thread) - - if enabled: - return { - "enabled": True, - "public_url": f"{base_url}/public/{thread.public_share_token}", - "share_token": thread.public_share_token, + return [ + { + "id": s.id, + "share_token": s.share_token, + "public_url": f"{base_url}/public/{s.share_token}", + "created_at": s.created_at.isoformat() if s.created_at else None, + "message_count": len(s.message_ids) if s.message_ids else 0, } - - return { - "enabled": False, - "public_url": None, - "share_token": None, - } + for s in snapshots + ] -async def get_public_chat( +# ============================================================================= +# Snapshot Deletion +# ============================================================================= + + +async def delete_snapshot( session: AsyncSession, - share_token: str, -) -> dict: - """ - Get a public chat by share token. - - Returns sanitized content suitable for public viewing. - """ + thread_id: int, + snapshot_id: int, + user: User, +) -> bool: + """Delete a specific snapshot. Only thread owner can delete.""" + # Get snapshot with thread result = await session.execute( - select(NewChatThread) - .options(selectinload(NewChatThread.messages)) + select(PublicChatSnapshot) + .options(selectinload(PublicChatSnapshot.thread)) .filter( - NewChatThread.public_share_token == share_token, - NewChatThread.public_share_enabled.is_(True), + PublicChatSnapshot.id == snapshot_id, + PublicChatSnapshot.thread_id == thread_id, ) ) - thread = result.scalars().first() + snapshot = result.scalars().first() - if not thread: - raise HTTPException(status_code=404, detail="Not found") + if not snapshot: + raise HTTPException(status_code=404, detail="Snapshot not found") - user_cache: dict[UUID, dict] = {} - - messages = [] - for msg in sorted(thread.messages, key=lambda m: m.created_at): - author = await get_author_display(session, msg.author_id, user_cache) - sanitized_content = sanitize_content_for_public(msg.content) - - messages.append( - { - "role": msg.role, - "content": sanitized_content, - "author": author, - "created_at": msg.created_at, - } + if snapshot.thread.created_by_id != user.id: + raise HTTPException( + status_code=403, + detail="Only the creator can delete snapshots", ) - return { - "thread": { - "title": thread.title, - "created_at": thread.created_at, - }, - "messages": messages, - } + await session.delete(snapshot) + await session.commit() + return True -async def get_thread_by_share_token( +async def delete_affected_snapshots( session: AsyncSession, - share_token: str, -) -> NewChatThread | None: - """Get a thread by its public share token if sharing is enabled.""" + thread_id: int, + message_ids: list[int], +) -> int: + """ + Delete snapshots that contain any of the given message IDs. + + Called when messages are edited/deleted/regenerated. + + Returns the number of deleted snapshots. + """ + if not message_ids: + return 0 + + # Use raw SQL for array overlap query + # The && operator checks if arrays have any elements in common result = await session.execute( - select(NewChatThread) - .options(selectinload(NewChatThread.messages)) - .filter( - NewChatThread.public_share_token == share_token, - NewChatThread.public_share_enabled.is_(True), - ) + delete(PublicChatSnapshot) + .where(PublicChatSnapshot.thread_id == thread_id) + .where(PublicChatSnapshot.message_ids.overlap(message_ids)) + .returning(PublicChatSnapshot.id) ) - return result.scalars().first() + + deleted_ids = result.scalars().all() + await session.commit() + + return len(deleted_ids) + + +# ============================================================================= +# Cloning from Snapshot +# ============================================================================= async def get_user_default_search_space( @@ -222,8 +469,6 @@ async def get_user_default_search_space( Returns the first search space where user is owner, or None if not found. """ - from app.db import SearchSpaceMembership - result = await session.execute( select(SearchSpaceMembership) .filter( @@ -240,140 +485,96 @@ async def get_user_default_search_space( return None -async def complete_clone_content( +async def clone_from_snapshot( session: AsyncSession, - target_thread: NewChatThread, - source_thread_id: int, - target_search_space_id: int, -) -> int: + share_token: str, + user: User, +) -> dict: """ Copy messages and podcasts from source thread to target thread. - Sets clone_pending=False and needs_history_bootstrap=True when done. - Returns the number of messages copied. + Creates thread and copies messages from snapshot_data. + Returns the new thread info. """ - from app.db import NewChatMessage + # Get snapshot + snapshot = await get_snapshot_by_token(session, share_token) - result = await session.execute( - select(NewChatThread) - .options(selectinload(NewChatThread.messages)) - .filter(NewChatThread.id == source_thread_id) + if not snapshot: + raise HTTPException( + status_code=404, detail="Chat not found or no longer public" + ) + + # Get user's default search space + target_search_space_id = await get_user_default_search_space(session, user.id) + + if target_search_space_id is None: + raise HTTPException(status_code=400, detail="No search space found for user") + + # Get snapshot data + data = snapshot.snapshot_data + messages_data = data.get("messages", []) + + # Create new thread + new_thread = NewChatThread( + title=data.get("title", "Cloned Chat"), + archived=False, + visibility=ChatVisibility.PRIVATE, + search_space_id=target_search_space_id, + created_by_id=user.id, + cloned_from_thread_id=snapshot.thread_id, + cloned_at=datetime.now(UTC), + needs_history_bootstrap=True, ) - source_thread = result.scalars().first() + session.add(new_thread) + await session.flush() # Get thread ID - if not source_thread: - raise ValueError("Source thread not found") - - podcast_id_map: dict[int, int] = {} - message_count = 0 - - for msg in sorted(source_thread.messages, key=lambda m: m.created_at): - new_content = sanitize_content_for_public(msg.content) - - if isinstance(new_content, list): - for part in new_content: - if ( - isinstance(part, dict) - and part.get("type") == "tool-call" - and part.get("toolName") == "generate_podcast" - ): - result_data = part.get("result", {}) - old_podcast_id = result_data.get("podcast_id") - if old_podcast_id and old_podcast_id not in podcast_id_map: - new_podcast_id = await _clone_podcast( - session, - old_podcast_id, - target_search_space_id, - target_thread.id, - ) - if new_podcast_id: - podcast_id_map[old_podcast_id] = new_podcast_id - - if old_podcast_id and old_podcast_id in podcast_id_map: - result_data["podcast_id"] = podcast_id_map[old_podcast_id] - elif old_podcast_id: - # Podcast couldn't be cloned (not ready), remove reference - result_data.pop("podcast_id", None) + # Copy messages from snapshot_data (preserve original authors) + for msg_data in messages_data: + # Parse original author_id if present + original_author_id = None + author_id_str = msg_data.get("author_id") + if author_id_str: + with contextlib.suppress(ValueError, TypeError): + original_author_id = UUID(author_id_str) new_message = NewChatMessage( - thread_id=target_thread.id, - role=msg.role, - content=new_content, - author_id=msg.author_id, - created_at=msg.created_at, + thread_id=new_thread.id, + role=msg_data.get("role", "user"), + content=msg_data.get("content", []), + author_id=original_author_id, ) session.add(new_message) - message_count += 1 - - target_thread.clone_pending = False - target_thread.needs_history_bootstrap = True await session.commit() + await session.refresh(new_thread) - return message_count + return { + "thread_id": new_thread.id, + "search_space_id": target_search_space_id, + } -async def _clone_podcast( +async def get_snapshot_podcast( session: AsyncSession, + share_token: str, podcast_id: int, - target_search_space_id: int, - target_thread_id: int, -) -> int | None: - """Clone a podcast record and its audio file. Only clones ready podcasts.""" - import shutil - import uuid - from pathlib import Path +) -> dict | None: + """ + Get podcast info from a snapshot by original podcast ID. - from app.db import Podcast, PodcastStatus + Used for streaming podcast audio from public view. + Looks up the podcast by its original_id in the snapshot's podcasts array. + """ + snapshot = await get_snapshot_by_token(session, share_token) - result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id)) - original = result.scalars().first() - if not original or original.status != PodcastStatus.READY: + if not snapshot: return None - new_file_path = None - if original.file_location: - original_path = Path(original.file_location) - if original_path.exists(): - new_filename = f"{uuid.uuid4()}_podcast.mp3" - new_dir = Path("podcasts") - new_dir.mkdir(parents=True, exist_ok=True) - new_file_path = str(new_dir / new_filename) - shutil.copy2(original.file_location, new_file_path) + podcasts = snapshot.snapshot_data.get("podcasts", []) - new_podcast = Podcast( - title=original.title, - podcast_transcript=original.podcast_transcript, - file_location=new_file_path, - status=PodcastStatus.READY, - search_space_id=target_search_space_id, - thread_id=target_thread_id, - ) - session.add(new_podcast) - await session.flush() + # Find podcast by original_id + for podcast in podcasts: + if podcast.get("original_id") == podcast_id: + return podcast - return new_podcast.id - - -async def is_podcast_publicly_accessible( - session: AsyncSession, - podcast_id: int, -) -> bool: - """ - Check if a podcast belongs to a publicly shared thread. - - Uses the thread_id foreign key for efficient lookup. - """ - from app.db import Podcast - - result = await session.execute( - select(Podcast) - .options(selectinload(Podcast.thread)) - .filter(Podcast.id == podcast_id) - ) - podcast = result.scalars().first() - - if not podcast or not podcast.thread: - return False - - return podcast.thread.public_share_enabled + return None