refactor: optimize report generation tool to use short-lived database sessions for improved performance and connection management

This commit is contained in:
Anish Sarkar 2026-02-14 18:48:36 +05:30
parent e1124d170d
commit 3ec30d94ce
2 changed files with 93 additions and 76 deletions

View file

@ -119,16 +119,15 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
),
requires=["search_space_id", "db_session", "thread_id"],
),
# Report generation tool (inline, no Celery)
# Report generation tool (inline, short-lived sessions for DB ops)
ToolDefinition(
name="generate_report",
description="Generate a structured Markdown report from provided content",
factory=lambda deps: create_generate_report_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
thread_id=deps["thread_id"],
),
requires=["search_space_id", "db_session", "thread_id"],
requires=["search_space_id", "thread_id"],
),
# Link preview tool - fetches Open Graph metadata for URLs
ToolDefinition(

View file

@ -6,18 +6,20 @@ that generates a structured Markdown report inline (no Celery). The LLM is
called within the tool, the result is saved to the database, and the tool
returns immediately with a ready status.
This follows the same inline pattern as generate_image and display_image,
NOT the Celery-based podcast pattern.
Uses short-lived database sessions to avoid holding connections during long
LLM calls (30-120+ seconds). Each DB operation (read config, save report)
opens and closes its own session, ensuring no connection is held idle during
the LLM API call.
"""
import logging
import re
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Report
from app.db import Report, async_session_maker
from app.services.llm_service import get_document_summary_llm
logger = logging.getLogger(__name__)
@ -97,7 +99,6 @@ def _extract_metadata(content: str) -> dict[str, Any]:
def create_generate_report_tool(
search_space_id: int,
db_session: AsyncSession,
thread_id: int | None = None,
):
"""
@ -106,9 +107,11 @@ def create_generate_report_tool(
The tool generates a Markdown report inline using the search space's
document summary LLM, saves it to the database, and returns immediately.
Uses short-lived database sessions for each DB operation so no connection
is held during the long LLM API call.
Args:
search_space_id: The user's search space ID
db_session: Database session for creating the report record
thread_id: The chat thread ID for associating the report
Returns:
@ -229,50 +232,37 @@ def create_generate_report_tool(
- word_count: Number of words in the report
- message: Status message (or "error" field if failed)
"""
# Resolve the parent report and its group (if versioning)
parent_report: Report | None = None
# Initialize version tracking variables (used by _save_failed_report closure)
parent_report_content: str | None = None
report_group_id: int | None = None
if parent_report_id:
parent_report = await db_session.get(Report, parent_report_id)
if parent_report:
report_group_id = parent_report.report_group_id
logger.info(
f"[generate_report] Creating new version from parent {parent_report_id} "
f"(group {report_group_id})"
)
else:
logger.warning(
f"[generate_report] parent_report_id={parent_report_id} not found, "
"creating standalone report"
)
async def _save_failed_report(error_msg: str) -> int | None:
"""Persist a failed report row so the error is visible later."""
"""Persist a failed report row using a short-lived session."""
try:
failed_report = Report(
title=topic,
content=None,
report_metadata={
"status": "failed",
"error_message": error_msg,
},
report_style=report_style,
search_space_id=search_space_id,
thread_id=thread_id,
report_group_id=report_group_id,
)
db_session.add(failed_report)
await db_session.commit()
await db_session.refresh(failed_report)
# If this is a new group (v1 failed), set group to self
if not failed_report.report_group_id:
failed_report.report_group_id = failed_report.id
await db_session.commit()
logger.info(
f"[generate_report] Saved failed report {failed_report.id}: {error_msg}"
)
return failed_report.id
async with async_session_maker() as session:
failed_report = Report(
title=topic,
content=None,
report_metadata={
"status": "failed",
"error_message": error_msg,
},
report_style=report_style,
search_space_id=search_space_id,
thread_id=thread_id,
report_group_id=report_group_id,
)
session.add(failed_report)
await session.commit()
await session.refresh(failed_report)
# If this is a new group (v1 failed), set group to self
if not failed_report.report_group_id:
failed_report.report_group_id = failed_report.id
await session.commit()
logger.info(
f"[generate_report] Saved failed report {failed_report.id}: {error_msg}"
)
return failed_report.id
except Exception:
logger.exception(
"[generate_report] Could not persist failed report row"
@ -280,8 +270,32 @@ def create_generate_report_tool(
return None
try:
# Get the LLM instance for this search space
llm = await get_document_summary_llm(db_session, search_space_id)
# ── Phase 1: READ (short-lived session) ──────────────────────
# Fetch parent report and LLM config, then close the session
# so no DB connection is held during the long LLM call.
async with async_session_maker() as read_session:
if parent_report_id:
parent_report = await read_session.get(
Report, parent_report_id
)
if parent_report:
report_group_id = parent_report.report_group_id
parent_report_content = parent_report.content
logger.info(
f"[generate_report] Creating new version from parent {parent_report_id} "
f"(group {report_group_id})"
)
else:
logger.warning(
f"[generate_report] parent_report_id={parent_report_id} not found, "
"creating standalone report"
)
llm = await get_document_summary_llm(
read_session, search_space_id
)
# read_session closed — connection returned to pool
if not llm:
error_msg = (
"No LLM configured. Please configure a language model in Settings."
@ -303,11 +317,11 @@ def create_generate_report_tool(
# If revising, include previous version content
previous_version_section = ""
if parent_report and parent_report.content:
if parent_report_content:
previous_version_section = (
"**Previous Version of This Report (refine this based on the instructions above — "
"preserve structure and quality, apply only the requested changes):**\n\n"
f"{parent_report.content}"
f"{parent_report_content}"
)
prompt = _REPORT_PROMPT.format(
@ -318,9 +332,7 @@ def create_generate_report_tool(
source_content=source_content[:100000], # Cap source content
)
# Call the LLM inline
from langchain_core.messages import HumanMessage
# ── Phase 2: LLM CALL (no DB connection held) ────────────────
response = await llm.ainvoke([HumanMessage(content=prompt)])
report_content = response.content
@ -351,35 +363,41 @@ def create_generate_report_tool(
# Extract metadata (includes "status": "ready")
metadata = _extract_metadata(report_content)
# Save to database
report = Report(
title=topic,
content=report_content,
report_metadata=metadata,
report_style=report_style,
search_space_id=search_space_id,
thread_id=thread_id,
report_group_id=report_group_id, # None for v1, inherited for v2+
)
db_session.add(report)
await db_session.commit()
await db_session.refresh(report)
# ── Phase 3: WRITE (short-lived session) ─────────────────────
# Save the report to the database, then close the session.
async with async_session_maker() as write_session:
report = Report(
title=topic,
content=report_content,
report_metadata=metadata,
report_style=report_style,
search_space_id=search_space_id,
thread_id=thread_id,
report_group_id=report_group_id,
)
write_session.add(report)
await write_session.commit()
await write_session.refresh(report)
# If this is a brand-new report (v1), set report_group_id = own id
if not report.report_group_id:
report.report_group_id = report.id
await db_session.commit()
# If this is a brand-new report (v1), set report_group_id = own id
if not report.report_group_id:
report.report_group_id = report.id
await write_session.commit()
saved_report_id = report.id
saved_group_id = report.report_group_id
# write_session closed — connection returned to pool
logger.info(
f"[generate_report] Created report {report.id} "
f"(group={report.report_group_id}): "
f"[generate_report] Created report {saved_report_id} "
f"(group={saved_group_id}): "
f"{metadata.get('word_count', 0)} words, "
f"{metadata.get('section_count', 0)} sections"
)
return {
"status": "ready",
"report_id": report.id,
"report_id": saved_report_id,
"title": topic,
"word_count": metadata.get("word_count", 0),
"message": f"Report generated successfully: {topic}",