This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-02-01 18:02:27 -08:00
commit 8301e0169c
71 changed files with 2889 additions and 732 deletions

View file

@ -32,7 +32,7 @@ PROVIDER_MAP = {
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",

View file

@ -94,7 +94,7 @@ async def validate_llm_config(
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
@ -241,7 +241,7 @@ async def get_search_space_llm_instance(
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",
@ -311,7 +311,7 @@ async def get_search_space_llm_instance(
"GROQ": "groq",
"COHERE": "cohere",
"GOOGLE": "gemini",
"OLLAMA": "ollama",
"OLLAMA": "ollama_chat",
"MISTRAL": "mistral",
"AZURE_OPENAI": "azure",
"OPENROUTER": "openrouter",

View file

@ -1,17 +1,35 @@
"""
Service layer for public chat sharing and cloning.
Service layer for public chat sharing via immutable snapshots.
Key concepts:
- Snapshots are frozen copies of a chat at a specific point in time
- Content hash enables deduplication (same content = same URL)
- Podcasts are embedded in snapshot_data for self-contained public views
- Single-phase clone reads directly from snapshot_data
"""
import hashlib
import json
import re
import secrets
from datetime import UTC, datetime
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.db import NewChatThread, User
from app.db import (
ChatVisibility,
NewChatMessage,
NewChatThread,
Podcast,
PodcastStatus,
PublicChatSnapshot,
SearchSpaceMembership,
User,
)
UI_TOOLS = {
"display_image",
@ -100,20 +118,242 @@ async def get_author_display(
return user_cache[author_id]
async def toggle_public_share(
# =============================================================================
# Content Hashing
# =============================================================================
def compute_content_hash(messages: list[dict]) -> str:
"""
Compute SHA-256 hash of message content for deduplication.
The hash is based on message IDs and content, ensuring that:
- Same messages = same hash = same URL (deduplication)
- Any change = different hash = new URL
"""
# Sort by message ID to ensure consistent ordering
sorted_messages = sorted(messages, key=lambda m: m.get("id", 0))
# Create normalized representation
normalized = []
for msg in sorted_messages:
normalized.append(
{
"id": msg.get("id"),
"role": msg.get("role"),
"content": msg.get("content"),
}
)
content_str = json.dumps(normalized, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(content_str.encode()).hexdigest()
# =============================================================================
# Snapshot Creation
# =============================================================================
async def create_snapshot(
session: AsyncSession,
thread_id: int,
enabled: bool,
user: User,
base_url: str,
) -> dict:
"""
Enable or disable public sharing for a thread.
Create a public snapshot of a chat thread.
Only the thread owner can toggle public sharing.
When enabling, generates a new token if one doesn't exist.
When disabling, keeps the token for potential re-enable.
Returns existing snapshot if content unchanged (same hash).
Returns new snapshot with unique URL if content changed.
"""
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
if thread.created_by_id != user.id:
raise HTTPException(
status_code=403,
detail="Only the creator of this chat can create public snapshots",
)
# Build snapshot data
user_cache: dict[UUID, dict] = {}
messages_data = []
message_ids = []
podcasts_data = []
podcast_ids_seen: set[int] = set()
for msg in sorted(thread.messages, key=lambda m: m.created_at):
author = await get_author_display(session, msg.author_id, user_cache)
sanitized_content = sanitize_content_for_public(msg.content)
# Extract podcast references and update status to "ready" for completed podcasts
if isinstance(sanitized_content, list):
for part in sanitized_content:
if (
isinstance(part, dict)
and part.get("type") == "tool-call"
and part.get("toolName") == "generate_podcast"
):
result_data = part.get("result", {})
podcast_id = result_data.get("podcast_id")
if podcast_id and podcast_id not in podcast_ids_seen:
podcast_info = await _get_podcast_for_snapshot(
session, podcast_id
)
if podcast_info:
podcasts_data.append(podcast_info)
podcast_ids_seen.add(podcast_id)
# Update status to "ready" so frontend renders PodcastPlayer
part["result"] = {**result_data, "status": "ready"}
messages_data.append(
{
"id": msg.id,
"role": msg.role.value if hasattr(msg.role, "value") else str(msg.role),
"content": sanitized_content,
"author": author,
"author_id": str(msg.author_id) if msg.author_id else None,
"created_at": msg.created_at.isoformat() if msg.created_at else None,
}
)
message_ids.append(msg.id)
if not messages_data:
raise HTTPException(status_code=400, detail="Cannot share an empty chat")
# Compute content hash for deduplication
content_hash = compute_content_hash(messages_data)
# Check if identical snapshot already exists
existing_result = await session.execute(
select(PublicChatSnapshot).filter(
PublicChatSnapshot.thread_id == thread_id,
PublicChatSnapshot.content_hash == content_hash,
)
)
existing = existing_result.scalars().first()
if existing:
# Return existing snapshot URL
return {
"snapshot_id": existing.id,
"share_token": existing.share_token,
"public_url": f"{base_url}/public/{existing.share_token}",
"is_new": False,
}
# Get thread author info
thread_author = await get_author_display(session, thread.created_by_id, user_cache)
# Create snapshot data
snapshot_data = {
"title": thread.title,
"snapshot_at": datetime.now(UTC).isoformat(),
"author": thread_author,
"messages": messages_data,
"podcasts": podcasts_data,
}
# Create new snapshot
share_token = secrets.token_urlsafe(48)
snapshot = PublicChatSnapshot(
thread_id=thread_id,
share_token=share_token,
content_hash=content_hash,
snapshot_data=snapshot_data,
message_ids=message_ids,
created_by_user_id=user.id,
)
session.add(snapshot)
await session.commit()
await session.refresh(snapshot)
return {
"snapshot_id": snapshot.id,
"share_token": snapshot.share_token,
"public_url": f"{base_url}/public/{snapshot.share_token}",
"is_new": True,
}
async def _get_podcast_for_snapshot(
session: AsyncSession,
podcast_id: int,
) -> dict | None:
"""Get podcast info for embedding in snapshot_data."""
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
podcast = result.scalars().first()
if not podcast or podcast.status != PodcastStatus.READY:
return None
return {
"original_id": podcast.id,
"title": podcast.title,
"transcript": podcast.podcast_transcript,
"file_path": podcast.file_location,
}
# =============================================================================
# Snapshot Retrieval
# =============================================================================
async def get_snapshot_by_token(
session: AsyncSession,
share_token: str,
) -> PublicChatSnapshot | None:
"""Get a snapshot by its share token."""
result = await session.execute(
select(PublicChatSnapshot).filter(
PublicChatSnapshot.share_token == share_token
)
)
return result.scalars().first()
async def get_public_chat(
session: AsyncSession,
share_token: str,
) -> dict:
"""
Get public chat data from a snapshot.
Returns sanitized content suitable for public viewing.
"""
snapshot = await get_snapshot_by_token(session, share_token)
if not snapshot:
raise HTTPException(status_code=404, detail="Not found")
data = snapshot.snapshot_data
return {
"thread": {
"title": data.get("title", "Untitled"),
"created_at": data.get("snapshot_at"),
},
"messages": data.get("messages", []),
}
async def list_snapshots_for_thread(
session: AsyncSession,
thread_id: int,
user: User,
base_url: str,
) -> list[dict]:
"""List all public snapshots for a thread."""
# Verify ownership
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
@ -125,92 +365,101 @@ async def toggle_public_share(
if thread.created_by_id != user.id:
raise HTTPException(
status_code=403,
detail="Only the creator of this chat can manage public sharing",
detail="Only the creator can view snapshots",
)
if enabled and not thread.public_share_token:
thread.public_share_token = secrets.token_urlsafe(48)
# Get snapshots
result = await session.execute(
select(PublicChatSnapshot)
.filter(PublicChatSnapshot.thread_id == thread_id)
.order_by(PublicChatSnapshot.created_at.desc())
)
snapshots = result.scalars().all()
thread.public_share_enabled = enabled
await session.commit()
await session.refresh(thread)
if enabled:
return {
"enabled": True,
"public_url": f"{base_url}/public/{thread.public_share_token}",
"share_token": thread.public_share_token,
return [
{
"id": s.id,
"share_token": s.share_token,
"public_url": f"{base_url}/public/{s.share_token}",
"created_at": s.created_at.isoformat() if s.created_at else None,
"message_count": len(s.message_ids) if s.message_ids else 0,
}
return {
"enabled": False,
"public_url": None,
"share_token": None,
}
for s in snapshots
]
async def get_public_chat(
# =============================================================================
# Snapshot Deletion
# =============================================================================
async def delete_snapshot(
session: AsyncSession,
share_token: str,
) -> dict:
"""
Get a public chat by share token.
Returns sanitized content suitable for public viewing.
"""
thread_id: int,
snapshot_id: int,
user: User,
) -> bool:
"""Delete a specific snapshot. Only thread owner can delete."""
# Get snapshot with thread
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
select(PublicChatSnapshot)
.options(selectinload(PublicChatSnapshot.thread))
.filter(
NewChatThread.public_share_token == share_token,
NewChatThread.public_share_enabled.is_(True),
PublicChatSnapshot.id == snapshot_id,
PublicChatSnapshot.thread_id == thread_id,
)
)
thread = result.scalars().first()
snapshot = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Not found")
if not snapshot:
raise HTTPException(status_code=404, detail="Snapshot not found")
user_cache: dict[UUID, dict] = {}
messages = []
for msg in sorted(thread.messages, key=lambda m: m.created_at):
author = await get_author_display(session, msg.author_id, user_cache)
sanitized_content = sanitize_content_for_public(msg.content)
messages.append(
{
"role": msg.role,
"content": sanitized_content,
"author": author,
"created_at": msg.created_at,
}
if snapshot.thread.created_by_id != user.id:
raise HTTPException(
status_code=403,
detail="Only the creator can delete snapshots",
)
return {
"thread": {
"title": thread.title,
"created_at": thread.created_at,
},
"messages": messages,
}
await session.delete(snapshot)
await session.commit()
return True
async def get_thread_by_share_token(
session: AsyncSession,
share_token: str,
) -> NewChatThread | None:
"""Get a thread by its public share token if sharing is enabled."""
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(
NewChatThread.public_share_token == share_token,
NewChatThread.public_share_enabled.is_(True),
async def delete_affected_snapshots(
session: AsyncSession, # noqa: ARG001 - kept for API compatibility
thread_id: int,
message_ids: list[int],
) -> int:
"""
Delete snapshots that contain any of the given message IDs.
Called when messages are edited/deleted/regenerated.
Uses independent session to work reliably in streaming response cleanup.
"""
if not message_ids:
return 0
from sqlalchemy.dialects.postgresql import array
from app.db import async_session_maker
async with async_session_maker() as independent_session:
result = await independent_session.execute(
delete(PublicChatSnapshot)
.where(PublicChatSnapshot.thread_id == thread_id)
.where(PublicChatSnapshot.message_ids.op("&&")(array(message_ids)))
.returning(PublicChatSnapshot.id)
)
)
return result.scalars().first()
deleted_ids = result.scalars().all()
await independent_session.commit()
return len(deleted_ids)
# =============================================================================
# Cloning from Snapshot
# =============================================================================
async def get_user_default_search_space(
@ -222,8 +471,6 @@ async def get_user_default_search_space(
Returns the first search space where user is owner, or None if not found.
"""
from app.db import SearchSpaceMembership
result = await session.execute(
select(SearchSpaceMembership)
.filter(
@ -240,140 +487,153 @@ async def get_user_default_search_space(
return None
async def complete_clone_content(
async def clone_from_snapshot(
session: AsyncSession,
target_thread: NewChatThread,
source_thread_id: int,
target_search_space_id: int,
) -> int:
share_token: str,
user: User,
) -> dict:
"""
Copy messages and podcasts from source thread to target thread.
Sets clone_pending=False and needs_history_bootstrap=True when done.
Returns the number of messages copied.
Creates thread and copies messages from snapshot_data.
When encountering generate_podcast tool-calls, creates cloned podcast records
and updates the podcast_id references inline.
Returns the new thread info.
"""
from app.db import NewChatMessage
import copy
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == source_thread_id)
snapshot = await get_snapshot_by_token(session, share_token)
if not snapshot:
raise HTTPException(
status_code=404, detail="Chat not found or no longer public"
)
target_search_space_id = await get_user_default_search_space(session, user.id)
if target_search_space_id is None:
raise HTTPException(status_code=400, detail="No search space found for user")
data = snapshot.snapshot_data
messages_data = data.get("messages", [])
podcasts_lookup = {p.get("original_id"): p for p in data.get("podcasts", [])}
new_thread = NewChatThread(
title=data.get("title", "Cloned Chat"),
archived=False,
visibility=ChatVisibility.PRIVATE,
search_space_id=target_search_space_id,
created_by_id=user.id,
cloned_from_thread_id=snapshot.thread_id,
cloned_from_snapshot_id=snapshot.id,
cloned_at=datetime.now(UTC),
needs_history_bootstrap=True,
)
source_thread = result.scalars().first()
session.add(new_thread)
await session.flush()
if not source_thread:
raise ValueError("Source thread not found")
podcast_id_mapping: dict[int, int] = {}
podcast_id_map: dict[int, int] = {}
message_count = 0
# Check which authors from snapshot still exist in DB
author_ids_from_snapshot: set[UUID] = set()
for msg_data in messages_data:
if author_str := msg_data.get("author_id"):
try:
author_ids_from_snapshot.add(UUID(author_str))
except (ValueError, TypeError):
pass
for msg in sorted(source_thread.messages, key=lambda m: m.created_at):
new_content = sanitize_content_for_public(msg.content)
existing_authors: set[UUID] = set()
if author_ids_from_snapshot:
result = await session.execute(
select(User.id).where(User.id.in_(author_ids_from_snapshot))
)
existing_authors = {row[0] for row in result.fetchall()}
if isinstance(new_content, list):
for part in new_content:
for msg_data in messages_data:
role = msg_data.get("role", "user")
# Use original author if exists, otherwise None
author_id = None
if author_str := msg_data.get("author_id"):
try:
parsed_id = UUID(author_str)
if parsed_id in existing_authors:
author_id = parsed_id
except (ValueError, TypeError):
pass
content = copy.deepcopy(msg_data.get("content", []))
if isinstance(content, list):
for part in content:
if (
isinstance(part, dict)
and part.get("type") == "tool-call"
and part.get("toolName") == "generate_podcast"
):
result_data = part.get("result", {})
old_podcast_id = result_data.get("podcast_id")
if old_podcast_id and old_podcast_id not in podcast_id_map:
new_podcast_id = await _clone_podcast(
session,
old_podcast_id,
target_search_space_id,
target_thread.id,
)
if new_podcast_id:
podcast_id_map[old_podcast_id] = new_podcast_id
result = part.get("result", {})
old_podcast_id = result.get("podcast_id")
if old_podcast_id and old_podcast_id in podcast_id_map:
result_data["podcast_id"] = podcast_id_map[old_podcast_id]
elif old_podcast_id:
# Podcast couldn't be cloned (not ready), remove reference
result_data.pop("podcast_id", None)
if old_podcast_id and old_podcast_id not in podcast_id_mapping:
podcast_info = podcasts_lookup.get(old_podcast_id)
if podcast_info:
new_podcast = Podcast(
title=podcast_info.get("title", "Cloned Podcast"),
podcast_transcript=podcast_info.get("transcript"),
file_location=podcast_info.get("file_path"),
status=PodcastStatus.READY,
search_space_id=target_search_space_id,
thread_id=new_thread.id,
)
session.add(new_podcast)
await session.flush()
podcast_id_mapping[old_podcast_id] = new_podcast.id
if old_podcast_id and old_podcast_id in podcast_id_mapping:
part["result"] = {
**result,
"podcast_id": podcast_id_mapping[old_podcast_id],
}
new_message = NewChatMessage(
thread_id=target_thread.id,
role=msg.role,
content=new_content,
author_id=msg.author_id,
created_at=msg.created_at,
thread_id=new_thread.id,
role=role,
content=content,
author_id=author_id,
)
session.add(new_message)
message_count += 1
target_thread.clone_pending = False
target_thread.needs_history_bootstrap = True
await session.commit()
await session.refresh(new_thread)
return message_count
return {
"thread_id": new_thread.id,
"search_space_id": target_search_space_id,
}
async def _clone_podcast(
async def get_snapshot_podcast(
session: AsyncSession,
share_token: str,
podcast_id: int,
target_search_space_id: int,
target_thread_id: int,
) -> int | None:
"""Clone a podcast record and its audio file. Only clones ready podcasts."""
import shutil
import uuid
from pathlib import Path
) -> dict | None:
"""
Get podcast info from a snapshot by original podcast ID.
from app.db import Podcast, PodcastStatus
Used for streaming podcast audio from public view.
Looks up the podcast by its original_id in the snapshot's podcasts array.
"""
snapshot = await get_snapshot_by_token(session, share_token)
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
original = result.scalars().first()
if not original or original.status != PodcastStatus.READY:
if not snapshot:
return None
new_file_path = None
if original.file_location:
original_path = Path(original.file_location)
if original_path.exists():
new_filename = f"{uuid.uuid4()}_podcast.mp3"
new_dir = Path("podcasts")
new_dir.mkdir(parents=True, exist_ok=True)
new_file_path = str(new_dir / new_filename)
shutil.copy2(original.file_location, new_file_path)
podcasts = snapshot.snapshot_data.get("podcasts", [])
new_podcast = Podcast(
title=original.title,
podcast_transcript=original.podcast_transcript,
file_location=new_file_path,
status=PodcastStatus.READY,
search_space_id=target_search_space_id,
thread_id=target_thread_id,
)
session.add(new_podcast)
await session.flush()
# Find podcast by original_id
for podcast in podcasts:
if podcast.get("original_id") == podcast_id:
return podcast
return new_podcast.id
async def is_podcast_publicly_accessible(
session: AsyncSession,
podcast_id: int,
) -> bool:
"""
Check if a podcast belongs to a publicly shared thread.
Uses the thread_id foreign key for efficient lookup.
"""
from app.db import Podcast
result = await session.execute(
select(Podcast)
.options(selectinload(Podcast.thread))
.filter(Podcast.id == podcast_id)
)
podcast = result.scalars().first()
if not podcast or not podcast.thread:
return False
return podcast.thread.public_share_enabled
return None