diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 12bdea455..4ba12c171 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -15,6 +15,8 @@ from dataclasses import dataclass from typing import Any from uuid import UUID +import logging + from langchain_core.messages import HumanMessage from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select @@ -28,7 +30,7 @@ from app.agents.new_chat.llm_config import ( load_agent_config, load_llm_config_from_yaml, ) -from app.db import ChatVisibility, Document, Report, SurfsenseDocsDocument +from app.db import ChatVisibility, Document, Report, SurfsenseDocsDocument, async_session_maker from app.prompts import TITLE_GENERATION_PROMPT_TEMPLATE from app.services.chat_session_state_service import ( clear_ai_responding, @@ -1276,8 +1278,21 @@ async def stream_new_chat( yield streaming_service.format_done() finally: - # Clear AI responding state for live collaboration - await clear_ai_responding(session, chat_id) + # Clear AI responding state for live collaboration. + # The original session may be broken (client disconnect / CancelledError + # can corrupt the underlying DB connection), so we try a rollback first + # and fall back to a fresh session if the original is unusable. + try: + await session.rollback() + await clear_ai_responding(session, chat_id) + except Exception: + try: + async with async_session_maker() as fresh_session: + await clear_ai_responding(fresh_session, chat_id) + except Exception: + logging.getLogger(__name__).warning( + "Failed to clear AI responding state for thread %s", chat_id + ) async def stream_resume_chat( @@ -1397,4 +1412,14 @@ async def stream_resume_chat( yield streaming_service.format_done() finally: - await clear_ai_responding(session, chat_id) + try: + await session.rollback() + await clear_ai_responding(session, chat_id) + except Exception: + try: + async with async_session_maker() as fresh_session: + await clear_ai_responding(fresh_session, chat_id) + except Exception: + logging.getLogger(__name__).warning( + "Failed to clear AI responding state for thread %s", chat_id + )