mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-26 09:16:22 +02:00
refactor: optimize report generation tool to use short-lived database sessions for improved performance and connection management
This commit is contained in:
parent
e1124d170d
commit
3ec30d94ce
2 changed files with 93 additions and 76 deletions
|
|
@ -119,16 +119,15 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
),
|
||||
requires=["search_space_id", "db_session", "thread_id"],
|
||||
),
|
||||
# Report generation tool (inline, no Celery)
|
||||
# Report generation tool (inline, short-lived sessions for DB ops)
|
||||
ToolDefinition(
|
||||
name="generate_report",
|
||||
description="Generate a structured Markdown report from provided content",
|
||||
factory=lambda deps: create_generate_report_tool(
|
||||
search_space_id=deps["search_space_id"],
|
||||
db_session=deps["db_session"],
|
||||
thread_id=deps["thread_id"],
|
||||
),
|
||||
requires=["search_space_id", "db_session", "thread_id"],
|
||||
requires=["search_space_id", "thread_id"],
|
||||
),
|
||||
# Link preview tool - fetches Open Graph metadata for URLs
|
||||
ToolDefinition(
|
||||
|
|
|
|||
|
|
@ -6,18 +6,20 @@ that generates a structured Markdown report inline (no Celery). The LLM is
|
|||
called within the tool, the result is saved to the database, and the tool
|
||||
returns immediately with a ready status.
|
||||
|
||||
This follows the same inline pattern as generate_image and display_image,
|
||||
NOT the Celery-based podcast pattern.
|
||||
Uses short-lived database sessions to avoid holding connections during long
|
||||
LLM calls (30-120+ seconds). Each DB operation (read config, save report)
|
||||
opens and closes its own session, ensuring no connection is held idle during
|
||||
the LLM API call.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Report
|
||||
from app.db import Report, async_session_maker
|
||||
from app.services.llm_service import get_document_summary_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -97,7 +99,6 @@ def _extract_metadata(content: str) -> dict[str, Any]:
|
|||
|
||||
def create_generate_report_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
thread_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -106,9 +107,11 @@ def create_generate_report_tool(
|
|||
The tool generates a Markdown report inline using the search space's
|
||||
document summary LLM, saves it to the database, and returns immediately.
|
||||
|
||||
Uses short-lived database sessions for each DB operation so no connection
|
||||
is held during the long LLM API call.
|
||||
|
||||
Args:
|
||||
search_space_id: The user's search space ID
|
||||
db_session: Database session for creating the report record
|
||||
thread_id: The chat thread ID for associating the report
|
||||
|
||||
Returns:
|
||||
|
|
@ -229,50 +232,37 @@ def create_generate_report_tool(
|
|||
- word_count: Number of words in the report
|
||||
- message: Status message (or "error" field if failed)
|
||||
"""
|
||||
# Resolve the parent report and its group (if versioning)
|
||||
parent_report: Report | None = None
|
||||
# Initialize version tracking variables (used by _save_failed_report closure)
|
||||
parent_report_content: str | None = None
|
||||
report_group_id: int | None = None
|
||||
|
||||
if parent_report_id:
|
||||
parent_report = await db_session.get(Report, parent_report_id)
|
||||
if parent_report:
|
||||
report_group_id = parent_report.report_group_id
|
||||
logger.info(
|
||||
f"[generate_report] Creating new version from parent {parent_report_id} "
|
||||
f"(group {report_group_id})"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[generate_report] parent_report_id={parent_report_id} not found, "
|
||||
"creating standalone report"
|
||||
)
|
||||
|
||||
async def _save_failed_report(error_msg: str) -> int | None:
|
||||
"""Persist a failed report row so the error is visible later."""
|
||||
"""Persist a failed report row using a short-lived session."""
|
||||
try:
|
||||
failed_report = Report(
|
||||
title=topic,
|
||||
content=None,
|
||||
report_metadata={
|
||||
"status": "failed",
|
||||
"error_message": error_msg,
|
||||
},
|
||||
report_style=report_style,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
report_group_id=report_group_id,
|
||||
)
|
||||
db_session.add(failed_report)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(failed_report)
|
||||
# If this is a new group (v1 failed), set group to self
|
||||
if not failed_report.report_group_id:
|
||||
failed_report.report_group_id = failed_report.id
|
||||
await db_session.commit()
|
||||
logger.info(
|
||||
f"[generate_report] Saved failed report {failed_report.id}: {error_msg}"
|
||||
)
|
||||
return failed_report.id
|
||||
async with async_session_maker() as session:
|
||||
failed_report = Report(
|
||||
title=topic,
|
||||
content=None,
|
||||
report_metadata={
|
||||
"status": "failed",
|
||||
"error_message": error_msg,
|
||||
},
|
||||
report_style=report_style,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
report_group_id=report_group_id,
|
||||
)
|
||||
session.add(failed_report)
|
||||
await session.commit()
|
||||
await session.refresh(failed_report)
|
||||
# If this is a new group (v1 failed), set group to self
|
||||
if not failed_report.report_group_id:
|
||||
failed_report.report_group_id = failed_report.id
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"[generate_report] Saved failed report {failed_report.id}: {error_msg}"
|
||||
)
|
||||
return failed_report.id
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[generate_report] Could not persist failed report row"
|
||||
|
|
@ -280,8 +270,32 @@ def create_generate_report_tool(
|
|||
return None
|
||||
|
||||
try:
|
||||
# Get the LLM instance for this search space
|
||||
llm = await get_document_summary_llm(db_session, search_space_id)
|
||||
# ── Phase 1: READ (short-lived session) ──────────────────────
|
||||
# Fetch parent report and LLM config, then close the session
|
||||
# so no DB connection is held during the long LLM call.
|
||||
async with async_session_maker() as read_session:
|
||||
if parent_report_id:
|
||||
parent_report = await read_session.get(
|
||||
Report, parent_report_id
|
||||
)
|
||||
if parent_report:
|
||||
report_group_id = parent_report.report_group_id
|
||||
parent_report_content = parent_report.content
|
||||
logger.info(
|
||||
f"[generate_report] Creating new version from parent {parent_report_id} "
|
||||
f"(group {report_group_id})"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[generate_report] parent_report_id={parent_report_id} not found, "
|
||||
"creating standalone report"
|
||||
)
|
||||
|
||||
llm = await get_document_summary_llm(
|
||||
read_session, search_space_id
|
||||
)
|
||||
# read_session closed — connection returned to pool
|
||||
|
||||
if not llm:
|
||||
error_msg = (
|
||||
"No LLM configured. Please configure a language model in Settings."
|
||||
|
|
@ -303,11 +317,11 @@ def create_generate_report_tool(
|
|||
|
||||
# If revising, include previous version content
|
||||
previous_version_section = ""
|
||||
if parent_report and parent_report.content:
|
||||
if parent_report_content:
|
||||
previous_version_section = (
|
||||
"**Previous Version of This Report (refine this based on the instructions above — "
|
||||
"preserve structure and quality, apply only the requested changes):**\n\n"
|
||||
f"{parent_report.content}"
|
||||
f"{parent_report_content}"
|
||||
)
|
||||
|
||||
prompt = _REPORT_PROMPT.format(
|
||||
|
|
@ -318,9 +332,7 @@ def create_generate_report_tool(
|
|||
source_content=source_content[:100000], # Cap source content
|
||||
)
|
||||
|
||||
# Call the LLM inline
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
# ── Phase 2: LLM CALL (no DB connection held) ────────────────
|
||||
response = await llm.ainvoke([HumanMessage(content=prompt)])
|
||||
report_content = response.content
|
||||
|
||||
|
|
@ -351,35 +363,41 @@ def create_generate_report_tool(
|
|||
# Extract metadata (includes "status": "ready")
|
||||
metadata = _extract_metadata(report_content)
|
||||
|
||||
# Save to database
|
||||
report = Report(
|
||||
title=topic,
|
||||
content=report_content,
|
||||
report_metadata=metadata,
|
||||
report_style=report_style,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
report_group_id=report_group_id, # None for v1, inherited for v2+
|
||||
)
|
||||
db_session.add(report)
|
||||
await db_session.commit()
|
||||
await db_session.refresh(report)
|
||||
# ── Phase 3: WRITE (short-lived session) ─────────────────────
|
||||
# Save the report to the database, then close the session.
|
||||
async with async_session_maker() as write_session:
|
||||
report = Report(
|
||||
title=topic,
|
||||
content=report_content,
|
||||
report_metadata=metadata,
|
||||
report_style=report_style,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
report_group_id=report_group_id,
|
||||
)
|
||||
write_session.add(report)
|
||||
await write_session.commit()
|
||||
await write_session.refresh(report)
|
||||
|
||||
# If this is a brand-new report (v1), set report_group_id = own id
|
||||
if not report.report_group_id:
|
||||
report.report_group_id = report.id
|
||||
await db_session.commit()
|
||||
# If this is a brand-new report (v1), set report_group_id = own id
|
||||
if not report.report_group_id:
|
||||
report.report_group_id = report.id
|
||||
await write_session.commit()
|
||||
|
||||
saved_report_id = report.id
|
||||
saved_group_id = report.report_group_id
|
||||
# write_session closed — connection returned to pool
|
||||
|
||||
logger.info(
|
||||
f"[generate_report] Created report {report.id} "
|
||||
f"(group={report.report_group_id}): "
|
||||
f"[generate_report] Created report {saved_report_id} "
|
||||
f"(group={saved_group_id}): "
|
||||
f"{metadata.get('word_count', 0)} words, "
|
||||
f"{metadata.get('section_count', 0)} sections"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "ready",
|
||||
"report_id": report.id,
|
||||
"report_id": saved_report_id,
|
||||
"title": topic,
|
||||
"word_count": metadata.get("word_count", 0),
|
||||
"message": f"Report generated successfully: {topic}",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue