feat: add clone tracking and history bootstrap for cloned chats

This commit is contained in:
CREDO23 2026-01-27 13:33:36 +02:00
parent a7145b2c63
commit 3c40c6e365
8 changed files with 225 additions and 70 deletions

View file

@ -1,4 +1,3 @@
{
"biome.configurationPath": "./surfsense_web/biome.json",
"deepscan.ignoreConfirmWarning": true
"biome.configurationPath": "./surfsense_web/biome.json"
}

View file

@ -0,0 +1,105 @@
"""Add public chat sharing and cloning features to new_chat_threads
Revision ID: 81
Revises: 80
Create Date: 2026-01-23
Adds columns for:
1. Public sharing via tokenized URLs (public_share_token, public_share_enabled)
2. Clone tracking for audit (cloned_from_thread_id, cloned_at)
3. History bootstrap flag for cloned chats (needs_history_bootstrap)
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "81"
down_revision: str | None = "80"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add public sharing and cloning columns to new_chat_threads."""
op.execute(
"""
ALTER TABLE new_chat_threads
ADD COLUMN IF NOT EXISTS public_share_token VARCHAR(64);
"""
)
op.execute(
"""
ALTER TABLE new_chat_threads
ADD COLUMN IF NOT EXISTS public_share_enabled BOOLEAN NOT NULL DEFAULT FALSE;
"""
)
op.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS ix_new_chat_threads_public_share_token
ON new_chat_threads(public_share_token)
WHERE public_share_token IS NOT NULL;
"""
)
op.execute(
"""
CREATE INDEX IF NOT EXISTS ix_new_chat_threads_public_share_enabled
ON new_chat_threads(public_share_enabled)
WHERE public_share_enabled = TRUE;
"""
)
op.execute(
"""
ALTER TABLE new_chat_threads
ADD COLUMN IF NOT EXISTS cloned_from_thread_id INTEGER
REFERENCES new_chat_threads(id) ON DELETE SET NULL;
"""
)
op.execute(
"""
ALTER TABLE new_chat_threads
ADD COLUMN IF NOT EXISTS cloned_at TIMESTAMP WITH TIME ZONE;
"""
)
op.execute(
"""
ALTER TABLE new_chat_threads
ADD COLUMN IF NOT EXISTS needs_history_bootstrap BOOLEAN NOT NULL DEFAULT FALSE;
"""
)
op.execute(
"""
CREATE INDEX IF NOT EXISTS ix_new_chat_threads_cloned_from_thread_id
ON new_chat_threads(cloned_from_thread_id)
WHERE cloned_from_thread_id IS NOT NULL;
"""
)
def downgrade() -> None:
"""Remove public sharing and cloning columns from new_chat_threads."""
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_cloned_from_thread_id")
op.execute(
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS needs_history_bootstrap"
)
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS cloned_at")
op.execute(
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS cloned_from_thread_id"
)
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_public_share_enabled")
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_public_share_token")
op.execute(
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS public_share_enabled"
)
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS public_share_token")

View file

@ -1,66 +0,0 @@
"""Add public sharing columns to new_chat_threads
Revision ID: 81
Revises: 80
Create Date: 2026-01-23
Adds public_share_token and public_share_enabled columns to enable
public sharing of chat threads via secure tokenized URLs.
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "81"
down_revision: str | None = "80"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add public sharing columns to new_chat_threads."""
# Add public_share_token column
op.execute(
"""
ALTER TABLE new_chat_threads
ADD COLUMN IF NOT EXISTS public_share_token VARCHAR(64);
"""
)
# Add public_share_enabled column with default false
op.execute(
"""
ALTER TABLE new_chat_threads
ADD COLUMN IF NOT EXISTS public_share_enabled BOOLEAN NOT NULL DEFAULT FALSE;
"""
)
# Add unique partial index on public_share_token (only non-null values)
op.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS ix_new_chat_threads_public_share_token
ON new_chat_threads(public_share_token)
WHERE public_share_token IS NOT NULL;
"""
)
# Add partial index on public_share_enabled for fast public chat queries
op.execute(
"""
CREATE INDEX IF NOT EXISTS ix_new_chat_threads_public_share_enabled
ON new_chat_threads(public_share_enabled)
WHERE public_share_enabled = TRUE;
"""
)
def downgrade() -> None:
"""Remove public sharing columns from new_chat_threads."""
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_public_share_enabled")
op.execute("DROP INDEX IF EXISTS ix_new_chat_threads_public_share_token")
op.execute(
"ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS public_share_enabled"
)
op.execute("ALTER TABLE new_chat_threads DROP COLUMN IF EXISTS public_share_token")

View file

