mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-26 21:39:43 +02:00
refactor: improve session management and cleanup in chat streaming
- Added proper session closure to prevent connection leaks during streaming. - Implemented a fresh session for cleanup tasks to ensure data integrity. - Enhanced error handling during session operations to improve robustness. - Removed unnecessary session parameters from function signatures for clarity.
This commit is contained in:
parent
cbf9bc6bc9
commit
dd3da2bc36
2 changed files with 91 additions and 48 deletions
|
|
@ -31,6 +31,7 @@ from app.db import (
|
||||||
Permission,
|
Permission,
|
||||||
SearchSpace,
|
SearchSpace,
|
||||||
User,
|
User,
|
||||||
|
async_session_maker,
|
||||||
get_async_session,
|
get_async_session,
|
||||||
)
|
)
|
||||||
from app.schemas.new_chat import (
|
from app.schemas.new_chat import (
|
||||||
|
|
@ -1092,13 +1093,18 @@ async def handle_new_chat(
|
||||||
# on searchspaces/documents for the entire duration of the stream.
|
# on searchspaces/documents for the entire duration of the stream.
|
||||||
# expire_on_commit=False keeps loaded ORM attrs usable.
|
# expire_on_commit=False keeps loaded ORM attrs usable.
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
# Close the dependency session now so its connection returns to
|
||||||
|
# the pool before streaming begins. Without this, Starlette's
|
||||||
|
# BaseHTTPMiddleware cancels the scope on client disconnect and
|
||||||
|
# the dependency generator's __aexit__ never runs, orphaning the
|
||||||
|
# connection (the "Exception terminating connection" errors).
|
||||||
|
await session.close()
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_new_chat(
|
stream_new_chat(
|
||||||
user_query=request.user_query,
|
user_query=request.user_query,
|
||||||
search_space_id=request.search_space_id,
|
search_space_id=request.search_space_id,
|
||||||
chat_id=request.chat_id,
|
chat_id=request.chat_id,
|
||||||
session=session,
|
|
||||||
user_id=str(user.id),
|
user_id=str(user.id),
|
||||||
llm_config_id=llm_config_id,
|
llm_config_id=llm_config_id,
|
||||||
mentioned_document_ids=request.mentioned_document_ids,
|
mentioned_document_ids=request.mentioned_document_ids,
|
||||||
|
|
@ -1323,6 +1329,7 @@ async def regenerate_response(
|
||||||
# on searchspaces/documents for the entire duration of the stream.
|
# on searchspaces/documents for the entire duration of the stream.
|
||||||
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
|
# expire_on_commit=False keeps loaded ORM attrs (including messages_to_delete PKs) usable.
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
await session.close()
|
||||||
|
|
||||||
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
|
# Create a wrapper generator that deletes messages only AFTER streaming succeeds
|
||||||
# This prevents data loss if streaming fails (network error, LLM error, etc.)
|
# This prevents data loss if streaming fails (network error, LLM error, etc.)
|
||||||
|
|
@ -1333,7 +1340,6 @@ async def regenerate_response(
|
||||||
user_query=user_query_to_use,
|
user_query=user_query_to_use,
|
||||||
search_space_id=request.search_space_id,
|
search_space_id=request.search_space_id,
|
||||||
chat_id=thread_id,
|
chat_id=thread_id,
|
||||||
session=session,
|
|
||||||
user_id=str(user.id),
|
user_id=str(user.id),
|
||||||
llm_config_id=llm_config_id,
|
llm_config_id=llm_config_id,
|
||||||
mentioned_document_ids=request.mentioned_document_ids,
|
mentioned_document_ids=request.mentioned_document_ids,
|
||||||
|
|
@ -1344,29 +1350,35 @@ async def regenerate_response(
|
||||||
current_user_display_name=user.display_name or "A team member",
|
current_user_display_name=user.display_name or "A team member",
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
# If we get here, streaming completed successfully
|
|
||||||
streaming_completed = True
|
streaming_completed = True
|
||||||
finally:
|
finally:
|
||||||
# Only delete old messages if streaming completed successfully
|
# Only delete old messages if streaming completed successfully.
|
||||||
# This ensures we don't lose data on streaming failures
|
# Uses a fresh session since stream_new_chat manages its own.
|
||||||
if streaming_completed and messages_to_delete:
|
if streaming_completed and message_ids_to_delete:
|
||||||
try:
|
try:
|
||||||
for msg in messages_to_delete:
|
async with async_session_maker() as cleanup_session:
|
||||||
await session.delete(msg)
|
for msg_id in message_ids_to_delete:
|
||||||
await session.commit()
|
_res = await cleanup_session.execute(
|
||||||
|
select(NewChatMessage).filter(
|
||||||
|
NewChatMessage.id == msg_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_msg = _res.scalars().first()
|
||||||
|
if _msg:
|
||||||
|
await cleanup_session.delete(_msg)
|
||||||
|
await cleanup_session.commit()
|
||||||
|
|
||||||
# Delete any public snapshots that contain the modified messages
|
from app.services.public_chat_service import (
|
||||||
from app.services.public_chat_service import (
|
delete_affected_snapshots,
|
||||||
delete_affected_snapshots,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
await delete_affected_snapshots(
|
await delete_affected_snapshots(
|
||||||
session, thread_id, message_ids_to_delete
|
cleanup_session, thread_id, message_ids_to_delete
|
||||||
)
|
)
|
||||||
except Exception as cleanup_error:
|
except Exception as cleanup_error:
|
||||||
# Log but don't fail - the new messages are already streamed
|
_logger.warning(
|
||||||
print(
|
"[regenerate] Failed to delete old messages: %s",
|
||||||
f"[regenerate] Warning: Failed to delete old messages: {cleanup_error}"
|
cleanup_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return streaming response with checkpoint_id for rewinding
|
# Return streaming response with checkpoint_id for rewinding
|
||||||
|
|
@ -1440,13 +1452,13 @@ async def resume_chat(
|
||||||
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
# Release the read-transaction so we don't hold ACCESS SHARE locks
|
||||||
# on searchspaces/documents for the entire duration of the stream.
|
# on searchspaces/documents for the entire duration of the stream.
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
await session.close()
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_resume_chat(
|
stream_resume_chat(
|
||||||
chat_id=thread_id,
|
chat_id=thread_id,
|
||||||
search_space_id=request.search_space_id,
|
search_space_id=request.search_space_id,
|
||||||
decisions=decisions,
|
decisions=decisions,
|
||||||
session=session,
|
|
||||||
user_id=str(user.id),
|
user_id=str(user.id),
|
||||||
llm_config_id=llm_config_id,
|
llm_config_id=llm_config_id,
|
||||||
thread_visibility=thread.visibility,
|
thread_visibility=thread.visibility,
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ Supports loading LLM configurations from:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -20,9 +21,9 @@ from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
import anyio
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
|
@ -1012,7 +1013,6 @@ async def stream_new_chat(
|
||||||
user_query: str,
|
user_query: str,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
session: AsyncSession,
|
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
llm_config_id: int = -1,
|
llm_config_id: int = -1,
|
||||||
mentioned_document_ids: list[int] | None = None,
|
mentioned_document_ids: list[int] | None = None,
|
||||||
|
|
@ -1028,11 +1028,13 @@ async def stream_new_chat(
|
||||||
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
|
This uses the Vercel AI SDK Data Stream Protocol (SSE format) for streaming.
|
||||||
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
|
The chat_id is used as LangGraph's thread_id for memory/checkpointing.
|
||||||
|
|
||||||
|
The function creates and manages its own database session to guarantee proper
|
||||||
|
cleanup even when Starlette's middleware cancels the task on client disconnect.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_query: The user's query
|
user_query: The user's query
|
||||||
search_space_id: The search space ID
|
search_space_id: The search space ID
|
||||||
chat_id: The chat ID (used as LangGraph thread_id for memory)
|
chat_id: The chat ID (used as LangGraph thread_id for memory)
|
||||||
session: The database session
|
|
||||||
user_id: The current user's UUID string (for memory tools and session state)
|
user_id: The current user's UUID string (for memory tools and session state)
|
||||||
llm_config_id: The LLM configuration ID (default: -1 for first global config)
|
llm_config_id: The LLM configuration ID (default: -1 for first global config)
|
||||||
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
|
needs_history_bootstrap: If True, load message history from DB (for cloned chats)
|
||||||
|
|
@ -1048,6 +1050,7 @@ async def stream_new_chat(
|
||||||
_t_total = time.perf_counter()
|
_t_total = time.perf_counter()
|
||||||
log_system_snapshot("stream_new_chat_START")
|
log_system_snapshot("stream_new_chat_START")
|
||||||
|
|
||||||
|
session = async_session_maker()
|
||||||
try:
|
try:
|
||||||
# Mark AI as responding to this user for live collaboration
|
# Mark AI as responding to this user for live collaboration
|
||||||
if user_id:
|
if user_id:
|
||||||
|
|
@ -1283,6 +1286,12 @@ async def stream_new_chat(
|
||||||
# short-lived transactions (or use isolated sessions).
|
# short-lived transactions (or use isolated sessions).
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
# Detach heavy ORM objects (documents with chunks, reports, etc.)
|
||||||
|
# from the session identity map now that we've extracted the data
|
||||||
|
# we need. This prevents them from accumulating in memory for the
|
||||||
|
# entire duration of LLM streaming (which can be several minutes).
|
||||||
|
session.expunge_all()
|
||||||
|
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
|
"[stream_new_chat] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||||
time.perf_counter() - _t_total,
|
time.perf_counter() - _t_total,
|
||||||
|
|
@ -1459,23 +1468,35 @@ async def stream_new_chat(
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clear AI responding state for live collaboration.
|
# Shield the ENTIRE async cleanup from anyio cancel-scope
|
||||||
# The original session may be broken (client disconnect / CancelledError
|
# cancellation. Starlette's BaseHTTPMiddleware uses anyio task
|
||||||
# can corrupt the underlying DB connection), so we try a rollback first
|
# groups; on client disconnect, it cancels the scope with
|
||||||
# and fall back to a fresh session if the original is unusable.
|
# level-triggered cancellation — every unshielded `await` inside
|
||||||
try:
|
# the cancelled scope raises CancelledError immediately. Without
|
||||||
await session.rollback()
|
# this shield the very first `await` (session.rollback) would
|
||||||
await clear_ai_responding(session, chat_id)
|
# raise CancelledError, `except Exception` wouldn't catch it
|
||||||
except Exception:
|
# (CancelledError is a BaseException), and the rest of the
|
||||||
|
# finally block — including session.close() — would never run.
|
||||||
|
with anyio.CancelScope(shield=True):
|
||||||
try:
|
try:
|
||||||
async with async_session_maker() as fresh_session:
|
await session.rollback()
|
||||||
await clear_ai_responding(fresh_session, chat_id)
|
await clear_ai_responding(session, chat_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.getLogger(__name__).warning(
|
try:
|
||||||
"Failed to clear AI responding state for thread %s", chat_id
|
async with async_session_maker() as fresh_session:
|
||||||
)
|
await clear_ai_responding(fresh_session, chat_id)
|
||||||
|
except Exception:
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
"Failed to clear AI responding state for thread %s", chat_id
|
||||||
|
)
|
||||||
|
|
||||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||||
|
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
session.expunge_all()
|
||||||
|
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await session.close()
|
||||||
|
|
||||||
# Break circular refs held by the agent graph, tools, and LLM
|
# Break circular refs held by the agent graph, tools, and LLM
|
||||||
# wrappers so the GC can reclaim them in a single pass.
|
# wrappers so the GC can reclaim them in a single pass.
|
||||||
|
|
@ -1483,6 +1504,7 @@ async def stream_new_chat(
|
||||||
mentioned_documents = mentioned_surfsense_docs = None
|
mentioned_documents = mentioned_surfsense_docs = None
|
||||||
recent_reports = langchain_messages = input_state = None
|
recent_reports = langchain_messages = input_state = None
|
||||||
stream_result = None
|
stream_result = None
|
||||||
|
session = None
|
||||||
|
|
||||||
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||||
if collected:
|
if collected:
|
||||||
|
|
@ -1498,7 +1520,6 @@ async def stream_resume_chat(
|
||||||
chat_id: int,
|
chat_id: int,
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
decisions: list[dict],
|
decisions: list[dict],
|
||||||
session: AsyncSession,
|
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
llm_config_id: int = -1,
|
llm_config_id: int = -1,
|
||||||
thread_visibility: ChatVisibility | None = None,
|
thread_visibility: ChatVisibility | None = None,
|
||||||
|
|
@ -1507,6 +1528,7 @@ async def stream_resume_chat(
|
||||||
stream_result = StreamResult()
|
stream_result = StreamResult()
|
||||||
_t_total = time.perf_counter()
|
_t_total = time.perf_counter()
|
||||||
|
|
||||||
|
session = async_session_maker()
|
||||||
try:
|
try:
|
||||||
if user_id:
|
if user_id:
|
||||||
await set_ai_responding(session, chat_id, UUID(user_id))
|
await set_ai_responding(session, chat_id, UUID(user_id))
|
||||||
|
|
@ -1603,6 +1625,7 @@ async def stream_resume_chat(
|
||||||
|
|
||||||
# Release the transaction before streaming (same rationale as stream_new_chat).
|
# Release the transaction before streaming (same rationale as stream_new_chat).
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
session.expunge_all()
|
||||||
|
|
||||||
_perf_log.info(
|
_perf_log.info(
|
||||||
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
|
"[stream_resume] Total pre-stream setup in %.3fs (chat_id=%s)",
|
||||||
|
|
@ -1666,22 +1689,30 @@ async def stream_resume_chat(
|
||||||
yield streaming_service.format_done()
|
yield streaming_service.format_done()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
try:
|
with anyio.CancelScope(shield=True):
|
||||||
await session.rollback()
|
|
||||||
await clear_ai_responding(session, chat_id)
|
|
||||||
except Exception:
|
|
||||||
try:
|
try:
|
||||||
async with async_session_maker() as fresh_session:
|
await session.rollback()
|
||||||
await clear_ai_responding(fresh_session, chat_id)
|
await clear_ai_responding(session, chat_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.getLogger(__name__).warning(
|
try:
|
||||||
"Failed to clear AI responding state for thread %s", chat_id
|
async with async_session_maker() as fresh_session:
|
||||||
)
|
await clear_ai_responding(fresh_session, chat_id)
|
||||||
|
except Exception:
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
"Failed to clear AI responding state for thread %s", chat_id
|
||||||
|
)
|
||||||
|
|
||||||
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
_try_persist_and_delete_sandbox(chat_id, stream_result.sandbox_files)
|
||||||
|
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
session.expunge_all()
|
||||||
|
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await session.close()
|
||||||
|
|
||||||
agent = llm = connector_service = sandbox_backend = None
|
agent = llm = connector_service = sandbox_backend = None
|
||||||
stream_result = None
|
stream_result = None
|
||||||
|
session = None
|
||||||
|
|
||||||
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
collected = gc.collect(0) + gc.collect(1) + gc.collect(2)
|
||||||
if collected:
|
if collected:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue