SurfSense/surfsense_backend/app/services/public_chat_service.py

826 lines
26 KiB
Python
Raw Normal View History

2026-01-26 13:07:46 +02:00
"""
Service layer for public chat sharing via immutable snapshots.
Key concepts:
- Snapshots are frozen copies of a chat at a specific point in time
- Content hash enables deduplication (same content = same URL)
- Podcasts are embedded in snapshot_data for self-contained public views
- Single-phase clone reads directly from snapshot_data
2026-01-26 13:07:46 +02:00
"""
2026-02-01 21:17:24 -08:00
import contextlib
import hashlib
import json
2026-01-26 13:07:46 +02:00
import re
import secrets
from datetime import UTC, datetime
2026-01-26 13:07:46 +02:00
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete, select
2026-01-26 13:07:46 +02:00
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db import (
ChatVisibility,
NewChatMessage,
NewChatThread,
Permission,
Podcast,
PodcastStatus,
PublicChatSnapshot,
Report,
SearchSpaceMembership,
User,
)
from app.utils.rbac import check_permission
2026-01-26 13:07:46 +02:00
UI_TOOLS = {
"display_image",
"link_preview",
"generate_podcast",
"generate_report",
2026-01-26 13:07:46 +02:00
"scrape_webpage",
"multi_link_preview",
}
def strip_citations(text: str) -> str:
"""
Remove [citation:X] and [citation:doc-X] patterns from text.
Preserves newlines to maintain markdown formatting.
"""
# Remove citation patterns
text = re.sub(r"[\[【]\u200B?citation:(doc-)?\d+\u200B?[\]】]", "", text)
# Collapse multiple spaces/tabs (but NOT newlines) into single space
text = re.sub(r"[^\S\n]+", " ", text)
# Normalize excessive blank lines (3+ newlines → 2)
text = re.sub(r"\n{3,}", "\n\n", text)
# Clean up spaces around newlines
text = re.sub(r" *\n *", "\n", text)
2026-01-26 13:07:46 +02:00
return text.strip()
def sanitize_content_for_public(content: list | str | None) -> list:
2026-01-27 17:51:36 +02:00
"""
Filter message content for public view.
Strips citations and filters to UI-relevant tools.
"""
2026-01-26 13:07:46 +02:00
if content is None:
return []
if isinstance(content, str):
clean_text = strip_citations(content)
return [{"type": "text", "text": clean_text}] if clean_text else []
if not isinstance(content, list):
return []
sanitized = []
for part in content:
if not isinstance(part, dict):
continue
part_type = part.get("type")
if part_type == "text":
clean_text = strip_citations(part.get("text", ""))
if clean_text:
sanitized.append({"type": "text", "text": clean_text})
elif part_type == "tool-call":
tool_name = part.get("toolName")
if tool_name not in UI_TOOLS:
continue
sanitized.append(part)
2026-01-26 13:07:46 +02:00
return sanitized
async def get_author_display(
session: AsyncSession,
author_id: UUID | None,
user_cache: dict[UUID, dict],
) -> dict | None:
"""Transform author UUID to display info."""
if author_id is None:
return None
if author_id not in user_cache:
result = await session.execute(select(User).filter(User.id == author_id))
user = result.scalars().first()
if user:
user_cache[author_id] = {
"display_name": user.display_name or "User",
"avatar_url": user.avatar_url,
}
else:
user_cache[author_id] = {
"display_name": "Unknown User",
"avatar_url": None,
}
return user_cache[author_id]
# =============================================================================
# Content Hashing
# =============================================================================
def compute_content_hash(messages: list[dict]) -> str:
"""
Compute SHA-256 hash of message content for deduplication.
The hash is based on message IDs and content, ensuring that:
- Same messages = same hash = same URL (deduplication)
- Any change = different hash = new URL
"""
# Sort by message ID to ensure consistent ordering
sorted_messages = sorted(messages, key=lambda m: m.get("id", 0))
# Create normalized representation
normalized = []
for msg in sorted_messages:
normalized.append(
{
"id": msg.get("id"),
"role": msg.get("role"),
"content": msg.get("content"),
}
)
content_str = json.dumps(normalized, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(content_str.encode()).hexdigest()
# =============================================================================
# Snapshot Creation
# =============================================================================
async def create_snapshot(
2026-01-26 13:07:46 +02:00
session: AsyncSession,
thread_id: int,
user: User,
) -> dict:
"""
Create a public snapshot of a chat thread.
2026-01-26 13:07:46 +02:00
Returns existing snapshot if content unchanged (same hash).
Returns new snapshot with unique URL if content changed.
2026-01-26 13:07:46 +02:00
"""
from app.config import config
frontend_url = (config.NEXT_FRONTEND_URL or "").rstrip("/")
2026-01-26 13:07:46 +02:00
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == thread_id)
2026-01-26 13:07:46 +02:00
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.PUBLIC_SHARING_CREATE.value,
"You don't have permission to create public share links",
)
2026-01-26 13:07:46 +02:00
# Build snapshot data
user_cache: dict[UUID, dict] = {}
messages_data = []
message_ids = []
podcasts_data = []
podcast_ids_seen: set[int] = set()
reports_data = []
report_ids_seen: set[int] = set()
2026-01-26 13:07:46 +02:00
for msg in sorted(thread.messages, key=lambda m: m.created_at):
author = await get_author_display(session, msg.author_id, user_cache)
sanitized_content = sanitize_content_for_public(msg.content)
2026-01-26 13:07:46 +02:00
# Extract podcast/report references and update status to "ready" for completed ones
if isinstance(sanitized_content, list):
for part in sanitized_content:
if not isinstance(part, dict) or part.get("type") != "tool-call":
continue
tool_name = part.get("toolName")
if tool_name == "generate_podcast":
result_data = part.get("result", {})
podcast_id = result_data.get("podcast_id")
if podcast_id and podcast_id not in podcast_ids_seen:
podcast_info = await _get_podcast_for_snapshot(
session, podcast_id
)
if podcast_info:
podcasts_data.append(podcast_info)
podcast_ids_seen.add(podcast_id)
# Update status to "ready" so frontend renders PodcastPlayer
part["result"] = {**result_data, "status": "ready"}
2026-01-26 13:07:46 +02:00
elif tool_name == "generate_report":
result_data = part.get("result", {})
report_id = result_data.get("report_id")
if report_id and report_id not in report_ids_seen:
2026-02-13 02:43:26 +05:30
report_info = await _get_report_for_snapshot(session, report_id)
if report_info:
reports_data.append(report_info)
report_ids_seen.add(report_id)
# Update status to "ready" so frontend renders ReportCard
part["result"] = {**result_data, "status": "ready"}
messages_data.append(
{
"id": msg.id,
"role": msg.role.value if hasattr(msg.role, "value") else str(msg.role),
"content": sanitized_content,
"author": author,
"author_id": str(msg.author_id) if msg.author_id else None,
"created_at": msg.created_at.isoformat() if msg.created_at else None,
}
)
message_ids.append(msg.id)
if not messages_data:
raise HTTPException(status_code=400, detail="Cannot share an empty chat")
# Compute content hash for deduplication
content_hash = compute_content_hash(messages_data)
# Check if identical snapshot already exists
existing_result = await session.execute(
select(PublicChatSnapshot).filter(
PublicChatSnapshot.thread_id == thread_id,
PublicChatSnapshot.content_hash == content_hash,
)
)
existing = existing_result.scalars().first()
if existing:
# Return existing snapshot URL
2026-01-26 13:07:46 +02:00
return {
"snapshot_id": existing.id,
"share_token": existing.share_token,
"public_url": f"{frontend_url}/public/{existing.share_token}",
"is_new": False,
2026-01-26 13:07:46 +02:00
}
# Get thread author info
thread_author = await get_author_display(session, thread.created_by_id, user_cache)
# Create snapshot data
snapshot_data = {
"title": thread.title,
"snapshot_at": datetime.now(UTC).isoformat(),
"author": thread_author,
"messages": messages_data,
"podcasts": podcasts_data,
"reports": reports_data,
}
# Create new snapshot
share_token = secrets.token_urlsafe(48)
snapshot = PublicChatSnapshot(
thread_id=thread_id,
share_token=share_token,
content_hash=content_hash,
snapshot_data=snapshot_data,
message_ids=message_ids,
created_by_user_id=user.id,
)
session.add(snapshot)
await session.commit()
await session.refresh(snapshot)
return {
"snapshot_id": snapshot.id,
"share_token": snapshot.share_token,
"public_url": f"{frontend_url}/public/{snapshot.share_token}",
"is_new": True,
}
async def _get_podcast_for_snapshot(
session: AsyncSession,
podcast_id: int,
) -> dict | None:
"""Get podcast info for embedding in snapshot_data."""
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
podcast = result.scalars().first()
if not podcast or podcast.status != PodcastStatus.READY:
return None
2026-01-26 13:07:46 +02:00
return {
"original_id": podcast.id,
"title": podcast.title,
"transcript": podcast.podcast_transcript,
"file_path": podcast.file_location,
2026-01-26 13:07:46 +02:00
}
async def _get_report_for_snapshot(
session: AsyncSession,
report_id: int,
) -> dict | None:
"""Get report info for embedding in snapshot_data."""
result = await session.execute(select(Report).filter(Report.id == report_id))
report = result.scalars().first()
if not report:
return None
return {
"original_id": report.id,
"title": report.title,
"content": report.content,
"report_metadata": report.report_metadata,
"report_group_id": report.report_group_id,
"created_at": report.created_at.isoformat() if report.created_at else None,
}
# =============================================================================
# Snapshot Retrieval
# =============================================================================
async def get_snapshot_by_token(
session: AsyncSession,
share_token: str,
) -> PublicChatSnapshot | None:
"""Get a snapshot by its share token."""
result = await session.execute(
2026-02-01 21:17:24 -08:00
select(PublicChatSnapshot).filter(PublicChatSnapshot.share_token == share_token)
)
return result.scalars().first()
2026-01-26 13:07:46 +02:00
async def get_public_chat(
session: AsyncSession,
share_token: str,
) -> dict:
"""
Get public chat data from a snapshot.
2026-01-26 13:07:46 +02:00
Returns sanitized content suitable for public viewing.
"""
snapshot = await get_snapshot_by_token(session, share_token)
if not snapshot:
raise HTTPException(status_code=404, detail="Not found")
data = snapshot.snapshot_data
return {
"thread": {
"title": data.get("title", "Untitled"),
"created_at": data.get("snapshot_at"),
},
"messages": data.get("messages", []),
}
async def list_snapshots_for_thread(
session: AsyncSession,
thread_id: int,
user: User,
) -> list[dict]:
"""List all public snapshots for a thread."""
from app.config import config
2026-01-26 13:07:46 +02:00
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
2026-01-26 13:07:46 +02:00
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
2026-01-26 13:07:46 +02:00
2026-02-04 18:46:12 +02:00
# Check permission to view public share links
await check_permission(
session,
user,
thread.search_space_id,
Permission.PUBLIC_SHARING_VIEW.value,
"You don't have permission to view public share links",
)
2026-01-26 13:07:46 +02:00
result = await session.execute(
select(PublicChatSnapshot)
.filter(PublicChatSnapshot.thread_id == thread_id)
.order_by(PublicChatSnapshot.created_at.desc())
)
snapshots = result.scalars().all()
frontend_url = (config.NEXT_FRONTEND_URL or "").rstrip("/")
return [
{
"id": s.id,
"share_token": s.share_token,
"public_url": f"{frontend_url}/public/{s.share_token}",
"created_at": s.created_at.isoformat() if s.created_at else None,
"message_count": len(s.message_ids) if s.message_ids else 0,
}
for s in snapshots
]
async def list_snapshots_for_search_space(
session: AsyncSession,
search_space_id: int,
user: User,
) -> list[dict]:
"""List all public snapshots for a search space."""
from app.config import config
await check_permission(
session,
user,
search_space_id,
Permission.PUBLIC_SHARING_VIEW.value,
"You don't have permission to view public share links",
)
result = await session.execute(
select(PublicChatSnapshot)
.join(NewChatThread, PublicChatSnapshot.thread_id == NewChatThread.id)
.filter(NewChatThread.search_space_id == search_space_id)
.order_by(PublicChatSnapshot.created_at.desc())
)
snapshots = result.scalars().all()
snapshot_thread_ids = [s.thread_id for s in snapshots]
thread_result = await session.execute(
select(NewChatThread.id, NewChatThread.title).filter(
NewChatThread.id.in_(snapshot_thread_ids)
)
)
thread_titles = {row[0]: row[1] for row in thread_result.fetchall()}
frontend_url = (config.NEXT_FRONTEND_URL or "").rstrip("/")
return [
{
"id": s.id,
"share_token": s.share_token,
"public_url": f"{frontend_url}/public/{s.share_token}",
"created_at": s.created_at.isoformat() if s.created_at else None,
"message_count": len(s.message_ids) if s.message_ids else 0,
"thread_id": s.thread_id,
"thread_title": thread_titles.get(s.thread_id, "Untitled"),
2026-02-10 19:06:21 +05:30
"created_by_user_id": str(s.created_by_user_id)
if s.created_by_user_id
else None,
}
for s in snapshots
]
2026-01-26 13:07:46 +02:00
# =============================================================================
# Snapshot Deletion
# =============================================================================
2026-01-26 13:07:46 +02:00
async def delete_snapshot(
2026-01-26 13:07:46 +02:00
session: AsyncSession,
thread_id: int,
snapshot_id: int,
user: User,
) -> bool:
"""Delete a specific snapshot. Only thread owner can delete."""
# Get snapshot with thread
2026-01-26 13:07:46 +02:00
result = await session.execute(
select(PublicChatSnapshot)
.options(selectinload(PublicChatSnapshot.thread))
2026-01-26 13:07:46 +02:00
.filter(
PublicChatSnapshot.id == snapshot_id,
PublicChatSnapshot.thread_id == thread_id,
2026-01-26 13:07:46 +02:00
)
)
snapshot = result.scalars().first()
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
await check_permission(
session,
user,
snapshot.thread.search_space_id,
Permission.PUBLIC_SHARING_DELETE.value,
"You don't have permission to delete public share links",
)
await session.delete(snapshot)
await session.commit()
return True
async def delete_affected_snapshots(
2026-02-01 21:17:24 -08:00
session: AsyncSession,
thread_id: int,
message_ids: list[int],
) -> int:
"""
Delete snapshots that contain any of the given message IDs.
Called when messages are edited/deleted/regenerated.
Uses independent session to work reliably in streaming response cleanup.
"""
if not message_ids:
return 0
from sqlalchemy.dialects.postgresql import array
from app.db import async_session_maker
async with async_session_maker() as independent_session:
result = await independent_session.execute(
delete(PublicChatSnapshot)
.where(PublicChatSnapshot.thread_id == thread_id)
.where(PublicChatSnapshot.message_ids.op("&&")(array(message_ids)))
.returning(PublicChatSnapshot.id)
)
deleted_ids = result.scalars().all()
await independent_session.commit()
return len(deleted_ids)
# =============================================================================
# Cloning from Snapshot
# =============================================================================
2026-01-26 15:03:28 +02:00
async def get_user_default_search_space(
session: AsyncSession,
user_id: UUID,
) -> int | None:
"""
Get user's default search space for cloning.
Returns the first search space where user is owner, or None if not found.
"""
result = await session.execute(
select(SearchSpaceMembership)
.filter(
SearchSpaceMembership.user_id == user_id,
SearchSpaceMembership.is_owner.is_(True),
)
.limit(1)
)
membership = result.scalars().first()
if membership:
return membership.search_space_id
return None
async def clone_from_snapshot(
2026-01-26 15:03:28 +02:00
session: AsyncSession,
share_token: str,
user: User,
) -> dict:
2026-01-26 15:03:28 +02:00
"""
Copy messages and podcasts from source thread to target thread.
2026-01-26 15:03:28 +02:00
Creates thread and copies messages from snapshot_data.
When encountering generate_podcast tool-calls, creates cloned podcast records
and updates the podcast_id references inline.
Returns the new thread info.
2026-01-26 15:03:28 +02:00
"""
import copy
snapshot = await get_snapshot_by_token(session, share_token)
if not snapshot:
raise HTTPException(
status_code=404, detail="Chat not found or no longer public"
)
target_search_space_id = await get_user_default_search_space(session, user.id)
if target_search_space_id is None:
raise HTTPException(status_code=400, detail="No search space found for user")
data = snapshot.snapshot_data
messages_data = data.get("messages", [])
podcasts_lookup = {p.get("original_id"): p for p in data.get("podcasts", [])}
reports_lookup = {r.get("original_id"): r for r in data.get("reports", [])}
new_thread = NewChatThread(
title=data.get("title", "Cloned Chat"),
archived=False,
visibility=ChatVisibility.PRIVATE,
search_space_id=target_search_space_id,
created_by_id=user.id,
cloned_from_thread_id=snapshot.thread_id,
cloned_from_snapshot_id=snapshot.id,
cloned_at=datetime.now(UTC),
needs_history_bootstrap=True,
)
session.add(new_thread)
await session.flush()
podcast_id_mapping: dict[int, int] = {}
report_id_mapping: dict[int, int] = {}
# Check which authors from snapshot still exist in DB
author_ids_from_snapshot: set[UUID] = set()
for msg_data in messages_data:
if author_str := msg_data.get("author_id"):
2026-02-01 21:17:24 -08:00
with contextlib.suppress(ValueError, TypeError):
author_ids_from_snapshot.add(UUID(author_str))
existing_authors: set[UUID] = set()
if author_ids_from_snapshot:
result = await session.execute(
select(User.id).where(User.id.in_(author_ids_from_snapshot))
)
existing_authors = {row[0] for row in result.fetchall()}
for msg_data in messages_data:
role = msg_data.get("role", "user")
# Use original author if exists, otherwise None
author_id = None
if author_str := msg_data.get("author_id"):
try:
parsed_id = UUID(author_str)
if parsed_id in existing_authors:
author_id = parsed_id
except (ValueError, TypeError):
pass
content = copy.deepcopy(msg_data.get("content", []))
if isinstance(content, list):
for part in content:
if (
isinstance(part, dict)
and part.get("type") == "tool-call"
and part.get("toolName") == "generate_podcast"
):
result = part.get("result", {})
old_podcast_id = result.get("podcast_id")
if old_podcast_id and old_podcast_id not in podcast_id_mapping:
podcast_info = podcasts_lookup.get(old_podcast_id)
if podcast_info:
new_podcast = Podcast(
title=podcast_info.get("title", "Cloned Podcast"),
podcast_transcript=podcast_info.get("transcript"),
file_location=podcast_info.get("file_path"),
status=PodcastStatus.READY,
search_space_id=target_search_space_id,
thread_id=new_thread.id,
)
session.add(new_podcast)
await session.flush()
podcast_id_mapping[old_podcast_id] = new_podcast.id
if old_podcast_id and old_podcast_id in podcast_id_mapping:
part["result"] = {
**result,
"podcast_id": podcast_id_mapping[old_podcast_id],
}
if (
isinstance(part, dict)
and part.get("type") == "tool-call"
and part.get("toolName") == "generate_report"
):
result = part.get("result", {})
old_report_id = result.get("report_id")
if old_report_id and old_report_id not in report_id_mapping:
report_info = reports_lookup.get(old_report_id)
if report_info:
new_report = Report(
title=report_info.get("title", "Cloned Report"),
content=report_info.get("content"),
report_metadata=report_info.get("report_metadata"),
search_space_id=target_search_space_id,
thread_id=new_thread.id,
)
session.add(new_report)
await session.flush()
# For cloned reports, set report_group_id = own id
# (each cloned report starts as its own v1)
new_report.report_group_id = new_report.id
report_id_mapping[old_report_id] = new_report.id
if old_report_id and old_report_id in report_id_mapping:
part["result"] = {
**result,
"report_id": report_id_mapping[old_report_id],
}
new_message = NewChatMessage(
thread_id=new_thread.id,
role=role,
content=content,
author_id=author_id,
2026-01-26 15:03:28 +02:00
)
session.add(new_message)
2026-01-26 15:03:28 +02:00
await session.commit()
await session.refresh(new_thread)
return {
"thread_id": new_thread.id,
"search_space_id": target_search_space_id,
}
2026-01-26 15:03:28 +02:00
async def get_snapshot_podcast(
session: AsyncSession,
share_token: str,
podcast_id: int,
) -> dict | None:
"""
Get podcast info from a snapshot by original podcast ID.
Used for streaming podcast audio from public view.
Looks up the podcast by its original_id in the snapshot's podcasts array.
"""
snapshot = await get_snapshot_by_token(session, share_token)
if not snapshot:
return None
podcasts = snapshot.snapshot_data.get("podcasts", [])
# Find podcast by original_id
for podcast in podcasts:
if podcast.get("original_id") == podcast_id:
return podcast
return None
async def get_snapshot_report(
session: AsyncSession,
share_token: str,
report_id: int,
) -> dict | None:
"""
Get report info from a snapshot by original report ID.
Used for displaying report content in public view.
Looks up the report by its original_id in the snapshot's reports array.
"""
snapshot = await get_snapshot_by_token(session, share_token)
if not snapshot:
return None
reports = snapshot.snapshot_data.get("reports", [])
# Find report by original_id
for report in reports:
if report.get("original_id") == report_id:
return report
return None
async def get_snapshot_report_versions(
session: AsyncSession,
share_token: str,
report_group_id: int | None,
) -> list[dict]:
"""
Get all report versions in the same group from a snapshot.
Returns a list of lightweight version entries (id + created_at)
for the version switcher UI, sorted by original_id (insertion order).
"""
if not report_group_id:
return []
snapshot = await get_snapshot_by_token(session, share_token)
if not snapshot:
return []
reports = snapshot.snapshot_data.get("reports", [])
2026-02-13 02:43:26 +05:30
siblings = [r for r in reports if r.get("report_group_id") == report_group_id]
# Sort by original_id (ascending = insertion order ≈ created_at order)
siblings.sort(key=lambda r: r.get("original_id", 0))
return [
{"id": r.get("original_id"), "created_at": r.get("created_at")}
for r in siblings
]