SurfSense/surfsense_backend/app/routes/new_chat_routes.py
2026-05-02 14:34:23 -07:00

2240 lines
81 KiB
Python

"""
Routes for the new chat feature with assistant-ui integration.
These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
- GET /threads - List threads for sidebar (ThreadListPrimitive)
- POST /threads - Create a new thread
- GET /threads/{thread_id} - Get thread with messages (load)
- PUT /threads/{thread_id} - Update thread (rename, archive)
- DELETE /threads/{thread_id} - Delete thread
- POST /threads/{thread_id}/messages - Append message
"""
import asyncio
import json
import logging
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from sqlalchemy import func, or_
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.filesystem_selection import (
ClientPlatform,
FilesystemMode,
FilesystemSelection,
LocalFilesystemMount,
)
from app.agents.new_chat.middleware.busy_mutex import (
get_cancel_state,
is_cancel_requested,
manager,
request_cancel,
)
from app.config import config
from app.db import (
ChatComment,
ChatVisibility,
NewChatMessage,
NewChatMessageRole,
NewChatThread,
Permission,
SearchSpace,
User,
get_async_session,
shielded_async_session,
)
from app.schemas.new_chat import (
AgentToolInfo,
CancelActiveTurnResponse,
LocalFilesystemMountPayload,
NewChatMessageRead,
NewChatRequest,
NewChatThreadCreate,
NewChatThreadRead,
NewChatThreadUpdate,
NewChatThreadVisibilityUpdate,
NewChatThreadWithMessages,
PublicChatSnapshotCreateResponse,
PublicChatSnapshotListResponse,
RegenerateRequest,
ResumeRequest,
ThreadHistoryLoadResponse,
ThreadListItem,
ThreadListResponse,
TokenUsageSummary,
TurnStatusResponse,
)
from app.services.token_tracking_service import record_token_usage
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
from app.utils.rbac import check_permission
from app.utils.user_message_multimodal import (
split_langchain_human_content,
split_persisted_user_content_parts,
)
_logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task] = set()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
TURN_CANCELLING_MAX_DELAY_MS = 1500
router = APIRouter()
def _resolve_filesystem_selection(
*,
mode: str,
client_platform: str,
local_mounts: list[LocalFilesystemMountPayload] | None,
) -> FilesystemSelection:
"""Validate and normalize filesystem mode settings from request payload."""
try:
resolved_mode = FilesystemMode(mode)
except ValueError as exc:
raise HTTPException(status_code=400, detail="Invalid filesystem_mode") from exc
try:
resolved_platform = ClientPlatform(client_platform)
except ValueError as exc:
raise HTTPException(status_code=400, detail="Invalid client_platform") from exc
if resolved_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
if not config.ENABLE_DESKTOP_LOCAL_FILESYSTEM:
raise HTTPException(
status_code=400,
detail="Desktop local filesystem mode is disabled on this deployment.",
)
if resolved_platform != ClientPlatform.DESKTOP:
raise HTTPException(
status_code=400,
detail="desktop_local_folder mode is only available on desktop runtime.",
)
normalized_mounts: list[tuple[str, str]] = []
seen_mounts: set[str] = set()
for mount in local_mounts or []:
mount_id = mount.mount_id.strip()
root_path = mount.root_path.strip()
if not mount_id or not root_path:
continue
if mount_id in seen_mounts:
continue
seen_mounts.add(mount_id)
normalized_mounts.append((mount_id, root_path))
if not normalized_mounts:
raise HTTPException(
status_code=400,
detail=(
"local_filesystem_mounts must include at least one mount for "
"desktop_local_folder mode."
),
)
return FilesystemSelection(
mode=resolved_mode,
client_platform=resolved_platform,
local_mounts=tuple(
LocalFilesystemMount(mount_id=mount_id, root_path=root_path)
for mount_id, root_path in normalized_mounts
),
)
return FilesystemSelection(
mode=FilesystemMode.CLOUD,
client_platform=resolved_platform,
)
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
if attempt < 1:
attempt = 1
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
)
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
lock = manager.lock_for(str(thread_id))
if not lock.locked():
return {"status": "idle"}
if is_cancel_requested(str(thread_id)):
cancel_state = get_cancel_state(str(thread_id))
attempt = cancel_state[0] if cancel_state else 1
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
return {
"status": "cancelling",
"retry_after_ms": retry_after_ms,
"retry_after_at": retry_after_at,
}
return {"status": "busy"}
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
response.headers["retry-after-ms"] = str(retry_after_ms)
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
status_payload = _build_turn_status_payload(thread_id)
status = status_payload["status"]
if status == "idle":
return
if status == "cancelling":
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
detail = {
"errorCode": "TURN_CANCELLING",
"message": "A previous response is still stopping. Please try again in a moment.",
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
"retry_after_at": status_payload.get("retry_after_at"),
}
headers = (
{
"retry-after-ms": str(retry_after_ms),
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
}
if retry_after_ms > 0
else None
)
raise HTTPException(status_code=409, detail=detail, headers=headers)
raise HTTPException(
status_code=409,
detail={
"errorCode": "THREAD_BUSY",
"message": "Another response is still finishing for this thread. Please try again in a moment.",
},
)
def _find_pre_turn_checkpoint_id(
checkpoint_tuples: list,
*,
turn_id: str,
) -> str | None:
"""Locate the LangGraph checkpoint immediately before ``turn_id`` started.
``checkpoint_tuples`` arrives newest-first from
``checkpointer.alist(config)``. We walk OLDEST-first (``reversed``)
and remember the most recent checkpoint that does NOT belong to the
edited turn. As soon as we cross into the edited turn (a checkpoint
whose ``turn_id`` matches), we return the previously-tracked
checkpoint — that's the state immediately before ``turn_id`` began.
The naive "newest-first, return first non-matching" approach is
INCORRECT when later turns exist after ``turn_id``: their
checkpoints also satisfy ``cp_turn_id != turn_id`` and would be
returned before the real pre-turn boundary is reached.
Reads from ``cp_tuple.metadata`` (the durable surface promoted from
``configurable`` at write time) rather than ``config["configurable"]``
so the lookup is portable across checkpointer implementations.
Returns ``None`` when no eligible pre-turn checkpoint exists (e.g.
the edited turn is the very first turn of the thread). Callers fall
back to the oldest available checkpoint in that case.
"""
last_pre_turn_target: str | None = None
for cp_tuple in reversed(checkpoint_tuples): # oldest -> newest
metadata = getattr(cp_tuple, "metadata", None) or {}
cp_turn_id = metadata.get("turn_id") if isinstance(metadata, dict) else None
if cp_turn_id == turn_id:
# Crossed into the edited turn; the previous tracked
# checkpoint is the rewind target. May be ``None`` if we hit
# the edited turn on the very first iteration.
return last_pre_turn_target
try:
last_pre_turn_target = cp_tuple.config["configurable"]["checkpoint_id"]
except (KeyError, TypeError):
continue
return last_pre_turn_target
async def _revert_turns_for_regenerate(
*,
thread_id: int,
chat_turn_ids: list[str],
requester_user_id: str,
) -> dict:
"""Best-effort revert pass for every ``chat_turn_id`` in ``chat_turn_ids``.
Runs BEFORE the regenerate stream so the frontend can surface
partial-rollback feedback alongside the new assistant turn. Each
turn's actions are reverted in their own SAVEPOINTs (handled
inside :mod:`app.routes.agent_revert_route`'s helpers) so a single
failure never poisons the batch.
Sequencing inside the request: revert THEN regenerate. The
operation is NOT atomic and partial state IS surfaced — see the
plan's "Sequencing inside the request" note.
"""
from app.routes.agent_revert_route import (
RevertTurnActionResult,
_classify_outcome,
_OutcomeRollbackError,
_was_already_reverted,
_was_already_reverted_batch,
)
from app.services.revert_service import (
can_revert,
revert_action,
)
aggregated_results: list[dict] = []
# Exhaustive counters keep the response invariant
# ``total == sum(counters)`` true for ``data-revert-results``.
counts = {
"reverted": 0,
"already_reverted": 0,
"not_reversible": 0,
"permission_denied": 0,
"failed": 0,
"skipped": 0,
}
# Local import keeps the route module's existing imports tidy and
# avoids a circular dependency at module-load time.
from app.db import AgentActionLog as _AgentActionLog
async with shielded_async_session() as session:
for chat_turn_id in chat_turn_ids:
rows_stmt = (
select(_AgentActionLog)
.where(
_AgentActionLog.thread_id == thread_id,
_AgentActionLog.chat_turn_id == chat_turn_id,
)
.order_by(
_AgentActionLog.created_at.desc(),
_AgentActionLog.id.desc(),
)
)
rows = (await session.execute(rows_stmt)).scalars().all()
# Batch idempotency probe across the turn (single SELECT
# instead of one per row).
eligible_ids = [r.id for r in rows if r.reverse_of is None]
already_reverted_map = await _was_already_reverted_batch(
session, action_ids=eligible_ids
)
for action in rows:
if action.reverse_of is not None:
counts["skipped"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="skipped",
message="Row is itself a revert action; skipped.",
).model_dump()
)
continue
existing_revert_id = already_reverted_map.get(action.id)
if existing_revert_id is not None:
counts["already_reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
).model_dump()
)
continue
if not can_revert(
requester_user_id=requester_user_id,
action=action,
is_admin=False,
):
counts["permission_denied"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="permission_denied",
message="You are not allowed to revert this action.",
).model_dump()
)
continue
try:
async with session.begin_nested():
outcome = await revert_action(
session,
action=action,
requester_user_id=requester_user_id,
)
if outcome.status != "ok":
raise _OutcomeRollbackError(outcome)
except _OutcomeRollbackError as rollback:
outcome = rollback.outcome
classified = _classify_outcome(outcome)
if classified == "permission_denied":
counts["permission_denied"] += 1
else:
counts["not_reversible"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status=classified,
message=outcome.message,
).model_dump()
)
continue
except IntegrityError:
# Concurrent revert won the race against the
# pre-flight ``_was_already_reverted`` SELECT.
# Surface the winning revert id so the client can
# treat this as a successful idempotent op.
existing_revert_id = await _was_already_reverted(
session, action_id=action.id
)
counts["already_reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
).model_dump()
)
continue
except Exception as err: # pragma: no cover — defensive
_logger.exception(
"Unexpected revert failure during regenerate batch "
"for action_id=%s",
action.id,
)
counts["failed"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="failed",
error=str(err) or err.__class__.__name__,
).model_dump()
)
continue
counts["reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="reverted",
message=outcome.message,
new_action_id=outcome.new_action_id,
).model_dump()
)
try:
await session.commit()
except Exception:
_logger.exception(
"[regenerate-revert] Final commit failed; rolling back batch."
)
await session.rollback()
has_partial = (
counts["failed"] > 0
or counts["not_reversible"] > 0
or counts["permission_denied"] > 0
)
return {
"status": "partial" if has_partial else "ok",
"chat_turn_ids": chat_turn_ids,
"total": len(aggregated_results),
"reverted": counts["reverted"],
"already_reverted": counts["already_reverted"],
"not_reversible": counts["not_reversible"],
"permission_denied": counts["permission_denied"],
"failed": counts["failed"],
"skipped": counts["skipped"],
"results": aggregated_results,
}
def _try_delete_sandbox(thread_id: int) -> None:
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
from app.agents.new_chat.sandbox import (
delete_local_sandbox_files,
delete_sandbox,
is_sandbox_enabled,
)
if not is_sandbox_enabled():
return
async def _bg() -> None:
try:
await delete_sandbox(thread_id)
except Exception:
_logger.warning(
"Background sandbox delete failed for thread %s",
thread_id,
exc_info=True,
)
try:
delete_local_sandbox_files(thread_id)
except Exception:
_logger.warning(
"Local sandbox file cleanup failed for thread %s",
thread_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(_bg())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
pass
async def check_thread_access(
session: AsyncSession,
thread: NewChatThread,
user: User,
require_ownership: bool = False,
) -> bool:
"""
Check if a user has access to a thread based on visibility rules.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE (any member can access) - for read/update operations only
- Thread is a legacy thread (created_by_id is NULL) - only if user is search space owner
Args:
session: Database session
thread: The thread to check access for
user: The user requesting access
require_ownership: If True, ONLY the creator can perform this action (e.g., changing visibility).
This is checked FIRST, before visibility rules.
Returns:
True if access is granted
Raises:
HTTPException: If access is denied
"""
is_owner = thread.created_by_id == user.id
is_legacy = thread.created_by_id is None
# If ownership is required (e.g., changing visibility), ONLY the creator can do it
# This check comes first to ensure ownership-required operations are always creator-only
if require_ownership:
if not is_owner:
raise HTTPException(
status_code=403,
detail="Only the creator of this chat can perform this action",
)
return True
# Shared threads (SEARCH_SPACE) are accessible by any member for read/update operations
if thread.visibility == ChatVisibility.SEARCH_SPACE:
return True
# For legacy threads (created before visibility feature),
# only the search space owner can access
if is_legacy:
search_space_query = select(SearchSpace).filter(
SearchSpace.id == thread.search_space_id
)
search_space_result = await session.execute(search_space_query)
search_space = search_space_result.scalar_one_or_none()
is_search_space_owner = search_space and search_space.user_id == user.id
if is_search_space_owner:
return True
# Legacy threads are not accessible to non-owners
raise HTTPException(
status_code=403,
detail="You don't have access to this chat",
)
# For read access: owner can access their own private threads
if is_owner:
return True
# Private thread and user is not the owner
raise HTTPException(
status_code=403,
detail="You don't have access to this private chat",
)
# =============================================================================
# Thread Endpoints
# =============================================================================
@router.get("/threads", response_model=ThreadListResponse)
async def list_threads(
search_space_id: int,
limit: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List all accessible threads for the current user in a search space.
Returns threads and archived_threads for ThreadListPrimitive.
A user can see threads that are:
- Created by them (regardless of visibility)
- Shared with the search space (visibility = SEARCH_SPACE)
- Legacy threads with no creator (created_by_id is NULL) - only if user is search space owner
Args:
search_space_id: The search space to list threads for
limit: Optional limit on number of threads to return (applies to active threads only)
Requires CHATS_READ permission.
"""
try:
await check_permission(
session,
user,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check if user is the search space owner (for legacy thread visibility)
search_space_query = select(SearchSpace).filter(
SearchSpace.id == search_space_id
)
search_space_result = await session.execute(search_space_query)
search_space = search_space_result.scalar_one_or_none()
is_search_space_owner = search_space and search_space.user_id == user.id
# Build filter conditions:
# 1. Created by the current user (any visibility)
# 2. Shared with the search space (visibility = SEARCH_SPACE)
# 3. Legacy threads (created_by_id is NULL) - only visible to search space owner
filter_conditions = [
NewChatThread.created_by_id == user.id,
NewChatThread.visibility == ChatVisibility.SEARCH_SPACE,
]
# Only include legacy threads for the search space owner
if is_search_space_owner:
filter_conditions.append(NewChatThread.created_by_id.is_(None))
query = (
select(NewChatThread)
.filter(
NewChatThread.search_space_id == search_space_id,
or_(*filter_conditions),
)
.order_by(NewChatThread.updated_at.desc())
)
result = await session.execute(query)
all_threads = result.scalars().all()
# Separate active and archived threads
threads = []
archived_threads = []
for thread in all_threads:
# Legacy threads (no creator) are treated as own threads for owner
is_own_thread = thread.created_by_id == user.id or (
thread.created_by_id is None and is_search_space_owner
)
item = ThreadListItem(
id=thread.id,
title=thread.title,
archived=thread.archived,
visibility=thread.visibility,
created_by_id=thread.created_by_id,
is_own_thread=is_own_thread,
created_at=thread.created_at,
updated_at=thread.updated_at,
)
if thread.archived:
archived_threads.append(item)
else:
threads.append(item)
# Apply limit to active threads if specified
if limit is not None and limit > 0:
threads = threads[:limit]
return ThreadListResponse(threads=threads, archived_threads=archived_threads)
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching threads: {e!s}",
) from None
@router.get("/threads/search", response_model=list[ThreadListItem])
async def search_threads(
search_space_id: int,
title: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Search accessible threads by title in a search space.
A user can search threads that are:
- Created by them (regardless of visibility)
- Shared with the search space (visibility = SEARCH_SPACE)
- Legacy threads with no creator (created_by_id is NULL) - only if user is search space owner
Args:
search_space_id: The search space to search in
title: The search query (case-insensitive partial match)
Requires CHATS_READ permission.
"""
try:
await check_permission(
session,
user,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check if user is the search space owner (for legacy thread visibility)
search_space_query = select(SearchSpace).filter(
SearchSpace.id == search_space_id
)
search_space_result = await session.execute(search_space_query)
search_space = search_space_result.scalar_one_or_none()
is_search_space_owner = search_space and search_space.user_id == user.id
# Build filter conditions
filter_conditions = [
NewChatThread.created_by_id == user.id,
NewChatThread.visibility == ChatVisibility.SEARCH_SPACE,
]
# Only include legacy threads for the search space owner
if is_search_space_owner:
filter_conditions.append(NewChatThread.created_by_id.is_(None))
# Search accessible threads by title (case-insensitive)
query = (
select(NewChatThread)
.filter(
NewChatThread.search_space_id == search_space_id,
NewChatThread.title.ilike(f"%{title}%"),
or_(*filter_conditions),
)
.order_by(NewChatThread.updated_at.desc())
)
result = await session.execute(query)
threads = result.scalars().all()
return [
ThreadListItem(
id=thread.id,
title=thread.title,
archived=thread.archived,
visibility=thread.visibility,
created_by_id=thread.created_by_id,
# Legacy threads (no creator) are treated as own threads for owner
is_own_thread=(
thread.created_by_id == user.id
or (thread.created_by_id is None and is_search_space_owner)
),
created_at=thread.created_at,
updated_at=thread.updated_at,
)
for thread in threads
]
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while searching threads: {e!s}",
) from None
@router.post("/threads", response_model=NewChatThreadRead)
async def create_thread(
thread: NewChatThreadCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a new chat thread.
The thread is created with the specified visibility (defaults to PRIVATE).
The current user is recorded as the creator of the thread.
Requires CHATS_CREATE permission.
"""
try:
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to create chats in this search space",
)
now = datetime.now(UTC)
db_thread = NewChatThread(
title=thread.title,
archived=thread.archived,
visibility=thread.visibility,
search_space_id=thread.search_space_id,
created_by_id=user.id,
updated_at=now,
)
session.add(db_thread)
await session.commit()
await session.refresh(db_thread)
return db_thread
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while creating the thread: {e!s}",
) from None
@router.get("/threads/{thread_id}", response_model=ThreadHistoryLoadResponse)
async def get_thread_messages(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get a thread with all its messages.
This is used by ThreadHistoryAdapter.load() to restore conversation.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_READ permission.
"""
try:
# Get thread first
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
# Check permission to read chats in this search space
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Get messages with their authors and token usage loaded
messages_result = await session.execute(
select(NewChatMessage)
.options(
selectinload(NewChatMessage.author),
selectinload(NewChatMessage.token_usage),
)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
)
db_messages = messages_result.scalars().all()
# Return messages in the format expected by assistant-ui
messages = [
NewChatMessageRead(
id=msg.id,
thread_id=msg.thread_id,
role=msg.role,
content=msg.content,
created_at=msg.created_at,
author_id=msg.author_id,
author_display_name=msg.author.display_name if msg.author else None,
author_avatar_url=msg.author.avatar_url if msg.author else None,
token_usage=TokenUsageSummary.model_validate(msg.token_usage)
if msg.token_usage
else None,
turn_id=msg.turn_id,
)
for msg in db_messages
]
return ThreadHistoryLoadResponse(messages=messages)
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
) from None
@router.get("/threads/{thread_id}/full", response_model=NewChatThreadWithMessages)
async def get_thread_full(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get full thread details with all messages.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_READ permission.
"""
try:
result = await session.execute(
select(NewChatThread)
.options(
selectinload(NewChatThread.messages).selectinload(
NewChatMessage.token_usage
),
)
.filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Check if thread has any comments
comment_count = await session.scalar(
select(func.count())
.select_from(ChatComment)
.join(NewChatMessage, ChatComment.message_id == NewChatMessage.id)
.where(NewChatMessage.thread_id == thread.id)
)
return {
**thread.__dict__,
"messages": thread.messages,
"has_comments": (comment_count or 0) > 0,
}
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
) from None
@router.put("/threads/{thread_id}", response_model=NewChatThreadRead)
async def update_thread(
thread_id: int,
thread_update: NewChatThreadUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update a thread (title, archived status).
Used for renaming and archiving threads.
- PRIVATE threads: Only the creator can update
- SEARCH_SPACE threads: Any member with CHATS_UPDATE permission can update
Requires CHATS_UPDATE permission.
"""
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
db_thread = result.scalars().first()
if not db_thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
db_thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# For PRIVATE threads, only the creator can update
# For SEARCH_SPACE threads, any member with permission can update
if db_thread.visibility == ChatVisibility.PRIVATE:
await check_thread_access(session, db_thread, user, require_ownership=True)
# Update fields
update_data = thread_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_thread, key, value)
db_thread.updated_at = datetime.now(UTC)
await session.commit()
await session.refresh(db_thread)
return db_thread
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while updating the thread: {e!s}",
) from None
@router.delete("/threads/{thread_id}", response_model=dict)
async def delete_thread(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete a thread and all its messages.
- PRIVATE threads: Only the creator can delete
- SEARCH_SPACE threads: Any member with CHATS_DELETE permission can delete
Requires CHATS_DELETE permission.
"""
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
db_thread = result.scalars().first()
if not db_thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
db_thread.search_space_id,
Permission.CHATS_DELETE.value,
"You don't have permission to delete chats in this search space",
)
# For PRIVATE threads, only the creator can delete
# For SEARCH_SPACE threads, any member with permission can delete
# Legacy threads (created_by_id is NULL) have no recorded creator,
# so we skip strict ownership and fall through to legacy handling
# which allows the search space owner to delete them
if db_thread.visibility == ChatVisibility.PRIVATE:
await check_thread_access(
session,
db_thread,
user,
require_ownership=(db_thread.created_by_id is not None),
)
await session.delete(db_thread)
await session.commit()
_try_delete_sandbox(thread_id)
return {"message": "Thread deleted successfully"}
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400, detail="Cannot delete thread due to existing dependencies."
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while deleting the thread: {e!s}",
) from None
@router.patch("/threads/{thread_id}/visibility", response_model=NewChatThreadRead)
async def update_thread_visibility(
thread_id: int,
visibility_update: NewChatThreadVisibilityUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update the visibility/sharing settings of a thread.
Only the creator of the thread can change its visibility.
- PRIVATE: Only the creator can access the thread (default)
- SEARCH_SPACE: All members of the search space can access the thread
Requires CHATS_UPDATE permission.
"""
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
db_thread = result.scalars().first()
if not db_thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
db_thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# Only the creator can change visibility
await check_thread_access(session, db_thread, user, require_ownership=True)
# Update visibility
db_thread.visibility = visibility_update.visibility
db_thread.updated_at = datetime.now(UTC)
await session.commit()
await session.refresh(db_thread)
return db_thread
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while updating thread visibility: {e!s}",
) from None
# =============================================================================
# Snapshot Endpoints
# =============================================================================
@router.post(
"/threads/{thread_id}/snapshots", response_model=PublicChatSnapshotCreateResponse
)
async def create_thread_snapshot(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a public snapshot of the thread.
Returns existing snapshot URL if content unchanged (deduplication).
"""
from app.services.public_chat_service import create_snapshot
return await create_snapshot(
session=session,
thread_id=thread_id,
user=user,
)
@router.get(
"/threads/{thread_id}/snapshots", response_model=PublicChatSnapshotListResponse
)
async def list_thread_snapshots(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List all public snapshots for this thread.
Only the thread owner can view snapshots.
"""
from app.services.public_chat_service import list_snapshots_for_thread
return PublicChatSnapshotListResponse(
snapshots=await list_snapshots_for_thread(
session=session,
thread_id=thread_id,
user=user,
)
)
@router.delete("/threads/{thread_id}/snapshots/{snapshot_id}")
async def delete_thread_snapshot(
thread_id: int,
snapshot_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete a specific snapshot.
Only the thread owner can delete snapshots.
"""
from app.services.public_chat_service import delete_snapshot
await delete_snapshot(
session=session,
thread_id=thread_id,
snapshot_id=snapshot_id,
user=user,
)
return {"message": "Snapshot deleted successfully"}
# =============================================================================
# Message Endpoints
# =============================================================================
@router.post("/threads/{thread_id}/messages", response_model=NewChatMessageRead)
async def append_message(
thread_id: int,
request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Append a message to a thread.
This is used by ThreadHistoryAdapter.append() to persist messages.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_UPDATE permission.
"""
try:
# Parse raw body - extract only role and content, ignoring extra fields
raw_body = await request.json()
role = raw_body.get("role")
content = raw_body.get("content")
if not role:
raise HTTPException(status_code=400, detail="Missing required field: role")
if content is None:
raise HTTPException(
status_code=400, detail="Missing required field: content"
)
# Validate role early (before any DB work)
role_str = role.lower() if isinstance(role, str) else role
try:
message_role = NewChatMessageRole(role_str)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid role: {role}. Must be 'user', 'assistant', or 'system'.",
) from None
# Get thread
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Create message. ``turn_id`` is the per-turn correlation id from
# ``configurable.turn_id`` (added in migration 136) — when the
# client streams it back to ``appendMessage``, we persist it so
# C1's edit-from-arbitrary-position can later map this message
# back to the LangGraph checkpoint that produced its turn.
raw_turn_id = raw_body.get("turn_id")
turn_id_value = (
str(raw_turn_id).strip()
if isinstance(raw_turn_id, str) and raw_turn_id.strip()
else None
)
db_message = NewChatMessage(
thread_id=thread_id,
role=message_role,
content=content,
author_id=user.id,
turn_id=turn_id_value,
)
session.add(db_message)
# Update thread's updated_at timestamp
thread.updated_at = datetime.now(UTC)
# flush assigns the PK/defaults without a round-trip SELECT
await session.flush()
# Persist token usage if provided (for assistant messages).
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
# forwarded by the FE through the appendMessage round-trip so
# the historical TokenUsage row matches the credit debit applied
# at finalize time.
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
await record_token_usage(
session,
usage_type="chat",
search_space_id=thread.search_space_id,
user_id=user.id,
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
thread_id=thread_id,
message_id=db_message.id,
)
await session.commit()
# Build response manually to avoid lazy-loading the token_usage
# relationship after commit (which would trigger MissingGreenlet).
return NewChatMessageRead(
id=db_message.id,
thread_id=db_message.thread_id,
role=db_message.role,
content=db_message.content,
created_at=db_message.created_at,
author_id=db_message.author_id,
token_usage=None,
turn_id=db_message.turn_id,
)
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while appending the message: {e!s}",
) from None
@router.get("/threads/{thread_id}/messages", response_model=list[NewChatMessageRead])
async def list_messages(
thread_id: int,
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List messages in a thread with pagination.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_READ permission.
"""
try:
# Verify thread exists and user has access
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Get messages
query = (
select(NewChatMessage)
.options(selectinload(NewChatMessage.token_usage))
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
.offset(skip)
.limit(limit)
)
result = await session.execute(query)
return result.scalars().all()
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching messages: {e!s}",
) from None
# =============================================================================
# Agent Tools Endpoint
# =============================================================================
@router.get("/agent/tools", response_model=list[AgentToolInfo])
async def list_agent_tools(
_user: User = Depends(current_active_user),
):
"""Return the list of built-in agent tools with their metadata.
Hidden (WIP) tools are excluded from the response.
"""
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
return [
AgentToolInfo(
name=t.name,
description=t.description,
enabled_by_default=t.enabled_by_default,
)
for t in BUILTIN_TOOLS
if not t.hidden
]
# =============================================================================
# Chat Streaming Endpoint
# =============================================================================
@router.post("/new_chat")
async def handle_new_chat(
request: NewChatRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Stream chat responses from the deep agent.
This endpoint handles the new chat functionality with streaming responses
using Server-Sent Events (SSE) format compatible with Vercel AI SDK.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_CREATE permission.
"""
try:
# Verify thread exists and user has permission
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == request.chat_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to chat in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(request.chat_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
# Get search space to check LLM config preferences
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
)
search_space = search_space_result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Use agent_llm_id from search space for chat operations
# Positive IDs load from NewLLMConfig database table
# Negative IDs load from YAML global configs
# Falls back to -1 (first global config) if not configured
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs usable.
await session.commit()
# Close the dependency session now so its connection returns to
# the pool before streaming begins. Without this, Starlette's
# BaseHTTPMiddleware cancels the scope on client disconnect and
# the dependency generator's __aexit__ never runs, orphaning the
# connection (the "Exception terminating connection" errors).
await session.close()
image_urls = (
[p.as_data_url() for p in request.user_images]
if request.user_images
else None
)
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
search_space_id=request.search_space_id,
chat_id=request.chat_id,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
disabled_tools=request.disabled_tools,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=image_urls,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred: {e!s}",
) from None
@router.post(
"/threads/{thread_id}/cancel-active-turn",
response_model=CancelActiveTurnResponse,
)
async def cancel_active_turn(
thread_id: int,
response: Response,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Signal cancellation for the currently running turn on ``thread_id``."""
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
if status_payload["status"] == "idle":
return CancelActiveTurnResponse(
status="idle",
error_code="NO_ACTIVE_TURN",
)
request_cancel(str(thread_id))
response.status_code = 202
updated_payload = _build_turn_status_payload(thread_id)
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
retry_after_at = (
int(updated_payload["retry_after_at"])
if "retry_after_at" in updated_payload
else None
)
if retry_after_ms > 0:
_set_retry_after_headers(response, retry_after_ms)
return CancelActiveTurnResponse(
status="cancelling",
error_code="TURN_CANCELLING",
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
retry_after_at=retry_after_at,
)
@router.get(
"/threads/{thread_id}/turn-status",
response_model=TurnStatusResponse,
)
async def get_turn_status(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to view chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
return TurnStatusResponse(
status=status_payload["status"], # type: ignore[arg-type]
active_turn_id=None,
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
)
# =============================================================================
# Chat Regeneration Endpoint (Edit/Reload)
# =============================================================================
@router.post("/threads/{thread_id}/regenerate")
async def regenerate_response(
thread_id: int,
request: RegenerateRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Regenerate the AI response for a chat thread.
This endpoint supports two operations:
1. **Edit**: Provide a new `user_query` to replace the last user message and regenerate
2. **Reload**: Leave `user_query` empty (or None) to regenerate with the same query
Both operations:
- Rewind the LangGraph checkpointer to the state before the last AI response
- Delete the last user message and AI response from the database
- Stream a new response from that checkpoint
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_UPDATE permission.
"""
from langchain_core.messages import HumanMessage
from app.agents.new_chat.checkpointer import get_checkpointer
try:
# Verify thread exists and user has permission
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
# Get the checkpointer and state history
checkpointer = await get_checkpointer()
config = {"configurable": {"thread_id": str(thread_id)}}
# Collect checkpoint tuples from the async iterator
# CheckpointTuple has: config, checkpoint (dict with channel_values), metadata, parent_config
checkpoint_tuples = []
async for cp_tuple in checkpointer.alist(config):
checkpoint_tuples.append(cp_tuple)
if not checkpoint_tuples:
raise HTTPException(
status_code=400, detail="No conversation history found for this thread"
)
# Find the checkpoint to rewind to
# Checkpoints are in reverse chronological order (newest first)
# We need to find a checkpoint before the last user message was added
#
# The checkpointer stores states after each node execution.
# For a typical conversation flow:
# - User sends message -> state 1 (with HumanMessage)
# - Agent responds -> state 2 (with HumanMessage + AIMessage)
#
# To regenerate, we need the state BEFORE the last HumanMessage was processed
target_checkpoint_id = None
user_query_to_use = request.user_query
regenerate_image_urls: list[str] = []
# ---------------------------------------------------------------
# Edit-from-arbitrary-position. When the client passes
# ``from_message_id`` we look up its persisted ``turn_id`` (added
# in migration 136) and pick the checkpoint immediately before
# that turn started.
#
# Legacy graceful-degradation contract:
# * Rows persisted BEFORE migration 136 have ``turn_id IS NULL``.
# Returning 400 in that case is the wrong UX — the user is
# editing an old message in an existing thread and just wants
# it to work. We instead skip the checkpoint rewind (the
# stream falls back to the latest state) and skip the revert
# pass (no chat_turn_id available to walk). Deletion still
# uses ``created_at``, so the messages-after-cursor slice is
# correct on both legacy and post-136 rows.
# ---------------------------------------------------------------
from_message_turn_id: str | None = None
from_message_created_at: datetime | None = None
legacy_from_message: bool = False
if request.from_message_id is not None:
from_msg_row = await session.execute(
select(NewChatMessage).filter(
NewChatMessage.id == request.from_message_id,
NewChatMessage.thread_id == thread_id,
)
)
from_msg = from_msg_row.scalars().first()
if from_msg is None:
raise HTTPException(
status_code=404,
detail="from_message_id not found in this thread.",
)
from_message_created_at = from_msg.created_at
if not from_msg.turn_id:
# Legacy row — surface the degradation in logs but let
# the request proceed with the slice-based delete and a
# cold-start checkpoint.
legacy_from_message = True
_logger.warning(
"[regenerate] from_message_id=%s on thread=%s has no "
"turn_id (legacy row pre-migration-136). Falling back "
"to slice-based delete without checkpoint rewind. "
"revert_actions=%s will be ignored.",
request.from_message_id,
thread_id,
request.revert_actions,
)
else:
from_message_turn_id = from_msg.turn_id
# Walk oldest-to-newest and pick the LAST checkpoint whose
# ``turn_id`` differs from the edited turn — that's the state
# immediately before this turn started running. We read from
# ``metadata`` (the durable surface) rather than
# ``config["configurable"]`` so the lookup works across
# checkpointer implementations.
target_checkpoint_id = _find_pre_turn_checkpoint_id(
checkpoint_tuples,
turn_id=from_message_turn_id,
)
if target_checkpoint_id is None and len(checkpoint_tuples) > 0:
# Fall back to the oldest checkpoint — better than
# 400ing when the agent didn't checkpoint pre-turn
# (e.g. very first turn of the thread).
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
"checkpoint_id"
]
# Look through checkpoints to find the right one
# We want to find the checkpoint just before the last HumanMessage.
# We enter this branch when:
# * the client did NOT pin ``from_message_id`` (legacy reload/edit), OR
# * the client pinned ``from_message_id`` but the row is a
# legacy pre-migration-136 row with no ``turn_id`` (we
# downgraded to the same heuristic as a regular reload).
# We DO skip it when a real turn_id pinned ``target_checkpoint_id``
# — that's the C1 happy path and the heuristic below would just
# re-derive a worse target.
if request.from_message_id is None or legacy_from_message:
for i, cp_tuple in enumerate(checkpoint_tuples):
# Access the checkpoint's channel_values which contains "messages"
checkpoint_data = cp_tuple.checkpoint
channel_values = checkpoint_data.get("channel_values", {})
state_messages = channel_values.get("messages", [])
if state_messages:
last_msg = state_messages[-1]
# Find a checkpoint where the last message is NOT a HumanMessage
# This means we're at a state before the user's last message
if not isinstance(last_msg, HumanMessage):
# If no new user_query provided (reload), extract from a later checkpoint
if user_query_to_use is None and i > 0:
# Get the user query from a more recent checkpoint
for prev_cp_tuple in checkpoint_tuples[:i]:
prev_checkpoint_data = prev_cp_tuple.checkpoint
prev_channel_values = prev_checkpoint_data.get(
"channel_values", {}
)
prev_messages = prev_channel_values.get("messages", [])
for msg in reversed(prev_messages):
if isinstance(msg, HumanMessage):
q, imgs = split_langchain_human_content(
msg.content
)
user_query_to_use = q
regenerate_image_urls = imgs
break
if user_query_to_use is not None and (
str(user_query_to_use).strip()
or regenerate_image_urls
):
break
target_checkpoint_id = cp_tuple.config["configurable"][
"checkpoint_id"
]
break
# If we couldn't find a good checkpoint, try alternative approaches
if target_checkpoint_id is None and checkpoint_tuples:
if len(checkpoint_tuples) == 1:
# Only one checkpoint - get the user query from it if not provided
if user_query_to_use is None:
checkpoint_data = checkpoint_tuples[0].checkpoint
channel_values = checkpoint_data.get("channel_values", {})
state_messages = channel_values.get("messages", [])
for msg in state_messages:
if isinstance(msg, HumanMessage):
q, imgs = split_langchain_human_content(msg.content)
user_query_to_use = q
regenerate_image_urls = imgs
break
else:
# Use the oldest checkpoint
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
"checkpoint_id"
]
# If we still don't have a user query, get it from the database
if user_query_to_use is None:
# Get the last user message from the database
last_user_msg_result = await session.execute(
select(NewChatMessage)
.filter(
NewChatMessage.thread_id == thread_id,
NewChatMessage.role == NewChatMessageRole.USER,
)
.order_by(NewChatMessage.created_at.desc())
.limit(1)
)
last_user_msg = last_user_msg_result.scalars().first()
if last_user_msg:
content = last_user_msg.content
if isinstance(content, str):
user_query_to_use = content
elif isinstance(content, list):
plain, imgs = split_persisted_user_content_parts(content)
user_query_to_use = plain
regenerate_image_urls = imgs
if isinstance(user_query_to_use, list):
user_query_to_use, regenerate_image_urls = split_langchain_human_content(
user_query_to_use
)
if request.user_images is not None:
regenerate_image_urls = [p.as_data_url() for p in request.user_images]
if user_query_to_use is None:
raise HTTPException(
status_code=400,
detail="Could not determine user query for regeneration. Please provide a user_query.",
)
if not str(user_query_to_use).strip() and not regenerate_image_urls:
raise HTTPException(
status_code=400,
detail="Could not determine user query for regeneration. Please provide a user_query.",
)
# Get the messages to delete AFTER streaming succeeds.
# This prevents data loss if streaming fails.
#
# When ``from_message_id`` is set we slice from that message
# forward (using ``created_at`` so we also catch any tool/system
# messages persisted into the same turn). Otherwise
# we keep the legacy "last 2 messages" rewind.
if request.from_message_id is not None and from_message_created_at is not None:
last_messages_result = await session.execute(
select(NewChatMessage)
.filter(
NewChatMessage.thread_id == thread_id,
NewChatMessage.created_at >= from_message_created_at,
)
.order_by(NewChatMessage.created_at.desc())
)
else:
last_messages_result = await session.execute(
select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at.desc())
.limit(2)
)
messages_to_delete = list(last_messages_result.scalars().all())
message_ids_to_delete = [msg.id for msg in messages_to_delete]
# When revert_actions is requested, collect the set of
# ``chat_turn_id``s present in the slice we're about to delete.
# Each one will be reverted (best-effort) BEFORE the regenerate
# stream begins. Legacy rows have ``turn_id=None`` and silently
# contribute nothing — we already logged the degradation above.
revert_turn_ids: list[str] = []
if (
request.revert_actions
and request.from_message_id is not None
and not legacy_from_message
):
seen_turns: set[str] = set()
for msg in messages_to_delete:
tid = msg.turn_id
if tid and tid not in seen_turns:
seen_turns.add(tid)
revert_turn_ids.append(tid)
# Get search space for LLM config
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
)
search_space = search_space_result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
await session.commit()
await session.close()
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
# This prevents data loss if streaming fails (network error, LLM error, etc.)
async def stream_with_cleanup():
streaming_completed = False
# Best-effort revert pass BEFORE the regenerate stream begins.
# Each turn is reverted independently (per-row SAVEPOINTs
# inside the route helper) and the per-action results are surfaced
# on a single ``data-revert-results`` SSE event so the frontend
# can render any failed rows alongside the new turn. Failures here
# do NOT abort the regeneration — partial rollback is documented
# behaviour.
if revert_turn_ids:
revert_results = await _revert_turns_for_regenerate(
thread_id=thread_id,
chat_turn_ids=revert_turn_ids,
requester_user_id=str(user.id),
)
envelope = {
"type": "data-revert-results",
"data": revert_results,
}
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
try:
async for chunk in stream_new_chat(
user_query=str(user_query_to_use),
search_space_id=request.search_space_id,
chat_id=thread_id,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
checkpoint_id=target_checkpoint_id,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
disabled_tools=request.disabled_tools,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=regenerate_image_urls or None,
flow="regenerate",
):
yield chunk
streaming_completed = True
finally:
# Only delete old messages if streaming completed successfully.
# Uses a fresh session since stream_new_chat manages its own.
if streaming_completed and message_ids_to_delete:
try:
async with shielded_async_session() as cleanup_session:
for msg_id in message_ids_to_delete:
_res = await cleanup_session.execute(
select(NewChatMessage).filter(
NewChatMessage.id == msg_id
)
)
_msg = _res.scalars().first()
if _msg:
await cleanup_session.delete(_msg)
await cleanup_session.commit()
from app.services.public_chat_service import (
delete_affected_snapshots,
)
await delete_affected_snapshots(
cleanup_session, thread_id, message_ids_to_delete
)
except Exception as cleanup_error:
_logger.warning(
"[regenerate] Failed to delete old messages: %s",
cleanup_error,
)
# Return streaming response with checkpoint_id for rewinding
return StreamingResponse(
stream_with_cleanup(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred during regeneration: {e!s}",
) from None
# =============================================================================
# Resume Interrupted Chat Endpoint
# =============================================================================
@router.post("/threads/{thread_id}/resume")
async def resume_chat(
thread_id: int,
request: ResumeRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to chat in this search space",
)
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
)
search_space = search_space_result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
decisions = [d.model_dump() for d in request.decisions]
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
await session.commit()
await session.close()
return StreamingResponse(
stream_resume_chat(
chat_id=thread_id,
search_space_id=request.search_space_id,
decisions=decisions,
user_id=str(user.id),
llm_config_id=llm_config_id,
thread_visibility=thread.visibility,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred during resume: {e!s}",
) from None