mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-29 02:46:25 +02:00
refactor: optimize report generation tool to use short-lived database sessions for improved performance and connection management
This commit is contained in:
parent
e1124d170d
commit
3ec30d94ce
2 changed files with 93 additions and 76 deletions
|
|
@ -119,16 +119,15 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
||||||
),
|
),
|
||||||
requires=["search_space_id", "db_session", "thread_id"],
|
requires=["search_space_id", "db_session", "thread_id"],
|
||||||
),
|
),
|
||||||
# Report generation tool (inline, no Celery)
|
# Report generation tool (inline, short-lived sessions for DB ops)
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
name="generate_report",
|
name="generate_report",
|
||||||
description="Generate a structured Markdown report from provided content",
|
description="Generate a structured Markdown report from provided content",
|
||||||
factory=lambda deps: create_generate_report_tool(
|
factory=lambda deps: create_generate_report_tool(
|
||||||
search_space_id=deps["search_space_id"],
|
search_space_id=deps["search_space_id"],
|
||||||
db_session=deps["db_session"],
|
|
||||||
thread_id=deps["thread_id"],
|
thread_id=deps["thread_id"],
|
||||||
),
|
),
|
||||||
requires=["search_space_id", "db_session", "thread_id"],
|
requires=["search_space_id", "thread_id"],
|
||||||
),
|
),
|
||||||
# Link preview tool - fetches Open Graph metadata for URLs
|
# Link preview tool - fetches Open Graph metadata for URLs
|
||||||
ToolDefinition(
|
ToolDefinition(
|
||||||
|
|
|
||||||
|
|
@ -6,18 +6,20 @@ that generates a structured Markdown report inline (no Celery). The LLM is
|
||||||
called within the tool, the result is saved to the database, and the tool
|
called within the tool, the result is saved to the database, and the tool
|
||||||
returns immediately with a ready status.
|
returns immediately with a ready status.
|
||||||
|
|
||||||
This follows the same inline pattern as generate_image and display_image,
|
Uses short-lived database sessions to avoid holding connections during long
|
||||||
NOT the Celery-based podcast pattern.
|
LLM calls (30-120+ seconds). Each DB operation (read config, save report)
|
||||||
|
opens and closes its own session, ensuring no connection is held idle during
|
||||||
|
the LLM API call.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.db import Report
|
from app.db import Report, async_session_maker
|
||||||
from app.services.llm_service import get_document_summary_llm
|
from app.services.llm_service import get_document_summary_llm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -97,7 +99,6 @@ def _extract_metadata(content: str) -> dict[str, Any]:
|
||||||
|
|
||||||
def create_generate_report_tool(
|
def create_generate_report_tool(
|
||||||
search_space_id: int,
|
search_space_id: int,
|
||||||
db_session: AsyncSession,
|
|
||||||
thread_id: int | None = None,
|
thread_id: int | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -106,9 +107,11 @@ def create_generate_report_tool(
|
||||||
The tool generates a Markdown report inline using the search space's
|
The tool generates a Markdown report inline using the search space's
|
||||||
document summary LLM, saves it to the database, and returns immediately.
|
document summary LLM, saves it to the database, and returns immediately.
|
||||||
|
|
||||||
|
Uses short-lived database sessions for each DB operation so no connection
|
||||||
|
is held during the long LLM API call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
search_space_id: The user's search space ID
|
search_space_id: The user's search space ID
|
||||||
db_session: Database session for creating the report record
|
|
||||||
thread_id: The chat thread ID for associating the report
|
thread_id: The chat thread ID for associating the report
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -229,50 +232,37 @@ def create_generate_report_tool(
|
||||||
- word_count: Number of words in the report
|
- word_count: Number of words in the report
|
||||||
- message: Status message (or "error" field if failed)
|
- message: Status message (or "error" field if failed)
|
||||||
"""
|
"""
|
||||||
# Resolve the parent report and its group (if versioning)
|
# Initialize version tracking variables (used by _save_failed_report closure)
|
||||||
parent_report: Report | None = None
|
parent_report_content: str | None = None
|
||||||
report_group_id: int | None = None
|
report_group_id: int | None = None
|
||||||
|
|
||||||
if parent_report_id:
|
|
||||||
parent_report = await db_session.get(Report, parent_report_id)
|
|
||||||
if parent_report:
|
|
||||||
report_group_id = parent_report.report_group_id
|
|
||||||
logger.info(
|
|
||||||
f"[generate_report] Creating new version from parent {parent_report_id} "
|
|
||||||
f"(group {report_group_id})"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"[generate_report] parent_report_id={parent_report_id} not found, "
|
|
||||||
"creating standalone report"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _save_failed_report(error_msg: str) -> int | None:
|
async def _save_failed_report(error_msg: str) -> int | None:
|
||||||
"""Persist a failed report row so the error is visible later."""
|
"""Persist a failed report row using a short-lived session."""
|
||||||
try:
|
try:
|
||||||
failed_report = Report(
|
async with async_session_maker() as session:
|
||||||
title=topic,
|
failed_report = Report(
|
||||||
content=None,
|
title=topic,
|
||||||
report_metadata={
|
content=None,
|
||||||
"status": "failed",
|
report_metadata={
|
||||||
"error_message": error_msg,
|
"status": "failed",
|
||||||
},
|
"error_message": error_msg,
|
||||||
report_style=report_style,
|
},
|
||||||
search_space_id=search_space_id,
|
report_style=report_style,
|
||||||
thread_id=thread_id,
|
search_space_id=search_space_id,
|
||||||
report_group_id=report_group_id,
|
thread_id=thread_id,
|
||||||
)
|
report_group_id=report_group_id,
|
||||||
db_session.add(failed_report)
|
)
|
||||||
await db_session.commit()
|
session.add(failed_report)
|
||||||
await db_session.refresh(failed_report)
|
await session.commit()
|
||||||
# If this is a new group (v1 failed), set group to self
|
await session.refresh(failed_report)
|
||||||
if not failed_report.report_group_id:
|
# If this is a new group (v1 failed), set group to self
|
||||||
failed_report.report_group_id = failed_report.id
|
if not failed_report.report_group_id:
|
||||||
await db_session.commit()
|
failed_report.report_group_id = failed_report.id
|
||||||
logger.info(
|
await session.commit()
|
||||||
f"[generate_report] Saved failed report {failed_report.id}: {error_msg}"
|
logger.info(
|
||||||
)
|
f"[generate_report] Saved failed report {failed_report.id}: {error_msg}"
|
||||||
return failed_report.id
|
)
|
||||||
|
return failed_report.id
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"[generate_report] Could not persist failed report row"
|
"[generate_report] Could not persist failed report row"
|
||||||
|
|
@ -280,8 +270,32 @@ def create_generate_report_tool(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get the LLM instance for this search space
|
# ── Phase 1: READ (short-lived session) ──────────────────────
|
||||||
llm = await get_document_summary_llm(db_session, search_space_id)
|
# Fetch parent report and LLM config, then close the session
|
||||||
|
# so no DB connection is held during the long LLM call.
|
||||||
|
async with async_session_maker() as read_session:
|
||||||
|
if parent_report_id:
|
||||||
|
parent_report = await read_session.get(
|
||||||
|
Report, parent_report_id
|
||||||
|
)
|
||||||
|
if parent_report:
|
||||||
|
report_group_id = parent_report.report_group_id
|
||||||
|
parent_report_content = parent_report.content
|
||||||
|
logger.info(
|
||||||
|
f"[generate_report] Creating new version from parent {parent_report_id} "
|
||||||
|
f"(group {report_group_id})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"[generate_report] parent_report_id={parent_report_id} not found, "
|
||||||
|
"creating standalone report"
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = await get_document_summary_llm(
|
||||||
|
read_session, search_space_id
|
||||||
|
)
|
||||||
|
# read_session closed — connection returned to pool
|
||||||
|
|
||||||
if not llm:
|
if not llm:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
"No LLM configured. Please configure a language model in Settings."
|
"No LLM configured. Please configure a language model in Settings."
|
||||||
|
|
@ -303,11 +317,11 @@ def create_generate_report_tool(
|
||||||
|
|
||||||
# If revising, include previous version content
|
# If revising, include previous version content
|
||||||
previous_version_section = ""
|
previous_version_section = ""
|
||||||
if parent_report and parent_report.content:
|
if parent_report_content:
|
||||||
previous_version_section = (
|
previous_version_section = (
|
||||||
"**Previous Version of This Report (refine this based on the instructions above — "
|
"**Previous Version of This Report (refine this based on the instructions above — "
|
||||||
"preserve structure and quality, apply only the requested changes):**\n\n"
|
"preserve structure and quality, apply only the requested changes):**\n\n"
|
||||||
f"{parent_report.content}"
|
f"{parent_report_content}"
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = _REPORT_PROMPT.format(
|
prompt = _REPORT_PROMPT.format(
|
||||||
|
|
@ -318,9 +332,7 @@ def create_generate_report_tool(
|
||||||
source_content=source_content[:100000], # Cap source content
|
source_content=source_content[:100000], # Cap source content
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call the LLM inline
|
# ── Phase 2: LLM CALL (no DB connection held) ────────────────
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
|
|
||||||
response = await llm.ainvoke([HumanMessage(content=prompt)])
|
response = await llm.ainvoke([HumanMessage(content=prompt)])
|
||||||
report_content = response.content
|
report_content = response.content
|
||||||
|
|
||||||
|
|
@ -351,35 +363,41 @@ def create_generate_report_tool(
|
||||||
# Extract metadata (includes "status": "ready")
|
# Extract metadata (includes "status": "ready")
|
||||||
metadata = _extract_metadata(report_content)
|
metadata = _extract_metadata(report_content)
|
||||||
|
|
||||||
# Save to database
|
# ── Phase 3: WRITE (short-lived session) ─────────────────────
|
||||||
report = Report(
|
# Save the report to the database, then close the session.
|
||||||
title=topic,
|
async with async_session_maker() as write_session:
|
||||||
content=report_content,
|
report = Report(
|
||||||
report_metadata=metadata,
|
title=topic,
|
||||||
report_style=report_style,
|
content=report_content,
|
||||||
search_space_id=search_space_id,
|
report_metadata=metadata,
|
||||||
thread_id=thread_id,
|
report_style=report_style,
|
||||||
report_group_id=report_group_id, # None for v1, inherited for v2+
|
search_space_id=search_space_id,
|
||||||
)
|
thread_id=thread_id,
|
||||||
db_session.add(report)
|
report_group_id=report_group_id,
|
||||||
await db_session.commit()
|
)
|
||||||
await db_session.refresh(report)
|
write_session.add(report)
|
||||||
|
await write_session.commit()
|
||||||
|
await write_session.refresh(report)
|
||||||
|
|
||||||
# If this is a brand-new report (v1), set report_group_id = own id
|
# If this is a brand-new report (v1), set report_group_id = own id
|
||||||
if not report.report_group_id:
|
if not report.report_group_id:
|
||||||
report.report_group_id = report.id
|
report.report_group_id = report.id
|
||||||
await db_session.commit()
|
await write_session.commit()
|
||||||
|
|
||||||
|
saved_report_id = report.id
|
||||||
|
saved_group_id = report.report_group_id
|
||||||
|
# write_session closed — connection returned to pool
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[generate_report] Created report {report.id} "
|
f"[generate_report] Created report {saved_report_id} "
|
||||||
f"(group={report.report_group_id}): "
|
f"(group={saved_group_id}): "
|
||||||
f"{metadata.get('word_count', 0)} words, "
|
f"{metadata.get('word_count', 0)} words, "
|
||||||
f"{metadata.get('section_count', 0)} sections"
|
f"{metadata.get('section_count', 0)} sections"
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"report_id": report.id,
|
"report_id": saved_report_id,
|
||||||
"title": topic,
|
"title": topic,
|
||||||
"word_count": metadata.get("word_count", 0),
|
"word_count": metadata.get("word_count", 0),
|
||||||
"message": f"Report generated successfully: {topic}",
|
"message": f"Report generated successfully: {topic}",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue