Merge pull request #851 from MODSetter/dev

refactor: improve session management and cleanup in chat streaming
This commit is contained in:
Rohan Verma 2026-02-28 23:18:11 -08:00 committed by GitHub
commit d2e06583ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 91 additions and 48 deletions

View file

@ -31,6 +31,7 @@ from app.db import (
Permission,
SearchSpace,
User,
async_session_maker,
get_async_session,
)
from app.schemas.new_chat import (
@ -1092,13 +1093,18 @@ async def handle_new_chat(
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs usable.
await session.commit()
# Close the dependency session now so its connection returns to
# the pool before streaming begins. Without this, Starlette's
# BaseHTTPMiddleware cancels the scope on client disconnect and
# the dependency generator's __aexit__ never runs, orphaning the
# connection (the "Exception terminating connection" errors).
await session.close()
return StreamingResponse(
stream_new_chat(
user_query=request.user_query,
search_space_id=request.search_space_id,
chat_id=request.chat_id,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
@ -1323,6 +1329,7 @@ async def regenerate_response(
# on searchspaces/documents for the entire duration of the stream.
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
await session.commit()
await session.close()
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
# This prevents data loss if streaming fails (network error, LLM error, etc.)
@ -1333,7 +1340,6 @@ async def regenerate_response(
user_query=user_query_to_use,
search_space_id=request.search_space_id,
chat_id=thread_id,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
mentioned_document_ids=request.mentioned_document_ids,
@ -1344,29 +1350,35 @@ async def regenerate_response(
current_user_display_name=user.display_name or "A team member",
):
yield chunk
# If we get here, streaming completed successfully
streaming_completed = True
finally:
# Only delete old messages if streaming completed successfully
# This ensures we don't lose data on streaming failures
if streaming_completed and messages_to_delete:
# Only delete old messages if streaming completed successfully.
# Uses a fresh session since stream_new_chat manages its own.
if streaming_completed and message_ids_to_delete:
try:
for msg in messages_to_delete:
await session.delete(msg)
await session.commit()
async with async_session_maker() as cleanup_session:
for msg_id in message_ids_to_delete:
_res = await cleanup_session.execute(
select(NewChatMessage).filter(
NewChatMessage.id == msg_id
)
)
_msg = _res.scalars().first()
if _msg:
await cleanup_session.delete(_msg)
await cleanup_session.commit()
# Delete any public snapshots that contain the modified messages
from app.services.public_chat_service import (
delete_affected_snapshots,
)
from app.services.public_chat_service import (
delete_affected_snapshots,
)
await delete_affected_snapshots(
session, thread_id, message_ids_to_delete
)
await delete_affected_snapshots(
cleanup_session, thread_id, message_ids_to_delete
)
except Exception as cleanup_error:
# Log but don't fail - the new messages are already streamed
print(
f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}"
_logger.warning(
"[regenerate] Failed to delete old messages: %s",
cleanup_error,
)
# Return streaming response with checkpoint_id for rewinding
@ -1440,13 +1452,13 @@ async def resume_chat(
# Release the read-transaction so we don't hold ACCESS SHARE locks
# on searchspaces/documents for the entire duration of the stream.
await session.commit()
await session.close()
return StreamingResponse(
stream_resume_chat(
chat_id=thread_id,
search_space_id=request.search_space_id,
decisions=decisions,
session=session,
user_id=str(user.id),
llm_config_id=llm_config_id,
thread_visibility=thread.visibility,

View file

@ -10,6 +10,7 @@ Supports loading LLM configurations from:
"""
import asyncio
import contextlib
import gc
import json
import logging
@ -20,9 +21,9 @@ from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
import anyio
from langchain_core.messages import HumanMessage
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
@ -1012,7 +1013,6 @@ async def stream_new_chat(
user_query: str,
search_space_id: int,
chat_id: int,
session: AsyncSession,
user_id: str | None = None,
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
@ -1028,11 +1028,13 @@ async def stream_new_chat(
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
The function creates and manages its own database session to guarantee proper
cleanup even when Starlette's middleware cancels the task on client disconnect.
Args:
user_query: The user's query
search_space_id: The search space ID
chat_id: The chat ID (used as LangGraph thread_id for memory)
session: The database session
user_id: The current user's UUID string (for memory tools and session state)
llm_config_id: The LLM configuration ID (default: -1 for first global config)
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
@ -1048,6 +1050,7 @@ async def stream_new_chat(
_t_total = time.perf_counter()
log_system_snapshot("stream_new_chat_START")
session = async_session_maker()
try:
# Mark AI as responding to this user for live collaboration
if user_id:
@ -1283,6 +1286,12 @@ async def stream_new_chat(
# short-lived transactions (or use isolated sessions).
await session.commit()
# Detach heavy ORM objects (documents with chunks, reports, etc.)
# from the session identity map now that we've extracted the data
# we need. This prevents them from accumulating in memory for the
# entire duration of LLM streaming (which can be several minutes).
session.expunge_all()
_perf_log.info(
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
@ -1459,23 +1468,35 @@ async def stream_new_chat(
yield streaming_service.format_done()
finally:
# Clear AI responding state for live collaboration.
# The original session may be broken (client disconnect / CancelledError
# can corrupt the underlying DB connection), so we try a rollback first
# and fall back to a fresh session if the original is unusable.
try:
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
# Shield the ENTIRE async cleanup from anyio cancel-scope
# cancellation. Starlette's BaseHTTPMiddleware uses anyio task
# groups; on client disconnect, it cancels the scope with
# level-triggered cancellation — every unshielded `await` inside
# the cancelled scope raises CancelledError immediately. Without
# this shield the very first `await` (session.rollback) would
# raise CancelledError, `except Exception` wouldn't catch it
# (CancelledError is a BaseException), and the rest of the
# finally block — including session.close() — would never run.
with anyio.CancelScope(shield=True):
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
@ -1483,6 +1504,7 @@ async def stream_new_chat(
mentioned_documents = mentioned_surfsense_docs = None
recent_reports = langchain_messages = input_state = None
stream_result = None
session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected:
@ -1498,7 +1520,6 @@ async def stream_resume_chat(
chat_id: int,
search_space_id: int,
decisions: list[dict],
session: AsyncSession,
user_id: str | None = None,
llm_config_id: int = -1,
thread_visibility: ChatVisibility | None = None,
@ -1507,6 +1528,7 @@ async def stream_resume_chat(
stream_result = StreamResult()
_t_total = time.perf_counter()
session = async_session_maker()
try:
if user_id:
await set_ai_responding(session, chat_id, UUID(user_id))
@ -1603,6 +1625,7 @@ async def stream_resume_chat(
# Release the transaction before streaming (same rationale as stream_new_chat).
await session.commit()
session.expunge_all()
_perf_log.info(
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
@ -1666,22 +1689,30 @@ async def stream_resume_chat(
yield streaming_service.format_done()
finally:
try:
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
with anyio.CancelScope(shield=True):
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
agent = llm = connector_service = sandbox_backend = None
stream_result = None
session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected: