refactor: improve session management and cleanup in chat streaming

- Added proper session closure to prevent connection leaks during streaming.
- Implemented a fresh session for cleanup tasks to ensure data integrity.
- Enhanced error handling during session operations to improve robustness.
- Removed unnecessary session parameters from function signatures for clarity.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-02-28 23:17:11 -08:00
parent cbf9bc6bc9
commit dd3da2bc36
2 changed files with 91 additions and 48 deletions

View file

@ -31,6 +31,7 @@ from app.db import (
Permission, Permission,
SearchSpace, SearchSpace,
User, User,
async_session_maker,
get_async_session, get_async_session,
) )
from app.schemas.new_chat import ( from app.schemas.new_chat import (
@ -1092,13 +1093,18 @@ async def handle_new_chat(
# on searchspaces/documents for the entire duration of the stream. # on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs usable. # expire_on_commit=False keeps loaded ORM attrs usable.
await session.commit() await session.commit()
# Close the dependency session now so its connection returns to
# the pool before streaming begins. Without this, Starlette's
# BaseHTTPMiddleware cancels the scope on client disconnect and
# the dependency generator's __aexit__ never runs, orphaning the
# connection (the "Exception terminating connection" errors).
await session.close()
return StreamingResponse( return StreamingResponse(
stream_new_chat( stream_new_chat(
user_query=request.user_query, user_query=request.user_query,
search_space_id=request.search_space_id, search_space_id=request.search_space_id,
chat_id=request.chat_id, chat_id=request.chat_id,
session=session,
user_id=str(user.id), user_id=str(user.id),
llm_config_id=llm_config_id, llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids, mentioned_document_ids=request.mentioned_document_ids,
@ -1323,6 +1329,7 @@ async def regenerate_response(
# on searchspaces/documents for the entire duration of the stream. # on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable. # expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
await session.commit() await session.commit()
await session.close()
# Create a wrapper generator that deletes messages only AFTER streaming succeeds # Create a wrapper generator that deletes messages only AFTER streaming succeeds
# This prevents data loss if streaming fails (network error, LLM error, etc.) # This prevents data loss if streaming fails (network error, LLM error, etc.)
@ -1333,7 +1340,6 @@ async def regenerate_response(
user_query=user_query_to_use, user_query=user_query_to_use,
search_space_id=request.search_space_id, search_space_id=request.search_space_id,
chat_id=thread_id, chat_id=thread_id,
session=session,
user_id=str(user.id), user_id=str(user.id),
llm_config_id=llm_config_id, llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids, mentioned_document_ids=request.mentioned_document_ids,
@ -1344,29 +1350,35 @@ async def regenerate_response(
current_user_display_name=user.display_name or "A team member", current_user_display_name=user.display_name or "A team member",
): ):
yield chunk yield chunk
# If we get here, streaming completed successfully
streaming_completed = True streaming_completed = True
finally: finally:
# Only delete old messages if streaming completed successfully # Only delete old messages if streaming completed successfully.
# This ensures we don't lose data on streaming failures # Uses a fresh session since stream_new_chat manages its own.
if streaming_completed and messages_to_delete: if streaming_completed and message_ids_to_delete:
try: try:
for msg in messages_to_delete: async with async_session_maker() as cleanup_session:
await session.delete(msg) for msg_id in message_ids_to_delete:
await session.commit() _res = await cleanup_session.execute(
select(NewChatMessage).filter(
NewChatMessage.id == msg_id
)
)
_msg = _res.scalars().first()
if _msg:
await cleanup_session.delete(_msg)
await cleanup_session.commit()
# Delete any public snapshots that contain the modified messages from app.services.public_chat_service import (
from app.services.public_chat_service import ( delete_affected_snapshots,
delete_affected_snapshots, )
)
await delete_affected_snapshots( await delete_affected_snapshots(
session, thread_id, message_ids_to_delete cleanup_session, thread_id, message_ids_to_delete
) )
except Exception as cleanup_error: except Exception as cleanup_error:
# Log but don't fail - the new messages are already streamed _logger.warning(
print( "[regenerate] Failed to delete old messages: %s",
f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}" cleanup_error,
) )
# Return streaming response with checkpoint_id for rewinding # Return streaming response with checkpoint_id for rewinding
@ -1440,13 +1452,13 @@ async def resume_chat(
# Release the read-transaction so we don't hold ACCESS SHARE locks # Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream. # on searchspaces/documents for the entire duration of the stream.
await session.commit() await session.commit()
await session.close()
return StreamingResponse( return StreamingResponse(
stream_resume_chat( stream_resume_chat(
chat_id=thread_id, chat_id=thread_id,
search_space_id=request.search_space_id, search_space_id=request.search_space_id,
decisions=decisions, decisions=decisions,
session=session,
user_id=str(user.id), user_id=str(user.id),
llm_config_id=llm_config_id, llm_config_id=llm_config_id,
thread_visibility=thread.visibility, thread_visibility=thread.visibility,

View file

@ -10,6 +10,7 @@ Supports loading LLM configurations from:
""" """
import asyncio import asyncio
import contextlib
import gc import gc
import json import json
import logging import logging
@ -20,9 +21,9 @@ from dataclasses import dataclass, field
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
import anyio
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select from sqlalchemy.future import select
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
@ -1012,7 +1013,6 @@ async def stream_new_chat(
user_query: str, user_query: str,
search_space_id: int, search_space_id: int,
chat_id: int, chat_id: int,
session: AsyncSession,
user_id: str | None = None, user_id: str | None = None,
llm_config_id: int = -1, llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None, mentioned_document_ids: list[int] | None = None,
@ -1028,11 +1028,13 @@ async def stream_new_chat(
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming. This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
The chat_id is used as LangGraph's thread_id for memory/checkpointing. The chat_id is used as LangGraph's thread_id for memory/checkpointing.
The function creates and manages its own database session to guarantee proper
cleanup even when Starlette's middleware cancels the task on client disconnect.
Args: Args:
user_query: The user's query user_query: The user's query
search_space_id: The search space ID search_space_id: The search space ID
chat_id: The chat ID (used as LangGraph thread_id for memory) chat_id: The chat ID (used as LangGraph thread_id for memory)
session: The database session
user_id: The current user's UUID string (for memory tools and session state) user_id: The current user's UUID string (for memory tools and session state)
llm_config_id: The LLM configuration ID (default: -1 for first global config) llm_config_id: The LLM configuration ID (default: -1 for first global config)
needs_history_bootstrap: If True, load message history from DB (for cloned chats) needs_history_bootstrap: If True, load message history from DB (for cloned chats)
@ -1048,6 +1050,7 @@ async def stream_new_chat(
_t_total = time.perf_counter() _t_total = time.perf_counter()
log_system_snapshot("stream_new_chat_START") log_system_snapshot("stream_new_chat_START")
session = async_session_maker()
try: try:
# Mark AI as responding to this user for live collaboration # Mark AI as responding to this user for live collaboration
if user_id: if user_id:
@ -1283,6 +1286,12 @@ async def stream_new_chat(
# short-lived transactions (or use isolated sessions). # short-lived transactions (or use isolated sessions).
await session.commit() await session.commit()
# Detach heavy ORM objects (documents with chunks, reports, etc.)
# from the session identity map now that we've extracted the data
# we need. This prevents them from accumulating in memory for the
# entire duration of LLM streaming (which can be several minutes).
session.expunge_all()
_perf_log.info( _perf_log.info(
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)", "[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total, time.perf_counter() - _t_total,
@ -1459,23 +1468,35 @@ async def stream_new_chat(
yield streaming_service.format_done() yield streaming_service.format_done()
finally: finally:
# Clear AI responding state for live collaboration. # Shield the ENTIRE async cleanup from anyio cancel-scope
# The original session may be broken (client disconnect / CancelledError # cancellation. Starlette's BaseHTTPMiddleware uses anyio task
# can corrupt the underlying DB connection), so we try a rollback first # groups; on client disconnect, it cancels the scope with
# and fall back to a fresh session if the original is unusable. # level-triggered cancellation — every unshielded `await` inside
try: # the cancelled scope raises CancelledError immediately. Without
await session.rollback() # this shield the very first `await` (session.rollback) would
await clear_ai_responding(session, chat_id) # raise CancelledError, `except Exception` wouldn't catch it
except Exception: # (CancelledError is a BaseException), and the rest of the
# finally block — including session.close() — would never run.
with anyio.CancelScope(shield=True):
try: try:
async with async_session_maker() as fresh_session: await session.rollback()
await clear_ai_responding(fresh_session, chat_id) await clear_ai_responding(session, chat_id)
except Exception: except Exception:
logging.getLogger(__name__).warning( try:
"Failed to clear AI responding state for thread %s", chat_id async with async_session_maker() as fresh_session:
) await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
# Break circular refs held by the agent graph, tools, and LLM # Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass. # wrappers so the GC can reclaim them in a single pass.
@ -1483,6 +1504,7 @@ async def stream_new_chat(
mentioned_documents = mentioned_surfsense_docs = None mentioned_documents = mentioned_surfsense_docs = None
recent_reports = langchain_messages = input_state = None recent_reports = langchain_messages = input_state = None
stream_result = None stream_result = None
session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2) collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected: if collected:
@ -1498,7 +1520,6 @@ async def stream_resume_chat(
chat_id: int, chat_id: int,
search_space_id: int, search_space_id: int,
decisions: list[dict], decisions: list[dict],
session: AsyncSession,
user_id: str | None = None, user_id: str | None = None,
llm_config_id: int = -1, llm_config_id: int = -1,
thread_visibility: ChatVisibility | None = None, thread_visibility: ChatVisibility | None = None,
@ -1507,6 +1528,7 @@ async def stream_resume_chat(
stream_result = StreamResult() stream_result = StreamResult()
_t_total = time.perf_counter() _t_total = time.perf_counter()
session = async_session_maker()
try: try:
if user_id: if user_id:
await set_ai_responding(session, chat_id, UUID(user_id)) await set_ai_responding(session, chat_id, UUID(user_id))
@ -1603,6 +1625,7 @@ async def stream_resume_chat(
# Release the transaction before streaming (same rationale as stream_new_chat). # Release the transaction before streaming (same rationale as stream_new_chat).
await session.commit() await session.commit()
session.expunge_all()
_perf_log.info( _perf_log.info(
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)", "[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
@ -1666,22 +1689,30 @@ async def stream_resume_chat(
yield streaming_service.format_done() yield streaming_service.format_done()
finally: finally:
try: with anyio.CancelScope(shield=True):
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
try: try:
async with async_session_maker() as fresh_session: await session.rollback()
await clear_ai_responding(fresh_session, chat_id) await clear_ai_responding(session, chat_id)
except Exception: except Exception:
logging.getLogger(__name__).warning( try:
"Failed to clear AI responding state for thread %s", chat_id async with async_session_maker() as fresh_session:
) await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files) _try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
agent = llm = connector_service = sandbox_backend = None agent = llm = connector_service = sandbox_backend = None
stream_result = None stream_result = None
session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2) collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected: if collected: