diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index c997cba68..8952907a0 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -31,6 +31,7 @@ from app.db import ( Permission, SearchSpace, User, + async_session_maker, get_async_session, ) from app.schemas.new_chat import ( @@ -1092,13 +1093,18 @@ async def handle_new_chat( # on searchspaces/documents for the entire duration of the stream. # expire_on_commit=False keeps loaded ORM attrs usable. await session.commit() + # Close the dependency session now so its connection returns to + # the pool before streaming begins. Without this, Starlette's + # BaseHTTPMiddleware cancels the scope on client disconnect and + # the dependency generator's __aexit__ never runs, orphaning the + # connection (the "Exception terminating connection" errors). + await session.close() return StreamingResponse( stream_new_chat( user_query=request.user_query, search_space_id=request.search_space_id, chat_id=request.chat_id, - session=session, user_id=str(user.id), llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, @@ -1323,6 +1329,7 @@ async def regenerate_response( # on searchspaces/documents for the entire duration of the stream. # expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable. await session.commit() + await session.close() # Create a wrapper generator that deletes messages only AFTER streaming succeeds # This prevents data loss if streaming fails (network error, LLM error, etc.) @@ -1333,7 +1340,6 @@ async def regenerate_response( user_query=user_query_to_use, search_space_id=request.search_space_id, chat_id=thread_id, - session=session, user_id=str(user.id), llm_config_id=llm_config_id, mentioned_document_ids=request.mentioned_document_ids, @@ -1344,29 +1350,35 @@ async def regenerate_response( current_user_display_name=user.display_name or "A team member", ): yield chunk - # If we get here, streaming completed successfully streaming_completed = True finally: - # Only delete old messages if streaming completed successfully - # This ensures we don't lose data on streaming failures - if streaming_completed and messages_to_delete: + # Only delete old messages if streaming completed successfully. + # Uses a fresh session since stream_new_chat manages its own. + if streaming_completed and message_ids_to_delete: try: - for msg in messages_to_delete: - await session.delete(msg) - await session.commit() + async with async_session_maker() as cleanup_session: + for msg_id in message_ids_to_delete: + _res = await cleanup_session.execute( + select(NewChatMessage).filter( + NewChatMessage.id == msg_id + ) + ) + _msg = _res.scalars().first() + if _msg: + await cleanup_session.delete(_msg) + await cleanup_session.commit() - # Delete any public snapshots that contain the modified messages - from app.services.public_chat_service import ( - delete_affected_snapshots, - ) + from app.services.public_chat_service import ( + delete_affected_snapshots, + ) - await delete_affected_snapshots( - session, thread_id, message_ids_to_delete - ) + await delete_affected_snapshots( + cleanup_session, thread_id, message_ids_to_delete + ) except Exception as cleanup_error: - # Log but don't fail - the new messages are already streamed - print( - f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}" + _logger.warning( + "[regenerate] Failed to delete old messages: %s", + cleanup_error, ) # Return streaming response with checkpoint_id for rewinding @@ -1440,13 +1452,13 @@ async def resume_chat( # Release the read-transaction so we don't hold ACCESS SHARE locks # on searchspaces/documents for the entire duration of the stream. await session.commit() + await session.close() return StreamingResponse( stream_resume_chat( chat_id=thread_id, search_space_id=request.search_space_id, decisions=decisions, - session=session, user_id=str(user.id), llm_config_id=llm_config_id, thread_visibility=thread.visibility, diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index ae7001a40..34ea6ec82 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -10,6 +10,7 @@ Supports loading LLM configurations from: """ import asyncio +import contextlib import gc import json import logging @@ -20,9 +21,9 @@ from dataclasses import dataclass, field from typing import Any from uuid import UUID +import anyio from langchain_core.messages import HumanMessage from sqlalchemy import func -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload @@ -1012,7 +1013,6 @@ async def stream_new_chat( user_query: str, search_space_id: int, chat_id: int, - session: AsyncSession, user_id: str | None = None, llm_config_id: int = -1, mentioned_document_ids: list[int] | None = None, @@ -1028,11 +1028,13 @@ async def stream_new_chat( This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming. The chat_id is used as LangGraph's thread_id for memory/checkpointing. + The function creates and manages its own database session to guarantee proper + cleanup even when Starlette's middleware cancels the task on client disconnect. + Args: user_query: The user's query search_space_id: The search space ID chat_id: The chat ID (used as LangGraph thread_id for memory) - session: The database session user_id: The current user's UUID string (for memory tools and session state) llm_config_id: The LLM configuration ID (default: -1 for first global config) needs_history_bootstrap: If True, load message history from DB (for cloned chats) @@ -1048,6 +1050,7 @@ async def stream_new_chat( _t_total = time.perf_counter() log_system_snapshot("stream_new_chat_START") + session = async_session_maker() try: # Mark AI as responding to this user for live collaboration if user_id: @@ -1283,6 +1286,12 @@ async def stream_new_chat( # short-lived transactions (or use isolated sessions). await session.commit() + # Detach heavy ORM objects (documents with chunks, reports, etc.) + # from the session identity map now that we've extracted the data + # we need. This prevents them from accumulating in memory for the + # entire duration of LLM streaming (which can be several minutes). + session.expunge_all() + _perf_log.info( "[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)", time.perf_counter() - _t_total, @@ -1459,23 +1468,35 @@ async def stream_new_chat( yield streaming_service.format_done() finally: - # Clear AI responding state for live collaboration. - # The original session may be broken (client disconnect / CancelledError - # can corrupt the underlying DB connection), so we try a rollback first - # and fall back to a fresh session if the original is unusable. - try: - await session.rollback() - await clear_ai_responding(session, chat_id) - except Exception: + # Shield the ENTIRE async cleanup from anyio cancel-scope + # cancellation. Starlette's BaseHTTPMiddleware uses anyio task + # groups; on client disconnect, it cancels the scope with + # level-triggered cancellation — every unshielded `await` inside + # the cancelled scope raises CancelledError immediately. Without + # this shield the very first `await` (session.rollback) would + # raise CancelledError, `except Exception` wouldn't catch it + # (CancelledError is a BaseException), and the rest of the + # finally block — including session.close() — would never run. + with anyio.CancelScope(shield=True): try: - async with async_session_maker() as fresh_session: - await clear_ai_responding(fresh_session, chat_id) + await session.rollback() + await clear_ai_responding(session, chat_id) except Exception: - logging.getLogger(__name__).warning( - "Failed to clear AI responding state for thread %s", chat_id - ) + try: + async with async_session_maker() as fresh_session: + await clear_ai_responding(fresh_session, chat_id) + except Exception: + logging.getLogger(__name__).warning( + "Failed to clear AI responding state for thread %s", chat_id + ) - _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) + _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) + + with contextlib.suppress(Exception): + session.expunge_all() + + with contextlib.suppress(Exception): + await session.close() # Break circular refs held by the agent graph, tools, and LLM # wrappers so the GC can reclaim them in a single pass. @@ -1483,6 +1504,7 @@ async def stream_new_chat( mentioned_documents = mentioned_surfsense_docs = None recent_reports = langchain_messages = input_state = None stream_result = None + session = None collected = gc.collect(0) + gc.collect(1) + gc.collect(2) if collected: @@ -1498,7 +1520,6 @@ async def stream_resume_chat( chat_id: int, search_space_id: int, decisions: list[dict], - session: AsyncSession, user_id: str | None = None, llm_config_id: int = -1, thread_visibility: ChatVisibility | None = None, @@ -1507,6 +1528,7 @@ async def stream_resume_chat( stream_result = StreamResult() _t_total = time.perf_counter() + session = async_session_maker() try: if user_id: await set_ai_responding(session, chat_id, UUID(user_id)) @@ -1603,6 +1625,7 @@ async def stream_resume_chat( # Release the transaction before streaming (same rationale as stream_new_chat). await session.commit() + session.expunge_all() _perf_log.info( "[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)", @@ -1666,22 +1689,30 @@ async def stream_resume_chat( yield streaming_service.format_done() finally: - try: - await session.rollback() - await clear_ai_responding(session, chat_id) - except Exception: + with anyio.CancelScope(shield=True): try: - async with async_session_maker() as fresh_session: - await clear_ai_responding(fresh_session, chat_id) + await session.rollback() + await clear_ai_responding(session, chat_id) except Exception: - logging.getLogger(__name__).warning( - "Failed to clear AI responding state for thread %s", chat_id - ) + try: + async with async_session_maker() as fresh_session: + await clear_ai_responding(fresh_session, chat_id) + except Exception: + logging.getLogger(__name__).warning( + "Failed to clear AI responding state for thread %s", chat_id + ) - _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) + _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) + + with contextlib.suppress(Exception): + session.expunge_all() + + with contextlib.suppress(Exception): + await session.close() agent = llm = connector_service = sandbox_backend = None stream_result = None + session = None collected = gc.collect(0) + gc.collect(1) + gc.collect(2) if collected: