refactor: improve session management and cleanup in chat streaming

- Added proper session closure to prevent connection leaks during streaming.
- Implemented a fresh session for cleanup tasks to ensure data integrity.
- Enhanced error handling during session operations to improve robustness.
- Removed unnecessary session parameters from function signatures for clarity.
This commit is contained in:
DESKTOP-RTLN3BA\$punk 2026-02-28 23:17:11 -08:00
parent cbf9bc6bc9
commit dd3da2bc36
2 changed files with 91 additions and 48 deletions

View file

@ -10,6 +10,7 @@ Supports loading LLM configurations from:
"""
import asyncio
import contextlib
import gc
import json
import logging
@ -20,9 +21,9 @@ from dataclasses import dataclass, field
from typing import Any
from uuid import UUID
import anyio
from langchain_core.messages import HumanMessage
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
@ -1012,7 +1013,6 @@ async def stream_new_chat(
user_query: str,
search_space_id: int,
chat_id: int,
session: AsyncSession,
user_id: str | None = None,
llm_config_id: int = -1,
mentioned_document_ids: list[int] | None = None,
@ -1028,11 +1028,13 @@ async def stream_new_chat(
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
The function creates and manages its own database session to guarantee proper
cleanup even when Starlette's middleware cancels the task on client disconnect.
Args:
user_query: The user's query
search_space_id: The search space ID
chat_id: The chat ID (used as LangGraph thread_id for memory)
session: The database session
user_id: The current user's UUID string (for memory tools and session state)
llm_config_id: The LLM configuration ID (default: -1 for first global config)
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
@ -1048,6 +1050,7 @@ async def stream_new_chat(
_t_total = time.perf_counter()
log_system_snapshot("stream_new_chat_START")
session = async_session_maker()
try:
# Mark AI as responding to this user for live collaboration
if user_id:
@ -1283,6 +1286,12 @@ async def stream_new_chat(
# short-lived transactions (or use isolated sessions).
await session.commit()
# Detach heavy ORM objects (documents with chunks, reports, etc.)
# from the session identity map now that we've extracted the data
# we need. This prevents them from accumulating in memory for the
# entire duration of LLM streaming (which can be several minutes).
session.expunge_all()
_perf_log.info(
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
time.perf_counter() - _t_total,
@ -1459,23 +1468,35 @@ async def stream_new_chat(
yield streaming_service.format_done()
finally:
# Clear AI responding state for live collaboration.
# The original session may be broken (client disconnect / CancelledError
# can corrupt the underlying DB connection), so we try a rollback first
# and fall back to a fresh session if the original is unusable.
try:
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
# Shield the ENTIRE async cleanup from anyio cancel-scope
# cancellation. Starlette's BaseHTTPMiddleware uses anyio task
# groups; on client disconnect, it cancels the scope with
# level-triggered cancellation — every unshielded `await` inside
# the cancelled scope raises CancelledError immediately. Without
# this shield the very first `await` (session.rollback) would
# raise CancelledError, `except Exception` wouldn't catch it
# (CancelledError is a BaseException), and the rest of the
# finally block — including session.close() — would never run.
with anyio.CancelScope(shield=True):
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
# Break circular refs held by the agent graph, tools, and LLM
# wrappers so the GC can reclaim them in a single pass.
@ -1483,6 +1504,7 @@ async def stream_new_chat(
mentioned_documents = mentioned_surfsense_docs = None
recent_reports = langchain_messages = input_state = None
stream_result = None
session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected:
@ -1498,7 +1520,6 @@ async def stream_resume_chat(
chat_id: int,
search_space_id: int,
decisions: list[dict],
session: AsyncSession,
user_id: str | None = None,
llm_config_id: int = -1,
thread_visibility: ChatVisibility | None = None,
@ -1507,6 +1528,7 @@ async def stream_resume_chat(
stream_result = StreamResult()
_t_total = time.perf_counter()
session = async_session_maker()
try:
if user_id:
await set_ai_responding(session, chat_id, UUID(user_id))
@ -1603,6 +1625,7 @@ async def stream_resume_chat(
# Release the transaction before streaming (same rationale as stream_new_chat).
await session.commit()
session.expunge_all()
_perf_log.info(
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
@ -1666,22 +1689,30 @@ async def stream_resume_chat(
yield streaming_service.format_done()
finally:
try:
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
with anyio.CancelScope(shield=True):
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
await session.rollback()
await clear_ai_responding(session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
try:
async with async_session_maker() as fresh_session:
await clear_ai_responding(fresh_session, chat_id)
except Exception:
logging.getLogger(__name__).warning(
"Failed to clear AI responding state for thread %s", chat_id
)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
with contextlib.suppress(Exception):
session.expunge_all()
with contextlib.suppress(Exception):
await session.close()
agent = llm = connector_service = sandbox_backend = None
stream_result = None
session = None
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
if collected: