From 8d0090c6a1dace6e12640c9e9b74ac5ccbed6a49 Mon Sep 17 00:00:00 2001 From: CREDO23 Date: Thu, 4 Jun 2026 20:30:30 +0200 Subject: [PATCH] refactor(agents): delete deliverable dead twins in shared/tools; fix live image api_base bug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The deliverables subagent runs its own generate_image/podcast/report/resume/ video_presentation (via tools/index.py); the shared/tools copies had zero production importers — classic dead twins. Removed them so deliverable tools live only in their vertical slice. While repointing the 2 stranded unit tests at the LIVE deliverables modules, found the OpenRouter empty-api_base defense (resolve_api_base) existed ONLY in the dead shared generate_image, never propagated to the live multi-agent copy. Ported the fix into deliverables/tools/generate_image.py (both the global-config and user-DB-config branches) so an empty api_base no longer falls through to LiteLLM's global api_base (Azure) and 404s. Tests now exercise the live Command/receipt-returning tools (invoke the raw coroutine with a hand-built ToolRuntime; resume progress events neutralized). --- .../deliverables/tools/generate_image.py | 48 +- .../app/agents/shared/tools/__init__.py | 27 +- .../app/agents/shared/tools/generate_image.py | 280 ----- .../app/agents/shared/tools/podcast.py | 160 --- .../app/agents/shared/tools/report.py | 1084 ----------------- .../app/agents/shared/tools/resume.py | 812 ------------ .../agents/shared/tools/video_presentation.py | 138 --- .../app/services/provider_capabilities.py | 2 +- .../new_chat/tools/test_resume_page_limits.py | 53 +- .../test_image_gen_api_base_defense.py | 19 +- 10 files changed, 104 insertions(+), 2519 deletions(-) delete mode 100644 surfsense_backend/app/agents/shared/tools/generate_image.py delete mode 100644 surfsense_backend/app/agents/shared/tools/podcast.py delete mode 100644 surfsense_backend/app/agents/shared/tools/report.py delete mode 100644 surfsense_backend/app/agents/shared/tools/resume.py delete mode 100644 surfsense_backend/app/agents/shared/tools/video_presentation.py diff --git a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py index 094371760..d7105f903 100644 --- a/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py +++ b/surfsense_backend/app/agents/multi_agent_chat/subagents/builtins/deliverables/tools/generate_image.py @@ -25,6 +25,7 @@ from app.services.image_gen_router_service import ( ImageGenRouterService, is_image_gen_auto_mode, ) +from app.services.provider_api_base import resolve_api_base from app.utils.signed_image_urls import generate_image_token logger = logging.getLogger(__name__) @@ -43,13 +44,16 @@ _PROVIDER_MAP = { } +def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: + if custom_provider: + return custom_provider + return _PROVIDER_MAP.get(provider.upper(), provider.lower()) + + def _build_model_string( provider: str, model_name: str, custom_provider: str | None ) -> str: - if custom_provider: - return f"{custom_provider}/{model_name}" - prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower()) - return f"{prefix}/{model_name}" + return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}" def _get_global_image_gen_config(config_id: int) -> dict | None: @@ -163,14 +167,20 @@ def create_generate_image_tool( err = f"Image generation config {config_id} not found" return _failed({"error": err}, error=err) - model_string = _build_model_string( - cfg.get("provider", ""), - cfg["model_name"], - cfg.get("custom_provider"), + provider_prefix = _resolve_provider_prefix( + cfg.get("provider", ""), cfg.get("custom_provider") ) + model_string = f"{provider_prefix}/{cfg['model_name']}" gen_kwargs["api_key"] = cfg.get("api_key") - if cfg.get("api_base"): - gen_kwargs["api_base"] = cfg["api_base"] + # Defense-in-depth: an empty ``api_base`` must not fall + # through to LiteLLM's global ``api_base`` (e.g. Azure). + api_base = resolve_api_base( + provider=cfg.get("provider"), + provider_prefix=provider_prefix, + config_api_base=cfg.get("api_base"), + ) + if api_base: + gen_kwargs["api_base"] = api_base if cfg.get("api_version"): gen_kwargs["api_version"] = cfg["api_version"] if cfg.get("litellm_params"): @@ -191,14 +201,20 @@ def create_generate_image_tool( err = f"Image generation config {config_id} not found" return _failed({"error": err}, error=err) - model_string = _build_model_string( - db_cfg.provider.value, - db_cfg.model_name, - db_cfg.custom_provider, + provider_prefix = _resolve_provider_prefix( + db_cfg.provider.value, db_cfg.custom_provider ) + model_string = f"{provider_prefix}/{db_cfg.model_name}" gen_kwargs["api_key"] = db_cfg.api_key - if db_cfg.api_base: - gen_kwargs["api_base"] = db_cfg.api_base + # Defense-in-depth: an empty ``api_base`` must not fall + # through to LiteLLM's global ``api_base`` (e.g. Azure). + api_base = resolve_api_base( + provider=db_cfg.provider.value, + provider_prefix=provider_prefix, + config_api_base=db_cfg.api_base, + ) + if api_base: + gen_kwargs["api_base"] = api_base if db_cfg.api_version: gen_kwargs["api_version"] = db_cfg.api_version if db_cfg.litellm_params: diff --git a/surfsense_backend/app/agents/shared/tools/__init__.py b/surfsense_backend/app/agents/shared/tools/__init__.py index a7c8c71a3..e4689c25a 100644 --- a/surfsense_backend/app/agents/shared/tools/__init__.py +++ b/surfsense_backend/app/agents/shared/tools/__init__.py @@ -1,37 +1,24 @@ -""" -Tools module for SurfSense deep agent. +"""Cross-agent shared tools and tool metadata. -This module contains all the tools available to the SurfSense agent. -To add a new tool, see the documentation in registry.py. - -Available tools: -- generate_podcast: Generate audio podcasts from content -- generate_video_presentation: Generate video presentations with slides and narration -- generate_image: Generate images from text descriptions using AI models +Tool *implementations* live with the agents that own them (e.g. deliverable +generators under ``subagents/builtins/deliverables/tools``). This package +holds only the genuinely shared pieces: the display-metadata catalog and the +knowledge-base helpers used across agents. """ -# Registry exports -# Tool factory exports (for direct use) -from .generate_image import create_generate_image_tool +from .catalog import TOOL_CATALOG, ToolMetadata from .knowledge_base import ( CONNECTOR_DESCRIPTIONS, format_documents_for_context, search_knowledge_base_async, ) -from .catalog import TOOL_CATALOG, ToolMetadata -from .podcast import create_generate_podcast_tool -from .video_presentation import create_generate_video_presentation_tool __all__ = [ # Tool catalog (display metadata) "TOOL_CATALOG", + "ToolMetadata", # Knowledge base utilities "CONNECTOR_DESCRIPTIONS", - "ToolMetadata", - # Tool factories - "create_generate_image_tool", - "create_generate_podcast_tool", - "create_generate_video_presentation_tool", "format_documents_for_context", "search_knowledge_base_async", ] diff --git a/surfsense_backend/app/agents/shared/tools/generate_image.py b/surfsense_backend/app/agents/shared/tools/generate_image.py deleted file mode 100644 index 9e287ac51..000000000 --- a/surfsense_backend/app/agents/shared/tools/generate_image.py +++ /dev/null @@ -1,280 +0,0 @@ -""" -Image generation tool for the SurfSense agent. - -This module provides a tool that generates images using litellm.aimage_generation() -and returns the result directly in a format the frontend Image component can render. - -Config resolution: -1. Uses the search space's image_generation_config_id preference -2. Falls back to Auto mode (router load balancing) if available -3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs) -""" - -import hashlib -import logging -from typing import Any - -from langchain_core.tools import tool -from litellm import aimage_generation -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.config import config -from app.db import ( - ImageGeneration, - ImageGenerationConfig, - SearchSpace, - shielded_async_session, -) -from app.services.image_gen_router_service import ( - IMAGE_GEN_AUTO_MODE_ID, - ImageGenRouterService, - is_image_gen_auto_mode, -) -from app.services.provider_api_base import resolve_api_base -from app.utils.signed_image_urls import generate_image_token - -logger = logging.getLogger(__name__) - -# Provider mapping (same as routes) -_PROVIDER_MAP = { - "OPENAI": "openai", - "AZURE_OPENAI": "azure", - "GOOGLE": "gemini", - "VERTEX_AI": "vertex_ai", - "BEDROCK": "bedrock", - "RECRAFT": "recraft", - "OPENROUTER": "openrouter", - "XINFERENCE": "xinference", - "NSCALE": "nscale", -} - - -def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str: - if custom_provider: - return custom_provider - return _PROVIDER_MAP.get(provider.upper(), provider.lower()) - - -def _build_model_string( - provider: str, model_name: str, custom_provider: str | None -) -> str: - prefix = _resolve_provider_prefix(provider, custom_provider) - return f"{prefix}/{model_name}" - - -def _get_global_image_gen_config(config_id: int) -> dict | None: - """Get a global image gen config by negative ID.""" - for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS: - if cfg.get("id") == config_id: - return cfg - return None - - -def create_generate_image_tool( - search_space_id: int, - db_session: AsyncSession, -): - """ - Factory function to create the generate_image tool. - - Args: - search_space_id: The search space ID (for config resolution) - db_session: Reserved for compatibility with the tool registry. - The streaming task's ``AsyncSession`` is shared by every tool; - because AsyncSession is not concurrency-safe, parallel tool calls - would interleave flushes (e.g. podcast + image in the same step) - and poison the transaction. This tool opens its own session. - """ - del db_session # use a fresh per-call session, see below - - @tool - async def generate_image( - prompt: str, - n: int = 1, - ) -> dict[str, Any]: - """ - Generate an image from a text description using AI image models. - - Use this tool when the user asks you to create, generate, draw, or make an image. - The generated image will be displayed directly in the chat. - - Args: - prompt: A detailed text description of the image to generate. - Be specific about subject, style, colors, composition, and mood. - n: Number of images to generate (1-4). Default: 1 - - Returns: - A dictionary containing the generated image(s) for display in the chat. - """ - try: - # Use a per-call session so concurrent tool calls don't share an - # AsyncSession (which is not concurrency-safe). The streaming - # task's session is shared across every tool; without isolation, - # autoflushes from a concurrent writer poison this tool too. - async with shielded_async_session() as session: - result = await session.execute( - select(SearchSpace).filter(SearchSpace.id == search_space_id) - ) - search_space = result.scalars().first() - if not search_space: - return {"error": "Search space not found"} - - config_id = ( - search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID - ) - - # Build generation kwargs - # NOTE: size, quality, and style are intentionally NOT passed. - # Different models support different values for these params - # (e.g. DALL-E 3 wants "hd"/"standard" for quality while - # gpt-image-1 wants "high"/"medium"/"low"; size options also - # differ). Letting the model use its own defaults avoids errors. - gen_kwargs: dict[str, Any] = {} - if n is not None and n > 1: - gen_kwargs["n"] = n - - # Call litellm based on config type - if is_image_gen_auto_mode(config_id): - if not ImageGenRouterService.is_initialized(): - return { - "error": "No image generation models configured. " - "Please add an image model in Settings > Image Models." - } - response = await ImageGenRouterService.aimage_generation( - prompt=prompt, model="auto", **gen_kwargs - ) - elif config_id < 0: - cfg = _get_global_image_gen_config(config_id) - if not cfg: - return { - "error": f"Image generation config {config_id} not found" - } - - provider_prefix = _resolve_provider_prefix( - cfg.get("provider", ""), cfg.get("custom_provider") - ) - model_string = f"{provider_prefix}/{cfg['model_name']}" - gen_kwargs["api_key"] = cfg.get("api_key") - api_base = resolve_api_base( - provider=cfg.get("provider"), - provider_prefix=provider_prefix, - config_api_base=cfg.get("api_base"), - ) - if api_base: - gen_kwargs["api_base"] = api_base - if cfg.get("api_version"): - gen_kwargs["api_version"] = cfg["api_version"] - if cfg.get("litellm_params"): - gen_kwargs.update(cfg["litellm_params"]) - - response = await aimage_generation( - prompt=prompt, model=model_string, **gen_kwargs - ) - else: - # Positive ID = user-created ImageGenerationConfig - cfg_result = await session.execute( - select(ImageGenerationConfig).filter( - ImageGenerationConfig.id == config_id - ) - ) - db_cfg = cfg_result.scalars().first() - if not db_cfg: - return { - "error": f"Image generation config {config_id} not found" - } - - provider_prefix = _resolve_provider_prefix( - db_cfg.provider.value, db_cfg.custom_provider - ) - model_string = f"{provider_prefix}/{db_cfg.model_name}" - gen_kwargs["api_key"] = db_cfg.api_key - api_base = resolve_api_base( - provider=db_cfg.provider.value, - provider_prefix=provider_prefix, - config_api_base=db_cfg.api_base, - ) - if api_base: - gen_kwargs["api_base"] = api_base - if db_cfg.api_version: - gen_kwargs["api_version"] = db_cfg.api_version - if db_cfg.litellm_params: - gen_kwargs.update(db_cfg.litellm_params) - - response = await aimage_generation( - prompt=prompt, model=model_string, **gen_kwargs - ) - - # Parse the response and store in DB - response_dict = ( - response.model_dump() - if hasattr(response, "model_dump") - else dict(response) - ) - - # Generate a random access token for this image - access_token = generate_image_token() - - # Save to image_generations table for history - db_image_gen = ImageGeneration( - prompt=prompt, - model=getattr(response, "_hidden_params", {}).get("model"), - n=n, - image_generation_config_id=config_id, - response_data=response_dict, - search_space_id=search_space_id, - access_token=access_token, - ) - session.add(db_image_gen) - await session.commit() - await session.refresh(db_image_gen) - db_image_gen_id = db_image_gen.id - - # Extract image URLs from response - images = response_dict.get("data", []) - if not images: - return {"error": "No images were generated"} - - first_image = images[0] - revised_prompt = first_image.get("revised_prompt", prompt) - - # Resolve image URL: - # - If the API returned a URL, use it directly. - # - If the API returned b64_json (e.g. gpt-image-1), serve the - # image through our backend endpoint to avoid bloating the - # LLM context with megabytes of base64 data. - if first_image.get("url"): - image_url = first_image["url"] - elif first_image.get("b64_json"): - backend_url = config.BACKEND_URL or "http://localhost:8000" - image_url = ( - f"{backend_url}/api/v1/image-generations/" - f"{db_image_gen_id}/image?token={access_token}" - ) - else: - return {"error": "No displayable image data in the response"} - - image_id = f"image-{hashlib.md5(image_url.encode()).hexdigest()[:12]}" - - return { - "id": image_id, - "assetId": image_url, - "src": image_url, - "alt": revised_prompt or prompt, - "title": "Generated Image", - "description": revised_prompt if revised_prompt != prompt else None, - "domain": "ai-generated", - "ratio": "auto", - "generated": True, - "prompt": prompt, - "image_count": len(images), - } - - except Exception as e: - logger.exception("Image generation failed in tool") - return { - "error": f"Image generation failed: {e!s}", - "prompt": prompt, - } - - return generate_image diff --git a/surfsense_backend/app/agents/shared/tools/podcast.py b/surfsense_backend/app/agents/shared/tools/podcast.py deleted file mode 100644 index 83ac98768..000000000 --- a/surfsense_backend/app/agents/shared/tools/podcast.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -Podcast generation tool for the SurfSense agent. - -This module provides a factory function for creating the generate_podcast tool -that submits a Celery task for background podcast generation. The tool then -polls the podcast row until it reaches a terminal status (READY/FAILED) and -returns that status. The wait is bounded by the chat's HTTP / process -lifetime; see app.agents.shared.deliverable_wait for details. -""" - -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.shared.deliverable_wait import wait_for_deliverable -from app.db import Podcast, PodcastStatus, shielded_async_session - -logger = logging.getLogger(__name__) - - -def create_generate_podcast_tool( - search_space_id: int, - db_session: AsyncSession, - thread_id: int | None = None, -): - """ - Factory function to create the generate_podcast tool with injected dependencies. - - Pre-creates podcast record with pending status so podcast_id is available - immediately for frontend polling. - - Args: - search_space_id: The user's search space ID - db_session: Reserved for future read-side use; the row is written via a - fresh, tool-local session so parallel tool calls (e.g. podcast + - video presentation in the same agent step) don't share an - ``AsyncSession`` (which is not concurrency-safe). - thread_id: The chat thread ID for associating the podcast - - Returns: - A configured tool function for generating podcasts - """ - del db_session # writes use a fresh tool-local session, see below - - @tool - async def generate_podcast( - source_content: str, - podcast_title: str = "SurfSense Podcast", - user_prompt: str | None = None, - ) -> dict[str, Any]: - """ - Generate a podcast from the provided content. - - Use this tool when the user asks to create, generate, or make a podcast. - Common triggers include phrases like: - - "Give me a podcast about this" - - "Create a podcast from this conversation" - - "Generate a podcast summary" - - "Make a podcast about..." - - "Turn this into a podcast" - - Args: - source_content: The text content to convert into a podcast. - podcast_title: Title for the podcast (default: "SurfSense Podcast") - user_prompt: Optional instructions for podcast style, tone, or format. - - Returns: - A dictionary containing: - - status: PodcastStatus value (pending, generating, or failed) - - podcast_id: The podcast ID for polling (when status is pending or generating) - - title: The podcast title - - message: Status message (or "error" field if status is failed) - """ - try: - # Open a fresh session per call. The streaming task's session is - # shared between every tool, and ``AsyncSession`` is NOT safe for - # concurrent use: when the LLM emits parallel tool calls, two - # concurrent ``add()`` / ``commit()`` paths interleave and the - # second one hits "Session.add() during flush" → the transaction - # is poisoned for both tools. - async with shielded_async_session() as session: - podcast = Podcast( - title=podcast_title, - status=PodcastStatus.PENDING, - search_space_id=search_space_id, - thread_id=thread_id, - ) - session.add(podcast) - await session.commit() - await session.refresh(podcast) - podcast_id = podcast.id - - from app.tasks.celery_tasks.podcast_tasks import ( - generate_content_podcast_task, - ) - - task = generate_content_podcast_task.delay( - podcast_id=podcast_id, - source_content=source_content, - search_space_id=search_space_id, - user_prompt=user_prompt, - ) - - logger.info( - "[generate_podcast] Created podcast %s, task: %s", - podcast_id, - task.id, - ) - - # Wait until the Celery worker flips the row to a terminal - # state. No internal budget — see deliverable_wait module. - terminal_status, columns, elapsed = await wait_for_deliverable( - model=Podcast, - row_id=podcast_id, - columns=[Podcast.status, Podcast.file_location], - terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED}, - ) - - if terminal_status == PodcastStatus.READY: - file_location = columns[1] if columns else None - logger.info( - "[generate_podcast] Podcast %s READY in %.2fs (file=%s)", - podcast_id, - elapsed, - file_location, - ) - return { - "status": PodcastStatus.READY.value, - "podcast_id": podcast_id, - "title": podcast_title, - "file_location": file_location, - "message": ("Podcast generated and saved to your podcast panel."), - } - - # Only other terminal state is FAILED. - logger.warning( - "[generate_podcast] Podcast %s FAILED in %.2fs", - podcast_id, - elapsed, - ) - return { - "status": PodcastStatus.FAILED.value, - "podcast_id": podcast_id, - "title": podcast_title, - "error": ("Background worker reported FAILED status for this podcast."), - } - - except Exception as e: - error_message = str(e) - logger.exception("[generate_podcast] Error: %s", error_message) - return { - "status": PodcastStatus.FAILED.value, - "error": error_message, - "title": podcast_title, - "podcast_id": None, - } - - return generate_podcast diff --git a/surfsense_backend/app/agents/shared/tools/report.py b/surfsense_backend/app/agents/shared/tools/report.py deleted file mode 100644 index 6bc1b7d57..000000000 --- a/surfsense_backend/app/agents/shared/tools/report.py +++ /dev/null @@ -1,1084 +0,0 @@ -""" -Report generation tool for the SurfSense agent. - -This module provides a factory function for creating the generate_report tool -that generates a structured Markdown report inline (no Celery). The LLM is -called within the tool, the result is saved to the database, and the tool -returns immediately with a ready status. - -Uses short-lived database sessions to avoid holding connections during long -LLM calls (30-120+ seconds). Each DB operation (read config, save report) -opens and closes its own session, ensuring no connection is held idle during -the LLM API call. - -Generation strategies: - - Single-shot generation for all new reports - - Section-level revision for targeted edits (preserves unchanged sections) - - Full-document revision as fallback for global changes - -Source strategies (how source content is collected): - - "provided" — Use only the supplied source_content (default, backward-compat) - - "conversation" — Same as "provided"; agent passes conversation summary - - "kb_search" — Tool searches knowledge base internally with targeted queries - - "auto" — Use source_content if sufficient, else search KB as fallback -""" - -import asyncio -import json -import logging -import re -from typing import Any - -from langchain_core.callbacks import dispatch_custom_event -from langchain_core.messages import HumanMessage -from langchain_core.tools import tool - -from app.db import Report, shielded_async_session -from app.services.connector_service import ConnectorService -from app.services.llm_service import get_document_summary_llm - -logger = logging.getLogger(__name__) - -# ─── Shared Formatting Rules ──────────────────────────────────────────────── -# Reusable formatting instructions appended to section-level and review prompts. - -_FORMATTING_RULES = """\ -- IMPORTANT: Output raw Markdown directly. Do NOT wrap the entire output in a \ -code fence (e.g. ```markdown, ````markdown, or any backtick fence). Individual \ -code examples and diagrams inside the report should still use fenced code blocks, \ -but the report itself must NOT be enclosed in one. -- Maintain proper Markdown formatting throughout. -- When including code examples, ALWAYS format them as proper fenced code blocks \ -with the correct language identifier (e.g. ```java, ```python). Code inside code \ -blocks MUST have proper line breaks and indentation — NEVER put multiple statements \ -on a single line. Each statement, brace, and logical block must be on its own line \ -with correct indentation. -- When including Mermaid diagrams, use ```mermaid fenced code blocks. Each Mermaid \ -statement MUST be on its own line — NEVER use semicolons to join multiple statements \ -on one line. For line breaks inside node labels, use
(NOT
). -- When including mathematical formulas or equations, ALWAYS use LaTeX notation. \ -NEVER use backtick code spans or Unicode symbols for math.""" - -# ─── Standard Report Footer ───────────────────────────────────────────────── -# Appended to every generated report after content generation. - -_REPORT_FOOTER = "Powered by SurfSense AI." - -# ─── Prompt: Single-Shot Report Generation ─────────────────────────────────── - -_REPORT_PROMPT = """You are an expert report writer. Generate a comprehensive Markdown report. - -**Topic:** {topic} -**Report Style:** {report_style} -{user_instructions_section} -{previous_version_section} - -**Source Content:** -{source_content} - ---- - -{length_instruction} - -Write a well-structured Markdown report with a # title, executive summary, organized sections, and conclusion. Cite facts from the source content. Be thorough and professional. - -{formatting_rules} -""" - -# ─── Prompt: Full-Document Revision (fallback when section-level fails) ────── - -_REVISION_PROMPT = """You are an expert report editor. Apply ONLY the requested changes — do NOT rewrite from scratch. - -**Topic:** {topic} -**Report Style:** {report_style} -**Modification Instructions:** {user_instructions_section} - -**Source Content (use if relevant):** -{source_content} - ---- - -**EXISTING REPORT:** - -{previous_report_content} - ---- - -{length_instruction} - -Preserve all structure and content not affected by the modification. - -{formatting_rules} -""" - -# ─── Prompt: Section-Level Revision — Identify Affected Sections ───────────── - -_IDENTIFY_SECTIONS_PROMPT = """You are analyzing a Markdown report to determine which sections need modification based on the user's request. - -**User's Modification Request:** {user_instructions} - -**Report Sections (indexed starting at 0):** -{sections_listing} - ---- - -Determine which sections need to be modified, added, or removed to fulfill the user's request. - -Return ONLY a JSON object with these fields: -- "modify": Array of section indices (0-based) that need content changes -- "add": Array of objects like {{"after_index": 2, "heading": "## New Section Title", "description": "What this section should cover"}} for new sections to insert -- "remove": Array of section indices to remove entirely (use sparingly) -- "reasoning": A brief explanation of your decisions - -Guidelines: -- If the change is GLOBAL (e.g., "change the tone", "make the whole report shorter", "translate to Spanish"), include ALL section indices in "modify". -- If the change is TARGETED (e.g., "expand the budget section", "fix the conclusion"), include ONLY the affected section indices. -- For "add a section about X", use the "add" field with the appropriate insertion point. -- Prefer modifying over removing+adding when possible. - -Return ONLY valid JSON, no markdown fences: -""" - -# ─── Prompt: Section-Level Revision — Revise a Single Section ──────────────── - -_REVISE_SECTION_PROMPT = """Revise ONLY this section based on the instructions. If the instructions don't apply, return it UNCHANGED. - -**Modification Instructions:** {user_instructions} - -**Current Section:** -{section_content} - -**Context (surrounding sections — for coherence only, do NOT output them):** -{context_sections} - -**Source Content:** -{source_content} - ---- - -Keep the same heading and heading level. Preserve content not affected by the modification. -{formatting_rules} -""" - -# ─── Prompt: New Section Generation (for section-level add) ───────────────── - -_NEW_SECTION_PROMPT = """You are an expert report writer. Write a new section to be inserted into an existing report. - -**Report Topic:** {topic} -**Report Style:** {report_style} -**Section Heading:** {heading} -**Section Goal:** {description} -**User Instructions:** {user_instructions} - -**Surrounding Context:** -{context_sections} - -**Source Content:** -{source_content} - ---- - -**Rules:** -1. Write ONLY this section, starting with the heading "{heading}". -2. Ensure the section flows naturally with the surrounding context. -3. Be comprehensive — cover the topic described above. -{formatting_rules} - -Write the new section now: -""" - - -# ─── Utility Functions ────────────────────────────────────────────────────── - - -def _strip_wrapping_code_fences(text: str) -> str: - """Remove wrapping code fences that LLMs often add around Markdown output. - - Handles patterns like: - ```markdown\\n...content...\\n``` - ````markdown\\n...content...\\n```` - ```md\\n...content...\\n``` - ```\\n...content...\\n``` - ```json\\n...content...\\n``` - Supports 3 or more backticks (LLMs escalate when content has triple-backtick blocks). - """ - stripped = text.strip() - # Match opening fence with 3+ backticks and optional language tag - m = re.match(r"^(`{3,})(?:markdown|md|json)?\s*\n", stripped) - if m: - fence = m.group(1) # e.g. "```" or "````" - if stripped.endswith(fence): - stripped = stripped[m.end() :] # remove opening fence - stripped = stripped[: -len(fence)].rstrip() # remove closing fence - return stripped - - -def _extract_metadata(content: str) -> dict[str, Any]: - """Extract metadata from generated Markdown content.""" - # Count section headings - headings = re.findall(r"^(#{1,6})\s+(.+)$", content, re.MULTILINE) - - # Word count - word_count = len(content.split()) - - # Character count - char_count = len(content) - - return { - "status": "ready", - "word_count": word_count, - "char_count": char_count, - "section_count": len(headings), - } - - -def _parse_sections(content: str) -> list[dict[str, str]]: - """Parse Markdown content into sections split by # and ## headings. - - Returns a list of dicts: [{"heading": "## Title", "body": "content..."}, ...] - Content before the first heading is captured with heading="". - ### and deeper headings are kept inside their parent ## section's body. - """ - lines = content.split("\n") - sections: list[dict[str, str]] = [] - current_heading = "" - current_body_lines: list[str] = [] - in_code_block = False - - for line in lines: - # Track code blocks to avoid matching headings inside them - stripped = line.strip() - if stripped.startswith("```"): - in_code_block = not in_code_block - - # Only split on # or ## headings (not ### or deeper) and only outside code blocks - is_section_heading = ( - not in_code_block - and re.match(r"^#{1,2}\s+", line) - and not re.match(r"^#{3,}\s+", line) - ) - - if is_section_heading: - # Save previous section - if current_heading or current_body_lines: - sections.append( - { - "heading": current_heading, - "body": "\n".join(current_body_lines).strip(), - } - ) - current_heading = line.strip() - current_body_lines = [] - else: - current_body_lines.append(line) - - # Save last section - if current_heading or current_body_lines: - sections.append( - { - "heading": current_heading, - "body": "\n".join(current_body_lines).strip(), - } - ) - - return sections - - -def _stitch_sections(sections: list[dict[str, str]]) -> str: - """Stitch parsed sections back into a single Markdown string.""" - parts = [] - for section in sections: - if section["heading"]: - parts.append(section["heading"]) - if section["body"]: - parts.append(section["body"]) - return "\n\n".join(parts) - - -# ─── Async Generation Helpers ─────────────────────────────────────────────── - - -async def _revise_with_sections( - llm: Any, - parent_content: str, - user_instructions: str, - source_content: str, - topic: str, - report_style: str, -) -> str | None: - """Section-level revision: identify affected sections and revise only those. - - Unchanged sections are kept byte-for-byte identical. - Returns the revised content, or None to trigger full-document revision fallback. - """ - # Parse report into sections - sections = _parse_sections(parent_content) - if len(sections) < 2: - logger.info( - "[generate_report] Too few sections for section-level revision, using full revision" - ) - return None - - # Build a sections listing for the LLM - sections_listing = "" - for i, sec in enumerate(sections): - heading = sec["heading"] or "(preamble — content before first heading)" - body_preview = ( - sec["body"][:200] + "..." if len(sec["body"]) > 200 else sec["body"] - ) - sections_listing += f"\n[{i}] {heading}\n Preview: {body_preview}\n" - - # Step 1: Ask LLM which sections need modification - identify_prompt = _IDENTIFY_SECTIONS_PROMPT.format( - user_instructions=user_instructions, - sections_listing=sections_listing, - ) - - try: - response = await llm.ainvoke([HumanMessage(content=identify_prompt)]) - raw = response.content - if not raw or not isinstance(raw, str): - return None - - raw = _strip_wrapping_code_fences(raw).strip() - json_match = re.search(r"\{[\s\S]*\}", raw) - if json_match: - raw = json_match.group(0) - - plan = json.loads(raw) - modify_indices: list[int] = plan.get("modify", []) - add_sections: list[dict[str, Any]] = plan.get("add", []) - remove_indices: list[int] = plan.get("remove", []) - reasoning = plan.get("reasoning", "") - - logger.info( - f"[generate_report] Section-level revision plan: " - f"modify={modify_indices}, add={len(add_sections)}, " - f"remove={remove_indices}, reasoning={reasoning}" - ) - except Exception: - logger.warning( - "[generate_report] Failed to identify sections for revision, " - "falling back to full revision", - exc_info=True, - ) - return None - - # If ALL sections need modification, full revision is more efficient and coherent - if len(modify_indices) >= len(sections): - logger.info( - "[generate_report] All sections need modification, deferring to full revision" - ) - return None - - # Compute total operations for progress tracking - total_ops = len(modify_indices) + len(add_sections) - current_op = 0 - - # Emit plan summary - parts = [] - if modify_indices: - parts.append( - f"modifying {len(modify_indices)} section{'s' if len(modify_indices) > 1 else ''}" - ) - if add_sections: - parts.append( - f"adding {len(add_sections)} new section{'s' if len(add_sections) > 1 else ''}" - ) - if remove_indices: - parts.append( - f"removing {len(remove_indices)} section{'s' if len(remove_indices) > 1 else ''}" - ) - plan_summary = ", ".join(parts) if parts else "no changes needed" - - dispatch_custom_event( - "report_progress", - { - "phase": "revision_plan", - "message": plan_summary.capitalize(), - "modify_count": len(modify_indices), - "add_count": len(add_sections), - "remove_count": len(remove_indices), - "total_ops": total_ops, - }, - ) - - # Step 2: Revise only the affected sections - revised_sections = list(sections) # shallow copy — unmodified sections stay as-is - - for idx in modify_indices: - if idx < 0 or idx >= len(sections): - continue - - current_op += 1 - sec = sections[idx] - - # Extract plain section name (strip markdown heading markers) - section_name = ( - re.sub(r"^#+\s*", "", sec["heading"]).strip() - if sec["heading"] - else "Preamble" - ) - dispatch_custom_event( - "report_progress", - { - "phase": "revising_section", - "message": f"Revising: {section_name} ({current_op}/{total_ops})...", - }, - ) - - section_content = ( - f"{sec['heading']}\n\n{sec['body']}" if sec["heading"] else sec["body"] - ) - - # Build context from surrounding sections - context_parts = [] - if idx > 0: - prev = sections[idx - 1] - prev_preview = prev["body"][:300] + ( - "..." if len(prev["body"]) > 300 else "" - ) - context_parts.append( - f"**Previous section:** {prev['heading']}\n{prev_preview}" - ) - if idx < len(sections) - 1: - nxt = sections[idx + 1] - nxt_preview = nxt["body"][:300] + ("..." if len(nxt["body"]) > 300 else "") - context_parts.append(f"**Next section:** {nxt['heading']}\n{nxt_preview}") - context = ( - "\n\n".join(context_parts) if context_parts else "(No surrounding sections)" - ) - - revise_prompt = _REVISE_SECTION_PROMPT.format( - user_instructions=user_instructions, - section_content=section_content, - context_sections=context, - source_content=source_content[:40000], - formatting_rules=_FORMATTING_RULES, - ) - - resp = await llm.ainvoke([HumanMessage(content=revise_prompt)]) - revised_text = resp.content - if revised_text and isinstance(revised_text, str): - revised_text = _strip_wrapping_code_fences(revised_text).strip() - # Parse the LLM output back into heading + body - revised_parsed = _parse_sections(revised_text) - if revised_parsed: - revised_sections[idx] = revised_parsed[0] - else: - revised_sections[idx] = { - "heading": sec["heading"], - "body": revised_text, - } - - logger.info(f"[generate_report] Revised section [{idx}]: {sec['heading']}") - - # Step 3: Handle new section additions (insert in reverse order to preserve indices) - for add_info in sorted( - add_sections, - key=lambda x: x.get("after_index", len(revised_sections) - 1), - reverse=True, - ): - current_op += 1 - after_idx = add_info.get("after_index", len(revised_sections) - 1) - heading = add_info.get("heading", "## New Section") - description = add_info.get("description", "") - - # Extract plain section name for progress display - plain_heading = re.sub(r"^#+\s*", "", heading).strip() - dispatch_custom_event( - "report_progress", - { - "phase": "adding_section", - "message": f"Adding: {plain_heading} ({current_op}/{total_ops})...", - }, - ) - - # Build context from the surrounding sections at the insertion point - ctx_parts = [] - if 0 <= after_idx < len(revised_sections): - before_sec = revised_sections[after_idx] - ctx_parts.append( - f"**Section before:** {before_sec['heading']}\n{before_sec['body'][:300]}" - ) - insert_idx = min(after_idx + 1, len(revised_sections)) - if insert_idx < len(revised_sections): - after_sec = revised_sections[insert_idx] - ctx_parts.append( - f"**Section after:** {after_sec['heading']}\n{after_sec['body'][:300]}" - ) - - new_prompt = _NEW_SECTION_PROMPT.format( - topic=topic, - report_style=report_style, - heading=heading, - description=description, - user_instructions=user_instructions, - context_sections="\n\n".join(ctx_parts) if ctx_parts else "(None)", - source_content=source_content[:30000], - formatting_rules=_FORMATTING_RULES, - ) - - resp = await llm.ainvoke([HumanMessage(content=new_prompt)]) - new_content = resp.content - if new_content and isinstance(new_content, str): - new_content = _strip_wrapping_code_fences(new_content).strip() - new_parsed = _parse_sections(new_content) - if new_parsed: - revised_sections.insert(insert_idx, new_parsed[0]) - else: - revised_sections.insert( - insert_idx, - { - "heading": heading, - "body": new_content, - }, - ) - - logger.info( - f"[generate_report] Added new section after [{after_idx}]: {heading}" - ) - - # Step 4: Handle removals (reverse order to preserve indices) - for idx in sorted(remove_indices, reverse=True): - if 0 <= idx < len(revised_sections): - logger.info( - f"[generate_report] Removed section [{idx}]: " - f"{revised_sections[idx]['heading']}" - ) - revised_sections.pop(idx) - - return _stitch_sections(revised_sections) - - -# ─── Tool Factory ─────────────────────────────────────────────────────────── - - -def create_generate_report_tool( - search_space_id: int, - thread_id: int | None = None, - connector_service: ConnectorService | None = None, - available_connectors: list[str] | None = None, - available_document_types: list[str] | None = None, -): - """ - Factory function to create the generate_report tool with injected dependencies. - - The tool generates a Markdown report inline using the search space's - document summary LLM, saves it to the database, and returns immediately. - - Uses short-lived database sessions for each DB operation so no connection - is held during the long LLM API call. - - Generation strategies: - - New reports: single-shot generation (1 LLM call) - - Revisions (targeted edits): section-level (unchanged sections preserved) - - Revisions (global changes): full-document revision fallback - - Source strategies: - - "provided"/"conversation": use only the supplied source_content - - "kb_search": search the knowledge base internally using targeted queries - - "auto": use source_content if sufficient, otherwise fall back to KB search - - Args: - search_space_id: The user's search space ID - thread_id: The chat thread ID for associating the report - connector_service: Optional connector service for internal KB search. - When provided, the tool can search the knowledge base internally - (used by the "kb_search" and "auto" source strategies). - available_connectors: Optional list of connector types available in the - search space (used to scope internal KB searches). - - Returns: - A configured tool function for generating reports - """ - - @tool - async def generate_report( - topic: str, - source_content: str = "", - source_strategy: str = "provided", - search_queries: list[str] | None = None, - report_style: str = "detailed", - user_instructions: str | None = None, - parent_report_id: int | None = None, - ) -> dict[str, Any]: - """ - Generate a structured Markdown report artifact from provided content. - - Use this tool when the user asks to create, generate, write, produce, - draft, or summarize into a report-style deliverable. - - Trigger classes include: - - Direct trigger words WITH creation/modification verb: report, - document, memo, letter, template, article, guide, blog post, - one-pager, briefing, comprehensive guide. - - Creation-intent phrases: "write a report", "generate a document", - "draft a summary", "create an executive summary". - - Modification-intent phrases: "revise the report", "update the - report", "make it shorter", "add a section about X", "expand the - budget section", "rewrite in formal tone". - - IMPORTANT — what does NOT count as "asking for a report": - - Questions or discussion about a report or its topic are NOT report - requests. Respond to these conversationally in chat. - Examples: "What other examples to put there?", "What else could be - added?", "Can you explain section 2?", "Is the data accurate?", - "What's missing?", "How could this be improved?", "What other - topics are related?" - - Quick summary requests, explanations, or follow-up questions. - - The test: Does the message contain a creation/modification VERB - (write, create, generate, draft, add, revise, update, expand, - rewrite, make) directed at producing a deliverable? If no verb - → answer in chat. - - FORMAT/EXPORT RULE: - - Always generate the report content in Markdown. - - If the user requests DOCX/Word/PDF or another file format, export - from the generated Markdown report. - - SOURCE STRATEGY (how to collect source material): - - source_strategy="conversation" — The conversation already has - enough context (prior Q&A, filesystem exploration, pasted text, - uploaded files, scraped webpages). Pass a thorough summary as - source_content. - - source_strategy="kb_search" — Search the knowledge base - internally. Provide 1-5 targeted search_queries. The tool - handles searching internally — do NOT manually read and dump - /documents/ files into source_content. - - source_strategy="provided" — Use only what is in source_content - (default, backward-compatible). - - source_strategy="auto" — Use source_content if it has enough - material; otherwise fall back to internal KB search using - search_queries. - - CONVERSATION REUSE (HIGH PRIORITY): - - If the user has been asking questions in this chat and the - conversation contains substantive answers/discussion on the - topic, prefer source_strategy="conversation" with a thorough - summary of the full chat history as source_content. - - The user's prior questions and your answers ARE the source - material. Do NOT redundantly search the knowledge base for - information that is already in the chat. - - VERSIONING — parent_report_id: - - Set parent_report_id when the user wants to MODIFY, REVISE, - IMPROVE, UPDATE, EXPAND, or ADD CONTENT TO an existing report - that was already generated in this conversation. - - This includes both explicit AND implicit modification requests. - If the user references the existing report using words like "it", - "this", "here", "the report", or clearly refers to a previously - generated report, treat it as a revision request. - - The value must be the report_id from a previous generate_report - result in this same conversation. - - Do NOT set parent_report_id when: - * The user asks for a report on a completely NEW/DIFFERENT topic - * The user says "generate another report" (new report, not revision) - * There is no prior report to reference - - Examples of when to SET parent_report_id: - User: "Make that report shorter" → parent_report_id = - User: "Add a cost analysis section to the report" → parent_report_id = - User: "Rewrite the report in a more formal tone" → parent_report_id = - User: "I want more details about pricing in here" → parent_report_id = - User: "Include more examples" → parent_report_id = - User: "Can you also cover nutrition in this?" → parent_report_id = - User: "Make it more detailed" → parent_report_id = - User: "Not bad, but expand on the budget section" → parent_report_id = - User: "Also mention the competitor landscape" → parent_report_id = - - Examples of when to LEAVE parent_report_id as None: - User: "Generate a report on climate change" → None (new topic) - User: "Write me a report about the budget" → None (new topic) - User: "Create another report, this time about marketing" → None - User: "Now write one about travel trends in Europe" → None (new topic) - - Args: - topic: Short title for the report (max ~8 words). - source_content: Text to base the report on. Can be empty when - using source_strategy="kb_search". - source_strategy: How to collect source material. One of - "provided", "conversation", "kb_search", or "auto". - search_queries: When source_strategy is "kb_search" or "auto", - provide 1-5 targeted search queries for the knowledge base. - These should be specific, not just the topic repeated. - report_style: "detailed", "deep_research", or "brief". - user_instructions: Optional focus or modification instructions. - When revising (parent_report_id set), describe WHAT TO CHANGE. - parent_report_id: ID of a previous report to revise (creates new - version in the same version group). - - Returns: - Dict with status, report_id, title, word_count, and message. - """ - # Initialize version tracking variables (used by _save_failed_report closure) - parent_report_content: str | None = None - report_group_id: int | None = None - - async def _save_failed_report(error_msg: str) -> int | None: - """Persist a failed report row using a short-lived session.""" - try: - async with shielded_async_session() as session: - failed_report = Report( - title=topic, - content=None, - report_metadata={ - "status": "failed", - "error_message": error_msg, - }, - report_style=report_style, - search_space_id=search_space_id, - thread_id=thread_id, - report_group_id=report_group_id, - ) - session.add(failed_report) - await session.commit() - await session.refresh(failed_report) - # If this is a new group (v1 failed), set group to self - if not failed_report.report_group_id: - failed_report.report_group_id = failed_report.id - await session.commit() - logger.info( - f"[generate_report] Saved failed report {failed_report.id}: {error_msg}" - ) - return failed_report.id - except Exception: - logger.exception( - "[generate_report] Could not persist failed report row" - ) - return None - - try: - # ── Phase 1: READ (short-lived session) ────────────────────── - # Fetch parent report and LLM config, then close the session - # so no DB connection is held during the long LLM call. - async with shielded_async_session() as read_session: - if parent_report_id: - parent_report = await read_session.get(Report, parent_report_id) - if parent_report: - report_group_id = parent_report.report_group_id - parent_report_content = parent_report.content - logger.info( - f"[generate_report] Creating new version from parent {parent_report_id} " - f"(group {report_group_id})" - ) - else: - logger.warning( - f"[generate_report] parent_report_id={parent_report_id} not found, " - "creating standalone report" - ) - - llm = await get_document_summary_llm(read_session, search_space_id) - # read_session closed — connection returned to pool - - if not llm: - error_msg = ( - "No LLM configured. Please configure a language model in Settings." - ) - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": topic, - } - - # Build the user instructions string - user_instructions_section = "" - if user_instructions: - user_instructions_section = ( - f"**Additional Instructions:** {user_instructions}" - ) - - # ── Phase 1b: SOURCE COLLECTION (smart KB search) ──────────── - # Decide whether to augment source_content with KB search results. - effective_source = source_content or "" - - strategy = (source_strategy or "provided").lower().strip() - - needs_kb_search = False - if strategy == "kb_search": - needs_kb_search = True - elif strategy == "auto": - # Heuristic: if source_content has fewer than 200 words, - # it's likely insufficient — augment with KB search. - word_count_estimate = len(effective_source.split()) - if word_count_estimate < 200: - needs_kb_search = True - logger.info( - f"[generate_report] auto strategy: source has ~{word_count_estimate} words, " - "triggering KB search" - ) - # "provided" and "conversation" → use source_content as-is - - if needs_kb_search and connector_service and search_queries: - query_count = min(len(search_queries), 5) - dispatch_custom_event( - "report_progress", - { - "phase": "kb_search", - "message": f"Searching knowledge base ({query_count} queries)...", - }, - ) - logger.info( - f"[generate_report] Running internal KB search with " - f"{query_count} queries: {search_queries[:5]}" - ) - try: - from .knowledge_base import search_knowledge_base_async - - # Run all queries in parallel, each with its own session - async def _run_single_query(q: str) -> str: - async with shielded_async_session() as kb_session: - kb_connector_svc = ConnectorService( - kb_session, search_space_id - ) - return await search_knowledge_base_async( - query=q, - search_space_id=search_space_id, - db_session=kb_session, - connector_service=kb_connector_svc, - top_k=10, - available_connectors=available_connectors, - available_document_types=available_document_types, - ) - - kb_results = await asyncio.gather( - *[_run_single_query(q) for q in search_queries[:5]] - ) - - # Merge non-empty results into source_content - kb_text_parts = [r for r in kb_results if r and r.strip()] - if kb_text_parts: - kb_combined = "\n\n---\n\n".join(kb_text_parts) - if effective_source.strip(): - effective_source = ( - effective_source - + "\n\n--- Knowledge Base Search Results ---\n\n" - + kb_combined - ) - else: - effective_source = kb_combined - - # Count docs found (rough: count tags) - doc_count = kb_combined.count("") - dispatch_custom_event( - "report_progress", - { - "phase": "kb_search_done", - "message": f"Found {doc_count} relevant documents" - if doc_count - else f"Found results from {len(kb_text_parts)} queries", - }, - ) - logger.info( - f"[generate_report] KB search added ~{len(kb_combined)} chars " - f"from {len(kb_text_parts)} queries" - ) - else: - dispatch_custom_event( - "report_progress", - { - "phase": "kb_search_done", - "message": "No results found in knowledge base", - }, - ) - logger.info("[generate_report] KB search returned no results") - - except Exception as e: - logger.warning( - f"[generate_report] Internal KB search failed: {e}. " - "Proceeding with existing source_content." - ) - elif needs_kb_search and not connector_service: - logger.warning( - "[generate_report] KB search requested but connector_service " - "not available. Using source_content as-is." - ) - elif needs_kb_search and not search_queries: - logger.warning( - "[generate_report] KB search requested but no search_queries " - "provided. Using source_content as-is." - ) - - capped_source = effective_source[:100000] # Cap source content - - # Length constraint — only when user explicitly asks for brevity - length_instruction = "" - if report_style == "brief": - length_instruction = ( - "**LENGTH CONSTRAINT (MANDATORY):** The user wants a SHORT report. " - "Keep it concise — aim for ~400 words (~1 page) unless a different " - "length is specified in the Additional Instructions above. " - "Prioritize brevity over thoroughness. Do NOT write a long report." - ) - - # ── Phase 2: LLM GENERATION (no DB connection held) ────────── - - report_content: str | None = None - - if parent_report_content: - # ─── REVISION MODE ─────────────────────────────────────── - # Strategy: Try section-level revision first (preserves - # unchanged sections byte-for-byte). Falls back to full- - # document revision if section identification fails or if - # all sections need changes. - dispatch_custom_event( - "report_progress", - { - "phase": "revision_start", - "message": "Analyzing sections to modify...", - }, - ) - logger.info( - "[generate_report] Revision mode — attempting section-level revision" - ) - report_content = await _revise_with_sections( - llm=llm, - parent_content=parent_report_content, - user_instructions=user_instructions - or "Improve and refine the report.", - source_content=capped_source, - topic=topic, - report_style=report_style, - ) - - if report_content is None: - # Fallback: full-document revision - dispatch_custom_event( - "report_progress", - {"phase": "writing", "message": "Rewriting your full report"}, - ) - logger.info( - "[generate_report] Section-level revision deferred, " - "using full-document revision" - ) - prompt = _REVISION_PROMPT.format( - topic=topic, - report_style=report_style, - user_instructions_section=user_instructions_section - or "Improve and refine the report.", - source_content=capped_source, - previous_report_content=parent_report_content, - length_instruction=length_instruction, - formatting_rules=_FORMATTING_RULES, - ) - response = await llm.ainvoke([HumanMessage(content=prompt)]) - report_content = response.content - - else: - # ─── NEW REPORT MODE ───────────────────────────────────── - # Single-shot generation: one LLM call produces the full - # report. Fast, globally coherent, and cost-efficient. - dispatch_custom_event( - "report_progress", - {"phase": "writing", "message": "Writing your report"}, - ) - logger.info( - "[generate_report] New report — using single-shot generation" - ) - prompt = _REPORT_PROMPT.format( - topic=topic, - report_style=report_style, - user_instructions_section=user_instructions_section, - previous_version_section="", - source_content=capped_source, - length_instruction=length_instruction, - formatting_rules=_FORMATTING_RULES, - ) - response = await llm.ainvoke([HumanMessage(content=prompt)]) - report_content = response.content - - # ── Validate LLM output ────────────────────────────────────── - - if not report_content or not isinstance(report_content, str): - error_msg = "LLM returned empty or invalid content" - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": topic, - } - - # LLMs often wrap output in ```markdown ... ``` fences — strip them - report_content = _strip_wrapping_code_fences(report_content) - - if not report_content: - error_msg = "LLM returned empty or invalid content" - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": topic, - } - - # Strip any existing footer(s) carried over from parent version(s) - while report_content.rstrip().endswith(_REPORT_FOOTER): - idx = report_content.rstrip().rfind(_REPORT_FOOTER) - report_content = report_content[:idx].rstrip() - if report_content.rstrip().endswith("---"): - report_content = report_content.rstrip()[:-3].rstrip() - - # Append exactly one standard disclaimer - report_content += "\n\n---\n\n" + _REPORT_FOOTER - - # Extract metadata (includes "status": "ready") - metadata = _extract_metadata(report_content) - - # ── Phase 3: WRITE (short-lived session) ───────────────────── - # Save the report to the database, then close the session. - async with shielded_async_session() as write_session: - report = Report( - title=topic, - content=report_content, - report_metadata=metadata, - report_style=report_style, - search_space_id=search_space_id, - thread_id=thread_id, - report_group_id=report_group_id, - ) - write_session.add(report) - await write_session.commit() - await write_session.refresh(report) - - # If this is a brand-new report (v1), set report_group_id = own id - if not report.report_group_id: - report.report_group_id = report.id - await write_session.commit() - - saved_report_id = report.id - saved_group_id = report.report_group_id - # write_session closed — connection returned to pool - - logger.info( - f"[generate_report] Created report {saved_report_id} " - f"(group={saved_group_id}): " - f"{metadata.get('word_count', 0)} words, " - f"{metadata.get('section_count', 0)} sections" - ) - - return { - "status": "ready", - "report_id": saved_report_id, - "title": topic, - "word_count": metadata.get("word_count", 0), - "is_revision": bool(parent_report_content), - "report_markdown": report_content, - "message": f"Report generated successfully: {topic}", - } - - except Exception as e: - error_message = str(e) - logger.exception(f"[generate_report] Error: {error_message}") - report_id = await _save_failed_report(error_message) - - return { - "status": "failed", - "error": error_message, - "report_id": report_id, - "title": topic, - } - - return generate_report diff --git a/surfsense_backend/app/agents/shared/tools/resume.py b/surfsense_backend/app/agents/shared/tools/resume.py deleted file mode 100644 index 4abe48ba6..000000000 --- a/surfsense_backend/app/agents/shared/tools/resume.py +++ /dev/null @@ -1,812 +0,0 @@ -""" -Resume generation tool for the SurfSense agent. - -Generates a structured resume as Typst source code using the rendercv package. -The LLM outputs only the content body (= heading, sections, entries) while -the template header (import + show rule) is hardcoded and prepended by the -backend. This eliminates LLM errors in the complex configuration block. - -Templates are stored in a registry so new designs can be added by defining -a new entry in _TEMPLATES. - -Uses the same short-lived session pattern as generate_report so no DB -connection is held during the long LLM call. -""" - -import io -import logging -import re -from datetime import UTC, datetime -from typing import Any - -import pypdf -import typst -from langchain_core.callbacks import dispatch_custom_event -from langchain_core.messages import HumanMessage -from langchain_core.tools import tool - -from app.db import Report, shielded_async_session -from app.services.llm_service import get_document_summary_llm - -logger = logging.getLogger(__name__) - - -# ─── Template Registry ─────────────────────────────────────────────────────── -# Each template defines: -# header - Typst import + show rule with {name}, {year}, {month}, {day} placeholders -# component_reference - component docs shown to the LLM -# rules - generation rules for the LLM - -_TEMPLATES: dict[str, dict[str, str]] = { - "classic": { - "header": """\ -#import "@preview/rendercv:0.3.0": * - -#show: rendercv.with( - name: "{name}", - title: "{name} - Resume", - footer: context {{ [#emph[{name} -- #str(here().page())\\/#str(counter(page).final().first())]] }}, - top-note: [ #emph[Last updated in {month_name} {year}] ], - locale-catalog-language: "en", - text-direction: ltr, - page-size: "us-letter", - page-top-margin: 0.7in, - page-bottom-margin: 0.7in, - page-left-margin: 0.7in, - page-right-margin: 0.7in, - page-show-footer: false, - page-show-top-note: true, - colors-body: rgb(0, 0, 0), - colors-name: rgb(0, 0, 0), - colors-headline: rgb(0, 0, 0), - colors-connections: rgb(0, 0, 0), - colors-section-titles: rgb(0, 0, 0), - colors-links: rgb(0, 0, 0), - colors-footer: rgb(128, 128, 128), - colors-top-note: rgb(128, 128, 128), - typography-line-spacing: 0.6em, - typography-alignment: "justified", - typography-date-and-location-column-alignment: right, - typography-font-family-body: "XCharter", - typography-font-family-name: "XCharter", - typography-font-family-headline: "XCharter", - typography-font-family-connections: "XCharter", - typography-font-family-section-titles: "XCharter", - typography-font-size-body: 10pt, - typography-font-size-name: 25pt, - typography-font-size-headline: 10pt, - typography-font-size-connections: 10pt, - typography-font-size-section-titles: 1.2em, - typography-small-caps-name: false, - typography-small-caps-headline: false, - typography-small-caps-connections: false, - typography-small-caps-section-titles: false, - typography-bold-name: false, - typography-bold-headline: false, - typography-bold-connections: false, - typography-bold-section-titles: true, - links-underline: true, - links-show-external-link-icon: false, - header-alignment: center, - header-photo-width: 3.5cm, - header-space-below-name: 0.7cm, - header-space-below-headline: 0.7cm, - header-space-below-connections: 0.7cm, - header-connections-hyperlink: true, - header-connections-show-icons: false, - header-connections-display-urls-instead-of-usernames: true, - header-connections-separator: "|", - header-connections-space-between-connections: 0.5cm, - section-titles-type: "with_full_line", - section-titles-line-thickness: 0.5pt, - section-titles-space-above: 0.5cm, - section-titles-space-below: 0.3cm, - sections-allow-page-break: true, - sections-space-between-text-based-entries: 0.15cm, - sections-space-between-regular-entries: 0.42cm, - entries-date-and-location-width: 4.15cm, - entries-side-space: 0cm, - entries-space-between-columns: 0.1cm, - entries-allow-page-break: false, - entries-short-second-row: false, - entries-degree-width: 1cm, - entries-summary-space-left: 0cm, - entries-summary-space-above: 0.08cm, - entries-highlights-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt), - entries-highlights-nested-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt), - entries-highlights-space-left: 0cm, - entries-highlights-space-above: 0.08cm, - entries-highlights-space-between-items: 0.02cm, - entries-highlights-space-between-bullet-and-text: 0.3em, - date: datetime( - year: {year}, - month: {month}, - day: {day}, - ), -) - -""", - "component_reference": """\ -Available components (use ONLY these): - -= Full Name // Top-level heading — person's full name - -#connections( // Contact info row (pipe-separated) - [City, Country], - [#link("mailto:email@example.com", icon: false, if-underline: false, if-color: false)[email\\@example.com]], - [#link("https://linkedin.com/in/user", icon: false, if-underline: false, if-color: false)[linkedin.com\\/in\\/user]], - [#link("https://github.com/user", icon: false, if-underline: false, if-color: false)[github.com\\/user]], -) - -== Section Title // Section heading (arbitrary name) - -#regular-entry( // Work experience, projects, publications, etc. - [ - #strong[Role/Title], Company Name -- Location - ], - [ - Start -- End - ], - main-column-second-row: [ - - Achievement or responsibility - - Another bullet point - ], -) - -#education-entry( // Education entries - [ - #strong[Institution], Degree in Field -- Location - ], - [ - Start -- End - ], - main-column-second-row: [ - - GPA, honours, relevant coursework - ], -) - -#summary([Short paragraph summary]) // Optional summary inside an entry -#content-area([Free-form content]) // Freeform text block - -For skills sections, use one bullet per category label: -- #strong[Category:] item1, item2, item3 - -For simple list sections (e.g. Honors), use plain bullet points: -- Item one -- Item two -""", - "rules": """\ -RULES: -- Do NOT include any #import or #show lines. Start directly with = Full Name. -- Output ONLY valid Typst content. No explanatory text before or after. -- Do NOT wrap output in ```typst code fences. -- The = heading MUST use the person's COMPLETE full name exactly as provided. NEVER shorten or abbreviate. -- Escape @ symbols inside link labels with a backslash: email\\@example.com -- Escape forward slashes in link display text: linkedin.com\\/in\\/user -- Every section MUST use == heading. -- Use #regular-entry() for experience, projects, publications, certifications, and similar entries. -- Use #education-entry() for education. -- For skills sections, use one bullet line per category with a bold label. -- Keep content professional, concise, and achievement-oriented. -- Use action verbs for bullet points (Led, Built, Designed, Reduced, etc.). -- This template works for ALL professions — adapt sections to the user's field. -- Default behavior should prioritize concise one-page content. -""", - }, -} - -DEFAULT_TEMPLATE = "classic" -MIN_RESUME_PAGES = 1 -MAX_RESUME_PAGES = 5 -MAX_COMPRESSION_ATTEMPTS = 2 - - -# ─── Template Helpers ───────────────────────────────────────────────────────── - - -def _get_template(template_id: str | None = None) -> dict[str, str]: - """Get a template by ID, falling back to default.""" - return _TEMPLATES.get(template_id or DEFAULT_TEMPLATE, _TEMPLATES[DEFAULT_TEMPLATE]) - - -_MONTH_NAMES = [ - "", - "Jan", - "Feb", - "Mar", - "Apr", - "May", - "Jun", - "Jul", - "Aug", - "Sep", - "Oct", - "Nov", - "Dec", -] - - -def _build_header(template: dict[str, str], name: str) -> str: - """Build the template header with the person's name and current date.""" - now = datetime.now(tz=UTC) - return ( - template["header"] - .replace("{name}", name) - .replace("{year}", str(now.year)) - .replace("{month}", str(now.month)) - .replace("{day}", str(now.day)) - .replace("{month_name}", _MONTH_NAMES[now.month]) - ) - - -def _strip_header(full_source: str) -> str: - """Strip the import + show rule from stored source to get the body only. - - Finds the closing parenthesis of the rendercv.with(...) block by tracking - nesting depth, then returns everything after it. - """ - show_match = re.search(r"#show:\s*rendercv\.with\(", full_source) - if not show_match: - return full_source - - start = show_match.end() - depth = 1 - i = start - while i < len(full_source) and depth > 0: - if full_source[i] == "(": - depth += 1 - elif full_source[i] == ")": - depth -= 1 - i += 1 - - return full_source[i:].lstrip("\n") - - -def _extract_name(body: str) -> str | None: - """Extract the person's full name from the = heading in the body.""" - match = re.search(r"^=\s+(.+)$", body, re.MULTILINE) - return match.group(1).strip() if match else None - - -def _strip_imports(body: str) -> str: - """Remove any #import or #show lines the LLM might accidentally include.""" - lines = body.split("\n") - cleaned: list[str] = [] - skip_show = False - depth = 0 - - for line in lines: - stripped = line.strip() - - if stripped.startswith("#import"): - continue - - if skip_show: - depth += stripped.count("(") - stripped.count(")") - if depth <= 0: - skip_show = False - continue - - if stripped.startswith("#show:") and "rendercv" in stripped: - depth = stripped.count("(") - stripped.count(")") - if depth > 0: - skip_show = True - continue - - cleaned.append(line) - - result = "\n".join(cleaned).strip() - return result - - -def _build_llm_reference(template: dict[str, str]) -> str: - """Build the LLM prompt reference from a template.""" - return f"""\ -You MUST output valid Typst content for a resume. -Do NOT include any #import or #show lines — those are handled automatically. -Start directly with the = Full Name heading. - -{template["component_reference"]} - -{template["rules"]}""" - - -# ─── Prompts ───────────────────────────────────────────────────────────────── - -_RESUME_PROMPT = """\ -You are an expert resume writer. Generate professional resume content as Typst markup. - -{llm_reference} - -**User Information:** -{user_info} - -**Target Maximum Pages:** {max_pages} - -{user_instructions_section} - -Generate the resume content now (starting with = Full Name): -""" - -_REVISION_PROMPT = """\ -You are an expert resume editor. Modify the existing resume according to the instructions. -Apply ONLY the requested changes — do NOT rewrite sections that are not affected. - -{llm_reference} - -**Target Maximum Pages:** {max_pages} - -**Modification Instructions:** {user_instructions} - -**EXISTING RESUME CONTENT:** - -{previous_content} - ---- - -Output the complete, updated resume content with the changes applied (starting with = Full Name): -""" - -_FIX_COMPILE_PROMPT = """\ -The resume content you generated failed to compile. Fix the error while preserving all content. - -{llm_reference} - -**Compilation Error:** -{error} - -**Full Typst Source (for context — error line numbers refer to this):** -{full_source} - -**Your content starts after the template header. Output ONLY the content portion \ -(starting with = Full Name), NOT the #import or #show rule:** -""" - -_COMPRESS_TO_PAGE_LIMIT_PROMPT = """\ -The resume compiles, but it exceeds the maximum allowed page count. -Compress the resume while preserving high-impact accomplishments and role relevance. - -{llm_reference} - -**Target Maximum Pages:** {max_pages} -**Current Page Count:** {actual_pages} -**Compression Attempt:** {attempt_number} - -Compression priorities (in this order): -1) Keep recent, high-impact, role-relevant bullets. -2) Remove low-impact or redundant bullets. -3) Shorten verbose wording while preserving meaning. -4) Trim older or less relevant details before recent ones. - -Return the complete updated Typst content (starting with = Full Name), and keep it at or below the target pages. - -**EXISTING RESUME CONTENT:** -{previous_content} -""" - - -# ─── Helpers ───────────────────────────────────────────────────────────────── - - -def _strip_typst_fences(text: str) -> str: - """Remove wrapping ```typst ... ``` fences that LLMs sometimes add.""" - stripped = text.strip() - m = re.match(r"^(`{3,})(?:typst|typ)?\s*\n", stripped) - if m: - fence = m.group(1) - if stripped.endswith(fence): - stripped = stripped[m.end() :] - stripped = stripped[: -len(fence)].rstrip() - return stripped - - -def _compile_typst(source: str) -> bytes: - """Compile Typst source to PDF bytes. Raises on failure.""" - return typst.compile(source.encode("utf-8")) - - -def _count_pdf_pages(pdf_bytes: bytes) -> int: - """Count the number of pages in compiled PDF bytes.""" - with io.BytesIO(pdf_bytes) as pdf_stream: - reader = pypdf.PdfReader(pdf_stream) - return len(reader.pages) - - -def _validate_max_pages(max_pages: int) -> int: - """Validate and normalize max_pages input.""" - if MIN_RESUME_PAGES <= max_pages <= MAX_RESUME_PAGES: - return max_pages - msg = ( - f"max_pages must be between {MIN_RESUME_PAGES} and " - f"{MAX_RESUME_PAGES}. Received: {max_pages}" - ) - raise ValueError(msg) - - -# ─── Tool Factory ─────────────────────────────────────────────────────────── - - -def create_generate_resume_tool( - search_space_id: int, - thread_id: int | None = None, -): - """ - Factory function to create the generate_resume tool. - - Generates a Typst-based resume, validates it via compilation, - and stores the source in the Report table with content_type='typst'. - The LLM generates only the content body; the template header is - prepended by the backend. - """ - - @tool - async def generate_resume( - user_info: str, - user_instructions: str | None = None, - parent_report_id: int | None = None, - max_pages: int = 1, - ) -> dict[str, Any]: - """ - Generate a professional resume as a Typst document. - - Use this tool when the user asks to create, build, generate, write, - or draft a resume or CV. Also use it when the user wants to modify, - update, or revise an existing resume generated in this conversation. - - Trigger phrases include: - - "build me a resume", "create my resume", "generate a CV" - - "update my resume", "change my title", "add my new job" - - "make my resume more concise", "reformat my resume" - - Do NOT use this tool for: - - General questions about resumes or career advice - - Reviewing or critiquing a resume without changes - - Cover letters (use generate_report instead) - - VERSIONING — parent_report_id: - - Set parent_report_id when the user wants to MODIFY an existing - resume that was already generated in this conversation. - - Leave as None for new resumes. - - Args: - user_info: The user's resume content — work experience, - education, skills, contact info, etc. Can be structured - or unstructured text. - user_instructions: Optional style or content preferences - (e.g. "emphasize leadership", "keep it to one page", - "use a modern style"). For revisions, describe what to change. - parent_report_id: ID of a previous resume to revise (creates - new version in the same version group). - max_pages: Maximum number of pages for the generated resume. - Defaults to 1. Allowed range: 1-5. - - Returns: - Dict with status, report_id, title, and content_type. - """ - report_group_id: int | None = None - parent_content: str | None = None - - template = _get_template() - llm_reference = _build_llm_reference(template) - - async def _save_failed_report(error_msg: str) -> int | None: - try: - async with shielded_async_session() as session: - failed = Report( - title="Resume", - content=None, - content_type="typst", - report_metadata={ - "status": "failed", - "error_message": error_msg, - }, - report_style="resume", - search_space_id=search_space_id, - thread_id=thread_id, - report_group_id=report_group_id, - ) - session.add(failed) - await session.commit() - await session.refresh(failed) - if not failed.report_group_id: - failed.report_group_id = failed.id - await session.commit() - logger.info( - f"[generate_resume] Saved failed report {failed.id}: {error_msg}" - ) - return failed.id - except Exception: - logger.exception( - "[generate_resume] Could not persist failed report row" - ) - return None - - try: - try: - validated_max_pages = _validate_max_pages(max_pages) - except ValueError as e: - error_msg = str(e) - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - # ── Phase 1: READ ───────────────────────────────────────────── - async with shielded_async_session() as read_session: - if parent_report_id: - parent_report = await read_session.get(Report, parent_report_id) - if parent_report: - report_group_id = parent_report.report_group_id - parent_content = parent_report.content - logger.info( - f"[generate_resume] Revising from parent {parent_report_id} " - f"(group {report_group_id})" - ) - - llm = await get_document_summary_llm(read_session, search_space_id) - - if not llm: - error_msg = ( - "No LLM configured. Please configure a language model in Settings." - ) - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - # ── Phase 2: LLM GENERATION ─────────────────────────────────── - - user_instructions_section = "" - if user_instructions: - user_instructions_section = ( - f"**Additional Instructions:** {user_instructions}" - ) - - if parent_content: - dispatch_custom_event( - "report_progress", - {"phase": "writing", "message": "Updating your resume"}, - ) - parent_body = _strip_header(parent_content) - prompt = _REVISION_PROMPT.format( - llm_reference=llm_reference, - max_pages=validated_max_pages, - user_instructions=user_instructions - or "Improve and refine the resume.", - previous_content=parent_body, - ) - else: - dispatch_custom_event( - "report_progress", - {"phase": "writing", "message": "Building your resume"}, - ) - prompt = _RESUME_PROMPT.format( - llm_reference=llm_reference, - user_info=user_info, - max_pages=validated_max_pages, - user_instructions_section=user_instructions_section, - ) - - response = await llm.ainvoke([HumanMessage(content=prompt)]) - body = response.content - - if not body or not isinstance(body, str): - error_msg = "LLM returned empty or invalid content" - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - body = _strip_typst_fences(body) - body = _strip_imports(body) - - # ── Phase 3: ASSEMBLE + COMPILE ─────────────────────────────── - dispatch_custom_event( - "report_progress", - {"phase": "compiling", "message": "Compiling resume..."}, - ) - - name = _extract_name(body) or "Resume" - typst_source = "" - actual_pages = 0 - compression_attempts = 0 - target_page_met = False - - for compression_round in range(MAX_COMPRESSION_ATTEMPTS + 1): - header = _build_header(template, name) - typst_source = header + body - compile_error: str | None = None - pdf_bytes: bytes | None = None - - for compile_attempt in range(2): - try: - pdf_bytes = _compile_typst(typst_source) - compile_error = None - break - except Exception as e: - compile_error = str(e) - logger.warning( - "[generate_resume] Compile attempt %s failed: %s", - compile_attempt + 1, - compile_error, - ) - - if compile_attempt == 0: - dispatch_custom_event( - "report_progress", - { - "phase": "fixing", - "message": "Fixing compilation issue...", - }, - ) - fix_prompt = _FIX_COMPILE_PROMPT.format( - llm_reference=llm_reference, - error=compile_error, - full_source=typst_source, - ) - fix_response = await llm.ainvoke( - [HumanMessage(content=fix_prompt)] - ) - if fix_response.content and isinstance( - fix_response.content, str - ): - body = _strip_typst_fences(fix_response.content) - body = _strip_imports(body) - name = _extract_name(body) or name - header = _build_header(template, name) - typst_source = header + body - - if compile_error or not pdf_bytes: - error_msg = ( - "Typst compilation failed after 2 attempts: " - f"{compile_error or 'Unknown compile error'}" - ) - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - actual_pages = _count_pdf_pages(pdf_bytes) - if actual_pages <= validated_max_pages: - target_page_met = True - break - - if compression_round >= MAX_COMPRESSION_ATTEMPTS: - break - - compression_attempts += 1 - dispatch_custom_event( - "report_progress", - { - "phase": "compressing", - "message": f"Condensing resume to {validated_max_pages} page(s)...", - }, - ) - compress_prompt = _COMPRESS_TO_PAGE_LIMIT_PROMPT.format( - llm_reference=llm_reference, - max_pages=validated_max_pages, - actual_pages=actual_pages, - attempt_number=compression_attempts, - previous_content=body, - ) - compress_response = await llm.ainvoke( - [HumanMessage(content=compress_prompt)] - ) - if not compress_response.content or not isinstance( - compress_response.content, str - ): - error_msg = "LLM returned empty content while compressing resume" - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - body = _strip_typst_fences(compress_response.content) - body = _strip_imports(body) - name = _extract_name(body) or name - - if actual_pages > MAX_RESUME_PAGES: - error_msg = ( - "Resume exceeds hard page limit after compression retries. " - f"Hard limit: <= {MAX_RESUME_PAGES} page(s), actual: {actual_pages}." - ) - report_id = await _save_failed_report(error_msg) - return { - "status": "failed", - "error": error_msg, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - # ── Phase 4: SAVE ───────────────────────────────────────────── - dispatch_custom_event( - "report_progress", - {"phase": "saving", "message": "Saving your resume"}, - ) - - resume_title = f"{name} - Resume" if name != "Resume" else "Resume" - - metadata: dict[str, Any] = { - "status": "ready", - "word_count": len(typst_source.split()), - "char_count": len(typst_source), - "target_max_pages": validated_max_pages, - "actual_page_count": actual_pages, - "page_limit_enforced": True, - "compression_attempts": compression_attempts, - "target_page_met": target_page_met, - } - - async with shielded_async_session() as write_session: - report = Report( - title=resume_title, - content=typst_source, - content_type="typst", - report_metadata=metadata, - report_style="resume", - search_space_id=search_space_id, - thread_id=thread_id, - report_group_id=report_group_id, - ) - write_session.add(report) - await write_session.commit() - await write_session.refresh(report) - - if not report.report_group_id: - report.report_group_id = report.id - await write_session.commit() - - saved_id = report.id - - logger.info(f"[generate_resume] Created resume {saved_id}: {resume_title}") - - return { - "status": "ready", - "report_id": saved_id, - "title": resume_title, - "content_type": "typst", - "is_revision": bool(parent_content), - "message": ( - f"Resume generated successfully: {resume_title}" - if target_page_met - else ( - f"Resume generated, but could not fit the target of <= {validated_max_pages} " - f"page(s). Final length: {actual_pages} page(s)." - ) - ), - } - - except Exception as e: - error_message = str(e) - logger.exception(f"[generate_resume] Error: {error_message}") - report_id = await _save_failed_report(error_message) - return { - "status": "failed", - "error": error_message, - "report_id": report_id, - "title": "Resume", - "content_type": "typst", - } - - return generate_resume diff --git a/surfsense_backend/app/agents/shared/tools/video_presentation.py b/surfsense_backend/app/agents/shared/tools/video_presentation.py deleted file mode 100644 index 34f5183ca..000000000 --- a/surfsense_backend/app/agents/shared/tools/video_presentation.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -Video presentation generation tool for the SurfSense agent. - -This module provides a factory function for creating the generate_video_presentation -tool that submits a Celery task for background video presentation generation. The -tool then polls the row until it reaches a terminal status (READY/FAILED) and -returns that status. The wait is bounded by the chat's HTTP / process lifetime; -see app.agents.shared.deliverable_wait for details. -""" - -import logging -from typing import Any - -from langchain_core.tools import tool -from sqlalchemy.ext.asyncio import AsyncSession - -from app.agents.shared.deliverable_wait import wait_for_deliverable -from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session - -logger = logging.getLogger(__name__) - - -def create_generate_video_presentation_tool( - search_space_id: int, - db_session: AsyncSession, - thread_id: int | None = None, -): - """ - Factory function to create the generate_video_presentation tool with injected dependencies. - - Pre-creates video presentation record with pending status so the ID is available - immediately for frontend polling. The row is written via a fresh, tool-local - session so parallel tool calls (e.g. video + podcast in the same agent step) - don't share an ``AsyncSession`` (which is not concurrency-safe). - """ - del db_session # writes use a fresh tool-local session, see below - - @tool - async def generate_video_presentation( - source_content: str, - video_title: str = "SurfSense Presentation", - user_prompt: str | None = None, - ) -> dict[str, Any]: - """Generate a video presentation from the provided content. - - Use this tool when the user asks to create a video, presentation, slides, or slide deck. - - Args: - source_content: The text content to turn into a presentation. - video_title: Title for the presentation (default: "SurfSense Presentation") - user_prompt: Optional style/tone instructions. - """ - try: - # See podcast.py for the rationale: parallel tool calls share the - # streaming session, and AsyncSession is not concurrency-safe — - # interleaved flushes produce "Session.add() during flush" and - # poison the transaction for every concurrent tool. - async with shielded_async_session() as session: - video_pres = VideoPresentation( - title=video_title, - status=VideoPresentationStatus.PENDING, - search_space_id=search_space_id, - thread_id=thread_id, - ) - session.add(video_pres) - await session.commit() - await session.refresh(video_pres) - video_pres_id = video_pres.id - - from app.tasks.celery_tasks.video_presentation_tasks import ( - generate_video_presentation_task, - ) - - task = generate_video_presentation_task.delay( - video_presentation_id=video_pres_id, - source_content=source_content, - search_space_id=search_space_id, - user_prompt=user_prompt, - ) - - logger.info( - "[generate_video_presentation] Created video presentation %s, task: %s", - video_pres_id, - task.id, - ) - - # Wait until the Celery worker flips the row to a terminal - # state. No internal budget — see deliverable_wait module. - terminal_status, _columns, elapsed = await wait_for_deliverable( - model=VideoPresentation, - row_id=video_pres_id, - columns=[VideoPresentation.status], - terminal_statuses={ - VideoPresentationStatus.READY, - VideoPresentationStatus.FAILED, - }, - ) - - if terminal_status == VideoPresentationStatus.READY: - logger.info( - "[generate_video_presentation] %s READY in %.2fs", - video_pres_id, - elapsed, - ) - return { - "status": VideoPresentationStatus.READY.value, - "video_presentation_id": video_pres_id, - "title": video_title, - "message": "Video presentation generated and saved.", - } - - # Only other terminal state is FAILED. - logger.warning( - "[generate_video_presentation] %s FAILED in %.2fs", - video_pres_id, - elapsed, - ) - return { - "status": VideoPresentationStatus.FAILED.value, - "video_presentation_id": video_pres_id, - "title": video_title, - "error": ( - "Background worker reported FAILED status for this " - "video presentation." - ), - } - - except Exception as e: - error_message = str(e) - logger.exception("[generate_video_presentation] Error: %s", error_message) - return { - "status": VideoPresentationStatus.FAILED.value, - "error": error_message, - "title": video_title, - "video_presentation_id": None, - } - - return generate_video_presentation diff --git a/surfsense_backend/app/services/provider_capabilities.py b/surfsense_backend/app/services/provider_capabilities.py index e68fd53f3..f7eafe11d 100644 --- a/surfsense_backend/app/services/provider_capabilities.py +++ b/surfsense_backend/app/services/provider_capabilities.py @@ -56,7 +56,7 @@ logger = logging.getLogger(__name__) # class-body init time. ``app.agents.shared.llm_config`` re-exports # this constant under the historical ``PROVIDER_MAP`` name; placing the # map there directly would re-introduce the -# ``app.config -> ... -> app.agents.shared.tools.generate_image -> +# ``app.config -> ... -> deliverables/tools/generate_image -> # app.config`` cycle that prompted the move. _PROVIDER_PREFIX_MAP: dict[str, str] = { "OPENAI": "openai", diff --git a/surfsense_backend/tests/unit/agents/new_chat/tools/test_resume_page_limits.py b/surfsense_backend/tests/unit/agents/new_chat/tools/test_resume_page_limits.py index 8bfcb8947..5c4c41b64 100644 --- a/surfsense_backend/tests/unit/agents/new_chat/tools/test_resume_page_limits.py +++ b/surfsense_backend/tests/unit/agents/new_chat/tools/test_resume_page_limits.py @@ -1,17 +1,58 @@ -"""Unit tests for resume page-limit helpers and enforcement flow.""" +"""Unit tests for resume page-limit helpers and enforcement flow. + +Targets the live deliverables resume tool. The tool returns a +``Command`` (payload JSON-encoded in ``update["messages"][0].content`` +plus a receipt), so flow tests invoke it via a ToolCall dict and unwrap +the payload. +""" import io +import json from types import SimpleNamespace from unittest.mock import AsyncMock import pypdf import pytest +from langchain.tools import ToolRuntime -from app.agents.shared.tools import resume as resume_tool +from app.agents.multi_agent_chat.subagents.builtins.deliverables.tools import ( + resume as resume_tool, +) pytestmark = pytest.mark.unit +@pytest.fixture(autouse=True) +def _silence_progress_events(monkeypatch): + """The live tool emits ``dispatch_custom_event`` progress updates that + require a langgraph run context; neutralize them for direct unit calls.""" + monkeypatch.setattr(resume_tool, "dispatch_custom_event", lambda *a, **k: None) + + +def _runtime(tool_call_id: str = "call-1") -> ToolRuntime: + """Minimal ToolRuntime; the resume tool only reads ``tool_call_id``.""" + return ToolRuntime( + state={}, + context=None, + config={}, + stream_writer=None, + tool_call_id=tool_call_id, + store=None, + ) + + +async def _invoke(tool, args: dict) -> dict: + """Drive a Command-returning tool and return its decoded payload. + + These tools take an injected ``ToolRuntime`` and return a + ``Command``; invoke the raw coroutine with a hand-built runtime + (the repo's pattern for unit-testing such tools) and decode the + ToolMessage payload. + """ + command = await tool.coroutine(runtime=_runtime(), **args) + return json.loads(command.update["messages"][0].content) + + class _FakeReport: _next_id = 1000 @@ -108,7 +149,7 @@ async def test_generate_resume_defaults_to_one_page_target(monkeypatch) -> None: monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: 1) tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1) - result = await tool.ainvoke({"user_info": "Jane Doe experience"}) + result = await _invoke(tool, {"user_info": "Jane Doe experience"}) assert result["status"] == "ready" assert prompts @@ -138,7 +179,7 @@ async def test_generate_resume_compresses_when_over_limit(monkeypatch) -> None: monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts)) tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1) - result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1}) + result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1}) assert result["status"] == "ready" assert write_session.added, "Expected successful report write" @@ -173,7 +214,7 @@ async def test_generate_resume_returns_ready_when_target_not_met(monkeypatch) -> monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts)) tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1) - result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1}) + result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1}) assert result["status"] == "ready" assert "could not fit the target" in (result["message"] or "").lower() @@ -206,7 +247,7 @@ async def test_generate_resume_fails_when_hard_limit_exceeded(monkeypatch) -> No monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts)) tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1) - result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1}) + result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1}) assert result["status"] == "failed" assert "hard page limit" in (result["error"] or "").lower() diff --git a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py index 575d245c2..6ba66ec57 100644 --- a/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py +++ b/surfsense_backend/tests/unit/services/test_image_gen_api_base_defense.py @@ -20,6 +20,7 @@ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch import pytest +from langchain.tools import ToolRuntime pytestmark = pytest.mark.unit @@ -90,7 +91,9 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty(): async def test_generate_image_tool_global_sets_api_base_when_config_empty(): """Same defense at the agent tool entry point — both surfaces share the same OpenRouter config payloads.""" - from app.agents.shared.tools import generate_image as gi_module + from app.agents.multi_agent_chat.subagents.builtins.deliverables.tools import ( + generate_image as gi_module, + ) cfg = { "id": -20_001, @@ -150,7 +153,19 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty(): tool = gi_module.create_generate_image_tool( search_space_id=1, db_session=MagicMock() ) - await tool.ainvoke({"prompt": "a cat", "n": 1}) + # The live tool takes an injected ToolRuntime and returns a Command; + # drive the raw coroutine with a minimal runtime (the tool only reads + # ``tool_call_id``). We assert on what was forwarded to litellm, not + # on the return value. + runtime = ToolRuntime( + state={}, + context=None, + config={}, + stream_writer=None, + tool_call_id="call-1", + store=None, + ) + await tool.coroutine(prompt="a cat", n=1, runtime=runtime) assert captured.get("api_base") == "https://openrouter.ai/api/v1" assert captured["model"] == "openrouter/openai/gpt-image-1"