mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 16:56:22 +02:00
825 lines
26 KiB
Python
825 lines
26 KiB
Python
"""
|
|
Service layer for public chat sharing via immutable snapshots.
|
|
|
|
Key concepts:
|
|
- Snapshots are frozen copies of a chat at a specific point in time
|
|
- Content hash enables deduplication (same content = same URL)
|
|
- Podcasts are embedded in snapshot_data for self-contained public views
|
|
- Single-phase clone reads directly from snapshot_data
|
|
"""
|
|
|
|
import contextlib
|
|
import hashlib
|
|
import json
|
|
import re
|
|
import secrets
|
|
from datetime import UTC, datetime
|
|
from uuid import UUID
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import delete, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.db import (
|
|
ChatVisibility,
|
|
NewChatMessage,
|
|
NewChatThread,
|
|
Permission,
|
|
Podcast,
|
|
PodcastStatus,
|
|
PublicChatSnapshot,
|
|
Report,
|
|
SearchSpaceMembership,
|
|
User,
|
|
)
|
|
from app.utils.rbac import check_permission
|
|
|
|
UI_TOOLS = {
|
|
"display_image",
|
|
"link_preview",
|
|
"generate_podcast",
|
|
"generate_report",
|
|
"scrape_webpage",
|
|
"multi_link_preview",
|
|
}
|
|
|
|
|
|
def strip_citations(text: str) -> str:
|
|
"""
|
|
Remove [citation:X] and [citation:doc-X] patterns from text.
|
|
Preserves newlines to maintain markdown formatting.
|
|
"""
|
|
# Remove citation patterns
|
|
text = re.sub(r"[\[【]\u200B?citation:(doc-)?\d+\u200B?[\]】]", "", text)
|
|
# Collapse multiple spaces/tabs (but NOT newlines) into single space
|
|
text = re.sub(r"[^\S\n]+", " ", text)
|
|
# Normalize excessive blank lines (3+ newlines → 2)
|
|
text = re.sub(r"\n{3,}", "\n\n", text)
|
|
# Clean up spaces around newlines
|
|
text = re.sub(r" *\n *", "\n", text)
|
|
return text.strip()
|
|
|
|
|
|
def sanitize_content_for_public(content: list | str | None) -> list:
|
|
"""
|
|
Filter message content for public view.
|
|
Strips citations and filters to UI-relevant tools.
|
|
"""
|
|
if content is None:
|
|
return []
|
|
|
|
if isinstance(content, str):
|
|
clean_text = strip_citations(content)
|
|
return [{"type": "text", "text": clean_text}] if clean_text else []
|
|
|
|
if not isinstance(content, list):
|
|
return []
|
|
|
|
sanitized = []
|
|
for part in content:
|
|
if not isinstance(part, dict):
|
|
continue
|
|
|
|
part_type = part.get("type")
|
|
|
|
if part_type == "text":
|
|
clean_text = strip_citations(part.get("text", ""))
|
|
if clean_text:
|
|
sanitized.append({"type": "text", "text": clean_text})
|
|
|
|
elif part_type == "tool-call":
|
|
tool_name = part.get("toolName")
|
|
if tool_name not in UI_TOOLS:
|
|
continue
|
|
sanitized.append(part)
|
|
|
|
return sanitized
|
|
|
|
|
|
async def get_author_display(
|
|
session: AsyncSession,
|
|
author_id: UUID | None,
|
|
user_cache: dict[UUID, dict],
|
|
) -> dict | None:
|
|
"""Transform author UUID to display info."""
|
|
if author_id is None:
|
|
return None
|
|
|
|
if author_id not in user_cache:
|
|
result = await session.execute(select(User).filter(User.id == author_id))
|
|
user = result.scalars().first()
|
|
if user:
|
|
user_cache[author_id] = {
|
|
"display_name": user.display_name or "User",
|
|
"avatar_url": user.avatar_url,
|
|
}
|
|
else:
|
|
user_cache[author_id] = {
|
|
"display_name": "Unknown User",
|
|
"avatar_url": None,
|
|
}
|
|
|
|
return user_cache[author_id]
|
|
|
|
|
|
# =============================================================================
|
|
# Content Hashing
|
|
# =============================================================================
|
|
|
|
|
|
def compute_content_hash(messages: list[dict]) -> str:
|
|
"""
|
|
Compute SHA-256 hash of message content for deduplication.
|
|
|
|
The hash is based on message IDs and content, ensuring that:
|
|
- Same messages = same hash = same URL (deduplication)
|
|
- Any change = different hash = new URL
|
|
"""
|
|
# Sort by message ID to ensure consistent ordering
|
|
sorted_messages = sorted(messages, key=lambda m: m.get("id", 0))
|
|
|
|
# Create normalized representation
|
|
normalized = []
|
|
for msg in sorted_messages:
|
|
normalized.append(
|
|
{
|
|
"id": msg.get("id"),
|
|
"role": msg.get("role"),
|
|
"content": msg.get("content"),
|
|
}
|
|
)
|
|
|
|
content_str = json.dumps(normalized, sort_keys=True, separators=(",", ":"))
|
|
return hashlib.sha256(content_str.encode()).hexdigest()
|
|
|
|
|
|
# =============================================================================
|
|
# Snapshot Creation
|
|
# =============================================================================
|
|
|
|
|
|
async def create_snapshot(
|
|
session: AsyncSession,
|
|
thread_id: int,
|
|
user: User,
|
|
) -> dict:
|
|
"""
|
|
Create a public snapshot of a chat thread.
|
|
|
|
Returns existing snapshot if content unchanged (same hash).
|
|
Returns new snapshot with unique URL if content changed.
|
|
"""
|
|
from app.config import config
|
|
|
|
frontend_url = (config.NEXT_FRONTEND_URL or "").rstrip("/")
|
|
result = await session.execute(
|
|
select(NewChatThread)
|
|
.options(selectinload(NewChatThread.messages))
|
|
.filter(NewChatThread.id == thread_id)
|
|
)
|
|
thread = result.scalars().first()
|
|
|
|
if not thread:
|
|
raise HTTPException(status_code=404, detail="Thread not found")
|
|
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
thread.search_space_id,
|
|
Permission.PUBLIC_SHARING_CREATE.value,
|
|
"You don't have permission to create public share links",
|
|
)
|
|
|
|
# Build snapshot data
|
|
user_cache: dict[UUID, dict] = {}
|
|
messages_data = []
|
|
message_ids = []
|
|
podcasts_data = []
|
|
podcast_ids_seen: set[int] = set()
|
|
reports_data = []
|
|
report_ids_seen: set[int] = set()
|
|
|
|
for msg in sorted(thread.messages, key=lambda m: m.created_at):
|
|
author = await get_author_display(session, msg.author_id, user_cache)
|
|
sanitized_content = sanitize_content_for_public(msg.content)
|
|
|
|
# Extract podcast/report references and update status to "ready" for completed ones
|
|
if isinstance(sanitized_content, list):
|
|
for part in sanitized_content:
|
|
if not isinstance(part, dict) or part.get("type") != "tool-call":
|
|
continue
|
|
|
|
tool_name = part.get("toolName")
|
|
|
|
if tool_name == "generate_podcast":
|
|
result_data = part.get("result", {})
|
|
podcast_id = result_data.get("podcast_id")
|
|
if podcast_id and podcast_id not in podcast_ids_seen:
|
|
podcast_info = await _get_podcast_for_snapshot(
|
|
session, podcast_id
|
|
)
|
|
if podcast_info:
|
|
podcasts_data.append(podcast_info)
|
|
podcast_ids_seen.add(podcast_id)
|
|
# Update status to "ready" so frontend renders PodcastPlayer
|
|
part["result"] = {**result_data, "status": "ready"}
|
|
|
|
elif tool_name == "generate_report":
|
|
result_data = part.get("result", {})
|
|
report_id = result_data.get("report_id")
|
|
if report_id and report_id not in report_ids_seen:
|
|
report_info = await _get_report_for_snapshot(session, report_id)
|
|
if report_info:
|
|
reports_data.append(report_info)
|
|
report_ids_seen.add(report_id)
|
|
# Update status to "ready" so frontend renders ReportCard
|
|
part["result"] = {**result_data, "status": "ready"}
|
|
|
|
messages_data.append(
|
|
{
|
|
"id": msg.id,
|
|
"role": msg.role.value if hasattr(msg.role, "value") else str(msg.role),
|
|
"content": sanitized_content,
|
|
"author": author,
|
|
"author_id": str(msg.author_id) if msg.author_id else None,
|
|
"created_at": msg.created_at.isoformat() if msg.created_at else None,
|
|
}
|
|
)
|
|
message_ids.append(msg.id)
|
|
|
|
if not messages_data:
|
|
raise HTTPException(status_code=400, detail="Cannot share an empty chat")
|
|
|
|
# Compute content hash for deduplication
|
|
content_hash = compute_content_hash(messages_data)
|
|
|
|
# Check if identical snapshot already exists
|
|
existing_result = await session.execute(
|
|
select(PublicChatSnapshot).filter(
|
|
PublicChatSnapshot.thread_id == thread_id,
|
|
PublicChatSnapshot.content_hash == content_hash,
|
|
)
|
|
)
|
|
existing = existing_result.scalars().first()
|
|
|
|
if existing:
|
|
# Return existing snapshot URL
|
|
return {
|
|
"snapshot_id": existing.id,
|
|
"share_token": existing.share_token,
|
|
"public_url": f"{frontend_url}/public/{existing.share_token}",
|
|
"is_new": False,
|
|
}
|
|
|
|
# Get thread author info
|
|
thread_author = await get_author_display(session, thread.created_by_id, user_cache)
|
|
|
|
# Create snapshot data
|
|
snapshot_data = {
|
|
"title": thread.title,
|
|
"snapshot_at": datetime.now(UTC).isoformat(),
|
|
"author": thread_author,
|
|
"messages": messages_data,
|
|
"podcasts": podcasts_data,
|
|
"reports": reports_data,
|
|
}
|
|
|
|
# Create new snapshot
|
|
share_token = secrets.token_urlsafe(48)
|
|
snapshot = PublicChatSnapshot(
|
|
thread_id=thread_id,
|
|
share_token=share_token,
|
|
content_hash=content_hash,
|
|
snapshot_data=snapshot_data,
|
|
message_ids=message_ids,
|
|
created_by_user_id=user.id,
|
|
)
|
|
session.add(snapshot)
|
|
await session.commit()
|
|
await session.refresh(snapshot)
|
|
|
|
return {
|
|
"snapshot_id": snapshot.id,
|
|
"share_token": snapshot.share_token,
|
|
"public_url": f"{frontend_url}/public/{snapshot.share_token}",
|
|
"is_new": True,
|
|
}
|
|
|
|
|
|
async def _get_podcast_for_snapshot(
|
|
session: AsyncSession,
|
|
podcast_id: int,
|
|
) -> dict | None:
|
|
"""Get podcast info for embedding in snapshot_data."""
|
|
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
|
podcast = result.scalars().first()
|
|
|
|
if not podcast or podcast.status != PodcastStatus.READY:
|
|
return None
|
|
|
|
return {
|
|
"original_id": podcast.id,
|
|
"title": podcast.title,
|
|
"transcript": podcast.podcast_transcript,
|
|
"file_path": podcast.file_location,
|
|
}
|
|
|
|
|
|
async def _get_report_for_snapshot(
|
|
session: AsyncSession,
|
|
report_id: int,
|
|
) -> dict | None:
|
|
"""Get report info for embedding in snapshot_data."""
|
|
result = await session.execute(select(Report).filter(Report.id == report_id))
|
|
report = result.scalars().first()
|
|
|
|
if not report:
|
|
return None
|
|
|
|
return {
|
|
"original_id": report.id,
|
|
"title": report.title,
|
|
"content": report.content,
|
|
"report_metadata": report.report_metadata,
|
|
"report_group_id": report.report_group_id,
|
|
"created_at": report.created_at.isoformat() if report.created_at else None,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Snapshot Retrieval
|
|
# =============================================================================
|
|
|
|
|
|
async def get_snapshot_by_token(
|
|
session: AsyncSession,
|
|
share_token: str,
|
|
) -> PublicChatSnapshot | None:
|
|
"""Get a snapshot by its share token."""
|
|
result = await session.execute(
|
|
select(PublicChatSnapshot).filter(PublicChatSnapshot.share_token == share_token)
|
|
)
|
|
return result.scalars().first()
|
|
|
|
|
|
async def get_public_chat(
|
|
session: AsyncSession,
|
|
share_token: str,
|
|
) -> dict:
|
|
"""
|
|
Get public chat data from a snapshot.
|
|
|
|
Returns sanitized content suitable for public viewing.
|
|
"""
|
|
snapshot = await get_snapshot_by_token(session, share_token)
|
|
|
|
if not snapshot:
|
|
raise HTTPException(status_code=404, detail="Not found")
|
|
|
|
data = snapshot.snapshot_data
|
|
|
|
return {
|
|
"thread": {
|
|
"title": data.get("title", "Untitled"),
|
|
"created_at": data.get("snapshot_at"),
|
|
},
|
|
"messages": data.get("messages", []),
|
|
}
|
|
|
|
|
|
async def list_snapshots_for_thread(
|
|
session: AsyncSession,
|
|
thread_id: int,
|
|
user: User,
|
|
) -> list[dict]:
|
|
"""List all public snapshots for a thread."""
|
|
from app.config import config
|
|
|
|
result = await session.execute(
|
|
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
|
)
|
|
thread = result.scalars().first()
|
|
|
|
if not thread:
|
|
raise HTTPException(status_code=404, detail="Thread not found")
|
|
|
|
# Check permission to view public share links
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
thread.search_space_id,
|
|
Permission.PUBLIC_SHARING_VIEW.value,
|
|
"You don't have permission to view public share links",
|
|
)
|
|
|
|
result = await session.execute(
|
|
select(PublicChatSnapshot)
|
|
.filter(PublicChatSnapshot.thread_id == thread_id)
|
|
.order_by(PublicChatSnapshot.created_at.desc())
|
|
)
|
|
snapshots = result.scalars().all()
|
|
|
|
frontend_url = (config.NEXT_FRONTEND_URL or "").rstrip("/")
|
|
|
|
return [
|
|
{
|
|
"id": s.id,
|
|
"share_token": s.share_token,
|
|
"public_url": f"{frontend_url}/public/{s.share_token}",
|
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
|
"message_count": len(s.message_ids) if s.message_ids else 0,
|
|
}
|
|
for s in snapshots
|
|
]
|
|
|
|
|
|
async def list_snapshots_for_search_space(
|
|
session: AsyncSession,
|
|
search_space_id: int,
|
|
user: User,
|
|
) -> list[dict]:
|
|
"""List all public snapshots for a search space."""
|
|
from app.config import config
|
|
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
search_space_id,
|
|
Permission.PUBLIC_SHARING_VIEW.value,
|
|
"You don't have permission to view public share links",
|
|
)
|
|
|
|
result = await session.execute(
|
|
select(PublicChatSnapshot)
|
|
.join(NewChatThread, PublicChatSnapshot.thread_id == NewChatThread.id)
|
|
.filter(NewChatThread.search_space_id == search_space_id)
|
|
.order_by(PublicChatSnapshot.created_at.desc())
|
|
)
|
|
snapshots = result.scalars().all()
|
|
|
|
snapshot_thread_ids = [s.thread_id for s in snapshots]
|
|
thread_result = await session.execute(
|
|
select(NewChatThread.id, NewChatThread.title).filter(
|
|
NewChatThread.id.in_(snapshot_thread_ids)
|
|
)
|
|
)
|
|
thread_titles = {row[0]: row[1] for row in thread_result.fetchall()}
|
|
|
|
frontend_url = (config.NEXT_FRONTEND_URL or "").rstrip("/")
|
|
|
|
return [
|
|
{
|
|
"id": s.id,
|
|
"share_token": s.share_token,
|
|
"public_url": f"{frontend_url}/public/{s.share_token}",
|
|
"created_at": s.created_at.isoformat() if s.created_at else None,
|
|
"message_count": len(s.message_ids) if s.message_ids else 0,
|
|
"thread_id": s.thread_id,
|
|
"thread_title": thread_titles.get(s.thread_id, "Untitled"),
|
|
"created_by_user_id": str(s.created_by_user_id)
|
|
if s.created_by_user_id
|
|
else None,
|
|
}
|
|
for s in snapshots
|
|
]
|
|
|
|
|
|
# =============================================================================
|
|
# Snapshot Deletion
|
|
# =============================================================================
|
|
|
|
|
|
async def delete_snapshot(
|
|
session: AsyncSession,
|
|
thread_id: int,
|
|
snapshot_id: int,
|
|
user: User,
|
|
) -> bool:
|
|
"""Delete a specific snapshot. Only thread owner can delete."""
|
|
# Get snapshot with thread
|
|
result = await session.execute(
|
|
select(PublicChatSnapshot)
|
|
.options(selectinload(PublicChatSnapshot.thread))
|
|
.filter(
|
|
PublicChatSnapshot.id == snapshot_id,
|
|
PublicChatSnapshot.thread_id == thread_id,
|
|
)
|
|
)
|
|
snapshot = result.scalars().first()
|
|
|
|
if not snapshot:
|
|
raise HTTPException(status_code=404, detail="Snapshot not found")
|
|
|
|
await check_permission(
|
|
session,
|
|
user,
|
|
snapshot.thread.search_space_id,
|
|
Permission.PUBLIC_SHARING_DELETE.value,
|
|
"You don't have permission to delete public share links",
|
|
)
|
|
|
|
await session.delete(snapshot)
|
|
await session.commit()
|
|
return True
|
|
|
|
|
|
async def delete_affected_snapshots(
|
|
session: AsyncSession,
|
|
thread_id: int,
|
|
message_ids: list[int],
|
|
) -> int:
|
|
"""
|
|
Delete snapshots that contain any of the given message IDs.
|
|
|
|
Called when messages are edited/deleted/regenerated.
|
|
Uses independent session to work reliably in streaming response cleanup.
|
|
"""
|
|
if not message_ids:
|
|
return 0
|
|
|
|
from sqlalchemy.dialects.postgresql import array
|
|
|
|
from app.db import async_session_maker
|
|
|
|
async with async_session_maker() as independent_session:
|
|
result = await independent_session.execute(
|
|
delete(PublicChatSnapshot)
|
|
.where(PublicChatSnapshot.thread_id == thread_id)
|
|
.where(PublicChatSnapshot.message_ids.op("&&")(array(message_ids)))
|
|
.returning(PublicChatSnapshot.id)
|
|
)
|
|
|
|
deleted_ids = result.scalars().all()
|
|
await independent_session.commit()
|
|
|
|
return len(deleted_ids)
|
|
|
|
|
|
# =============================================================================
|
|
# Cloning from Snapshot
|
|
# =============================================================================
|
|
|
|
|
|
async def get_user_default_search_space(
|
|
session: AsyncSession,
|
|
user_id: UUID,
|
|
) -> int | None:
|
|
"""
|
|
Get user's default search space for cloning.
|
|
|
|
Returns the first search space where user is owner, or None if not found.
|
|
"""
|
|
result = await session.execute(
|
|
select(SearchSpaceMembership)
|
|
.filter(
|
|
SearchSpaceMembership.user_id == user_id,
|
|
SearchSpaceMembership.is_owner.is_(True),
|
|
)
|
|
.limit(1)
|
|
)
|
|
membership = result.scalars().first()
|
|
|
|
if membership:
|
|
return membership.search_space_id
|
|
|
|
return None
|
|
|
|
|
|
async def clone_from_snapshot(
|
|
session: AsyncSession,
|
|
share_token: str,
|
|
user: User,
|
|
) -> dict:
|
|
"""
|
|
Copy messages and podcasts from source thread to target thread.
|
|
|
|
Creates thread and copies messages from snapshot_data.
|
|
When encountering generate_podcast tool-calls, creates cloned podcast records
|
|
and updates the podcast_id references inline.
|
|
Returns the new thread info.
|
|
"""
|
|
import copy
|
|
|
|
snapshot = await get_snapshot_by_token(session, share_token)
|
|
|
|
if not snapshot:
|
|
raise HTTPException(
|
|
status_code=404, detail="Chat not found or no longer public"
|
|
)
|
|
|
|
target_search_space_id = await get_user_default_search_space(session, user.id)
|
|
|
|
if target_search_space_id is None:
|
|
raise HTTPException(status_code=400, detail="No search space found for user")
|
|
|
|
data = snapshot.snapshot_data
|
|
messages_data = data.get("messages", [])
|
|
podcasts_lookup = {p.get("original_id"): p for p in data.get("podcasts", [])}
|
|
reports_lookup = {r.get("original_id"): r for r in data.get("reports", [])}
|
|
|
|
new_thread = NewChatThread(
|
|
title=data.get("title", "Cloned Chat"),
|
|
archived=False,
|
|
visibility=ChatVisibility.PRIVATE,
|
|
search_space_id=target_search_space_id,
|
|
created_by_id=user.id,
|
|
cloned_from_thread_id=snapshot.thread_id,
|
|
cloned_from_snapshot_id=snapshot.id,
|
|
cloned_at=datetime.now(UTC),
|
|
needs_history_bootstrap=True,
|
|
)
|
|
session.add(new_thread)
|
|
await session.flush()
|
|
|
|
podcast_id_mapping: dict[int, int] = {}
|
|
report_id_mapping: dict[int, int] = {}
|
|
|
|
# Check which authors from snapshot still exist in DB
|
|
author_ids_from_snapshot: set[UUID] = set()
|
|
for msg_data in messages_data:
|
|
if author_str := msg_data.get("author_id"):
|
|
with contextlib.suppress(ValueError, TypeError):
|
|
author_ids_from_snapshot.add(UUID(author_str))
|
|
|
|
existing_authors: set[UUID] = set()
|
|
if author_ids_from_snapshot:
|
|
result = await session.execute(
|
|
select(User.id).where(User.id.in_(author_ids_from_snapshot))
|
|
)
|
|
existing_authors = {row[0] for row in result.fetchall()}
|
|
|
|
for msg_data in messages_data:
|
|
role = msg_data.get("role", "user")
|
|
|
|
# Use original author if exists, otherwise None
|
|
author_id = None
|
|
if author_str := msg_data.get("author_id"):
|
|
try:
|
|
parsed_id = UUID(author_str)
|
|
if parsed_id in existing_authors:
|
|
author_id = parsed_id
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
content = copy.deepcopy(msg_data.get("content", []))
|
|
|
|
if isinstance(content, list):
|
|
for part in content:
|
|
if (
|
|
isinstance(part, dict)
|
|
and part.get("type") == "tool-call"
|
|
and part.get("toolName") == "generate_podcast"
|
|
):
|
|
result = part.get("result", {})
|
|
old_podcast_id = result.get("podcast_id")
|
|
|
|
if old_podcast_id and old_podcast_id not in podcast_id_mapping:
|
|
podcast_info = podcasts_lookup.get(old_podcast_id)
|
|
if podcast_info:
|
|
new_podcast = Podcast(
|
|
title=podcast_info.get("title", "Cloned Podcast"),
|
|
podcast_transcript=podcast_info.get("transcript"),
|
|
file_location=podcast_info.get("file_path"),
|
|
status=PodcastStatus.READY,
|
|
search_space_id=target_search_space_id,
|
|
thread_id=new_thread.id,
|
|
)
|
|
session.add(new_podcast)
|
|
await session.flush()
|
|
podcast_id_mapping[old_podcast_id] = new_podcast.id
|
|
|
|
if old_podcast_id and old_podcast_id in podcast_id_mapping:
|
|
part["result"] = {
|
|
**result,
|
|
"podcast_id": podcast_id_mapping[old_podcast_id],
|
|
}
|
|
|
|
if (
|
|
isinstance(part, dict)
|
|
and part.get("type") == "tool-call"
|
|
and part.get("toolName") == "generate_report"
|
|
):
|
|
result = part.get("result", {})
|
|
old_report_id = result.get("report_id")
|
|
|
|
if old_report_id and old_report_id not in report_id_mapping:
|
|
report_info = reports_lookup.get(old_report_id)
|
|
if report_info:
|
|
new_report = Report(
|
|
title=report_info.get("title", "Cloned Report"),
|
|
content=report_info.get("content"),
|
|
report_metadata=report_info.get("report_metadata"),
|
|
search_space_id=target_search_space_id,
|
|
thread_id=new_thread.id,
|
|
)
|
|
session.add(new_report)
|
|
await session.flush()
|
|
# For cloned reports, set report_group_id = own id
|
|
# (each cloned report starts as its own v1)
|
|
new_report.report_group_id = new_report.id
|
|
report_id_mapping[old_report_id] = new_report.id
|
|
|
|
if old_report_id and old_report_id in report_id_mapping:
|
|
part["result"] = {
|
|
**result,
|
|
"report_id": report_id_mapping[old_report_id],
|
|
}
|
|
|
|
new_message = NewChatMessage(
|
|
thread_id=new_thread.id,
|
|
role=role,
|
|
content=content,
|
|
author_id=author_id,
|
|
)
|
|
session.add(new_message)
|
|
|
|
await session.commit()
|
|
await session.refresh(new_thread)
|
|
|
|
return {
|
|
"thread_id": new_thread.id,
|
|
"search_space_id": target_search_space_id,
|
|
}
|
|
|
|
|
|
async def get_snapshot_podcast(
|
|
session: AsyncSession,
|
|
share_token: str,
|
|
podcast_id: int,
|
|
) -> dict | None:
|
|
"""
|
|
Get podcast info from a snapshot by original podcast ID.
|
|
|
|
Used for streaming podcast audio from public view.
|
|
Looks up the podcast by its original_id in the snapshot's podcasts array.
|
|
"""
|
|
snapshot = await get_snapshot_by_token(session, share_token)
|
|
|
|
if not snapshot:
|
|
return None
|
|
|
|
podcasts = snapshot.snapshot_data.get("podcasts", [])
|
|
|
|
# Find podcast by original_id
|
|
for podcast in podcasts:
|
|
if podcast.get("original_id") == podcast_id:
|
|
return podcast
|
|
|
|
return None
|
|
|
|
|
|
async def get_snapshot_report(
|
|
session: AsyncSession,
|
|
share_token: str,
|
|
report_id: int,
|
|
) -> dict | None:
|
|
"""
|
|
Get report info from a snapshot by original report ID.
|
|
|
|
Used for displaying report content in public view.
|
|
Looks up the report by its original_id in the snapshot's reports array.
|
|
"""
|
|
snapshot = await get_snapshot_by_token(session, share_token)
|
|
|
|
if not snapshot:
|
|
return None
|
|
|
|
reports = snapshot.snapshot_data.get("reports", [])
|
|
|
|
# Find report by original_id
|
|
for report in reports:
|
|
if report.get("original_id") == report_id:
|
|
return report
|
|
|
|
return None
|
|
|
|
|
|
async def get_snapshot_report_versions(
|
|
session: AsyncSession,
|
|
share_token: str,
|
|
report_group_id: int | None,
|
|
) -> list[dict]:
|
|
"""
|
|
Get all report versions in the same group from a snapshot.
|
|
|
|
Returns a list of lightweight version entries (id + created_at)
|
|
for the version switcher UI, sorted by original_id (insertion order).
|
|
"""
|
|
if not report_group_id:
|
|
return []
|
|
|
|
snapshot = await get_snapshot_by_token(session, share_token)
|
|
if not snapshot:
|
|
return []
|
|
|
|
reports = snapshot.snapshot_data.get("reports", [])
|
|
siblings = [r for r in reports if r.get("report_group_id") == report_group_id]
|
|
|
|
# Sort by original_id (ascending = insertion order ≈ created_at order)
|
|
siblings.sort(key=lambda r: r.get("original_id", 0))
|
|
|
|
return [
|
|
{"id": r.get("original_id"), "created_at": r.get("created_at")}
|
|
for r in siblings
|
|
]
|