SurfSense/surfsense_backend/app/routes/new_chat_routes.py

2419 lines
91 KiB
Python
Raw Normal View History

"""
Routes for the new chat feature with assistant-ui integration.
These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
- GET /threads - List threads for sidebar (ThreadListPrimitive)
- POST /threads - Create a new thread
- GET /threads/{thread_id} - Get thread with messages (load)
- PUT /threads/{thread_id} - Update thread (rename, archive)
- DELETE /threads/{thread_id} - Delete thread
- POST /threads/{thread_id}/messages - Append message
"""
import asyncio
2026-04-29 07:20:31 -07:00
import json
import logging
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request, Response
2025-12-21 19:33:52 -08:00
from fastapi.responses import StreamingResponse
from sqlalchemy import func, or_, text as sa_text
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.agents.new_chat.filesystem_selection import (
ClientPlatform,
FilesystemMode,
FilesystemSelection,
LocalFilesystemMount,
)
from app.agents.new_chat.middleware.busy_mutex import (
get_cancel_state,
is_cancel_requested,
manager,
request_cancel,
)
from app.config import config
from app.db import (
ChatComment,
2026-01-13 00:17:12 -08:00
ChatVisibility,
NewChatMessage,
NewChatMessageRole,
NewChatThread,
Permission,
2025-12-21 19:33:52 -08:00
SearchSpace,
TokenUsage,
User,
get_async_session,
shielded_async_session,
)
from app.schemas.new_chat import (
AgentToolInfo,
CancelActiveTurnResponse,
LocalFilesystemMountPayload,
NewChatMessageRead,
2025-12-21 19:33:52 -08:00
NewChatRequest,
NewChatThreadCreate,
NewChatThreadRead,
NewChatThreadUpdate,
2026-01-13 00:17:12 -08:00
NewChatThreadVisibilityUpdate,
NewChatThreadWithMessages,
PublicChatSnapshotCreateResponse,
PublicChatSnapshotListResponse,
RegenerateRequest,
ResumeRequest,
ThreadHistoryLoadResponse,
ThreadListItem,
ThreadListResponse,
TokenUsageSummary,
TurnStatusResponse,
)
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
from app.utils.perf import get_perf_logger
from app.utils.rbac import check_permission
from app.utils.user_message_multimodal import (
split_langchain_human_content,
split_persisted_user_content_parts,
)
_logger = logging.getLogger(__name__)
_perf_log = get_perf_logger()
_background_tasks: set[asyncio.Task] = set()
TURN_CANCELLING_INITIAL_DELAY_MS = 200
TURN_CANCELLING_BACKOFF_FACTOR = 2
TURN_CANCELLING_MAX_DELAY_MS = 1500
router = APIRouter()
def _resolve_filesystem_selection(
*,
mode: str,
client_platform: str,
local_mounts: list[LocalFilesystemMountPayload] | None,
) -> FilesystemSelection:
"""Validate and normalize filesystem mode settings from request payload."""
try:
resolved_mode = FilesystemMode(mode)
except ValueError as exc:
raise HTTPException(status_code=400, detail="Invalid filesystem_mode") from exc
try:
resolved_platform = ClientPlatform(client_platform)
except ValueError as exc:
raise HTTPException(status_code=400, detail="Invalid client_platform") from exc
if resolved_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
if not config.ENABLE_DESKTOP_LOCAL_FILESYSTEM:
raise HTTPException(
status_code=400,
detail="Desktop local filesystem mode is disabled on this deployment.",
)
if resolved_platform != ClientPlatform.DESKTOP:
raise HTTPException(
status_code=400,
detail="desktop_local_folder mode is only available on desktop runtime.",
)
normalized_mounts: list[tuple[str, str]] = []
seen_mounts: set[str] = set()
for mount in local_mounts or []:
mount_id = mount.mount_id.strip()
root_path = mount.root_path.strip()
if not mount_id or not root_path:
continue
if mount_id in seen_mounts:
continue
seen_mounts.add(mount_id)
normalized_mounts.append((mount_id, root_path))
if not normalized_mounts:
raise HTTPException(
status_code=400,
detail=(
"local_filesystem_mounts must include at least one mount for "
"desktop_local_folder mode."
),
)
return FilesystemSelection(
mode=resolved_mode,
client_platform=resolved_platform,
local_mounts=tuple(
LocalFilesystemMount(mount_id=mount_id, root_path=root_path)
for mount_id, root_path in normalized_mounts
),
)
return FilesystemSelection(
mode=FilesystemMode.CLOUD,
client_platform=resolved_platform,
)
def _compute_turn_cancelling_retry_delay(attempt: int) -> int:
"""Bounded exponential delay for TURN_CANCELLING retry hints."""
if attempt < 1:
attempt = 1
delay = TURN_CANCELLING_INITIAL_DELAY_MS * (
TURN_CANCELLING_BACKOFF_FACTOR ** (attempt - 1)
)
return min(delay, TURN_CANCELLING_MAX_DELAY_MS)
def _build_turn_status_payload(thread_id: int) -> dict[str, object]:
lock = manager.lock_for(str(thread_id))
if not lock.locked():
return {"status": "idle"}
if is_cancel_requested(str(thread_id)):
cancel_state = get_cancel_state(str(thread_id))
attempt = cancel_state[0] if cancel_state else 1
retry_after_ms = _compute_turn_cancelling_retry_delay(attempt)
retry_after_at = int(datetime.now(UTC).timestamp() * 1000) + retry_after_ms
return {
"status": "cancelling",
"retry_after_ms": retry_after_ms,
"retry_after_at": retry_after_at,
}
return {"status": "busy"}
def _set_retry_after_headers(response: Response, retry_after_ms: int) -> None:
response.headers["retry-after-ms"] = str(retry_after_ms)
response.headers["Retry-After"] = str(max(1, (retry_after_ms + 999) // 1000))
def _raise_if_thread_busy_for_start(thread_id: int) -> None:
status_payload = _build_turn_status_payload(thread_id)
status = status_payload["status"]
if status == "idle":
return
if status == "cancelling":
retry_after_ms = int(status_payload.get("retry_after_ms") or 0)
detail = {
"errorCode": "TURN_CANCELLING",
"message": "A previous response is still stopping. Please try again in a moment.",
"retry_after_ms": retry_after_ms if retry_after_ms > 0 else None,
"retry_after_at": status_payload.get("retry_after_at"),
}
headers = (
{
"retry-after-ms": str(retry_after_ms),
"Retry-After": str(max(1, (retry_after_ms + 999) // 1000)),
}
if retry_after_ms > 0
else None
)
raise HTTPException(status_code=409, detail=detail, headers=headers)
raise HTTPException(
status_code=409,
detail={
"errorCode": "THREAD_BUSY",
"message": "Another response is still finishing for this thread. Please try again in a moment.",
},
)
2026-04-29 07:20:31 -07:00
def _find_pre_turn_checkpoint_id(
checkpoint_tuples: list,
*,
turn_id: str,
) -> str | None:
"""Locate the LangGraph checkpoint immediately before ``turn_id`` started.
``checkpoint_tuples`` arrives newest-first from
``checkpointer.alist(config)``. We walk OLDEST-first (``reversed``)
and remember the most recent checkpoint that does NOT belong to the
edited turn. As soon as we cross into the edited turn (a checkpoint
whose ``turn_id`` matches), we return the previously-tracked
checkpoint that's the state immediately before ``turn_id`` began.
The naive "newest-first, return first non-matching" approach is
INCORRECT when later turns exist after ``turn_id``: their
checkpoints also satisfy ``cp_turn_id != turn_id`` and would be
returned before the real pre-turn boundary is reached.
Reads from ``cp_tuple.metadata`` (the durable surface promoted from
``configurable`` at write time) rather than ``config["configurable"]``
so the lookup is portable across checkpointer implementations.
Returns ``None`` when no eligible pre-turn checkpoint exists (e.g.
the edited turn is the very first turn of the thread). Callers fall
back to the oldest available checkpoint in that case.
"""
last_pre_turn_target: str | None = None
for cp_tuple in reversed(checkpoint_tuples): # oldest -> newest
metadata = getattr(cp_tuple, "metadata", None) or {}
cp_turn_id = metadata.get("turn_id") if isinstance(metadata, dict) else None
if cp_turn_id == turn_id:
# Crossed into the edited turn; the previous tracked
# checkpoint is the rewind target. May be ``None`` if we hit
# the edited turn on the very first iteration.
return last_pre_turn_target
try:
last_pre_turn_target = cp_tuple.config["configurable"]["checkpoint_id"]
except (KeyError, TypeError):
continue
return last_pre_turn_target
async def _revert_turns_for_regenerate(
*,
thread_id: int,
chat_turn_ids: list[str],
requester_user_id: str,
) -> dict:
"""Best-effort revert pass for every ``chat_turn_id`` in ``chat_turn_ids``.
Runs BEFORE the regenerate stream so the frontend can surface
partial-rollback feedback alongside the new assistant turn. Each
turn's actions are reverted in their own SAVEPOINTs (handled
inside :mod:`app.routes.agent_revert_route`'s helpers) so a single
failure never poisons the batch.
Sequencing inside the request: revert THEN regenerate. The
operation is NOT atomic and partial state IS surfaced see the
plan's "Sequencing inside the request" note.
"""
from app.routes.agent_revert_route import (
RevertTurnActionResult,
_classify_outcome,
_OutcomeRollbackError,
_was_already_reverted,
_was_already_reverted_batch,
)
from app.services.revert_service import (
can_revert,
revert_action,
)
aggregated_results: list[dict] = []
# Exhaustive counters keep the response invariant
# ``total == sum(counters)`` true for ``data-revert-results``.
counts = {
"reverted": 0,
"already_reverted": 0,
"not_reversible": 0,
"permission_denied": 0,
"failed": 0,
"skipped": 0,
}
# Local import keeps the route module's existing imports tidy and
# avoids a circular dependency at module-load time.
from app.db import AgentActionLog as _AgentActionLog
async with shielded_async_session() as session:
for chat_turn_id in chat_turn_ids:
rows_stmt = (
select(_AgentActionLog)
.where(
_AgentActionLog.thread_id == thread_id,
_AgentActionLog.chat_turn_id == chat_turn_id,
)
.order_by(
_AgentActionLog.created_at.desc(),
_AgentActionLog.id.desc(),
)
)
rows = (await session.execute(rows_stmt)).scalars().all()
# Batch idempotency probe across the turn (single SELECT
# instead of one per row).
eligible_ids = [r.id for r in rows if r.reverse_of is None]
already_reverted_map = await _was_already_reverted_batch(
session, action_ids=eligible_ids
)
for action in rows:
if action.reverse_of is not None:
counts["skipped"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="skipped",
message="Row is itself a revert action; skipped.",
).model_dump()
)
continue
existing_revert_id = already_reverted_map.get(action.id)
if existing_revert_id is not None:
counts["already_reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
).model_dump()
)
continue
if not can_revert(
requester_user_id=requester_user_id,
action=action,
is_admin=False,
):
counts["permission_denied"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="permission_denied",
message="You are not allowed to revert this action.",
).model_dump()
)
continue
try:
async with session.begin_nested():
outcome = await revert_action(
session,
action=action,
requester_user_id=requester_user_id,
)
if outcome.status != "ok":
raise _OutcomeRollbackError(outcome)
except _OutcomeRollbackError as rollback:
outcome = rollback.outcome
classified = _classify_outcome(outcome)
if classified == "permission_denied":
counts["permission_denied"] += 1
else:
counts["not_reversible"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status=classified,
message=outcome.message,
).model_dump()
)
continue
except IntegrityError:
# Concurrent revert won the race against the
# pre-flight ``_was_already_reverted`` SELECT.
# Surface the winning revert id so the client can
# treat this as a successful idempotent op.
existing_revert_id = await _was_already_reverted(
session, action_id=action.id
)
counts["already_reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="already_reverted",
new_action_id=existing_revert_id,
).model_dump()
)
continue
except Exception as err: # pragma: no cover — defensive
_logger.exception(
"Unexpected revert failure during regenerate batch "
"for action_id=%s",
action.id,
)
counts["failed"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="failed",
error=str(err) or err.__class__.__name__,
).model_dump()
)
continue
counts["reverted"] += 1
aggregated_results.append(
RevertTurnActionResult(
action_id=action.id,
tool_name=action.tool_name,
status="reverted",
message=outcome.message,
new_action_id=outcome.new_action_id,
).model_dump()
)
try:
await session.commit()
except Exception:
_logger.exception(
"[regenerate-revert] Final commit failed; rolling back batch."
)
await session.rollback()
has_partial = (
counts["failed"] > 0
or counts["not_reversible"] > 0
or counts["permission_denied"] > 0
)
return {
"status": "partial" if has_partial else "ok",
"chat_turn_ids": chat_turn_ids,
"total": len(aggregated_results),
"reverted": counts["reverted"],
"already_reverted": counts["already_reverted"],
"not_reversible": counts["not_reversible"],
"permission_denied": counts["permission_denied"],
"failed": counts["failed"],
"skipped": counts["skipped"],
"results": aggregated_results,
}
def _try_delete_sandbox(thread_id: int) -> None:
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
from app.agents.new_chat.sandbox import (
delete_local_sandbox_files,
delete_sandbox,
is_sandbox_enabled,
)
if not is_sandbox_enabled():
return
async def _bg() -> None:
try:
await delete_sandbox(thread_id)
except Exception:
_logger.warning(
"Background sandbox delete failed for thread %s",
thread_id,
exc_info=True,
)
try:
delete_local_sandbox_files(thread_id)
except Exception:
_logger.warning(
"Local sandbox file cleanup failed for thread %s",
thread_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(_bg())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
pass
2026-01-13 00:17:12 -08:00
async def check_thread_access(
session: AsyncSession,
2026-01-13 00:17:12 -08:00
thread: NewChatThread,
user: User,
require_ownership: bool = False,
) -> bool:
"""
Check if a user has access to a thread based on visibility rules.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE (any member can access) - for read/update operations only
- Thread is a legacy thread (created_by_id is NULL) - only if user is search space owner
2026-01-13 00:17:12 -08:00
Args:
session: Database session
2026-01-13 00:17:12 -08:00
thread: The thread to check access for
user: The user requesting access
require_ownership: If True, ONLY the creator can perform this action (e.g., changing visibility).
This is checked FIRST, before visibility rules.
2026-01-13 00:17:12 -08:00
Returns:
True if access is granted
Raises:
HTTPException: If access is denied
"""
is_owner = thread.created_by_id == user.id
is_legacy = thread.created_by_id is None
# If ownership is required (e.g., changing visibility), ONLY the creator can do it
# This check comes first to ensure ownership-required operations are always creator-only
if require_ownership:
if not is_owner:
raise HTTPException(
status_code=403,
detail="Only the creator of this chat can perform this action",
)
return True
# Shared threads (SEARCH_SPACE) are accessible by any member for read/update operations
if thread.visibility == ChatVisibility.SEARCH_SPACE:
2026-01-13 00:17:12 -08:00
return True
# For legacy threads (created before visibility feature),
# only the search space owner can access
if is_legacy:
search_space_query = select(SearchSpace).filter(
SearchSpace.id == thread.search_space_id
)
search_space_result = await session.execute(search_space_query)
search_space = search_space_result.scalar_one_or_none()
is_search_space_owner = search_space and search_space.user_id == user.id
if is_search_space_owner:
return True
# Legacy threads are not accessible to non-owners
raise HTTPException(
status_code=403,
detail="You don't have access to this chat",
)
# For read access: owner can access their own private threads
2026-01-13 00:17:12 -08:00
if is_owner:
return True
# Private thread and user is not the owner
raise HTTPException(
status_code=403,
detail="You don't have access to this private chat",
)
# =============================================================================
# Thread Endpoints
# =============================================================================
@router.get("/threads", response_model=ThreadListResponse)
async def list_threads(
search_space_id: int,
limit: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
2026-01-13 00:17:12 -08:00
List all accessible threads for the current user in a search space.
Returns threads and archived_threads for ThreadListPrimitive.
2026-01-13 00:17:12 -08:00
A user can see threads that are:
- Created by them (regardless of visibility)
- Shared with the search space (visibility = SEARCH_SPACE)
- Legacy threads with no creator (created_by_id is NULL) - only if user is search space owner
2026-01-13 00:17:12 -08:00
Args:
search_space_id: The search space to list threads for
limit: Optional limit on number of threads to return (applies to active threads only)
Requires CHATS_READ permission.
"""
try:
await check_permission(
session,
user,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check if user is the search space owner (for legacy thread visibility)
search_space_query = select(SearchSpace).filter(
SearchSpace.id == search_space_id
)
search_space_result = await session.execute(search_space_query)
search_space = search_space_result.scalar_one_or_none()
is_search_space_owner = search_space and search_space.user_id == user.id
# Build filter conditions:
2026-01-13 00:17:12 -08:00
# 1. Created by the current user (any visibility)
# 2. Shared with the search space (visibility = SEARCH_SPACE)
# 3. Legacy threads (created_by_id is NULL) - only visible to search space owner
filter_conditions = [
NewChatThread.created_by_id == user.id,
NewChatThread.visibility == ChatVisibility.SEARCH_SPACE,
]
# Only include legacy threads for the search space owner
if is_search_space_owner:
filter_conditions.append(NewChatThread.created_by_id.is_(None))
query = (
select(NewChatThread)
2026-01-13 00:17:12 -08:00
.filter(
NewChatThread.search_space_id == search_space_id,
or_(*filter_conditions),
2026-01-13 00:17:12 -08:00
)
.order_by(NewChatThread.updated_at.desc())
)
result = await session.execute(query)
all_threads = result.scalars().all()
# Separate active and archived threads
threads = []
archived_threads = []
for thread in all_threads:
# Legacy threads (no creator) are treated as own threads for owner
is_own_thread = thread.created_by_id == user.id or (
thread.created_by_id is None and is_search_space_owner
2026-01-13 00:17:12 -08:00
)
item = ThreadListItem(
id=thread.id,
title=thread.title,
archived=thread.archived,
2026-01-13 00:17:12 -08:00
visibility=thread.visibility,
created_by_id=thread.created_by_id,
is_own_thread=is_own_thread,
created_at=thread.created_at,
updated_at=thread.updated_at,
)
if thread.archived:
archived_threads.append(item)
else:
threads.append(item)
# Apply limit to active threads if specified
if limit is not None and limit > 0:
threads = threads[:limit]
return ThreadListResponse(threads=threads, archived_threads=archived_threads)
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching threads: {e!s}",
) from None
@router.get("/threads/search", response_model=list[ThreadListItem])
async def search_threads(
search_space_id: int,
title: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
2026-01-13 00:17:12 -08:00
Search accessible threads by title in a search space.
A user can search threads that are:
- Created by them (regardless of visibility)
- Shared with the search space (visibility = SEARCH_SPACE)
- Legacy threads with no creator (created_by_id is NULL) - only if user is search space owner
Args:
search_space_id: The search space to search in
title: The search query (case-insensitive partial match)
Requires CHATS_READ permission.
"""
try:
await check_permission(
session,
user,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check if user is the search space owner (for legacy thread visibility)
search_space_query = select(SearchSpace).filter(
SearchSpace.id == search_space_id
)
search_space_result = await session.execute(search_space_query)
search_space = search_space_result.scalar_one_or_none()
is_search_space_owner = search_space and search_space.user_id == user.id
# Build filter conditions
filter_conditions = [
NewChatThread.created_by_id == user.id,
NewChatThread.visibility == ChatVisibility.SEARCH_SPACE,
]
# Only include legacy threads for the search space owner
if is_search_space_owner:
filter_conditions.append(NewChatThread.created_by_id.is_(None))
2026-01-13 00:17:12 -08:00
# Search accessible threads by title (case-insensitive)
query = (
select(NewChatThread)
.filter(
NewChatThread.search_space_id == search_space_id,
NewChatThread.title.ilike(f"%{title}%"),
or_(*filter_conditions),
)
.order_by(NewChatThread.updated_at.desc())
)
result = await session.execute(query)
threads = result.scalars().all()
return [
ThreadListItem(
id=thread.id,
title=thread.title,
archived=thread.archived,
2026-01-13 00:17:12 -08:00
visibility=thread.visibility,
created_by_id=thread.created_by_id,
# Legacy threads (no creator) are treated as own threads for owner
2026-01-13 00:17:12 -08:00
is_own_thread=(
thread.created_by_id == user.id
or (thread.created_by_id is None and is_search_space_owner)
2026-01-13 00:17:12 -08:00
),
created_at=thread.created_at,
updated_at=thread.updated_at,
)
for thread in threads
]
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while searching threads: {e!s}",
) from None
@router.post("/threads", response_model=NewChatThreadRead)
async def create_thread(
thread: NewChatThreadCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a new chat thread.
2026-01-13 00:17:12 -08:00
The thread is created with the specified visibility (defaults to PRIVATE).
The current user is recorded as the creator of the thread.
Requires CHATS_CREATE permission.
"""
try:
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to create chats in this search space",
)
now = datetime.now(UTC)
db_thread = NewChatThread(
title=thread.title,
archived=thread.archived,
2026-01-13 00:17:12 -08:00
visibility=thread.visibility,
search_space_id=thread.search_space_id,
2026-01-13 00:17:12 -08:00
created_by_id=user.id,
updated_at=now,
)
session.add(db_thread)
await session.commit()
await session.refresh(db_thread)
return db_thread
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while creating the thread: {e!s}",
) from None
@router.get("/threads/{thread_id}", response_model=ThreadHistoryLoadResponse)
async def get_thread_messages(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get a thread with all its messages.
This is used by ThreadHistoryAdapter.load() to restore conversation.
2026-01-13 00:17:12 -08:00
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_READ permission.
"""
try:
# Get thread first
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
2026-01-13 00:17:12 -08:00
# Check permission to read chats in this search space
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
2026-01-13 00:17:12 -08:00
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
2026-01-13 00:17:12 -08:00
# Get messages with their authors and token usage loaded
messages_result = await session.execute(
select(NewChatMessage)
.options(
selectinload(NewChatMessage.author),
selectinload(NewChatMessage.token_usage),
)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
)
db_messages = messages_result.scalars().all()
# Return messages in the format expected by assistant-ui
messages = [
NewChatMessageRead(
id=msg.id,
thread_id=msg.thread_id,
role=msg.role,
content=msg.content,
created_at=msg.created_at,
author_id=msg.author_id,
author_display_name=msg.author.display_name if msg.author else None,
author_avatar_url=msg.author.avatar_url if msg.author else None,
2026-04-14 21:52:26 +05:30
token_usage=TokenUsageSummary.model_validate(msg.token_usage)
if msg.token_usage
else None,
2026-04-29 07:20:31 -07:00
turn_id=msg.turn_id,
)
for msg in db_messages
]
return ThreadHistoryLoadResponse(messages=messages)
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
) from None
@router.get("/threads/{thread_id}/full", response_model=NewChatThreadWithMessages)
async def get_thread_full(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get full thread details with all messages.
2026-01-13 00:17:12 -08:00
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_READ permission.
"""
try:
result = await session.execute(
select(NewChatThread)
.options(
selectinload(NewChatThread.messages).selectinload(
NewChatMessage.token_usage
),
)
.filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
2026-01-13 00:17:12 -08:00
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
2026-01-13 00:17:12 -08:00
# Check if thread has any comments
comment_count = await session.scalar(
select(func.count())
.select_from(ChatComment)
.join(NewChatMessage, ChatComment.message_id == NewChatMessage.id)
.where(NewChatMessage.thread_id == thread.id)
)
return {
**thread.__dict__,
"messages": thread.messages,
"has_comments": (comment_count or 0) > 0,
}
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
) from None
@router.put("/threads/{thread_id}", response_model=NewChatThreadRead)
async def update_thread(
thread_id: int,
thread_update: NewChatThreadUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update a thread (title, archived status).
Used for renaming and archiving threads.
2026-01-13 00:17:12 -08:00
- PRIVATE threads: Only the creator can update
- SEARCH_SPACE threads: Any member with CHATS_UPDATE permission can update
Requires CHATS_UPDATE permission.
"""
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
db_thread = result.scalars().first()
if not db_thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
db_thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
2026-01-13 00:17:12 -08:00
# For PRIVATE threads, only the creator can update
# For SEARCH_SPACE threads, any member with permission can update
if db_thread.visibility == ChatVisibility.PRIVATE:
await check_thread_access(session, db_thread, user, require_ownership=True)
2026-01-13 00:17:12 -08:00
# Update fields
update_data = thread_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_thread, key, value)
db_thread.updated_at = datetime.now(UTC)
await session.commit()
await session.refresh(db_thread)
return db_thread
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while updating the thread: {e!s}",
) from None
@router.delete("/threads/{thread_id}", response_model=dict)
async def delete_thread(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete a thread and all its messages.
2026-01-13 00:17:12 -08:00
- PRIVATE threads: Only the creator can delete
- SEARCH_SPACE threads: Any member with CHATS_DELETE permission can delete
Requires CHATS_DELETE permission.
"""
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
db_thread = result.scalars().first()
if not db_thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
db_thread.search_space_id,
Permission.CHATS_DELETE.value,
"You don't have permission to delete chats in this search space",
)
2026-01-13 00:17:12 -08:00
# For PRIVATE threads, only the creator can delete
# For SEARCH_SPACE threads, any member with permission can delete
2026-02-15 23:36:20 -08:00
# Legacy threads (created_by_id is NULL) have no recorded creator,
# so we skip strict ownership and fall through to legacy handling
# which allows the search space owner to delete them
2026-01-13 00:17:12 -08:00
if db_thread.visibility == ChatVisibility.PRIVATE:
2026-02-15 23:36:20 -08:00
await check_thread_access(
session,
db_thread,
user,
require_ownership=(db_thread.created_by_id is not None),
)
2026-01-13 00:17:12 -08:00
await session.delete(db_thread)
await session.commit()
_try_delete_sandbox(thread_id)
return {"message": "Thread deleted successfully"}
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400, detail="Cannot delete thread due to existing dependencies."
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while deleting the thread: {e!s}",
) from None
2026-01-13 00:17:12 -08:00
@router.patch("/threads/{thread_id}/visibility", response_model=NewChatThreadRead)
async def update_thread_visibility(
thread_id: int,
visibility_update: NewChatThreadVisibilityUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update the visibility/sharing settings of a thread.
Only the creator of the thread can change its visibility.
- PRIVATE: Only the creator can access the thread (default)
- SEARCH_SPACE: All members of the search space can access the thread
Requires CHATS_UPDATE permission.
"""
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
db_thread = result.scalars().first()
if not db_thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
db_thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# Only the creator can change visibility
await check_thread_access(session, db_thread, user, require_ownership=True)
2026-01-13 00:17:12 -08:00
# Update visibility
db_thread.visibility = visibility_update.visibility
db_thread.updated_at = datetime.now(UTC)
await session.commit()
await session.refresh(db_thread)
return db_thread
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while updating thread visibility: {e!s}",
) from None
# =============================================================================
# Snapshot Endpoints
# =============================================================================
2026-02-03 13:32:11 -08:00
@router.post(
"/threads/{thread_id}/snapshots", response_model=PublicChatSnapshotCreateResponse
)
async def create_thread_snapshot(
2026-01-26 13:22:47 +02:00
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a public snapshot of the thread.
2026-01-26 13:22:47 +02:00
Returns existing snapshot URL if content unchanged (deduplication).
2026-01-26 13:22:47 +02:00
"""
from app.services.public_chat_service import create_snapshot
return await create_snapshot(
2026-01-26 13:22:47 +02:00
session=session,
thread_id=thread_id,
user=user,
)
2026-02-03 13:32:11 -08:00
@router.get(
"/threads/{thread_id}/snapshots", response_model=PublicChatSnapshotListResponse
)
async def list_thread_snapshots(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List all public snapshots for this thread.
Only the thread owner can view snapshots.
"""
from app.services.public_chat_service import list_snapshots_for_thread
return PublicChatSnapshotListResponse(
snapshots=await list_snapshots_for_thread(
session=session,
thread_id=thread_id,
user=user,
)
)
@router.delete("/threads/{thread_id}/snapshots/{snapshot_id}")
async def delete_thread_snapshot(
thread_id: int,
snapshot_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete a specific snapshot.
Only the thread owner can delete snapshots.
"""
from app.services.public_chat_service import delete_snapshot
await delete_snapshot(
session=session,
thread_id=thread_id,
snapshot_id=snapshot_id,
user=user,
)
return {"message": "Snapshot deleted successfully"}
# =============================================================================
# Message Endpoints
# =============================================================================
@router.post("/threads/{thread_id}/messages", response_model=NewChatMessageRead)
async def append_message(
thread_id: int,
request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
.. deprecated:: 2026-05
Replaced by the **SSE-based message ID handshake**. The streaming
generator (`stream_new_chat` / `stream_resume_chat`) now persists
both the user and assistant rows server-side via
``persist_user_turn`` / ``persist_assistant_shell`` and emits
``data-user-message-id`` / ``data-assistant-message-id`` SSE events
so the frontend can rename its optimistic IDs in real time. The
new FE bundle no longer calls this route.
This handler is retained as a **silent no-op for legacy / cached
FE bundles**: the underlying ``INSERT ... ON CONFLICT DO NOTHING``
pattern means a stale bundle hitting this route after the SSE
handshake already wrote the row simply returns the existing row
(200 OK) without raising or duplicating data. After a 2-week soak
(target: ``[persist_user_turn] outcome=race_recovered`` rate ~0)
this entire route and the FE ``appendMessage`` function is
earmarked for removal.
Append a message to a thread.
This is used by ThreadHistoryAdapter.append() to persist messages.
2026-01-13 00:17:12 -08:00
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_UPDATE permission.
"""
try:
# Capture ``user.id`` as a primitive UUID up front. The
# ``current_active_user`` dependency hands us a ``User`` ORM
# row bound to ``session``; if the outer ``except
# IntegrityError`` block below ever fires (an unexpected
# constraint like a foreign key violation — the common
# ``(thread_id, turn_id, role)`` race is now handled silently
# by ``ON CONFLICT DO NOTHING`` so it never raises) it calls
# ``session.rollback()``, which expires every attached ORM
# row including this user. Any later ``user.id`` access would
# then trigger a lazy PK reload — which on async SQLAlchemy
# fails with ``MissingGreenlet`` because the reload happens
# outside the awaitable greenlet boundary. Reading ``id``
# once here pins the value as a plain UUID so all downstream
# uses (TokenUsage insert, response build) are immune.
user_uuid = user.id
# Parse raw body - extract only role and content, ignoring extra fields
raw_body = await request.json()
role = raw_body.get("role")
content = raw_body.get("content")
if not role:
raise HTTPException(status_code=400, detail="Missing required field: role")
if content is None:
raise HTTPException(
status_code=400, detail="Missing required field: content"
)
# Validate role early (before any DB work)
role_str = role.lower() if isinstance(role, str) else role
try:
message_role = NewChatMessageRole(role_str)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid role: {role}. Must be 'user', 'assistant', or 'system'.",
) from None
# Get thread
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
2026-01-13 00:17:12 -08:00
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
2026-01-13 00:17:12 -08:00
2026-04-29 07:20:31 -07:00
# Create message. ``turn_id`` is the per-turn correlation id from
# ``configurable.turn_id`` (added in migration 136) — when the
# client streams it back to ``appendMessage``, we persist it so
# C1's edit-from-arbitrary-position can later map this message
# back to the LangGraph checkpoint that produced its turn.
raw_turn_id = raw_body.get("turn_id")
turn_id_value = (
str(raw_turn_id).strip()
if isinstance(raw_turn_id, str) and raw_turn_id.strip()
else None
)
# Update thread's updated_at timestamp (always — both insert
# and recovery paths represent thread activity).
thread.updated_at = datetime.now(UTC)
# Insert the new message via ``INSERT ... ON CONFLICT DO NOTHING``
# keyed on the ``(thread_id, turn_id, role)`` partial unique
# index from migration 141 (``WHERE turn_id IS NOT NULL``).
#
# Why ON CONFLICT instead of ``session.add() + flush() + except
# IntegrityError``:
# 1. The conflict between this legacy FE ``appendMessage``
# round-trip and the server-side
# ``finalize_assistant_turn`` writer is a NORMAL,
# *expected* race — every assistant turn fires it. Using
# catch-and-recover means asyncpg raises
# ``UniqueViolationError`` -> SQLAlchemy wraps it as
# ``IntegrityError`` -> our handler catches and recovers.
# Functionally fine, but every ``raise`` event lights up
# VS Code's debugger (debugpy's ``justMyCode=false`` mode
# loses track of the catch frame across SQLAlchemy's
# async greenlet boundary, so even ``Raised Exceptions``
# being unchecked doesn't reliably suppress the pause).
# ON CONFLICT pushes the conflict resolution into Postgres
# where no Python exception is constructed at all.
# 2. No ``session.rollback()`` -> no expiring of attached
# ORM rows -> no risk of ``MissingGreenlet`` from
# lazy-loading expired user/thread state later in the
# handler.
# 3. Cleaner production logs (no SQLAlchemy ``IntegrityError``
# tracebacks emitted by uvicorn's logger between the
# ``raise`` and our ``except``).
#
# When ``turn_id_value`` is ``None`` the partial index doesn't
# apply and the INSERT proceeds normally. Other constraint
# violations (FK, NOT NULL, etc.) still raise ``IntegrityError``
# and are caught by the outer ``except IntegrityError`` block
# to preserve the legacy 400 behavior.
#
# Note on ``content``: when we recover the existing row, we
# intentionally discard the FE's ``content`` payload from
# ``raw_body`` and return the row's existing ``content``. The
# streaming task is now the *authoritative writer* for
# assistant ``ContentPart[]`` shape (mid-stream
# ``AssistantContentBuilder`` -> ``finalize_assistant_turn``)
# so the FE's later ``appendMessage`` is just a stale snapshot
# of the same data — keeping the server-built rich content
# (with full tool-call args / argsText / langchainToolCallId)
# is correct, not lossy.
insert_stmt = (
pg_insert(NewChatMessage)
.values(
thread_id=thread_id,
role=message_role,
content=content,
author_id=user_uuid,
turn_id=turn_id_value,
)
.on_conflict_do_nothing(
index_elements=["thread_id", "turn_id", "role"],
index_where=sa_text("turn_id IS NOT NULL"),
)
.returning(NewChatMessage.id)
)
inserted_id = (await session.execute(insert_stmt)).scalar()
if inserted_id is None:
# Conflict on partial unique index — server-side stream
# already wrote this row. Look it up and reuse it.
if turn_id_value is None:
# Defensive: ON CONFLICT only fires for ``turn_id IS
# NOT NULL`` rows, so this branch should be
# unreachable. Preserve the legacy 400 just in case
# Postgres ever surprises us.
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
lookup = await session.execute(
select(NewChatMessage).filter(
NewChatMessage.thread_id == thread_id,
NewChatMessage.turn_id == turn_id_value,
NewChatMessage.role == message_role,
)
)
existing_message = lookup.scalars().first()
if existing_message is None:
# Conflict reported but the row vanished between
# INSERT and SELECT — extremely unlikely (would
# require a concurrent DELETE within the same
# transaction visibility), but preserve safe
# behavior.
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
db_message = existing_message
# Perf signal: counts how often the legacy FE round-trip
# races the server-side ``finalize_assistant_turn``. A
# rising rate after the rework is OK (it's exactly the
# ghost-thread fix's recovery path firing); a sudden drop
# to zero would mean the FE isn't posting appendMessage
# at all (different bug).
_perf_log.info(
"[append_message] outcome=recovered_via_unique_index "
"thread_id=%s turn_id=%s role=%s message_id=%s",
thread_id,
turn_id_value,
message_role.value,
db_message.id,
)
else:
# INSERT succeeded — load the full ORM row so the
# response can include server-side-defaulted columns
# (``created_at``, etc.) and the relationship surface
# stays consistent with the recovery path.
inserted_row = await session.get(NewChatMessage, inserted_id)
if inserted_row is None:
# Should be impossible: we just inserted it in this
# same transaction. Fail loud if it happens.
raise HTTPException(
status_code=500,
detail="Inserted message could not be loaded.",
) from None
db_message = inserted_row
# Persist token usage if provided (for assistant messages).
# ``cost_micros`` is the provider USD cost reported by LiteLLM,
# forwarded by the FE through the appendMessage round-trip so
# the historical TokenUsage row matches the credit debit applied
# at finalize time.
#
# De-dup: ``finalize_assistant_turn`` may also race to write a
# token_usage row for this same ``message_id`` (cross-session,
# cross-shielded). Use ``INSERT ... ON CONFLICT DO NOTHING`` keyed
# on the ``uq_token_usage_message_id`` partial unique index
# (migration 142). The loser silently drops its insert; exactly
# one row results regardless of which writer commits first.
token_usage_data = raw_body.get("token_usage")
if token_usage_data and message_role == NewChatMessageRole.ASSISTANT:
insert_stmt = (
pg_insert(TokenUsage)
.values(
usage_type="chat",
prompt_tokens=token_usage_data.get("prompt_tokens", 0),
completion_tokens=token_usage_data.get("completion_tokens", 0),
total_tokens=token_usage_data.get("total_tokens", 0),
cost_micros=token_usage_data.get("cost_micros", 0),
model_breakdown=token_usage_data.get("usage"),
call_details=token_usage_data.get("call_details"),
thread_id=thread_id,
message_id=db_message.id,
search_space_id=thread.search_space_id,
user_id=user_uuid,
)
.on_conflict_do_nothing(
index_elements=["message_id"],
index_where=sa_text("message_id IS NOT NULL"),
)
)
await session.execute(insert_stmt)
await session.commit()
# Build response manually to avoid lazy-loading the token_usage
# relationship after commit (which would trigger MissingGreenlet).
return NewChatMessageRead(
id=db_message.id,
thread_id=db_message.thread_id,
role=db_message.role,
content=db_message.content,
created_at=db_message.created_at,
author_id=db_message.author_id,
token_usage=None,
2026-04-29 07:20:31 -07:00
turn_id=db_message.turn_id,
)
except HTTPException:
raise
except IntegrityError:
# Any IntegrityError that escaped the inline handler above
# comes from a *different* constraint (foreign key, etc.) —
# preserve the legacy 400 path.
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while appending the message: {e!s}",
) from None
@router.get("/threads/{thread_id}/messages", response_model=list[NewChatMessageRead])
async def list_messages(
thread_id: int,
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List messages in a thread with pagination.
2026-01-13 00:17:12 -08:00
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_READ permission.
"""
try:
# Verify thread exists and user has access
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
2026-01-13 00:17:12 -08:00
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
2026-01-13 00:17:12 -08:00
# Get messages
query = (
select(NewChatMessage)
.options(selectinload(NewChatMessage.token_usage))
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
.offset(skip)
.limit(limit)
)
result = await session.execute(query)
return result.scalars().all()
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching messages: {e!s}",
) from None
2025-12-21 19:33:52 -08:00
# =============================================================================
# Agent Tools Endpoint
# =============================================================================
@router.get("/agent/tools", response_model=list[AgentToolInfo])
async def list_agent_tools(
_user: User = Depends(current_active_user),
):
"""Return the list of built-in agent tools with their metadata.
Hidden (WIP) tools are excluded from the response.
"""
from app.agents.new_chat.tools.registry import BUILTIN_TOOLS
return [
AgentToolInfo(
name=t.name,
description=t.description,
enabled_by_default=t.enabled_by_default,
)
for t in BUILTIN_TOOLS
if not t.hidden
]
2025-12-21 19:33:52 -08:00
# =============================================================================
# Chat Streaming Endpoint
# =============================================================================
@router.post("/new_chat")
async def handle_new_chat(
request: NewChatRequest,
http_request: Request,
2025-12-21 19:33:52 -08:00
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Stream chat responses from the deep agent.
2025-12-21 22:26:33 -08:00
2025-12-21 19:33:52 -08:00
This endpoint handles the new chat functionality with streaming responses
using Server-Sent Events (SSE) format compatible with Vercel AI SDK.
2025-12-21 22:26:33 -08:00
2026-01-13 00:17:12 -08:00
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
2025-12-21 19:33:52 -08:00
Requires CHATS_CREATE permission.
"""
try:
# Verify thread exists and user has permission
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == request.chat_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to chat in this search space",
)
2026-01-13 00:17:12 -08:00
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(request.chat_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
2026-01-13 00:17:12 -08:00
2025-12-21 19:33:52 -08:00
# Get search space to check LLM config preferences
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
)
search_space = search_space_result.scalars().first()
2025-12-23 01:16:25 -08:00
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Use agent_llm_id from search space for chat operations
# Positive IDs load from NewLLMConfig database table
# Negative IDs load from YAML global configs
# Falls back to -1 (first global config) if not configured
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
2025-12-21 19:33:52 -08:00
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs usable.
await session.commit()
# Close the dependency session now so its connection returns to
# the pool before streaming begins. Without this, Starlette's
# BaseHTTPMiddleware cancels the scope on client disconnect and
# the dependency generator's __aexit__ never runs, orphaning the
# connection (the "Exception terminating connection" errors).
await session.close()
image_urls = (
2026-04-28 21:37:51 -07:00
[p.as_data_url() for p in request.user_images]
if request.user_images
else None
)
mentioned_documents_payload = (
[doc.model_dump() for doc in request.mentioned_documents]
if request.mentioned_documents
else None
)
2025-12-21 19:33:52 -08:00
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
search_space_id=request.search_space_id,
chat_id=request.chat_id,
user_id=str(user.id),
2025-12-21 19:33:52 -08:00
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
mentioned_documents=mentioned_documents_payload,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
disabled_tools=request.disabled_tools,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=image_urls,
2025-12-21 19:33:52 -08:00
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred: {e!s}",
) from None
2025-12-21 22:26:33 -08:00
@router.post(
"/threads/{thread_id}/cancel-active-turn",
response_model=CancelActiveTurnResponse,
)
async def cancel_active_turn(
thread_id: int,
response: Response,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""Signal cancellation for the currently running turn on ``thread_id``."""
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
if status_payload["status"] == "idle":
return CancelActiveTurnResponse(
status="idle",
error_code="NO_ACTIVE_TURN",
)
request_cancel(str(thread_id))
response.status_code = 202
updated_payload = _build_turn_status_payload(thread_id)
retry_after_ms = int(updated_payload.get("retry_after_ms") or 0)
retry_after_at = (
int(updated_payload["retry_after_at"])
if "retry_after_at" in updated_payload
else None
)
if retry_after_ms > 0:
_set_retry_after_headers(response, retry_after_ms)
return CancelActiveTurnResponse(
status="cancelling",
error_code="TURN_CANCELLING",
retry_after_ms=retry_after_ms if retry_after_ms > 0 else None,
retry_after_at=retry_after_at,
)
@router.get(
"/threads/{thread_id}/turn-status",
response_model=TurnStatusResponse,
)
async def get_turn_status(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to view chats in this search space",
)
await check_thread_access(session, thread, user)
status_payload = _build_turn_status_payload(thread_id)
return TurnStatusResponse(
status=status_payload["status"], # type: ignore[arg-type]
active_turn_id=None,
retry_after_ms=status_payload.get("retry_after_ms"), # type: ignore[arg-type]
retry_after_at=status_payload.get("retry_after_at"), # type: ignore[arg-type]
)
# =============================================================================
# Chat Regeneration Endpoint (Edit/Reload)
# =============================================================================
@router.post("/threads/{thread_id}/regenerate")
async def regenerate_response(
thread_id: int,
request: RegenerateRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Regenerate the AI response for a chat thread.
This endpoint supports two operations:
1. **Edit**: Provide a new `user_query` to replace the last user message and regenerate
2. **Reload**: Leave `user_query` empty (or None) to regenerate with the same query
Both operations:
- Rewind the LangGraph checkpointer to the state before the last AI response
- Delete the last user message and AI response from the database
- Stream a new response from that checkpoint
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_UPDATE permission.
"""
from langchain_core.messages import HumanMessage
from app.agents.new_chat.checkpointer import get_checkpointer
try:
# Verify thread exists and user has permission
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
# Get the checkpointer and state history
checkpointer = await get_checkpointer()
config = {"configurable": {"thread_id": str(thread_id)}}
# Collect checkpoint tuples from the async iterator
# CheckpointTuple has: config, checkpoint (dict with channel_values), metadata, parent_config
checkpoint_tuples = []
async for cp_tuple in checkpointer.alist(config):
checkpoint_tuples.append(cp_tuple)
if not checkpoint_tuples:
raise HTTPException(
status_code=400, detail="No conversation history found for this thread"
)
# Find the checkpoint to rewind to
# Checkpoints are in reverse chronological order (newest first)
# We need to find a checkpoint before the last user message was added
#
# The checkpointer stores states after each node execution.
# For a typical conversation flow:
# - User sends message -> state 1 (with HumanMessage)
# - Agent responds -> state 2 (with HumanMessage + AIMessage)
#
# To regenerate, we need the state BEFORE the last HumanMessage was processed
target_checkpoint_id = None
user_query_to_use = request.user_query
regenerate_image_urls: list[str] = []
2026-04-29 07:20:31 -07:00
# ---------------------------------------------------------------
# Edit-from-arbitrary-position. When the client passes
# ``from_message_id`` we look up its persisted ``turn_id`` (added
# in migration 136) and pick the checkpoint immediately before
# that turn started.
#
# Legacy graceful-degradation contract:
# * Rows persisted BEFORE migration 136 have ``turn_id IS NULL``.
# Returning 400 in that case is the wrong UX — the user is
# editing an old message in an existing thread and just wants
# it to work. We instead skip the checkpoint rewind (the
# stream falls back to the latest state) and skip the revert
# pass (no chat_turn_id available to walk). Deletion still
# uses ``created_at``, so the messages-after-cursor slice is
# correct on both legacy and post-136 rows.
# ---------------------------------------------------------------
from_message_turn_id: str | None = None
from_message_created_at: datetime | None = None
legacy_from_message: bool = False
if request.from_message_id is not None:
from_msg_row = await session.execute(
select(NewChatMessage).filter(
NewChatMessage.id == request.from_message_id,
NewChatMessage.thread_id == thread_id,
)
)
from_msg = from_msg_row.scalars().first()
if from_msg is None:
raise HTTPException(
status_code=404,
detail="from_message_id not found in this thread.",
)
from_message_created_at = from_msg.created_at
if not from_msg.turn_id:
# Legacy row — surface the degradation in logs but let
# the request proceed with the slice-based delete and a
# cold-start checkpoint.
legacy_from_message = True
_logger.warning(
"[regenerate] from_message_id=%s on thread=%s has no "
"turn_id (legacy row pre-migration-136). Falling back "
"to slice-based delete without checkpoint rewind. "
"revert_actions=%s will be ignored.",
request.from_message_id,
thread_id,
request.revert_actions,
)
else:
from_message_turn_id = from_msg.turn_id
# Walk oldest-to-newest and pick the LAST checkpoint whose
# ``turn_id`` differs from the edited turn — that's the state
# immediately before this turn started running. We read from
# ``metadata`` (the durable surface) rather than
# ``config["configurable"]`` so the lookup works across
# checkpointer implementations.
target_checkpoint_id = _find_pre_turn_checkpoint_id(
checkpoint_tuples,
turn_id=from_message_turn_id,
)
if target_checkpoint_id is None and len(checkpoint_tuples) > 0:
# Fall back to the oldest checkpoint — better than
# 400ing when the agent didn't checkpoint pre-turn
# (e.g. very first turn of the thread).
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
"checkpoint_id"
]
# Look through checkpoints to find the right one
2026-04-29 07:20:31 -07:00
# We want to find the checkpoint just before the last HumanMessage.
# We enter this branch when:
# * the client did NOT pin ``from_message_id`` (legacy reload/edit), OR
# * the client pinned ``from_message_id`` but the row is a
# legacy pre-migration-136 row with no ``turn_id`` (we
# downgraded to the same heuristic as a regular reload).
# We DO skip it when a real turn_id pinned ``target_checkpoint_id``
# — that's the C1 happy path and the heuristic below would just
# re-derive a worse target.
if request.from_message_id is None or legacy_from_message:
for i, cp_tuple in enumerate(checkpoint_tuples):
# Access the checkpoint's channel_values which contains "messages"
checkpoint_data = cp_tuple.checkpoint
channel_values = checkpoint_data.get("channel_values", {})
state_messages = channel_values.get("messages", [])
if state_messages:
last_msg = state_messages[-1]
# Find a checkpoint where the last message is NOT a HumanMessage
# This means we're at a state before the user's last message
if not isinstance(last_msg, HumanMessage):
# If no new user_query provided (reload), extract from a later checkpoint
if user_query_to_use is None and i > 0:
# Get the user query from a more recent checkpoint
for prev_cp_tuple in checkpoint_tuples[:i]:
prev_checkpoint_data = prev_cp_tuple.checkpoint
prev_channel_values = prev_checkpoint_data.get(
"channel_values", {}
)
prev_messages = prev_channel_values.get("messages", [])
for msg in reversed(prev_messages):
if isinstance(msg, HumanMessage):
q, imgs = split_langchain_human_content(
msg.content
)
user_query_to_use = q
regenerate_image_urls = imgs
break
if user_query_to_use is not None and (
str(user_query_to_use).strip()
or regenerate_image_urls
):
break
2026-04-29 07:20:31 -07:00
target_checkpoint_id = cp_tuple.config["configurable"][
"checkpoint_id"
]
break
# If we couldn't find a good checkpoint, try alternative approaches
if target_checkpoint_id is None and checkpoint_tuples:
if len(checkpoint_tuples) == 1:
# Only one checkpoint - get the user query from it if not provided
if user_query_to_use is None:
checkpoint_data = checkpoint_tuples[0].checkpoint
channel_values = checkpoint_data.get("channel_values", {})
state_messages = channel_values.get("messages", [])
for msg in state_messages:
if isinstance(msg, HumanMessage):
q, imgs = split_langchain_human_content(msg.content)
user_query_to_use = q
regenerate_image_urls = imgs
break
else:
# Use the oldest checkpoint
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
"checkpoint_id"
]
# If we still don't have a user query, get it from the database
if user_query_to_use is None:
# Get the last user message from the database
last_user_msg_result = await session.execute(
select(NewChatMessage)
.filter(
NewChatMessage.thread_id == thread_id,
NewChatMessage.role == NewChatMessageRole.USER,
)
.order_by(NewChatMessage.created_at.desc())
.limit(1)
)
last_user_msg = last_user_msg_result.scalars().first()
if last_user_msg:
content = last_user_msg.content
if isinstance(content, str):
user_query_to_use = content
elif isinstance(content, list):
plain, imgs = split_persisted_user_content_parts(content)
user_query_to_use = plain
regenerate_image_urls = imgs
if isinstance(user_query_to_use, list):
user_query_to_use, regenerate_image_urls = split_langchain_human_content(
user_query_to_use
)
if request.user_images is not None:
regenerate_image_urls = [p.as_data_url() for p in request.user_images]
if user_query_to_use is None:
raise HTTPException(
status_code=400,
detail="Could not determine user query for regeneration. Please provide a user_query.",
)
if not str(user_query_to_use).strip() and not regenerate_image_urls:
raise HTTPException(
status_code=400,
detail="Could not determine user query for regeneration. Please provide a user_query.",
)
2026-04-29 07:20:31 -07:00
# Get the messages to delete AFTER streaming succeeds.
# This prevents data loss if streaming fails.
#
# When ``from_message_id`` is set we slice from that message
# forward (using ``created_at`` so we also catch any tool/system
# messages persisted into the same turn). Otherwise
# we keep the legacy "last 2 messages" rewind.
if request.from_message_id is not None and from_message_created_at is not None:
last_messages_result = await session.execute(
select(NewChatMessage)
.filter(
NewChatMessage.thread_id == thread_id,
NewChatMessage.created_at >= from_message_created_at,
)
.order_by(NewChatMessage.created_at.desc())
)
else:
last_messages_result = await session.execute(
select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at.desc())
.limit(2)
)
messages_to_delete = list(last_messages_result.scalars().all())
2026-02-01 21:17:24 -08:00
message_ids_to_delete = [msg.id for msg in messages_to_delete]
2026-04-29 07:20:31 -07:00
# When revert_actions is requested, collect the set of
# ``chat_turn_id``s present in the slice we're about to delete.
# Each one will be reverted (best-effort) BEFORE the regenerate
# stream begins. Legacy rows have ``turn_id=None`` and silently
# contribute nothing — we already logged the degradation above.
revert_turn_ids: list[str] = []
if (
request.revert_actions
and request.from_message_id is not None
and not legacy_from_message
):
seen_turns: set[str] = set()
for msg in messages_to_delete:
tid = msg.turn_id
if tid and tid not in seen_turns:
seen_turns.add(tid)
revert_turn_ids.append(tid)
# Get search space for LLM config
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
)
search_space = search_space_result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
await session.commit()
await session.close()
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
# This prevents data loss if streaming fails (network error, LLM error, etc.)
async def stream_with_cleanup():
streaming_completed = False
2026-04-29 07:20:31 -07:00
# Best-effort revert pass BEFORE the regenerate stream begins.
# Each turn is reverted independently (per-row SAVEPOINTs
# inside the route helper) and the per-action results are surfaced
# on a single ``data-revert-results`` SSE event so the frontend
# can render any failed rows alongside the new turn. Failures here
# do NOT abort the regeneration — partial rollback is documented
# behaviour.
if revert_turn_ids:
revert_results = await _revert_turns_for_regenerate(
thread_id=thread_id,
chat_turn_ids=revert_turn_ids,
requester_user_id=str(user.id),
)
envelope = {
"type": "data-revert-results",
"data": revert_results,
}
yield f"data: {json.dumps(envelope, default=str)}\n\n".encode()
mentioned_documents_payload = (
[doc.model_dump() for doc in request.mentioned_documents]
if request.mentioned_documents
else None
)
try:
async for chunk in stream_new_chat(
user_query=str(user_query_to_use),
search_space_id=request.search_space_id,
chat_id=thread_id,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
mentioned_documents=mentioned_documents_payload,
checkpoint_id=target_checkpoint_id,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
disabled_tools=request.disabled_tools,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
user_image_data_urls=regenerate_image_urls or None,
flow="regenerate",
):
yield chunk
streaming_completed = True
finally:
# Only delete old messages if streaming completed successfully.
# Uses a fresh session since stream_new_chat manages its own.
if streaming_completed and message_ids_to_delete:
try:
async with shielded_async_session() as cleanup_session:
for msg_id in message_ids_to_delete:
_res = await cleanup_session.execute(
select(NewChatMessage).filter(
NewChatMessage.id == msg_id
)
)
_msg = _res.scalars().first()
if _msg:
await cleanup_session.delete(_msg)
await cleanup_session.commit()
from app.services.public_chat_service import (
delete_affected_snapshots,
)
await delete_affected_snapshots(
cleanup_session, thread_id, message_ids_to_delete
)
except Exception as cleanup_error:
_logger.warning(
"[regenerate] Failed to delete old messages: %s",
cleanup_error,
)
# Return streaming response with checkpoint_id for rewinding
return StreamingResponse(
stream_with_cleanup(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred during regeneration: {e!s}",
) from None
# =============================================================================
# Resume Interrupted Chat Endpoint
# =============================================================================
@router.post("/threads/{thread_id}/resume")
async def resume_chat(
thread_id: int,
request: ResumeRequest,
http_request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to chat in this search space",
)
await check_thread_access(session, thread, user)
_raise_if_thread_busy_for_start(thread_id)
filesystem_selection = _resolve_filesystem_selection(
mode=request.filesystem_mode,
client_platform=request.client_platform,
local_mounts=request.local_filesystem_mounts,
)
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
)
search_space = search_space_result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
decisions = [d.model_dump() for d in request.decisions]
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
await session.commit()
await session.close()
return StreamingResponse(
stream_resume_chat(
chat_id=thread_id,
search_space_id=request.search_space_id,
decisions=decisions,
user_id=str(user.id),
llm_config_id=llm_config_id,
thread_visibility=thread.visibility,
filesystem_selection=filesystem_selection,
request_id=getattr(http_request.state, "request_id", "unknown"),
disabled_tools=request.disabled_tools,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred during resume: {e!s}",
) from None