mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
Merge pull request #851 from MODSetter/dev
refactor: improve session management and cleanup in chat streaming
This commit is contained in:
commit
d2e06583ca
2 changed files with 91 additions and 48 deletions
|
|
@ -31,6 +31,7 @@ from app.db import (
|
|||
Permission,
|
||||
SearchSpace,
|
||||
User,
|
||||
async_session_maker,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.new_chat import (
|
||||
|
|
@ -1092,13 +1093,18 @@ async def handle_new_chat(
|
|||
# on searchspaces/documents for the entire duration of the stream.
|
||||
# expire_on_commit=False keeps loaded ORM attrs usable.
|
||||
await session.commit()
|
||||
# Close the dependency session now so its connection returns to
|
||||
# the pool before streaming begins. Without this, Starlette's
|
||||
# BaseHTTPMiddleware cancels the scope on client disconnect and
|
||||
# the dependency generator's __aexit__ never runs, orphaning the
|
||||
# connection (the "Exception terminating connection" errors).
|
||||
await session.close()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_new_chat(
|
||||
user_query=request.user_query,
|
||||
search_space_id=request.search_space_id,
|
||||
chat_id=request.chat_id,
|
||||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
|
|
@ -1323,6 +1329,7 @@ async def regenerate_response(
|
|||
# on searchspaces/documents for the entire duration of the stream.
|
||||
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
|
||||
await session.commit()
|
||||
await session.close()
|
||||
|
||||
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
|
||||
# This prevents data loss if streaming fails (network error, LLM error, etc.)
|
||||
|
|
@ -1333,7 +1340,6 @@ async def regenerate_response(
|
|||
user_query=user_query_to_use,
|
||||
search_space_id=request.search_space_id,
|
||||
chat_id=thread_id,
|
||||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
mentioned_document_ids=request.mentioned_document_ids,
|
||||
|
|
@ -1344,29 +1350,35 @@ async def regenerate_response(
|
|||
current_user_display_name=user.display_name or "A team member",
|
||||
):
|
||||
yield chunk
|
||||
# If we get here, streaming completed successfully
|
||||
streaming_completed = True
|
||||
finally:
|
||||
# Only delete old messages if streaming completed successfully
|
||||
# This ensures we don't lose data on streaming failures
|
||||
if streaming_completed and messages_to_delete:
|
||||
# Only delete old messages if streaming completed successfully.
|
||||
# Uses a fresh session since stream_new_chat manages its own.
|
||||
if streaming_completed and message_ids_to_delete:
|
||||
try:
|
||||
for msg in messages_to_delete:
|
||||
await session.delete(msg)
|
||||
await session.commit()
|
||||
async with async_session_maker() as cleanup_session:
|
||||
for msg_id in message_ids_to_delete:
|
||||
_res = await cleanup_session.execute(
|
||||
select(NewChatMessage).filter(
|
||||
NewChatMessage.id == msg_id
|
||||
)
|
||||
)
|
||||
_msg = _res.scalars().first()
|
||||
if _msg:
|
||||
await cleanup_session.delete(_msg)
|
||||
await cleanup_session.commit()
|
||||
|
||||
# Delete any public snapshots that contain the modified messages
|
||||
from app.services.public_chat_service import (
|
||||
delete_affected_snapshots,
|
||||
)
|
||||
from app.services.public_chat_service import (
|
||||
delete_affected_snapshots,
|
||||
)
|
||||
|
||||
await delete_affected_snapshots(
|
||||
session, thread_id, message_ids_to_delete
|
||||
)
|
||||
await delete_affected_snapshots(
|
||||
cleanup_session, thread_id, message_ids_to_delete
|
||||
)
|
||||
except Exception as cleanup_error:
|
||||
# Log but don't fail - the new messages are already streamed
|
||||
print(
|
||||
f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}"
|
||||
_logger.warning(
|
||||
"[regenerate] Failed to delete old messages: %s",
|
||||
cleanup_error,
|
||||
)
|
||||
|
||||
# Return streaming response with checkpoint_id for rewinding
|
||||
|
|
@ -1440,13 +1452,13 @@ async def resume_chat(
|
|||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||
# on searchspaces/documents for the entire duration of the stream.
|
||||
await session.commit()
|
||||
await session.close()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_resume_chat(
|
||||
chat_id=thread_id,
|
||||
search_space_id=request.search_space_id,
|
||||
decisions=decisions,
|
||||
session=session,
|
||||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
thread_visibility=thread.visibility,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ Supports loading LLM configurations from:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
|
|
@ -20,9 +21,9 @@ from dataclasses import dataclass, field
|
|||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import anyio
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
|
@ -1012,7 +1013,6 @@ async def stream_new_chat(
|
|||
user_query: str,
|
||||
search_space_id: int,
|
||||
chat_id: int,
|
||||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
mentioned_document_ids: list[int] | None = None,
|
||||
|
|
@ -1028,11 +1028,13 @@ async def stream_new_chat(
|
|||
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
|
||||
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
|
||||
|
||||
The function creates and manages its own database session to guarantee proper
|
||||
cleanup even when Starlette's middleware cancels the task on client disconnect.
|
||||
|
||||
Args:
|
||||
user_query: The user's query
|
||||
search_space_id: The search space ID
|
||||
chat_id: The chat ID (used as LangGraph thread_id for memory)
|
||||
session: The database session
|
||||
user_id: The current user's UUID string (for memory tools and session state)
|
||||
llm_config_id: The LLM configuration ID (default: -1 for first global config)
|
||||
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
|
||||
|
|
@ -1048,6 +1050,7 @@ async def stream_new_chat(
|
|||
_t_total = time.perf_counter()
|
||||
log_system_snapshot("stream_new_chat_START")
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
# Mark AI as responding to this user for live collaboration
|
||||
if user_id:
|
||||
|
|
@ -1283,6 +1286,12 @@ async def stream_new_chat(
|
|||
# short-lived transactions (or use isolated sessions).
|
||||
await session.commit()
|
||||
|
||||
# Detach heavy ORM objects (documents with chunks, reports, etc.)
|
||||
# from the session identity map now that we've extracted the data
|
||||
# we need. This prevents them from accumulating in memory for the
|
||||
# entire duration of LLM streaming (which can be several minutes).
|
||||
session.expunge_all()
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||
time.perf_counter() - _t_total,
|
||||
|
|
@ -1459,23 +1468,35 @@ async def stream_new_chat(
|
|||
yield streaming_service.format_done()
|
||||
|
||||
finally:
|
||||
# Clear AI responding state for live collaboration.
|
||||
# The original session may be broken (client disconnect / CancelledError
|
||||
# can corrupt the underlying DB connection), so we try a rollback first
|
||||
# and fall back to a fresh session if the original is unusable.
|
||||
try:
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
# Shield the ENTIRE async cleanup from anyio cancel-scope
|
||||
# cancellation. Starlette's BaseHTTPMiddleware uses anyio task
|
||||
# groups; on client disconnect, it cancels the scope with
|
||||
# level-triggered cancellation — every unshielded `await` inside
|
||||
# the cancelled scope raises CancelledError immediately. Without
|
||||
# this shield the very first `await` (session.rollback) would
|
||||
# raise CancelledError, `except Exception` wouldn't catch it
|
||||
# (CancelledError is a BaseException), and the rest of the
|
||||
# finally block — including session.close() — would never run.
|
||||
with anyio.CancelScope(shield=True):
|
||||
try:
|
||||
async with async_session_maker() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
try:
|
||||
async with async_session_maker() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
session.expunge_all()
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
# Break circular refs held by the agent graph, tools, and LLM
|
||||
# wrappers so the GC can reclaim them in a single pass.
|
||||
|
|
@ -1483,6 +1504,7 @@ async def stream_new_chat(
|
|||
mentioned_documents = mentioned_surfsense_docs = None
|
||||
recent_reports = langchain_messages = input_state = None
|
||||
stream_result = None
|
||||
session = None
|
||||
|
||||
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||
if collected:
|
||||
|
|
@ -1498,7 +1520,6 @@ async def stream_resume_chat(
|
|||
chat_id: int,
|
||||
search_space_id: int,
|
||||
decisions: list[dict],
|
||||
session: AsyncSession,
|
||||
user_id: str | None = None,
|
||||
llm_config_id: int = -1,
|
||||
thread_visibility: ChatVisibility | None = None,
|
||||
|
|
@ -1507,6 +1528,7 @@ async def stream_resume_chat(
|
|||
stream_result = StreamResult()
|
||||
_t_total = time.perf_counter()
|
||||
|
||||
session = async_session_maker()
|
||||
try:
|
||||
if user_id:
|
||||
await set_ai_responding(session, chat_id, UUID(user_id))
|
||||
|
|
@ -1603,6 +1625,7 @@ async def stream_resume_chat(
|
|||
|
||||
# Release the transaction before streaming (same rationale as stream_new_chat).
|
||||
await session.commit()
|
||||
session.expunge_all()
|
||||
|
||||
_perf_log.info(
|
||||
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||
|
|
@ -1666,22 +1689,30 @@ async def stream_resume_chat(
|
|||
yield streaming_service.format_done()
|
||||
|
||||
finally:
|
||||
try:
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
with anyio.CancelScope(shield=True):
|
||||
try:
|
||||
async with async_session_maker() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
await session.rollback()
|
||||
await clear_ai_responding(session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
try:
|
||||
async with async_session_maker() as fresh_session:
|
||||
await clear_ai_responding(fresh_session, chat_id)
|
||||
except Exception:
|
||||
logging.getLogger(__name__).warning(
|
||||
"Failed to clear AI responding state for thread %s", chat_id
|
||||
)
|
||||
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
session.expunge_all()
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
|
||||
agent = llm = connector_service = sandbox_backend = None
|
||||
stream_result = None
|
||||
session = None
|
||||
|
||||
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||
if collected:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue