refactor: rewrite public_chat_service for immutable snapshots

This commit is contained in:
CREDO23 2026-01-29 20:24:50 +02:00
parent 665354b33d
commit a45412abad

View file

@ -1,17 +1,36 @@
"""
Service layer for public chat sharing and cloning.
Service layer for public chat sharing via immutable snapshots.
Key concepts:
- Snapshots are frozen copies of a chat at a specific point in time
- Content hash enables deduplication (same content = same URL)
- Podcasts are embedded in snapshot_data for self-contained public views
- Single-phase clone reads directly from snapshot_data
"""
import contextlib
import hashlib
import json
import re
import secrets
from datetime import UTC, datetime
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db import NewChatThread, User
from app.db import (
ChatVisibility,
NewChatMessage,
NewChatThread,
Podcast,
PodcastStatus,
PublicChatSnapshot,
SearchSpaceMembership,
User,
)
UI_TOOLS = {
"display_image",
@ -100,20 +119,241 @@ async def get_author_display(
return user_cache[author_id]
async def toggle_public_share(
# =============================================================================
# Content Hashing
# =============================================================================
def compute_content_hash(messages: list[dict]) -> str:
"""
Compute SHA-256 hash of message content for deduplication.
The hash is based on message IDs and content, ensuring that:
- Same messages = same hash = same URL (deduplication)
- Any change = different hash = new URL
"""
# Sort by message ID to ensure consistent ordering
sorted_messages = sorted(messages, key=lambda m: m.get("id", 0))
# Create normalized representation
normalized = []
for msg in sorted_messages:
normalized.append(
{
"id": msg.get("id"),
"role": msg.get("role"),
"content": msg.get("content"),
}
)
content_str = json.dumps(normalized, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(content_str.encode()).hexdigest()
# =============================================================================
# Snapshot Creation
# =============================================================================
async def create_snapshot(
session: AsyncSession,
thread_id: int,
enabled: bool,
user: User,
base_url: str,
) -> dict:
"""
Enable or disable public sharing for a thread.
Create a public snapshot of a chat thread.
Only the thread owner can toggle public sharing.
When enabling, generates a new token if one doesn't exist.
When disabling, keeps the token for potential re-enable.
Returns existing snapshot if content unchanged (same hash).
Returns new snapshot with unique URL if content changed.
"""
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
if thread.created_by_id != user.id:
raise HTTPException(
status_code=403,
detail="Only the creator of this chat can create public snapshots",
)
# Build snapshot data
user_cache: dict[UUID, dict] = {}
messages_data = []
message_ids = []
podcasts_data = []
podcast_ids_seen: set[int] = set()
for msg in sorted(thread.messages, key=lambda m: m.created_at):
author = await get_author_display(session, msg.author_id, user_cache)
sanitized_content = sanitize_content_for_public(msg.content)
# Extract podcast references (keep original podcast_id unchanged)
if isinstance(sanitized_content, list):
for part in sanitized_content:
if (
isinstance(part, dict)
and part.get("type") == "tool-call"
and part.get("toolName") == "generate_podcast"
):
result_data = part.get("result", {})
podcast_id = result_data.get("podcast_id")
if podcast_id and podcast_id not in podcast_ids_seen:
podcast_info = await _get_podcast_for_snapshot(
session, podcast_id
)
if podcast_info:
podcasts_data.append(podcast_info)
podcast_ids_seen.add(podcast_id)
messages_data.append(
{
"id": msg.id,
"role": msg.role.value if hasattr(msg.role, "value") else str(msg.role),
"content": sanitized_content,
"author": author,
"author_id": str(msg.author_id) if msg.author_id else None,
"created_at": msg.created_at.isoformat() if msg.created_at else None,
}
)
message_ids.append(msg.id)
if not messages_data:
raise HTTPException(status_code=400, detail="Cannot share an empty chat")
# Compute content hash for deduplication
content_hash = compute_content_hash(messages_data)
# Check if identical snapshot already exists
existing_result = await session.execute(
select(PublicChatSnapshot).filter(
PublicChatSnapshot.thread_id == thread_id,
PublicChatSnapshot.content_hash == content_hash,
)
)
existing = existing_result.scalars().first()
if existing:
# Return existing snapshot URL
return {
"snapshot_id": existing.id,
"share_token": existing.share_token,
"public_url": f"{base_url}/public/{existing.share_token}",
"is_new": False,
}
# Get thread author info
thread_author = await get_author_display(session, thread.created_by_id, user_cache)
# Create snapshot data
snapshot_data = {
"title": thread.title,
"snapshot_at": datetime.now(UTC).isoformat(),
"author": thread_author,
"messages": messages_data,
"podcasts": podcasts_data,
}
# Create new snapshot
share_token = secrets.token_urlsafe(48)
snapshot = PublicChatSnapshot(
thread_id=thread_id,
share_token=share_token,
content_hash=content_hash,
snapshot_data=snapshot_data,
message_ids=message_ids,
created_by_user_id=user.id,
)
session.add(snapshot)
await session.commit()
await session.refresh(snapshot)
return {
"snapshot_id": snapshot.id,
"share_token": snapshot.share_token,
"public_url": f"{base_url}/public/{snapshot.share_token}",
"is_new": True,
}
async def _get_podcast_for_snapshot(
session: AsyncSession,
podcast_id: int,
) -> dict | None:
"""Get podcast info for embedding in snapshot_data."""
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
podcast = result.scalars().first()
if not podcast or podcast.status != PodcastStatus.READY:
return None
return {
"original_id": podcast.id,
"title": podcast.title,
"transcript": podcast.podcast_transcript,
"file_path": podcast.file_location,
}
# =============================================================================
# Snapshot Retrieval
# =============================================================================
async def get_snapshot_by_token(
session: AsyncSession,
share_token: str,
) -> PublicChatSnapshot | None:
"""Get a snapshot by its share token."""
result = await session.execute(
select(PublicChatSnapshot).filter(
PublicChatSnapshot.share_token == share_token
)
)
return result.scalars().first()
async def get_public_chat(
session: AsyncSession,
share_token: str,
) -> dict:
"""
Get public chat data from a snapshot.
Returns sanitized content suitable for public viewing.
"""
snapshot = await get_snapshot_by_token(session, share_token)
if not snapshot:
raise HTTPException(status_code=404, detail="Not found")
data = snapshot.snapshot_data
return {
"thread": {
"title": data.get("title", "Untitled"),
"created_at": data.get("snapshot_at"),
},
"messages": data.get("messages", []),
}
async def list_snapshots_for_thread(
session: AsyncSession,
thread_id: int,
user: User,
base_url: str,
) -> list[dict]:
"""List all public snapshots for a thread."""
# Verify ownership
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
@ -125,92 +365,99 @@ async def toggle_public_share(
if thread.created_by_id != user.id:
raise HTTPException(
status_code=403,
detail="Only the creator of this chat can manage public sharing",
detail="Only the creator can view snapshots",
)
if enabled and not thread.public_share_token:
thread.public_share_token = secrets.token_urlsafe(48)
# Get snapshots
result = await session.execute(
select(PublicChatSnapshot)
.filter(PublicChatSnapshot.thread_id == thread_id)
.order_by(PublicChatSnapshot.created_at.desc())
)
snapshots = result.scalars().all()
thread.public_share_enabled = enabled
await session.commit()
await session.refresh(thread)
if enabled:
return {
"enabled": True,
"public_url": f"{base_url}/public/{thread.public_share_token}",
"share_token": thread.public_share_token,
return [
{
"id": s.id,
"share_token": s.share_token,
"public_url": f"{base_url}/public/{s.share_token}",
"created_at": s.created_at.isoformat() if s.created_at else None,
"message_count": len(s.message_ids) if s.message_ids else 0,
}
return {
"enabled": False,
"public_url": None,
"share_token": None,
}
for s in snapshots
]
async def get_public_chat(
# =============================================================================
# Snapshot Deletion
# =============================================================================
async def delete_snapshot(
session: AsyncSession,
share_token: str,
) -> dict:
"""
Get a public chat by share token.
Returns sanitized content suitable for public viewing.
"""
thread_id: int,
snapshot_id: int,
user: User,
) -> bool:
"""Delete a specific snapshot. Only thread owner can delete."""
# Get snapshot with thread
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
select(PublicChatSnapshot)
.options(selectinload(PublicChatSnapshot.thread))
.filter(
NewChatThread.public_share_token == share_token,
NewChatThread.public_share_enabled.is_(True),
PublicChatSnapshot.id == snapshot_id,
PublicChatSnapshot.thread_id == thread_id,
)
)
thread = result.scalars().first()
snapshot = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Not found")
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
user_cache: dict[UUID, dict] = {}
messages = []
for msg in sorted(thread.messages, key=lambda m: m.created_at):
author = await get_author_display(session, msg.author_id, user_cache)
sanitized_content = sanitize_content_for_public(msg.content)
messages.append(
{
"role": msg.role,
"content": sanitized_content,
"author": author,
"created_at": msg.created_at,
}
if snapshot.thread.created_by_id != user.id:
raise HTTPException(
status_code=403,
detail="Only the creator can delete snapshots",
)
return {
"thread": {
"title": thread.title,
"created_at": thread.created_at,
},
"messages": messages,
}
await session.delete(snapshot)
await session.commit()
return True
async def get_thread_by_share_token(
async def delete_affected_snapshots(
session: AsyncSession,
share_token: str,
) -> NewChatThread | None:
"""Get a thread by its public share token if sharing is enabled."""
thread_id: int,
message_ids: list[int],
) -> int:
"""
Delete snapshots that contain any of the given message IDs.
Called when messages are edited/deleted/regenerated.
Returns the number of deleted snapshots.
"""
if not message_ids:
return 0
# Use raw SQL for array overlap query
# The && operator checks if arrays have any elements in common
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(
NewChatThread.public_share_token == share_token,
NewChatThread.public_share_enabled.is_(True),
)
delete(PublicChatSnapshot)
.where(PublicChatSnapshot.thread_id == thread_id)
.where(PublicChatSnapshot.message_ids.overlap(message_ids))
.returning(PublicChatSnapshot.id)
)
return result.scalars().first()
deleted_ids = result.scalars().all()
await session.commit()
return len(deleted_ids)
# =============================================================================
# Cloning from Snapshot
# =============================================================================
async def get_user_default_search_space(
@ -222,8 +469,6 @@ async def get_user_default_search_space(
Returns the first search space where user is owner, or None if not found.
"""
from app.db import SearchSpaceMembership
result = await session.execute(
select(SearchSpaceMembership)
.filter(
@ -240,140 +485,96 @@ async def get_user_default_search_space(
return None
async def complete_clone_content(
async def clone_from_snapshot(
session: AsyncSession,
target_thread: NewChatThread,
source_thread_id: int,
target_search_space_id: int,
) -> int:
share_token: str,
user: User,
) -> dict:
"""
Copy messages and podcasts from source thread to target thread.
Sets clone_pending=False and needs_history_bootstrap=True when done.
Returns the number of messages copied.
Creates thread and copies messages from snapshot_data.
Returns the new thread info.
"""
from app.db import NewChatMessage
# Get snapshot
snapshot = await get_snapshot_by_token(session, share_token)
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == source_thread_id)
if not snapshot:
raise HTTPException(
status_code=404, detail="Chat not found or no longer public"
)
# Get user's default search space
target_search_space_id = await get_user_default_search_space(session, user.id)
if target_search_space_id is None:
raise HTTPException(status_code=400, detail="No search space found for user")
# Get snapshot data
data = snapshot.snapshot_data
messages_data = data.get("messages", [])
# Create new thread
new_thread = NewChatThread(
title=data.get("title", "Cloned Chat"),
archived=False,
visibility=ChatVisibility.PRIVATE,
search_space_id=target_search_space_id,
created_by_id=user.id,
cloned_from_thread_id=snapshot.thread_id,
cloned_at=datetime.now(UTC),
needs_history_bootstrap=True,
)
source_thread = result.scalars().first()
session.add(new_thread)
await session.flush() # Get thread ID
if not source_thread:
raise ValueError("Source thread not found")
podcast_id_map: dict[int, int] = {}
message_count = 0
for msg in sorted(source_thread.messages, key=lambda m: m.created_at):
new_content = sanitize_content_for_public(msg.content)
if isinstance(new_content, list):
for part in new_content:
if (
isinstance(part, dict)
and part.get("type") == "tool-call"
and part.get("toolName") == "generate_podcast"
):
result_data = part.get("result", {})
old_podcast_id = result_data.get("podcast_id")
if old_podcast_id and old_podcast_id not in podcast_id_map:
new_podcast_id = await _clone_podcast(
session,
old_podcast_id,
target_search_space_id,
target_thread.id,
)
if new_podcast_id:
podcast_id_map[old_podcast_id] = new_podcast_id
if old_podcast_id and old_podcast_id in podcast_id_map:
result_data["podcast_id"] = podcast_id_map[old_podcast_id]
elif old_podcast_id:
# Podcast couldn't be cloned (not ready), remove reference
result_data.pop("podcast_id", None)
# Copy messages from snapshot_data (preserve original authors)
for msg_data in messages_data:
# Parse original author_id if present
original_author_id = None
author_id_str = msg_data.get("author_id")
if author_id_str:
with contextlib.suppress(ValueError, TypeError):
original_author_id = UUID(author_id_str)
new_message = NewChatMessage(
thread_id=target_thread.id,
role=msg.role,
content=new_content,
author_id=msg.author_id,
created_at=msg.created_at,
thread_id=new_thread.id,
role=msg_data.get("role", "user"),
content=msg_data.get("content", []),
author_id=original_author_id,
)
session.add(new_message)
message_count += 1
target_thread.clone_pending = False
target_thread.needs_history_bootstrap = True
await session.commit()
await session.refresh(new_thread)
return message_count
return {
"thread_id": new_thread.id,
"search_space_id": target_search_space_id,
}
async def _clone_podcast(
async def get_snapshot_podcast(
session: AsyncSession,
share_token: str,
podcast_id: int,
target_search_space_id: int,
target_thread_id: int,
) -> int | None:
"""Clone a podcast record and its audio file. Only clones ready podcasts."""
import shutil
import uuid
from pathlib import Path
) -> dict | None:
"""
Get podcast info from a snapshot by original podcast ID.
from app.db import Podcast, PodcastStatus
Used for streaming podcast audio from public view.
Looks up the podcast by its original_id in the snapshot's podcasts array.
"""
snapshot = await get_snapshot_by_token(session, share_token)
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
original = result.scalars().first()
if not original or original.status != PodcastStatus.READY:
if not snapshot:
return None
new_file_path = None
if original.file_location:
original_path = Path(original.file_location)
if original_path.exists():
new_filename = f"{uuid.uuid4()}_podcast.mp3"
new_dir = Path("podcasts")
new_dir.mkdir(parents=True, exist_ok=True)
new_file_path = str(new_dir / new_filename)
shutil.copy2(original.file_location, new_file_path)
podcasts = snapshot.snapshot_data.get("podcasts", [])
new_podcast = Podcast(
title=original.title,
podcast_transcript=original.podcast_transcript,
file_location=new_file_path,
status=PodcastStatus.READY,
search_space_id=target_search_space_id,
thread_id=target_thread_id,
)
session.add(new_podcast)
await session.flush()
# Find podcast by original_id
for podcast in podcasts:
if podcast.get("original_id") == podcast_id:
return podcast
return new_podcast.id
async def is_podcast_publicly_accessible(
session: AsyncSession,
podcast_id: int,
) -> bool:
"""
Check if a podcast belongs to a publicly shared thread.
Uses the thread_id foreign key for efficient lookup.
"""
from app.db import Podcast
result = await session.execute(
select(Podcast)
.options(selectinload(Podcast.thread))
.filter(Podcast.id == podcast_id)
)
podcast = result.scalars().first()
if not podcast or not podcast.thread:
return False
return podcast.thread.public_share_enabled
return None