@ -412,6 +412,25 @@ class NewChatThread(BaseModel, TimestampMixin):
server_default="false",
)
# Clone tracking - for audit and history bootstrap
cloned_from_thread_id = Column(
Integer,
ForeignKey("new_chat_threads.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
cloned_at = Column(
TIMESTAMP(timezone=True),
nullable=True,
)
# Flag to bootstrap LangGraph checkpointer with DB messages on first message
needs_history_bootstrap = Column(
Boolean,
nullable=False,
default=False,
server_default="false",
)
# Relationships
search_space = relationship("SearchSpace", back_populates="new_chat_threads")
created_by = relationship("User", back_populates="new_chat_threads")

View file

@ -1027,6 +1027,7 @@ async def handle_new_chat(
attachments=request.attachments,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
needs_history_bootstrap=thread.needs_history_bootstrap,
),
media_type="text/event-stream",
headers={
@ -1254,6 +1255,7 @@ async def regenerate_response(
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
checkpoint_id=target_checkpoint_id,
needs_history_bootstrap=thread.needs_history_bootstrap,
):
yield chunk
# If we get here, streaming completed successfully

View file

@ -4,6 +4,7 @@ Service layer for public chat sharing and cloning.
import re
import secrets
from datetime import UTC, datetime
from uuid import UUID
from fastapi import HTTPException
@ -283,6 +284,9 @@ async def clone_public_chat(
search_space_id=target_search_space_id,
created_by_id=user_id,
public_share_enabled=False,
cloned_from_thread_id=source_thread.id,
cloned_at=datetime.now(UTC),
needs_history_bootstrap=True,
)
session.add(new_thread)
await session.flush()

View file

@ -18,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.chat_deepagent import create_surfsense_deep_agent
from app.utils.content_utils import bootstrap_history_from_db
from app.agents.new_chat.checkpointer import get_checkpointer
from app.agents.new_chat.llm_config import (
AgentConfig,
@ -205,13 +206,13 @@ async def stream_new_chat(
mentioned_document_ids: list[int] | None = None,
mentioned_surfsense_doc_ids: list[int] | None = None,
checkpoint_id: str | None = None,
needs_history_bootstrap: bool = False,
) -> AsyncGenerator[str, None]:
"""
Stream chat responses from the new SurfSense deep agent.
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
Message history can be passed from the frontend for context.
Args:
user_query: The user's query
@ -221,6 +222,7 @@ async def stream_new_chat(
user_id: The current user's UUID string (for memory tools and session state)
llm_config_id: The LLM configuration ID (default: -1 for first global config)
attachments: Optional attachments with extracted content
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
mentioned_document_ids: Optional list of document IDs mentioned with @ in the chat
mentioned_surfsense_doc_ids: Optional list of SurfSense doc IDs mentioned with @ in the chat
checkpoint_id: Optional checkpoint ID to rewind/fork from (for edit/reload operations)
@ -305,9 +307,24 @@ async def stream_new_chat(
firecrawl_api_key=firecrawl_api_key, # Pass Firecrawl API key if configured
)
# Build input with message history from frontend
# Build input with message history
langchain_messages = []
# Bootstrap history for cloned chats (no LangGraph checkpoint exists yet)
if needs_history_bootstrap:
langchain_messages = await bootstrap_history_from_db(session, chat_id)
# Clear the flag so we don't bootstrap again on next message
from app.db import NewChatThread
thread_result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == chat_id)
)
thread = thread_result.scalars().first()
if thread:
thread.needs_history_bootstrap = False
await session.commit()
# Fetch mentioned documents if any (with chunks for proper citations)
mentioned_documents: list[Document] = []
if mentioned_document_ids:

View file

@ -0,0 +1,75 @@
"""
Utilities for working with message content.
Message content in new_chat_messages can be stored in various formats:
- String: Simple text content
- List: Array of content parts [{"type": "text", "text": "..."}, {"type": "tool-call", ...}]
- Dict: Single content object
These utilities help extract and transform content for different use cases.
"""
from langchain_core.messages import AIMessage, HumanMessage
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
def extract_text_content(content: str | dict | list) -> str:
"""Extract plain text content from various message formats."""
if isinstance(content, str):
return content
if isinstance(content, dict):
# Handle dict with 'text' key
if "text" in content:
return content["text"]
return str(content)
if isinstance(content, list):
# Handle list of parts (e.g., [{"type": "text", "text": "..."}])
texts = []
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
texts.append(part.get("text", ""))
elif isinstance(part, str):
texts.append(part)
return "\n".join(texts) if texts else ""
return ""
async def bootstrap_history_from_db(
session: AsyncSession,
thread_id: int,
) -> list[HumanMessage | AIMessage]:
"""
Load message history from database and convert to LangChain format.
Used for cloned chats where the LangGraph checkpointer has no state,
but we have messages in the database that should be used as context.
Args:
session: Database session
thread_id: The chat thread ID
Returns:
List of LangChain messages (HumanMessage/AIMessage)
"""
from app.db import NewChatMessage
result = await session.execute(
select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
)
db_messages = result.scalars().all()
langchain_messages: list[HumanMessage | AIMessage] = []
for msg in db_messages:
text_content = extract_text_content(msg.content)
if not text_content:
continue
if msg.role == "user":
langchain_messages.append(HumanMessage(content=text_content))
elif msg.role == "assistant":
langchain_messages.append(AIMessage(content=text_content))
return langchain_messages