From 3ec30d94cefeced76c9665d24c14ac7fbbc8a2d6 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sat, 14 Feb 2026 18:48:36 +0530 Subject: [PATCH] refactor: optimize report generation tool to use short-lived database sessions for improved performance and connection management --- .../app/agents/new_chat/tools/registry.py | 5 +- .../app/agents/new_chat/tools/report.py | 164 ++++++++++-------- 2 files changed, 93 insertions(+), 76 deletions(-) diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index 275b674ec..3f9783b86 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -119,16 +119,15 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ), requires=["search_space_id", "db_session", "thread_id"], ), - # Report generation tool (inline, no Celery) + # Report generation tool (inline, short-lived sessions for DB ops) ToolDefinition( name="generate_report", description="Generate a structured Markdown report from provided content", factory=lambda deps: create_generate_report_tool( search_space_id=deps["search_space_id"], - db_session=deps["db_session"], thread_id=deps["thread_id"], ), - requires=["search_space_id", "db_session", "thread_id"], + requires=["search_space_id", "thread_id"], ), # Link preview tool - fetches Open Graph metadata for URLs ToolDefinition( diff --git a/surfsense_backend/app/agents/new_chat/tools/report.py b/surfsense_backend/app/agents/new_chat/tools/report.py index 85449f8e3..7a186b99b 100644 --- a/surfsense_backend/app/agents/new_chat/tools/report.py +++ b/surfsense_backend/app/agents/new_chat/tools/report.py @@ -6,18 +6,20 @@ that generates a structured Markdown report inline (no Celery). The LLM is called within the tool, the result is saved to the database, and the tool returns immediately with a ready status. -This follows the same inline pattern as generate_image and display_image, -NOT the Celery-based podcast pattern. +Uses short-lived database sessions to avoid holding connections during long +LLM calls (30-120+ seconds). Each DB operation (read config, save report) +opens and closes its own session, ensuring no connection is held idle during +the LLM API call. """ import logging import re from typing import Any +from langchain_core.messages import HumanMessage from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession -from app.db import Report +from app.db import Report, async_session_maker from app.services.llm_service import get_document_summary_llm logger = logging.getLogger(__name__) @@ -97,7 +99,6 @@ def _extract_metadata(content: str) -> dict[str, Any]: def create_generate_report_tool( search_space_id: int, - db_session: AsyncSession, thread_id: int | None = None, ): """ @@ -106,9 +107,11 @@ def create_generate_report_tool( The tool generates a Markdown report inline using the search space's document summary LLM, saves it to the database, and returns immediately. + Uses short-lived database sessions for each DB operation so no connection + is held during the long LLM API call. + Args: search_space_id: The user's search space ID - db_session: Database session for creating the report record thread_id: The chat thread ID for associating the report Returns: @@ -229,50 +232,37 @@ def create_generate_report_tool( - word_count: Number of words in the report - message: Status message (or "error" field if failed) """ - # Resolve the parent report and its group (if versioning) - parent_report: Report | None = None + # Initialize version tracking variables (used by _save_failed_report closure) + parent_report_content: str | None = None report_group_id: int | None = None - if parent_report_id: - parent_report = await db_session.get(Report, parent_report_id) - if parent_report: - report_group_id = parent_report.report_group_id - logger.info( - f"[generate_report] Creating new version from parent {parent_report_id} " - f"(group {report_group_id})" - ) - else: - logger.warning( - f"[generate_report] parent_report_id={parent_report_id} not found, " - "creating standalone report" - ) - async def _save_failed_report(error_msg: str) -> int | None: - """Persist a failed report row so the error is visible later.""" + """Persist a failed report row using a short-lived session.""" try: - failed_report = Report( - title=topic, - content=None, - report_metadata={ - "status": "failed", - "error_message": error_msg, - }, - report_style=report_style, - search_space_id=search_space_id, - thread_id=thread_id, - report_group_id=report_group_id, - ) - db_session.add(failed_report) - await db_session.commit() - await db_session.refresh(failed_report) - # If this is a new group (v1 failed), set group to self - if not failed_report.report_group_id: - failed_report.report_group_id = failed_report.id - await db_session.commit() - logger.info( - f"[generate_report] Saved failed report {failed_report.id}: {error_msg}" - ) - return failed_report.id + async with async_session_maker() as session: + failed_report = Report( + title=topic, + content=None, + report_metadata={ + "status": "failed", + "error_message": error_msg, + }, + report_style=report_style, + search_space_id=search_space_id, + thread_id=thread_id, + report_group_id=report_group_id, + ) + session.add(failed_report) + await session.commit() + await session.refresh(failed_report) + # If this is a new group (v1 failed), set group to self + if not failed_report.report_group_id: + failed_report.report_group_id = failed_report.id + await session.commit() + logger.info( + f"[generate_report] Saved failed report {failed_report.id}: {error_msg}" + ) + return failed_report.id except Exception: logger.exception( "[generate_report] Could not persist failed report row" @@ -280,8 +270,32 @@ def create_generate_report_tool( return None try: - # Get the LLM instance for this search space - llm = await get_document_summary_llm(db_session, search_space_id) + # ── Phase 1: READ (short-lived session) ────────────────────── + # Fetch parent report and LLM config, then close the session + # so no DB connection is held during the long LLM call. + async with async_session_maker() as read_session: + if parent_report_id: + parent_report = await read_session.get( + Report, parent_report_id + ) + if parent_report: + report_group_id = parent_report.report_group_id + parent_report_content = parent_report.content + logger.info( + f"[generate_report] Creating new version from parent {parent_report_id} " + f"(group {report_group_id})" + ) + else: + logger.warning( + f"[generate_report] parent_report_id={parent_report_id} not found, " + "creating standalone report" + ) + + llm = await get_document_summary_llm( + read_session, search_space_id + ) + # read_session closed — connection returned to pool + if not llm: error_msg = ( "No LLM configured. Please configure a language model in Settings." @@ -303,11 +317,11 @@ def create_generate_report_tool( # If revising, include previous version content previous_version_section = "" - if parent_report and parent_report.content: + if parent_report_content: previous_version_section = ( "**Previous Version of This Report (refine this based on the instructions above — " "preserve structure and quality, apply only the requested changes):**\n\n" - f"{parent_report.content}" + f"{parent_report_content}" ) prompt = _REPORT_PROMPT.format( @@ -318,9 +332,7 @@ def create_generate_report_tool( source_content=source_content[:100000], # Cap source content ) - # Call the LLM inline - from langchain_core.messages import HumanMessage - + # ── Phase 2: LLM CALL (no DB connection held) ──────────────── response = await llm.ainvoke([HumanMessage(content=prompt)]) report_content = response.content @@ -351,35 +363,41 @@ def create_generate_report_tool( # Extract metadata (includes "status": "ready") metadata = _extract_metadata(report_content) - # Save to database - report = Report( - title=topic, - content=report_content, - report_metadata=metadata, - report_style=report_style, - search_space_id=search_space_id, - thread_id=thread_id, - report_group_id=report_group_id, # None for v1, inherited for v2+ - ) - db_session.add(report) - await db_session.commit() - await db_session.refresh(report) + # ── Phase 3: WRITE (short-lived session) ───────────────────── + # Save the report to the database, then close the session. + async with async_session_maker() as write_session: + report = Report( + title=topic, + content=report_content, + report_metadata=metadata, + report_style=report_style, + search_space_id=search_space_id, + thread_id=thread_id, + report_group_id=report_group_id, + ) + write_session.add(report) + await write_session.commit() + await write_session.refresh(report) - # If this is a brand-new report (v1), set report_group_id = own id - if not report.report_group_id: - report.report_group_id = report.id - await db_session.commit() + # If this is a brand-new report (v1), set report_group_id = own id + if not report.report_group_id: + report.report_group_id = report.id + await write_session.commit() + + saved_report_id = report.id + saved_group_id = report.report_group_id + # write_session closed — connection returned to pool logger.info( - f"[generate_report] Created report {report.id} " - f"(group={report.report_group_id}): " + f"[generate_report] Created report {saved_report_id} " + f"(group={saved_group_id}): " f"{metadata.get('word_count', 0)} words, " f"{metadata.get('section_count', 0)} sections" ) return { "status": "ready", - "report_id": report.id, + "report_id": saved_report_id, "title": topic, "word_count": metadata.get("word_count", 0), "message": f"Report generated successfully: {topic}",