diff --git a/surfsense_backend/app/routes/new_chat_routes.py b/surfsense_backend/app/routes/new_chat_routes.py index f2c26e449..5e1bc238f 100644 --- a/surfsense_backend/app/routes/new_chat_routes.py +++ b/surfsense_backend/app/routes/new_chat_routes.py @@ -1273,6 +1273,8 @@ async def regenerate_response( .limit(2) ) messages_to_delete = list(last_messages_result.scalars().all()) + + message_ids_to_delete = [msg.id for msg in messages_to_delete] # Get search space for LLM config search_space_result = await session.execute( @@ -1313,9 +1315,6 @@ async def regenerate_response( # This ensures we don't lose data on streaming failures if streaming_completed and messages_to_delete: try: - # Get message IDs before deletion for snapshot cleanup - deleted_message_ids = [msg.id for msg in messages_to_delete] - for msg in messages_to_delete: await session.delete(msg) await session.commit() @@ -1326,7 +1325,7 @@ async def regenerate_response( ) await delete_affected_snapshots( - session, thread_id, deleted_message_ids + session, thread_id, message_ids_to_delete ) except Exception as cleanup_error: # Log but don't fail - the new messages are already streamed diff --git a/surfsense_backend/app/services/public_chat_service.py b/surfsense_backend/app/services/public_chat_service.py index f0d8eb18e..5e8580642 100644 --- a/surfsense_backend/app/services/public_chat_service.py +++ b/surfsense_backend/app/services/public_chat_service.py @@ -426,7 +426,7 @@ async def delete_snapshot( async def delete_affected_snapshots( - session: AsyncSession, + session: AsyncSession, # noqa: ARG001 - kept for API compatibility thread_id: int, message_ids: list[int], ) -> int: @@ -434,25 +434,27 @@ async def delete_affected_snapshots( Delete snapshots that contain any of the given message IDs. Called when messages are edited/deleted/regenerated. - - Returns the number of deleted snapshots. + Uses independent session to work reliably in streaming response cleanup. """ if not message_ids: return 0 - # Use raw SQL for array overlap query - # The && operator checks if arrays have any elements in common - result = await session.execute( - delete(PublicChatSnapshot) - .where(PublicChatSnapshot.thread_id == thread_id) - .where(PublicChatSnapshot.message_ids.overlap(message_ids)) - .returning(PublicChatSnapshot.id) - ) + from sqlalchemy.dialects.postgresql import array - deleted_ids = result.scalars().all() - await session.commit() + from app.db import async_session_maker - return len(deleted_ids) + async with async_session_maker() as independent_session: + result = await independent_session.execute( + delete(PublicChatSnapshot) + .where(PublicChatSnapshot.thread_id == thread_id) + .where(PublicChatSnapshot.message_ids.op("&&")(array(message_ids))) + .returning(PublicChatSnapshot.id) + ) + + deleted_ids = result.scalars().all() + await independent_session.commit() + + return len(deleted_ids) # =============================================================================