mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-27 01:36:30 +02:00
Merge branch 'dev' of https://github.com/MODSetter/SurfSense into dev
This commit is contained in:
commit
8301e0169c
71 changed files with 2889 additions and 732 deletions
|
|
@ -32,7 +32,7 @@ PROVIDER_MAP = {
|
|||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ async def validate_llm_config(
|
|||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
|
|
@ -241,7 +241,7 @@ async def get_search_space_llm_instance(
|
|||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
|
|
@ -311,7 +311,7 @@ async def get_search_space_llm_instance(
|
|||
"GROQ": "groq",
|
||||
"COHERE": "cohere",
|
||||
"GOOGLE": "gemini",
|
||||
"OLLAMA": "ollama",
|
||||
"OLLAMA": "ollama_chat",
|
||||
"MISTRAL": "mistral",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"OPENROUTER": "openrouter",
|
||||
|
|
|
|||
|
|
@ -1,17 +1,35 @@
|
|||
"""
|
||||
Service layer for public chat sharing and cloning.
|
||||
Service layer for public chat sharing via immutable snapshots.
|
||||
|
||||
Key concepts:
|
||||
- Snapshots are frozen copies of a chat at a specific point in time
|
||||
- Content hash enables deduplication (same content = same URL)
|
||||
- Podcasts are embedded in snapshot_data for self-contained public views
|
||||
- Single-phase clone reads directly from snapshot_data
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.db import NewChatThread, User
|
||||
from app.db import (
|
||||
ChatVisibility,
|
||||
NewChatMessage,
|
||||
NewChatThread,
|
||||
Podcast,
|
||||
PodcastStatus,
|
||||
PublicChatSnapshot,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
)
|
||||
|
||||
UI_TOOLS = {
|
||||
"display_image",
|
||||
|
|
@ -100,20 +118,242 @@ async def get_author_display(
|
|||
return user_cache[author_id]
|
||||
|
||||
|
||||
async def toggle_public_share(
|
||||
# =============================================================================
|
||||
# Content Hashing
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def compute_content_hash(messages: list[dict]) -> str:
|
||||
"""
|
||||
Compute SHA-256 hash of message content for deduplication.
|
||||
|
||||
The hash is based on message IDs and content, ensuring that:
|
||||
- Same messages = same hash = same URL (deduplication)
|
||||
- Any change = different hash = new URL
|
||||
"""
|
||||
# Sort by message ID to ensure consistent ordering
|
||||
sorted_messages = sorted(messages, key=lambda m: m.get("id", 0))
|
||||
|
||||
# Create normalized representation
|
||||
normalized = []
|
||||
for msg in sorted_messages:
|
||||
normalized.append(
|
||||
{
|
||||
"id": msg.get("id"),
|
||||
"role": msg.get("role"),
|
||||
"content": msg.get("content"),
|
||||
}
|
||||
)
|
||||
|
||||
content_str = json.dumps(normalized, sort_keys=True, separators=(",", ":"))
|
||||
return hashlib.sha256(content_str.encode()).hexdigest()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Snapshot Creation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def create_snapshot(
|
||||
session: AsyncSession,
|
||||
thread_id: int,
|
||||
enabled: bool,
|
||||
user: User,
|
||||
base_url: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Enable or disable public sharing for a thread.
|
||||
Create a public snapshot of a chat thread.
|
||||
|
||||
Only the thread owner can toggle public sharing.
|
||||
When enabling, generates a new token if one doesn't exist.
|
||||
When disabling, keeps the token for potential re-enable.
|
||||
Returns existing snapshot if content unchanged (same hash).
|
||||
Returns new snapshot with unique URL if content changed.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(NewChatThread)
|
||||
.options(selectinload(NewChatThread.messages))
|
||||
.filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
if thread.created_by_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only the creator of this chat can create public snapshots",
|
||||
)
|
||||
|
||||
# Build snapshot data
|
||||
user_cache: dict[UUID, dict] = {}
|
||||
messages_data = []
|
||||
message_ids = []
|
||||
podcasts_data = []
|
||||
podcast_ids_seen: set[int] = set()
|
||||
|
||||
for msg in sorted(thread.messages, key=lambda m: m.created_at):
|
||||
author = await get_author_display(session, msg.author_id, user_cache)
|
||||
sanitized_content = sanitize_content_for_public(msg.content)
|
||||
|
||||
# Extract podcast references and update status to "ready" for completed podcasts
|
||||
if isinstance(sanitized_content, list):
|
||||
for part in sanitized_content:
|
||||
if (
|
||||
isinstance(part, dict)
|
||||
and part.get("type") == "tool-call"
|
||||
and part.get("toolName") == "generate_podcast"
|
||||
):
|
||||
result_data = part.get("result", {})
|
||||
podcast_id = result_data.get("podcast_id")
|
||||
if podcast_id and podcast_id not in podcast_ids_seen:
|
||||
podcast_info = await _get_podcast_for_snapshot(
|
||||
session, podcast_id
|
||||
)
|
||||
if podcast_info:
|
||||
podcasts_data.append(podcast_info)
|
||||
podcast_ids_seen.add(podcast_id)
|
||||
# Update status to "ready" so frontend renders PodcastPlayer
|
||||
part["result"] = {**result_data, "status": "ready"}
|
||||
|
||||
|
||||
messages_data.append(
|
||||
{
|
||||
"id": msg.id,
|
||||
"role": msg.role.value if hasattr(msg.role, "value") else str(msg.role),
|
||||
"content": sanitized_content,
|
||||
"author": author,
|
||||
"author_id": str(msg.author_id) if msg.author_id else None,
|
||||
"created_at": msg.created_at.isoformat() if msg.created_at else None,
|
||||
}
|
||||
)
|
||||
message_ids.append(msg.id)
|
||||
|
||||
if not messages_data:
|
||||
raise HTTPException(status_code=400, detail="Cannot share an empty chat")
|
||||
|
||||
# Compute content hash for deduplication
|
||||
content_hash = compute_content_hash(messages_data)
|
||||
|
||||
# Check if identical snapshot already exists
|
||||
existing_result = await session.execute(
|
||||
select(PublicChatSnapshot).filter(
|
||||
PublicChatSnapshot.thread_id == thread_id,
|
||||
PublicChatSnapshot.content_hash == content_hash,
|
||||
)
|
||||
)
|
||||
existing = existing_result.scalars().first()
|
||||
|
||||
if existing:
|
||||
# Return existing snapshot URL
|
||||
return {
|
||||
"snapshot_id": existing.id,
|
||||
"share_token": existing.share_token,
|
||||
"public_url": f"{base_url}/public/{existing.share_token}",
|
||||
"is_new": False,
|
||||
}
|
||||
|
||||
# Get thread author info
|
||||
thread_author = await get_author_display(session, thread.created_by_id, user_cache)
|
||||
|
||||
# Create snapshot data
|
||||
snapshot_data = {
|
||||
"title": thread.title,
|
||||
"snapshot_at": datetime.now(UTC).isoformat(),
|
||||
"author": thread_author,
|
||||
"messages": messages_data,
|
||||
"podcasts": podcasts_data,
|
||||
}
|
||||
|
||||
# Create new snapshot
|
||||
share_token = secrets.token_urlsafe(48)
|
||||
snapshot = PublicChatSnapshot(
|
||||
thread_id=thread_id,
|
||||
share_token=share_token,
|
||||
content_hash=content_hash,
|
||||
snapshot_data=snapshot_data,
|
||||
message_ids=message_ids,
|
||||
created_by_user_id=user.id,
|
||||
)
|
||||
session.add(snapshot)
|
||||
await session.commit()
|
||||
await session.refresh(snapshot)
|
||||
|
||||
return {
|
||||
"snapshot_id": snapshot.id,
|
||||
"share_token": snapshot.share_token,
|
||||
"public_url": f"{base_url}/public/{snapshot.share_token}",
|
||||
"is_new": True,
|
||||
}
|
||||
|
||||
|
||||
async def _get_podcast_for_snapshot(
|
||||
session: AsyncSession,
|
||||
podcast_id: int,
|
||||
) -> dict | None:
|
||||
"""Get podcast info for embedding in snapshot_data."""
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
podcast = result.scalars().first()
|
||||
|
||||
if not podcast or podcast.status != PodcastStatus.READY:
|
||||
return None
|
||||
|
||||
return {
|
||||
"original_id": podcast.id,
|
||||
"title": podcast.title,
|
||||
"transcript": podcast.podcast_transcript,
|
||||
"file_path": podcast.file_location,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Snapshot Retrieval
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_snapshot_by_token(
|
||||
session: AsyncSession,
|
||||
share_token: str,
|
||||
) -> PublicChatSnapshot | None:
|
||||
"""Get a snapshot by its share token."""
|
||||
result = await session.execute(
|
||||
select(PublicChatSnapshot).filter(
|
||||
PublicChatSnapshot.share_token == share_token
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def get_public_chat(
|
||||
session: AsyncSession,
|
||||
share_token: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get public chat data from a snapshot.
|
||||
|
||||
Returns sanitized content suitable for public viewing.
|
||||
"""
|
||||
snapshot = await get_snapshot_by_token(session, share_token)
|
||||
|
||||
if not snapshot:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
data = snapshot.snapshot_data
|
||||
|
||||
return {
|
||||
"thread": {
|
||||
"title": data.get("title", "Untitled"),
|
||||
"created_at": data.get("snapshot_at"),
|
||||
},
|
||||
"messages": data.get("messages", []),
|
||||
}
|
||||
|
||||
|
||||
async def list_snapshots_for_thread(
|
||||
session: AsyncSession,
|
||||
thread_id: int,
|
||||
user: User,
|
||||
base_url: str,
|
||||
) -> list[dict]:
|
||||
"""List all public snapshots for a thread."""
|
||||
# Verify ownership
|
||||
result = await session.execute(
|
||||
select(NewChatThread).filter(NewChatThread.id == thread_id)
|
||||
)
|
||||
|
|
@ -125,92 +365,101 @@ async def toggle_public_share(
|
|||
if thread.created_by_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only the creator of this chat can manage public sharing",
|
||||
detail="Only the creator can view snapshots",
|
||||
)
|
||||
|
||||
if enabled and not thread.public_share_token:
|
||||
thread.public_share_token = secrets.token_urlsafe(48)
|
||||
# Get snapshots
|
||||
result = await session.execute(
|
||||
select(PublicChatSnapshot)
|
||||
.filter(PublicChatSnapshot.thread_id == thread_id)
|
||||
.order_by(PublicChatSnapshot.created_at.desc())
|
||||
)
|
||||
snapshots = result.scalars().all()
|
||||
|
||||
thread.public_share_enabled = enabled
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(thread)
|
||||
|
||||
if enabled:
|
||||
return {
|
||||
"enabled": True,
|
||||
"public_url": f"{base_url}/public/{thread.public_share_token}",
|
||||
"share_token": thread.public_share_token,
|
||||
return [
|
||||
{
|
||||
"id": s.id,
|
||||
"share_token": s.share_token,
|
||||
"public_url": f"{base_url}/public/{s.share_token}",
|
||||
"created_at": s.created_at.isoformat() if s.created_at else None,
|
||||
"message_count": len(s.message_ids) if s.message_ids else 0,
|
||||
}
|
||||
|
||||
return {
|
||||
"enabled": False,
|
||||
"public_url": None,
|
||||
"share_token": None,
|
||||
}
|
||||
for s in snapshots
|
||||
]
|
||||
|
||||
|
||||
async def get_public_chat(
|
||||
# =============================================================================
|
||||
# Snapshot Deletion
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def delete_snapshot(
|
||||
session: AsyncSession,
|
||||
share_token: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get a public chat by share token.
|
||||
|
||||
Returns sanitized content suitable for public viewing.
|
||||
"""
|
||||
thread_id: int,
|
||||
snapshot_id: int,
|
||||
user: User,
|
||||
) -> bool:
|
||||
"""Delete a specific snapshot. Only thread owner can delete."""
|
||||
# Get snapshot with thread
|
||||
result = await session.execute(
|
||||
select(NewChatThread)
|
||||
.options(selectinload(NewChatThread.messages))
|
||||
select(PublicChatSnapshot)
|
||||
.options(selectinload(PublicChatSnapshot.thread))
|
||||
.filter(
|
||||
NewChatThread.public_share_token == share_token,
|
||||
NewChatThread.public_share_enabled.is_(True),
|
||||
PublicChatSnapshot.id == snapshot_id,
|
||||
PublicChatSnapshot.thread_id == thread_id,
|
||||
)
|
||||
)
|
||||
thread = result.scalars().first()
|
||||
snapshot = result.scalars().first()
|
||||
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
if not snapshot:
|
||||
raise HTTPException(status_code=404, detail="Snapshot not found")
|
||||
|
||||
user_cache: dict[UUID, dict] = {}
|
||||
|
||||
messages = []
|
||||
for msg in sorted(thread.messages, key=lambda m: m.created_at):
|
||||
author = await get_author_display(session, msg.author_id, user_cache)
|
||||
sanitized_content = sanitize_content_for_public(msg.content)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": sanitized_content,
|
||||
"author": author,
|
||||
"created_at": msg.created_at,
|
||||
}
|
||||
if snapshot.thread.created_by_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only the creator can delete snapshots",
|
||||
)
|
||||
|
||||
return {
|
||||
"thread": {
|
||||
"title": thread.title,
|
||||
"created_at": thread.created_at,
|
||||
},
|
||||
"messages": messages,
|
||||
}
|
||||
await session.delete(snapshot)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def get_thread_by_share_token(
|
||||
session: AsyncSession,
|
||||
share_token: str,
|
||||
) -> NewChatThread | None:
|
||||
"""Get a thread by its public share token if sharing is enabled."""
|
||||
result = await session.execute(
|
||||
select(NewChatThread)
|
||||
.options(selectinload(NewChatThread.messages))
|
||||
.filter(
|
||||
NewChatThread.public_share_token == share_token,
|
||||
NewChatThread.public_share_enabled.is_(True),
|
||||
async def delete_affected_snapshots(
|
||||
session: AsyncSession, # noqa: ARG001 - kept for API compatibility
|
||||
thread_id: int,
|
||||
message_ids: list[int],
|
||||
) -> int:
|
||||
"""
|
||||
Delete snapshots that contain any of the given message IDs.
|
||||
|
||||
Called when messages are edited/deleted/regenerated.
|
||||
Uses independent session to work reliably in streaming response cleanup.
|
||||
"""
|
||||
if not message_ids:
|
||||
return 0
|
||||
|
||||
from sqlalchemy.dialects.postgresql import array
|
||||
|
||||
from app.db import async_session_maker
|
||||
|
||||
async with async_session_maker() as independent_session:
|
||||
result = await independent_session.execute(
|
||||
delete(PublicChatSnapshot)
|
||||
.where(PublicChatSnapshot.thread_id == thread_id)
|
||||
.where(PublicChatSnapshot.message_ids.op("&&")(array(message_ids)))
|
||||
.returning(PublicChatSnapshot.id)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
deleted_ids = result.scalars().all()
|
||||
await independent_session.commit()
|
||||
|
||||
return len(deleted_ids)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cloning from Snapshot
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_user_default_search_space(
|
||||
|
|
@ -222,8 +471,6 @@ async def get_user_default_search_space(
|
|||
|
||||
Returns the first search space where user is owner, or None if not found.
|
||||
"""
|
||||
from app.db import SearchSpaceMembership
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSpaceMembership)
|
||||
.filter(
|
||||
|
|
@ -240,140 +487,153 @@ async def get_user_default_search_space(
|
|||
return None
|
||||
|
||||
|
||||
async def complete_clone_content(
|
||||
async def clone_from_snapshot(
|
||||
session: AsyncSession,
|
||||
target_thread: NewChatThread,
|
||||
source_thread_id: int,
|
||||
target_search_space_id: int,
|
||||
) -> int:
|
||||
share_token: str,
|
||||
user: User,
|
||||
) -> dict:
|
||||
"""
|
||||
Copy messages and podcasts from source thread to target thread.
|
||||
|
||||
Sets clone_pending=False and needs_history_bootstrap=True when done.
|
||||
Returns the number of messages copied.
|
||||
Creates thread and copies messages from snapshot_data.
|
||||
When encountering generate_podcast tool-calls, creates cloned podcast records
|
||||
and updates the podcast_id references inline.
|
||||
Returns the new thread info.
|
||||
"""
|
||||
from app.db import NewChatMessage
|
||||
import copy
|
||||
|
||||
result = await session.execute(
|
||||
select(NewChatThread)
|
||||
.options(selectinload(NewChatThread.messages))
|
||||
.filter(NewChatThread.id == source_thread_id)
|
||||
snapshot = await get_snapshot_by_token(session, share_token)
|
||||
|
||||
if not snapshot:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Chat not found or no longer public"
|
||||
)
|
||||
|
||||
target_search_space_id = await get_user_default_search_space(session, user.id)
|
||||
|
||||
if target_search_space_id is None:
|
||||
raise HTTPException(status_code=400, detail="No search space found for user")
|
||||
|
||||
data = snapshot.snapshot_data
|
||||
messages_data = data.get("messages", [])
|
||||
podcasts_lookup = {p.get("original_id"): p for p in data.get("podcasts", [])}
|
||||
|
||||
new_thread = NewChatThread(
|
||||
title=data.get("title", "Cloned Chat"),
|
||||
archived=False,
|
||||
visibility=ChatVisibility.PRIVATE,
|
||||
search_space_id=target_search_space_id,
|
||||
created_by_id=user.id,
|
||||
cloned_from_thread_id=snapshot.thread_id,
|
||||
cloned_from_snapshot_id=snapshot.id,
|
||||
cloned_at=datetime.now(UTC),
|
||||
needs_history_bootstrap=True,
|
||||
)
|
||||
source_thread = result.scalars().first()
|
||||
session.add(new_thread)
|
||||
await session.flush()
|
||||
|
||||
if not source_thread:
|
||||
raise ValueError("Source thread not found")
|
||||
podcast_id_mapping: dict[int, int] = {}
|
||||
|
||||
podcast_id_map: dict[int, int] = {}
|
||||
message_count = 0
|
||||
# Check which authors from snapshot still exist in DB
|
||||
author_ids_from_snapshot: set[UUID] = set()
|
||||
for msg_data in messages_data:
|
||||
if author_str := msg_data.get("author_id"):
|
||||
try:
|
||||
author_ids_from_snapshot.add(UUID(author_str))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
for msg in sorted(source_thread.messages, key=lambda m: m.created_at):
|
||||
new_content = sanitize_content_for_public(msg.content)
|
||||
existing_authors: set[UUID] = set()
|
||||
if author_ids_from_snapshot:
|
||||
result = await session.execute(
|
||||
select(User.id).where(User.id.in_(author_ids_from_snapshot))
|
||||
)
|
||||
existing_authors = {row[0] for row in result.fetchall()}
|
||||
|
||||
if isinstance(new_content, list):
|
||||
for part in new_content:
|
||||
for msg_data in messages_data:
|
||||
role = msg_data.get("role", "user")
|
||||
|
||||
# Use original author if exists, otherwise None
|
||||
author_id = None
|
||||
if author_str := msg_data.get("author_id"):
|
||||
try:
|
||||
parsed_id = UUID(author_str)
|
||||
if parsed_id in existing_authors:
|
||||
author_id = parsed_id
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
content = copy.deepcopy(msg_data.get("content", []))
|
||||
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if (
|
||||
isinstance(part, dict)
|
||||
and part.get("type") == "tool-call"
|
||||
and part.get("toolName") == "generate_podcast"
|
||||
):
|
||||
result_data = part.get("result", {})
|
||||
old_podcast_id = result_data.get("podcast_id")
|
||||
if old_podcast_id and old_podcast_id not in podcast_id_map:
|
||||
new_podcast_id = await _clone_podcast(
|
||||
session,
|
||||
old_podcast_id,
|
||||
target_search_space_id,
|
||||
target_thread.id,
|
||||
)
|
||||
if new_podcast_id:
|
||||
podcast_id_map[old_podcast_id] = new_podcast_id
|
||||
result = part.get("result", {})
|
||||
old_podcast_id = result.get("podcast_id")
|
||||
|
||||
if old_podcast_id and old_podcast_id in podcast_id_map:
|
||||
result_data["podcast_id"] = podcast_id_map[old_podcast_id]
|
||||
elif old_podcast_id:
|
||||
# Podcast couldn't be cloned (not ready), remove reference
|
||||
result_data.pop("podcast_id", None)
|
||||
if old_podcast_id and old_podcast_id not in podcast_id_mapping:
|
||||
podcast_info = podcasts_lookup.get(old_podcast_id)
|
||||
if podcast_info:
|
||||
new_podcast = Podcast(
|
||||
title=podcast_info.get("title", "Cloned Podcast"),
|
||||
podcast_transcript=podcast_info.get("transcript"),
|
||||
file_location=podcast_info.get("file_path"),
|
||||
status=PodcastStatus.READY,
|
||||
search_space_id=target_search_space_id,
|
||||
thread_id=new_thread.id,
|
||||
)
|
||||
session.add(new_podcast)
|
||||
await session.flush()
|
||||
podcast_id_mapping[old_podcast_id] = new_podcast.id
|
||||
|
||||
if old_podcast_id and old_podcast_id in podcast_id_mapping:
|
||||
part["result"] = {
|
||||
**result,
|
||||
"podcast_id": podcast_id_mapping[old_podcast_id],
|
||||
}
|
||||
|
||||
new_message = NewChatMessage(
|
||||
thread_id=target_thread.id,
|
||||
role=msg.role,
|
||||
content=new_content,
|
||||
author_id=msg.author_id,
|
||||
created_at=msg.created_at,
|
||||
thread_id=new_thread.id,
|
||||
role=role,
|
||||
content=content,
|
||||
author_id=author_id,
|
||||
)
|
||||
session.add(new_message)
|
||||
message_count += 1
|
||||
|
||||
target_thread.clone_pending = False
|
||||
target_thread.needs_history_bootstrap = True
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(new_thread)
|
||||
|
||||
return message_count
|
||||
return {
|
||||
"thread_id": new_thread.id,
|
||||
"search_space_id": target_search_space_id,
|
||||
}
|
||||
|
||||
|
||||
async def _clone_podcast(
|
||||
async def get_snapshot_podcast(
|
||||
session: AsyncSession,
|
||||
share_token: str,
|
||||
podcast_id: int,
|
||||
target_search_space_id: int,
|
||||
target_thread_id: int,
|
||||
) -> int | None:
|
||||
"""Clone a podcast record and its audio file. Only clones ready podcasts."""
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
) -> dict | None:
|
||||
"""
|
||||
Get podcast info from a snapshot by original podcast ID.
|
||||
|
||||
from app.db import Podcast, PodcastStatus
|
||||
Used for streaming podcast audio from public view.
|
||||
Looks up the podcast by its original_id in the snapshot's podcasts array.
|
||||
"""
|
||||
snapshot = await get_snapshot_by_token(session, share_token)
|
||||
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
original = result.scalars().first()
|
||||
if not original or original.status != PodcastStatus.READY:
|
||||
if not snapshot:
|
||||
return None
|
||||
|
||||
new_file_path = None
|
||||
if original.file_location:
|
||||
original_path = Path(original.file_location)
|
||||
if original_path.exists():
|
||||
new_filename = f"{uuid.uuid4()}_podcast.mp3"
|
||||
new_dir = Path("podcasts")
|
||||
new_dir.mkdir(parents=True, exist_ok=True)
|
||||
new_file_path = str(new_dir / new_filename)
|
||||
shutil.copy2(original.file_location, new_file_path)
|
||||
podcasts = snapshot.snapshot_data.get("podcasts", [])
|
||||
|
||||
new_podcast = Podcast(
|
||||
title=original.title,
|
||||
podcast_transcript=original.podcast_transcript,
|
||||
file_location=new_file_path,
|
||||
status=PodcastStatus.READY,
|
||||
search_space_id=target_search_space_id,
|
||||
thread_id=target_thread_id,
|
||||
)
|
||||
session.add(new_podcast)
|
||||
await session.flush()
|
||||
# Find podcast by original_id
|
||||
for podcast in podcasts:
|
||||
if podcast.get("original_id") == podcast_id:
|
||||
return podcast
|
||||
|
||||
return new_podcast.id
|
||||
|
||||
|
||||
async def is_podcast_publicly_accessible(
|
||||
session: AsyncSession,
|
||||
podcast_id: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a podcast belongs to a publicly shared thread.
|
||||
|
||||
Uses the thread_id foreign key for efficient lookup.
|
||||
"""
|
||||
from app.db import Podcast
|
||||
|
||||
result = await session.execute(
|
||||
select(Podcast)
|
||||
.options(selectinload(Podcast.thread))
|
||||
.filter(Podcast.id == podcast_id)
|
||||
)
|
||||
podcast = result.scalars().first()
|
||||
|
||||
if not podcast or not podcast.thread:
|
||||
return False
|
||||
|
||||
return podcast.thread.public_share_enabled
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue