SurfSense/surfsense_backend/app/routes/new_chat_routes.py
DESKTOP-RTLN3BA\$punk aabc24f82c feat: enhance performance logging and caching in various components
- Introduced slow callback logging in FastAPI to identify blocking calls.
- Added performance logging for agent creation and tool loading processes.
- Implemented caching for MCP tools to reduce redundant server calls.
- Enhanced sandbox management with in-process caching for improved efficiency.
- Refactored several functions for better readability and performance tracking.
- Updated tests to ensure proper functionality of new features and optimizations.
2026-02-26 13:00:31 -08:00

1471 lines
50 KiB
Python

"""
Routes for the new chat feature with assistant-ui integration.
These endpoints support the ThreadHistoryAdapter pattern from assistant-ui:
- GET /threads - List threads for sidebar (ThreadListPrimitive)
- POST /threads - Create a new thread
- GET /threads/{thread_id} - Get thread with messages (load)
- PUT /threads/{thread_id} - Update thread (rename, archive)
- DELETE /threads/{thread_id} - Delete thread
- POST /threads/{thread_id}/messages - Append message
"""
import asyncio
import logging
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from sqlalchemy import func, or_
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from app.db import (
ChatComment,
ChatVisibility,
NewChatMessage,
NewChatMessageRole,
NewChatThread,
Permission,
SearchSpace,
User,
get_async_session,
)
from app.schemas.new_chat import (
NewChatMessageAppend,
NewChatMessageRead,
NewChatRequest,
NewChatThreadCreate,
NewChatThreadRead,
NewChatThreadUpdate,
NewChatThreadVisibilityUpdate,
NewChatThreadWithMessages,
PublicChatSnapshotCreateResponse,
PublicChatSnapshotListResponse,
RegenerateRequest,
ResumeRequest,
ThreadHistoryLoadResponse,
ThreadListItem,
ThreadListResponse,
)
from app.tasks.chat.stream_new_chat import stream_new_chat, stream_resume_chat
from app.users import current_active_user
from app.utils.rbac import check_permission
_logger = logging.getLogger(__name__)
_background_tasks: set[asyncio.Task] = set()
router = APIRouter()
def _try_delete_sandbox(thread_id: int) -> None:
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
from app.agents.new_chat.sandbox import (
delete_local_sandbox_files,
delete_sandbox,
is_sandbox_enabled,
)
if not is_sandbox_enabled():
return
async def _bg() -> None:
try:
await delete_sandbox(thread_id)
except Exception:
_logger.warning(
"Background sandbox delete failed for thread %s",
thread_id,
exc_info=True,
)
try:
delete_local_sandbox_files(thread_id)
except Exception:
_logger.warning(
"Local sandbox file cleanup failed for thread %s",
thread_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
task = loop.create_task(_bg())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except RuntimeError:
pass
async def check_thread_access(
session: AsyncSession,
thread: NewChatThread,
user: User,
require_ownership: bool = False,
) -> bool:
"""
Check if a user has access to a thread based on visibility rules.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE (any member can access) - for read/update operations only
- Thread is a legacy thread (created_by_id is NULL) - only if user is search space owner
Args:
session: Database session
thread: The thread to check access for
user: The user requesting access
require_ownership: If True, ONLY the creator can perform this action (e.g., changing visibility).
This is checked FIRST, before visibility rules.
Returns:
True if access is granted
Raises:
HTTPException: If access is denied
"""
is_owner = thread.created_by_id == user.id
is_legacy = thread.created_by_id is None
# If ownership is required (e.g., changing visibility), ONLY the creator can do it
# This check comes first to ensure ownership-required operations are always creator-only
if require_ownership:
if not is_owner:
raise HTTPException(
status_code=403,
detail="Only the creator of this chat can perform this action",
)
return True
# Shared threads (SEARCH_SPACE) are accessible by any member for read/update operations
if thread.visibility == ChatVisibility.SEARCH_SPACE:
return True
# For legacy threads (created before visibility feature),
# only the search space owner can access
if is_legacy:
search_space_query = select(SearchSpace).filter(
SearchSpace.id == thread.search_space_id
)
search_space_result = await session.execute(search_space_query)
search_space = search_space_result.scalar_one_or_none()
is_search_space_owner = search_space and search_space.user_id == user.id
if is_search_space_owner:
return True
# Legacy threads are not accessible to non-owners
raise HTTPException(
status_code=403,
detail="You don't have access to this chat",
)
# For read access: owner can access their own private threads
if is_owner:
return True
# Private thread and user is not the owner
raise HTTPException(
status_code=403,
detail="You don't have access to this private chat",
)
# =============================================================================
# Thread Endpoints
# =============================================================================
@router.get("/threads", response_model=ThreadListResponse)
async def list_threads(
search_space_id: int,
limit: int | None = None,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List all accessible threads for the current user in a search space.
Returns threads and archived_threads for ThreadListPrimitive.
A user can see threads that are:
- Created by them (regardless of visibility)
- Shared with the search space (visibility = SEARCH_SPACE)
- Legacy threads with no creator (created_by_id is NULL) - only if user is search space owner
Args:
search_space_id: The search space to list threads for
limit: Optional limit on number of threads to return (applies to active threads only)
Requires CHATS_READ permission.
"""
try:
await check_permission(
session,
user,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check if user is the search space owner (for legacy thread visibility)
search_space_query = select(SearchSpace).filter(
SearchSpace.id == search_space_id
)
search_space_result = await session.execute(search_space_query)
search_space = search_space_result.scalar_one_or_none()
is_search_space_owner = search_space and search_space.user_id == user.id
# Build filter conditions:
# 1. Created by the current user (any visibility)
# 2. Shared with the search space (visibility = SEARCH_SPACE)
# 3. Legacy threads (created_by_id is NULL) - only visible to search space owner
filter_conditions = [
NewChatThread.created_by_id == user.id,
NewChatThread.visibility == ChatVisibility.SEARCH_SPACE,
]
# Only include legacy threads for the search space owner
if is_search_space_owner:
filter_conditions.append(NewChatThread.created_by_id.is_(None))
query = (
select(NewChatThread)
.filter(
NewChatThread.search_space_id == search_space_id,
or_(*filter_conditions),
)
.order_by(NewChatThread.updated_at.desc())
)
result = await session.execute(query)
all_threads = result.scalars().all()
# Separate active and archived threads
threads = []
archived_threads = []
for thread in all_threads:
# Legacy threads (no creator) are treated as own threads for owner
is_own_thread = thread.created_by_id == user.id or (
thread.created_by_id is None and is_search_space_owner
)
item = ThreadListItem(
id=thread.id,
title=thread.title,
archived=thread.archived,
visibility=thread.visibility,
created_by_id=thread.created_by_id,
is_own_thread=is_own_thread,
created_at=thread.created_at,
updated_at=thread.updated_at,
)
if thread.archived:
archived_threads.append(item)
else:
threads.append(item)
# Apply limit to active threads if specified
if limit is not None and limit > 0:
threads = threads[:limit]
return ThreadListResponse(threads=threads, archived_threads=archived_threads)
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching threads: {e!s}",
) from None
@router.get("/threads/search", response_model=list[ThreadListItem])
async def search_threads(
search_space_id: int,
title: str,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Search accessible threads by title in a search space.
A user can search threads that are:
- Created by them (regardless of visibility)
- Shared with the search space (visibility = SEARCH_SPACE)
- Legacy threads with no creator (created_by_id is NULL) - only if user is search space owner
Args:
search_space_id: The search space to search in
title: The search query (case-insensitive partial match)
Requires CHATS_READ permission.
"""
try:
await check_permission(
session,
user,
search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check if user is the search space owner (for legacy thread visibility)
search_space_query = select(SearchSpace).filter(
SearchSpace.id == search_space_id
)
search_space_result = await session.execute(search_space_query)
search_space = search_space_result.scalar_one_or_none()
is_search_space_owner = search_space and search_space.user_id == user.id
# Build filter conditions
filter_conditions = [
NewChatThread.created_by_id == user.id,
NewChatThread.visibility == ChatVisibility.SEARCH_SPACE,
]
# Only include legacy threads for the search space owner
if is_search_space_owner:
filter_conditions.append(NewChatThread.created_by_id.is_(None))
# Search accessible threads by title (case-insensitive)
query = (
select(NewChatThread)
.filter(
NewChatThread.search_space_id == search_space_id,
NewChatThread.title.ilike(f"%{title}%"),
or_(*filter_conditions),
)
.order_by(NewChatThread.updated_at.desc())
)
result = await session.execute(query)
threads = result.scalars().all()
return [
ThreadListItem(
id=thread.id,
title=thread.title,
archived=thread.archived,
visibility=thread.visibility,
created_by_id=thread.created_by_id,
# Legacy threads (no creator) are treated as own threads for owner
is_own_thread=(
thread.created_by_id == user.id
or (thread.created_by_id is None and is_search_space_owner)
),
created_at=thread.created_at,
updated_at=thread.updated_at,
)
for thread in threads
]
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while searching threads: {e!s}",
) from None
@router.post("/threads", response_model=NewChatThreadRead)
async def create_thread(
thread: NewChatThreadCreate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a new chat thread.
The thread is created with the specified visibility (defaults to PRIVATE).
The current user is recorded as the creator of the thread.
Requires CHATS_CREATE permission.
"""
try:
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to create chats in this search space",
)
now = datetime.now(UTC)
db_thread = NewChatThread(
title=thread.title,
archived=thread.archived,
visibility=thread.visibility,
search_space_id=thread.search_space_id,
created_by_id=user.id,
updated_at=now,
)
session.add(db_thread)
await session.commit()
await session.refresh(db_thread)
return db_thread
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while creating the thread: {e!s}",
) from None
@router.get("/threads/{thread_id}", response_model=ThreadHistoryLoadResponse)
async def get_thread_messages(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get a thread with all its messages.
This is used by ThreadHistoryAdapter.load() to restore conversation.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_READ permission.
"""
try:
# Get thread first
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
# Check permission to read chats in this search space
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Get messages with their authors loaded
messages_result = await session.execute(
select(NewChatMessage)
.options(selectinload(NewChatMessage.author))
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
)
db_messages = messages_result.scalars().all()
# Return messages in the format expected by assistant-ui
messages = [
NewChatMessageRead(
id=msg.id,
thread_id=msg.thread_id,
role=msg.role,
content=msg.content,
created_at=msg.created_at,
author_id=msg.author_id,
author_display_name=msg.author.display_name if msg.author else None,
author_avatar_url=msg.author.avatar_url if msg.author else None,
)
for msg in db_messages
]
return ThreadHistoryLoadResponse(messages=messages)
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
) from None
@router.get("/threads/{thread_id}/full", response_model=NewChatThreadWithMessages)
async def get_thread_full(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Get full thread details with all messages.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_READ permission.
"""
try:
result = await session.execute(
select(NewChatThread)
.options(selectinload(NewChatThread.messages))
.filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Check if thread has any comments
comment_count = await session.scalar(
select(func.count())
.select_from(ChatComment)
.join(NewChatMessage, ChatComment.message_id == NewChatMessage.id)
.where(NewChatMessage.thread_id == thread.id)
)
return {
**thread.__dict__,
"messages": thread.messages,
"has_comments": (comment_count or 0) > 0,
}
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching the thread: {e!s}",
) from None
@router.put("/threads/{thread_id}", response_model=NewChatThreadRead)
async def update_thread(
thread_id: int,
thread_update: NewChatThreadUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update a thread (title, archived status).
Used for renaming and archiving threads.
- PRIVATE threads: Only the creator can update
- SEARCH_SPACE threads: Any member with CHATS_UPDATE permission can update
Requires CHATS_UPDATE permission.
"""
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
db_thread = result.scalars().first()
if not db_thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
db_thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# For PRIVATE threads, only the creator can update
# For SEARCH_SPACE threads, any member with permission can update
if db_thread.visibility == ChatVisibility.PRIVATE:
await check_thread_access(session, db_thread, user, require_ownership=True)
# Update fields
update_data = thread_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_thread, key, value)
db_thread.updated_at = datetime.now(UTC)
await session.commit()
await session.refresh(db_thread)
return db_thread
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while updating the thread: {e!s}",
) from None
@router.delete("/threads/{thread_id}", response_model=dict)
async def delete_thread(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete a thread and all its messages.
- PRIVATE threads: Only the creator can delete
- SEARCH_SPACE threads: Any member with CHATS_DELETE permission can delete
Requires CHATS_DELETE permission.
"""
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
db_thread = result.scalars().first()
if not db_thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
db_thread.search_space_id,
Permission.CHATS_DELETE.value,
"You don't have permission to delete chats in this search space",
)
# For PRIVATE threads, only the creator can delete
# For SEARCH_SPACE threads, any member with permission can delete
# Legacy threads (created_by_id is NULL) have no recorded creator,
# so we skip strict ownership and fall through to legacy handling
# which allows the search space owner to delete them
if db_thread.visibility == ChatVisibility.PRIVATE:
await check_thread_access(
session,
db_thread,
user,
require_ownership=(db_thread.created_by_id is not None),
)
await session.delete(db_thread)
await session.commit()
_try_delete_sandbox(thread_id)
return {"message": "Thread deleted successfully"}
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400, detail="Cannot delete thread due to existing dependencies."
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while deleting the thread: {e!s}",
) from None
@router.patch("/threads/{thread_id}/visibility", response_model=NewChatThreadRead)
async def update_thread_visibility(
thread_id: int,
visibility_update: NewChatThreadVisibilityUpdate,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Update the visibility/sharing settings of a thread.
Only the creator of the thread can change its visibility.
- PRIVATE: Only the creator can access the thread (default)
- SEARCH_SPACE: All members of the search space can access the thread
Requires CHATS_UPDATE permission.
"""
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
db_thread = result.scalars().first()
if not db_thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
db_thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# Only the creator can change visibility
await check_thread_access(session, db_thread, user, require_ownership=True)
# Update visibility
db_thread.visibility = visibility_update.visibility
db_thread.updated_at = datetime.now(UTC)
await session.commit()
await session.refresh(db_thread)
return db_thread
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while updating thread visibility: {e!s}",
) from None
# =============================================================================
# Snapshot Endpoints
# =============================================================================
@router.post(
"/threads/{thread_id}/snapshots", response_model=PublicChatSnapshotCreateResponse
)
async def create_thread_snapshot(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Create a public snapshot of the thread.
Returns existing snapshot URL if content unchanged (deduplication).
"""
from app.services.public_chat_service import create_snapshot
return await create_snapshot(
session=session,
thread_id=thread_id,
user=user,
)
@router.get(
"/threads/{thread_id}/snapshots", response_model=PublicChatSnapshotListResponse
)
async def list_thread_snapshots(
thread_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List all public snapshots for this thread.
Only the thread owner can view snapshots.
"""
from app.services.public_chat_service import list_snapshots_for_thread
return PublicChatSnapshotListResponse(
snapshots=await list_snapshots_for_thread(
session=session,
thread_id=thread_id,
user=user,
)
)
@router.delete("/threads/{thread_id}/snapshots/{snapshot_id}")
async def delete_thread_snapshot(
thread_id: int,
snapshot_id: int,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Delete a specific snapshot.
Only the thread owner can delete snapshots.
"""
from app.services.public_chat_service import delete_snapshot
await delete_snapshot(
session=session,
thread_id=thread_id,
snapshot_id=snapshot_id,
user=user,
)
return {"message": "Snapshot deleted successfully"}
# =============================================================================
# Message Endpoints
# =============================================================================
@router.post("/threads/{thread_id}/messages", response_model=NewChatMessageRead)
async def append_message(
thread_id: int,
request: Request,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Append a message to a thread.
This is used by ThreadHistoryAdapter.append() to persist messages.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_UPDATE permission.
"""
try:
# Parse raw body - extract only role and content, ignoring extra fields
raw_body = await request.json()
role = raw_body.get("role")
content = raw_body.get("content")
if not role:
raise HTTPException(status_code=400, detail="Missing required field: role")
if content is None:
raise HTTPException(
status_code=400, detail="Missing required field: content"
)
# Create message object manually
message = NewChatMessageAppend(role=role, content=content)
# Get thread
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Convert string role to enum
role_str = (
message.role.lower() if isinstance(message.role, str) else message.role
)
try:
message_role = NewChatMessageRole(role_str)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid role: {message.role}. Must be 'user', 'assistant', or 'system'.",
) from None
# Create message
db_message = NewChatMessage(
thread_id=thread_id,
role=message_role,
content=message.content,
author_id=user.id,
)
session.add(db_message)
# Update thread's updated_at timestamp
thread.updated_at = datetime.now(UTC)
# Note: Title generation now happens in stream_new_chat.py after the first response
# using LLM to generate a descriptive title (with truncation as fallback)
await session.commit()
await session.refresh(db_message)
return db_message
except HTTPException:
raise
except IntegrityError:
await session.rollback()
raise HTTPException(
status_code=400,
detail="Database constraint violation. Please check your input data.",
) from None
except OperationalError:
await session.rollback()
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
await session.rollback()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while appending the message: {e!s}",
) from None
@router.get("/threads/{thread_id}/messages", response_model=list[NewChatMessageRead])
async def list_messages(
thread_id: int,
skip: int = 0,
limit: int = 100,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
List messages in a thread with pagination.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_READ permission.
"""
try:
# Verify thread exists and user has access
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_READ.value,
"You don't have permission to read chats in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Get messages
query = (
select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at)
.offset(skip)
.limit(limit)
)
result = await session.execute(query)
return result.scalars().all()
except HTTPException:
raise
except OperationalError:
raise HTTPException(
status_code=503, detail="Database operation failed. Please try again later."
) from None
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while fetching messages: {e!s}",
) from None
# =============================================================================
# Chat Streaming Endpoint
# =============================================================================
@router.post("/new_chat")
async def handle_new_chat(
request: NewChatRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Stream chat responses from the deep agent.
This endpoint handles the new chat functionality with streaming responses
using Server-Sent Events (SSE) format compatible with Vercel AI SDK.
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_CREATE permission.
"""
try:
# Verify thread exists and user has permission
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == request.chat_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to chat in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Get search space to check LLM config preferences
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
)
search_space = search_space_result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
# Use agent_llm_id from search space for chat operations
# Positive IDs load from NewLLMConfig database table
# Negative IDs load from YAML global configs
# Falls back to -1 (first global config) if not configured
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs usable.
await session.commit()
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
search_space_id=request.search_space_id,
chat_id=request.chat_id,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred: {e!s}",
) from None
# =============================================================================
# Chat Regeneration Endpoint (Edit/Reload)
# =============================================================================
@router.post("/threads/{thread_id}/regenerate")
async def regenerate_response(
thread_id: int,
request: RegenerateRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
"""
Regenerate the AI response for a chat thread.
This endpoint supports two operations:
1. **Edit**: Provide a new `user_query` to replace the last user message and regenerate
2. **Reload**: Leave `user_query` empty (or None) to regenerate with the same query
Both operations:
- Rewind the LangGraph checkpointer to the state before the last AI response
- Delete the last user message and AI response from the database
- Stream a new response from that checkpoint
Access is granted if:
- User is the creator of the thread
- Thread visibility is SEARCH_SPACE
Requires CHATS_UPDATE permission.
"""
from langchain_core.messages import HumanMessage
from app.agents.new_chat.checkpointer import get_checkpointer
try:
# Verify thread exists and user has permission
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_UPDATE.value,
"You don't have permission to update chats in this search space",
)
# Check thread-level access based on visibility
await check_thread_access(session, thread, user)
# Get the checkpointer and state history
checkpointer = await get_checkpointer()
config = {"configurable": {"thread_id": str(thread_id)}}
# Collect checkpoint tuples from the async iterator
# CheckpointTuple has: config, checkpoint (dict with channel_values), metadata, parent_config
checkpoint_tuples = []
async for cp_tuple in checkpointer.alist(config):
checkpoint_tuples.append(cp_tuple)
if not checkpoint_tuples:
raise HTTPException(
status_code=400, detail="No conversation history found for this thread"
)
# Find the checkpoint to rewind to
# Checkpoints are in reverse chronological order (newest first)
# We need to find a checkpoint before the last user message was added
#
# The checkpointer stores states after each node execution.
# For a typical conversation flow:
# - User sends message -> state 1 (with HumanMessage)
# - Agent responds -> state 2 (with HumanMessage + AIMessage)
#
# To regenerate, we need the state BEFORE the last HumanMessage was processed
target_checkpoint_id = None
user_query_to_use = request.user_query
# Look through checkpoints to find the right one
# We want to find the checkpoint just before the last HumanMessage
for i, cp_tuple in enumerate(checkpoint_tuples):
# Access the checkpoint's channel_values which contains "messages"
checkpoint_data = cp_tuple.checkpoint
channel_values = checkpoint_data.get("channel_values", {})
state_messages = channel_values.get("messages", [])
if state_messages:
last_msg = state_messages[-1]
# Find a checkpoint where the last message is NOT a HumanMessage
# This means we're at a state before the user's last message
if not isinstance(last_msg, HumanMessage):
# If no new user_query provided (reload), extract from a later checkpoint
if user_query_to_use is None and i > 0:
# Get the user query from a more recent checkpoint
for prev_cp_tuple in checkpoint_tuples[:i]:
prev_checkpoint_data = prev_cp_tuple.checkpoint
prev_channel_values = prev_checkpoint_data.get(
"channel_values", {}
)
prev_messages = prev_channel_values.get("messages", [])
for msg in reversed(prev_messages):
if isinstance(msg, HumanMessage):
user_query_to_use = msg.content
break
if user_query_to_use:
break
target_checkpoint_id = cp_tuple.config["configurable"][
"checkpoint_id"
]
break
# If we couldn't find a good checkpoint, try alternative approaches
if target_checkpoint_id is None and checkpoint_tuples:
if len(checkpoint_tuples) == 1:
# Only one checkpoint - get the user query from it if not provided
if user_query_to_use is None:
checkpoint_data = checkpoint_tuples[0].checkpoint
channel_values = checkpoint_data.get("channel_values", {})
state_messages = channel_values.get("messages", [])
for msg in state_messages:
if isinstance(msg, HumanMessage):
user_query_to_use = msg.content
break
else:
# Use the oldest checkpoint
target_checkpoint_id = checkpoint_tuples[-1].config["configurable"][
"checkpoint_id"
]
# If we still don't have a user query, get it from the database
if user_query_to_use is None:
# Get the last user message from the database
last_user_msg_result = await session.execute(
select(NewChatMessage)
.filter(
NewChatMessage.thread_id == thread_id,
NewChatMessage.role == NewChatMessageRole.USER,
)
.order_by(NewChatMessage.created_at.desc())
.limit(1)
)
last_user_msg = last_user_msg_result.scalars().first()
if last_user_msg:
content = last_user_msg.content
if isinstance(content, str):
user_query_to_use = content
elif isinstance(content, list):
# Extract text from content parts
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
user_query_to_use = part.get("text", "")
break
elif isinstance(part, str):
user_query_to_use = part
break
if user_query_to_use is None:
raise HTTPException(
status_code=400,
detail="Could not determine user query for regeneration. Please provide a user_query.",
)
# Get the last two messages to delete AFTER streaming succeeds
# This prevents data loss if streaming fails
last_messages_result = await session.execute(
select(NewChatMessage)
.filter(NewChatMessage.thread_id == thread_id)
.order_by(NewChatMessage.created_at.desc())
.limit(2)
)
messages_to_delete = list(last_messages_result.scalars().all())
message_ids_to_delete = [msg.id for msg in messages_to_delete]
# Get search space for LLM config
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
)
search_space = search_space_result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
await session.commit()
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
# This prevents data loss if streaming fails (network error, LLM error, etc.)
async def stream_with_cleanup():
streaming_completed = False
try:
async for chunk in stream_new_chat(
user_query=user_query_to_use,
search_space_id=request.search_space_id,
chat_id=thread_id,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
mentioned_surfsense_doc_ids=request.mentioned_surfsense_doc_ids,
checkpoint_id=target_checkpoint_id,
needs_history_bootstrap=thread.needs_history_bootstrap,
thread_visibility=thread.visibility,
current_user_display_name=user.display_name or "A team member",
):
yield chunk
# If we get here, streaming completed successfully
streaming_completed = True
finally:
# Only delete old messages if streaming completed successfully
# This ensures we don't lose data on streaming failures
if streaming_completed and messages_to_delete:
try:
for msg in messages_to_delete:
await session.delete(msg)
await session.commit()
# Delete any public snapshots that contain the modified messages
from app.services.public_chat_service import (
delete_affected_snapshots,
)
await delete_affected_snapshots(
session, thread_id, message_ids_to_delete
)
except Exception as cleanup_error:
# Log but don't fail - the new messages are already streamed
print(
f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}"
)
# Return streaming response with checkpoint_id for rewinding
return StreamingResponse(
stream_with_cleanup(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred during regeneration: {e!s}",
) from None
# =============================================================================
# Resume Interrupted Chat Endpoint
# =============================================================================
@router.post("/threads/{thread_id}/resume")
async def resume_chat(
thread_id: int,
request: ResumeRequest,
session: AsyncSession = Depends(get_async_session),
user: User = Depends(current_active_user),
):
try:
result = await session.execute(
select(NewChatThread).filter(NewChatThread.id == thread_id)
)
thread = result.scalars().first()
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
await check_permission(
session,
user,
thread.search_space_id,
Permission.CHATS_CREATE.value,
"You don't have permission to chat in this search space",
)
await check_thread_access(session, thread, user)
search_space_result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
)
search_space = search_space_result.scalars().first()
if not search_space:
raise HTTPException(status_code=404, detail="Search space not found")
llm_config_id = (
search_space.agent_llm_id if search_space.agent_llm_id is not None else -1
)
decisions = [d.model_dump() for d in request.decisions]
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
await session.commit()
return StreamingResponse(
stream_resume_chat(
chat_id=thread_id,
search_space_id=request.search_space_id,
decisions=decisions,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
thread_visibility=thread.visibility,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except HTTPException:
raise
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred during resume: {e!s}",
) from None