SurfSense/surfsense_backend/app/services/public_chat_service.py
CREDO23 ecb5572e69 fix(backend): remove inaccessible podcast references when cloning chats
When a podcast can't be cloned (not READY), remove the podcast_id from
the cloned message to prevent 403 errors when users try to access it.
2026-01-28 19:25:15 +02:00

379 lines
11 KiB
Python

"""
Service layer for public chat sharing and cloning.
"""
import re
import secrets
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db import NewChatThread, User
UI_TOOLS = {
"display_image",
"link_preview",
"generate_podcast",
"scrape_webpage",
"multi_link_preview",
}
def strip_citations(text: str) -> str:
"""
Remove [citation:X] and [citation:doc-X] patterns from text.
Preserves newlines to maintain markdown formatting.
"""
# Remove citation patterns
text = re.sub(r"[\[【]\u200B?citation:(doc-)?\d+\u200B?[\]】]", "", text)
# Collapse multiple spaces/tabs (but NOT newlines) into single space
text = re.sub(r"[^\S\n]+", " ", text)
# Normalize excessive blank lines (3+ newlines → 2)
text = re.sub(r"\n{3,}", "\n\n", text)
# Clean up spaces around newlines
text = re.sub(r" *\n *", "\n", text)
return text.strip()
def sanitize_content_for_public(content: list | str | None) -> list:
"""
Filter message content for public view.
Strips citations and filters to UI-relevant tools.
"""
if content is None:
return []
if isinstance(content, str):
clean_text = strip_citations(content)
return [{"type": "text", "text": clean_text}] if clean_text else []
if not isinstance(content, list):
return []
sanitized = []
for part in content:
if not isinstance(part, dict):
continue
part_type = part.get("type")
if part_type == "text":
clean_text = strip_citations(part.get("text", ""))
if clean_text:
sanitized.append({"type": "text", "text": clean_text})
elif part_type == "tool-call":
tool_name = part.get("toolName")
if tool_name not in UI_TOOLS:
continue
sanitized.append(part)
return sanitized
async def get_author_display(
session: AsyncSession,
author_id: UUID | None,
user_cache: dict[UUID, dict],
) -> dict | None:
"""Transform author UUID to display info."""
if author_id is None:
return None
if author_id not in user_cache:
result = await session.execute(select(User).filter(User.id == author_id))
user = result.scalars().first()
if user:
user_cache[author_id] = {
"display_name": user.display_name or "User",
"avatar_url": user.avatar_url,
}
else:
user_cache[author_id] = {
"display_name": "Unknown User",
"avatar_url": None,
}
return user_cache[author_id]
async def toggle_public_share(
session: AsyncSession,
thread_id: int,
enabled: bool,
user: User,
base_url: str,
) -> dict:
"""
Enable or disable public sharing for a thread.
Only the thread owner can toggle public sharing.
When enabling, generates a new token if one doesn't exist.
When disabling, keeps the token for potential re-enable.
"""
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
if thread.created_by_id != user.id:
raise HTTPException(
status_code=403,
detail="Only the creator of this chat can manage public sharing",
)
if enabled and not thread.public_share_token:
thread.public_share_token = secrets.token_urlsafe(48)
thread.public_share_enabled = enabled
await session.commit()
await session.refresh(thread)
if enabled:
return {
"enabled": True,
"public_url": f"{base_url}/public/{thread.public_share_token}",
"share_token": thread.public_share_token,
}
return {
"enabled": False,
"public_url": None,
"share_token": None,
}
async def get_public_chat(
session: AsyncSession,
share_token: str,
) -> dict:
"""
Get a public chat by share token.
Returns sanitized content suitable for public viewing.
"""
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(
NewChatThread.public_share_token == share_token,
NewChatThread.public_share_enabled.is_(True),
)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Not found")
user_cache: dict[UUID, dict] = {}
messages = []
for msg in sorted(thread.messages, key=lambda m: m.created_at):
author = await get_author_display(session, msg.author_id, user_cache)
sanitized_content = sanitize_content_for_public(msg.content)
messages.append(
{
"role": msg.role,
"content": sanitized_content,
"author": author,
"created_at": msg.created_at,
}
)
return {
"thread": {
"title": thread.title,
"created_at": thread.created_at,
},
"messages": messages,
}
async def get_thread_by_share_token(
session: AsyncSession,
share_token: str,
) -> NewChatThread | None:
"""Get a thread by its public share token if sharing is enabled."""
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(
NewChatThread.public_share_token == share_token,
NewChatThread.public_share_enabled.is_(True),
)
)
return result.scalars().first()
async def get_user_default_search_space(
session: AsyncSession,
user_id: UUID,
) -> int | None:
"""
Get user's default search space for cloning.
Returns the first search space where user is owner, or None if not found.
"""
from app.db import SearchSpaceMembership
result = await session.execute(
select(SearchSpaceMembership)
.filter(
SearchSpaceMembership.user_id == user_id,
SearchSpaceMembership.is_owner.is_(True),
)
.limit(1)
)
membership = result.scalars().first()
if membership:
return membership.search_space_id
return None
async def complete_clone_content(
session: AsyncSession,
target_thread: NewChatThread,
source_thread_id: int,
target_search_space_id: int,
) -> int:
"""
Copy messages and podcasts from source thread to target thread.
Sets clone_pending=False and needs_history_bootstrap=True when done.
Returns the number of messages copied.
"""
from app.db import NewChatMessage
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == source_thread_id)
)
source_thread = result.scalars().first()
if not source_thread:
raise ValueError("Source thread not found")
podcast_id_map: dict[int, int] = {}
message_count = 0
for msg in sorted(source_thread.messages, key=lambda m: m.created_at):
new_content = sanitize_content_for_public(msg.content)
if isinstance(new_content, list):
for part in new_content:
if (
isinstance(part, dict)
and part.get("type") == "tool-call"
and part.get("toolName") == "generate_podcast"
):
result_data = part.get("result", {})
old_podcast_id = result_data.get("podcast_id")
if old_podcast_id and old_podcast_id not in podcast_id_map:
new_podcast_id = await _clone_podcast(
session,
old_podcast_id,
target_search_space_id,
target_thread.id,
)
if new_podcast_id:
podcast_id_map[old_podcast_id] = new_podcast_id
if old_podcast_id and old_podcast_id in podcast_id_map:
result_data["podcast_id"] = podcast_id_map[old_podcast_id]
elif old_podcast_id:
# Podcast couldn't be cloned (not ready), remove reference
result_data.pop("podcast_id", None)
new_message = NewChatMessage(
thread_id=target_thread.id,
role=msg.role,
content=new_content,
author_id=msg.author_id,
created_at=msg.created_at,
)
session.add(new_message)
message_count += 1
target_thread.clone_pending = False
target_thread.needs_history_bootstrap = True
await session.commit()
return message_count
async def _clone_podcast(
session: AsyncSession,
podcast_id: int,
target_search_space_id: int,
target_thread_id: int,
) -> int | None:
"""Clone a podcast record and its audio file. Only clones ready podcasts."""
import shutil
import uuid
from pathlib import Path
from app.db import Podcast, PodcastStatus
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
original = result.scalars().first()
if not original or original.status != PodcastStatus.READY:
return None
new_file_path = None
if original.file_location:
original_path = Path(original.file_location)
if original_path.exists():
new_filename = f"{uuid.uuid4()}_podcast.mp3"
new_dir = Path("podcasts")
new_dir.mkdir(parents=True, exist_ok=True)
new_file_path = str(new_dir / new_filename)
shutil.copy2(original.file_location, new_file_path)
new_podcast = Podcast(
title=original.title,
podcast_transcript=original.podcast_transcript,
file_location=new_file_path,
status=PodcastStatus.READY,
search_space_id=target_search_space_id,
thread_id=target_thread_id,
)
session.add(new_podcast)
await session.flush()
return new_podcast.id
async def is_podcast_publicly_accessible(
session: AsyncSession,
podcast_id: int,
) -> bool:
"""
Check if a podcast belongs to a publicly shared thread.
Uses the thread_id foreign key for efficient lookup.
"""
from app.db import Podcast
result = await session.execute(
select(Podcast)
.options(selectinload(Podcast.thread))
.filter(Podcast.id == podcast_id)
)
podcast = result.scalars().first()
if not podcast or not podcast.thread:
return False
return podcast.thread.public_share_enabled