mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): delete deliverable dead twins in shared/tools; fix live image api_base bug
The deliverables subagent runs its own generate_image/podcast/report/resume/ video_presentation (via tools/index.py); the shared/tools copies had zero production importers — classic dead twins. Removed them so deliverable tools live only in their vertical slice. While repointing the 2 stranded unit tests at the LIVE deliverables modules, found the OpenRouter empty-api_base defense (resolve_api_base) existed ONLY in the dead shared generate_image, never propagated to the live multi-agent copy. Ported the fix into deliverables/tools/generate_image.py (both the global-config and user-DB-config branches) so an empty api_base no longer falls through to LiteLLM's global api_base (Azure) and 404s. Tests now exercise the live Command/receipt-returning tools (invoke the raw coroutine with a hand-built ToolRuntime; resume progress events neutralized).
This commit is contained in:
parent
64512c604d
commit
8d0090c6a1
10 changed files with 104 additions and 2519 deletions
|
|
@ -25,6 +25,7 @@ from app.services.image_gen_router_service import (
|
|||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -43,13 +44,16 @@ _PROVIDER_MAP = {
|
|||
}
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
|
|
@ -163,14 +167,20 @@ def create_generate_image_tool(
|
|||
err = f"Image generation config {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg["model_name"],
|
||||
cfg.get("custom_provider"),
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
|
|
@ -191,14 +201,20 @@ def create_generate_image_tool(
|
|||
err = f"Image generation config {config_id} not found"
|
||||
return _failed({"error": err}, error=err)
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value,
|
||||
db_cfg.model_name,
|
||||
db_cfg.custom_provider,
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
|
|
|
|||
|
|
@ -1,37 +1,24 @@
|
|||
"""
|
||||
Tools module for SurfSense deep agent.
|
||||
"""Cross-agent shared tools and tool metadata.
|
||||
|
||||
This module contains all the tools available to the SurfSense agent.
|
||||
To add a new tool, see the documentation in registry.py.
|
||||
|
||||
Available tools:
|
||||
- generate_podcast: Generate audio podcasts from content
|
||||
- generate_video_presentation: Generate video presentations with slides and narration
|
||||
- generate_image: Generate images from text descriptions using AI models
|
||||
Tool *implementations* live with the agents that own them (e.g. deliverable
|
||||
generators under ``subagents/builtins/deliverables/tools``). This package
|
||||
holds only the genuinely shared pieces: the display-metadata catalog and the
|
||||
knowledge-base helpers used across agents.
|
||||
"""
|
||||
|
||||
# Registry exports
|
||||
# Tool factory exports (for direct use)
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .catalog import TOOL_CATALOG, ToolMetadata
|
||||
from .knowledge_base import (
|
||||
CONNECTOR_DESCRIPTIONS,
|
||||
format_documents_for_context,
|
||||
search_knowledge_base_async,
|
||||
)
|
||||
from .catalog import TOOL_CATALOG, ToolMetadata
|
||||
from .podcast import create_generate_podcast_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
|
||||
__all__ = [
|
||||
# Tool catalog (display metadata)
|
||||
"TOOL_CATALOG",
|
||||
"ToolMetadata",
|
||||
# Knowledge base utilities
|
||||
"CONNECTOR_DESCRIPTIONS",
|
||||
"ToolMetadata",
|
||||
# Tool factories
|
||||
"create_generate_image_tool",
|
||||
"create_generate_podcast_tool",
|
||||
"create_generate_video_presentation_tool",
|
||||
"format_documents_for_context",
|
||||
"search_knowledge_base_async",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,280 +0,0 @@
|
|||
"""
|
||||
Image generation tool for the SurfSense agent.
|
||||
|
||||
This module provides a tool that generates images using litellm.aimage_generation()
|
||||
and returns the result directly in a format the frontend Image component can render.
|
||||
|
||||
Config resolution:
|
||||
1. Uses the search space's image_generation_config_id preference
|
||||
2. Falls back to Auto mode (router load balancing) if available
|
||||
3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from litellm import aimage_generation
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
ImageGenerationConfig,
|
||||
SearchSpace,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.services.provider_api_base import resolve_api_base
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping (same as routes)
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||
if custom_provider:
|
||||
return custom_provider
|
||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
prefix = _resolve_provider_prefix(provider, custom_provider)
|
||||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image gen config by negative ID."""
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
|
||||
|
||||
def create_generate_image_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_image tool.
|
||||
|
||||
Args:
|
||||
search_space_id: The search space ID (for config resolution)
|
||||
db_session: Reserved for compatibility with the tool registry.
|
||||
The streaming task's ``AsyncSession`` is shared by every tool;
|
||||
because AsyncSession is not concurrency-safe, parallel tool calls
|
||||
would interleave flushes (e.g. podcast + image in the same step)
|
||||
and poison the transaction. This tool opens its own session.
|
||||
"""
|
||||
del db_session # use a fresh per-call session, see below
|
||||
|
||||
@tool
|
||||
async def generate_image(
|
||||
prompt: str,
|
||||
n: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate an image from a text description using AI image models.
|
||||
|
||||
Use this tool when the user asks you to create, generate, draw, or make an image.
|
||||
The generated image will be displayed directly in the chat.
|
||||
|
||||
Args:
|
||||
prompt: A detailed text description of the image to generate.
|
||||
Be specific about subject, style, colors, composition, and mood.
|
||||
n: Number of images to generate (1-4). Default: 1
|
||||
|
||||
Returns:
|
||||
A dictionary containing the generated image(s) for display in the chat.
|
||||
"""
|
||||
try:
|
||||
# Use a per-call session so concurrent tool calls don't share an
|
||||
# AsyncSession (which is not concurrency-safe). The streaming
|
||||
# task's session is shared across every tool; without isolation,
|
||||
# autoflushes from a concurrent writer poison this tool too.
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return {"error": "Search space not found"}
|
||||
|
||||
config_id = (
|
||||
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
|
||||
# Build generation kwargs
|
||||
# NOTE: size, quality, and style are intentionally NOT passed.
|
||||
# Different models support different values for these params
|
||||
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
|
||||
# gpt-image-1 wants "high"/"medium"/"low"; size options also
|
||||
# differ). Letting the model use its own defaults avoids errors.
|
||||
gen_kwargs: dict[str, Any] = {}
|
||||
if n is not None and n > 1:
|
||||
gen_kwargs["n"] = n
|
||||
|
||||
# Call litellm based on config type
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
return {
|
||||
"error": "No image generation models configured. "
|
||||
"Please add an image model in Settings > Image Models."
|
||||
}
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=prompt, model="auto", **gen_kwargs
|
||||
)
|
||||
elif config_id < 0:
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
return {
|
||||
"error": f"Image generation config {config_id} not found"
|
||||
}
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||
)
|
||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
api_base = resolve_api_base(
|
||||
provider=cfg.get("provider"),
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=cfg.get("api_base"),
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = user-created ImageGenerationConfig
|
||||
cfg_result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id
|
||||
)
|
||||
)
|
||||
db_cfg = cfg_result.scalars().first()
|
||||
if not db_cfg:
|
||||
return {
|
||||
"error": f"Image generation config {config_id} not found"
|
||||
}
|
||||
|
||||
provider_prefix = _resolve_provider_prefix(
|
||||
db_cfg.provider.value, db_cfg.custom_provider
|
||||
)
|
||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
api_base = resolve_api_base(
|
||||
provider=db_cfg.provider.value,
|
||||
provider_prefix=provider_prefix,
|
||||
config_api_base=db_cfg.api_base,
|
||||
)
|
||||
if api_base:
|
||||
gen_kwargs["api_base"] = api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
|
||||
# Parse the response and store in DB
|
||||
response_dict = (
|
||||
response.model_dump()
|
||||
if hasattr(response, "model_dump")
|
||||
else dict(response)
|
||||
)
|
||||
|
||||
# Generate a random access token for this image
|
||||
access_token = generate_image_token()
|
||||
|
||||
# Save to image_generations table for history
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
image_generation_config_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
db_image_gen_id = db_image_gen.id
|
||||
|
||||
# Extract image URLs from response
|
||||
images = response_dict.get("data", [])
|
||||
if not images:
|
||||
return {"error": "No images were generated"}
|
||||
|
||||
first_image = images[0]
|
||||
revised_prompt = first_image.get("revised_prompt", prompt)
|
||||
|
||||
# Resolve image URL:
|
||||
# - If the API returned a URL, use it directly.
|
||||
# - If the API returned b64_json (e.g. gpt-image-1), serve the
|
||||
# image through our backend endpoint to avoid bloating the
|
||||
# LLM context with megabytes of base64 data.
|
||||
if first_image.get("url"):
|
||||
image_url = first_image["url"]
|
||||
elif first_image.get("b64_json"):
|
||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||
image_url = (
|
||||
f"{backend_url}/api/v1/image-generations/"
|
||||
f"{db_image_gen_id}/image?token={access_token}"
|
||||
)
|
||||
else:
|
||||
return {"error": "No displayable image data in the response"}
|
||||
|
||||
image_id = f"image-{hashlib.md5(image_url.encode()).hexdigest()[:12]}"
|
||||
|
||||
return {
|
||||
"id": image_id,
|
||||
"assetId": image_url,
|
||||
"src": image_url,
|
||||
"alt": revised_prompt or prompt,
|
||||
"title": "Generated Image",
|
||||
"description": revised_prompt if revised_prompt != prompt else None,
|
||||
"domain": "ai-generated",
|
||||
"ratio": "auto",
|
||||
"generated": True,
|
||||
"prompt": prompt,
|
||||
"image_count": len(images),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Image generation failed in tool")
|
||||
return {
|
||||
"error": f"Image generation failed: {e!s}",
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
return generate_image
|
||||
|
|
@ -1,160 +0,0 @@
|
|||
"""
|
||||
Podcast generation tool for the SurfSense agent.
|
||||
|
||||
This module provides a factory function for creating the generate_podcast tool
|
||||
that submits a Celery task for background podcast generation. The tool then
|
||||
polls the podcast row until it reaches a terminal status (READY/FAILED) and
|
||||
returns that status. The wait is bounded by the chat's HTTP / process
|
||||
lifetime; see app.agents.shared.deliverable_wait for details.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.shared.deliverable_wait import wait_for_deliverable
|
||||
from app.db import Podcast, PodcastStatus, shielded_async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_generate_podcast_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
thread_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_podcast tool with injected dependencies.
|
||||
|
||||
Pre-creates podcast record with pending status so podcast_id is available
|
||||
immediately for frontend polling.
|
||||
|
||||
Args:
|
||||
search_space_id: The user's search space ID
|
||||
db_session: Reserved for future read-side use; the row is written via a
|
||||
fresh, tool-local session so parallel tool calls (e.g. podcast +
|
||||
video presentation in the same agent step) don't share an
|
||||
``AsyncSession`` (which is not concurrency-safe).
|
||||
thread_id: The chat thread ID for associating the podcast
|
||||
|
||||
Returns:
|
||||
A configured tool function for generating podcasts
|
||||
"""
|
||||
del db_session # writes use a fresh tool-local session, see below
|
||||
|
||||
@tool
|
||||
async def generate_podcast(
|
||||
source_content: str,
|
||||
podcast_title: str = "SurfSense Podcast",
|
||||
user_prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a podcast from the provided content.
|
||||
|
||||
Use this tool when the user asks to create, generate, or make a podcast.
|
||||
Common triggers include phrases like:
|
||||
- "Give me a podcast about this"
|
||||
- "Create a podcast from this conversation"
|
||||
- "Generate a podcast summary"
|
||||
- "Make a podcast about..."
|
||||
- "Turn this into a podcast"
|
||||
|
||||
Args:
|
||||
source_content: The text content to convert into a podcast.
|
||||
podcast_title: Title for the podcast (default: "SurfSense Podcast")
|
||||
user_prompt: Optional instructions for podcast style, tone, or format.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- status: PodcastStatus value (pending, generating, or failed)
|
||||
- podcast_id: The podcast ID for polling (when status is pending or generating)
|
||||
- title: The podcast title
|
||||
- message: Status message (or "error" field if status is failed)
|
||||
"""
|
||||
try:
|
||||
# Open a fresh session per call. The streaming task's session is
|
||||
# shared between every tool, and ``AsyncSession`` is NOT safe for
|
||||
# concurrent use: when the LLM emits parallel tool calls, two
|
||||
# concurrent ``add()`` / ``commit()`` paths interleave and the
|
||||
# second one hits "Session.add() during flush" → the transaction
|
||||
# is poisoned for both tools.
|
||||
async with shielded_async_session() as session:
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
status=PodcastStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
session.add(podcast)
|
||||
await session.commit()
|
||||
await session.refresh(podcast)
|
||||
podcast_id = podcast.id
|
||||
|
||||
from app.tasks.celery_tasks.podcast_tasks import (
|
||||
generate_content_podcast_task,
|
||||
)
|
||||
|
||||
task = generate_content_podcast_task.delay(
|
||||
podcast_id=podcast_id,
|
||||
source_content=source_content,
|
||||
search_space_id=search_space_id,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[generate_podcast] Created podcast %s, task: %s",
|
||||
podcast_id,
|
||||
task.id,
|
||||
)
|
||||
|
||||
# Wait until the Celery worker flips the row to a terminal
|
||||
# state. No internal budget — see deliverable_wait module.
|
||||
terminal_status, columns, elapsed = await wait_for_deliverable(
|
||||
model=Podcast,
|
||||
row_id=podcast_id,
|
||||
columns=[Podcast.status, Podcast.file_location],
|
||||
terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED},
|
||||
)
|
||||
|
||||
if terminal_status == PodcastStatus.READY:
|
||||
file_location = columns[1] if columns else None
|
||||
logger.info(
|
||||
"[generate_podcast] Podcast %s READY in %.2fs (file=%s)",
|
||||
podcast_id,
|
||||
elapsed,
|
||||
file_location,
|
||||
)
|
||||
return {
|
||||
"status": PodcastStatus.READY.value,
|
||||
"podcast_id": podcast_id,
|
||||
"title": podcast_title,
|
||||
"file_location": file_location,
|
||||
"message": ("Podcast generated and saved to your podcast panel."),
|
||||
}
|
||||
|
||||
# Only other terminal state is FAILED.
|
||||
logger.warning(
|
||||
"[generate_podcast] Podcast %s FAILED in %.2fs",
|
||||
podcast_id,
|
||||
elapsed,
|
||||
)
|
||||
return {
|
||||
"status": PodcastStatus.FAILED.value,
|
||||
"podcast_id": podcast_id,
|
||||
"title": podcast_title,
|
||||
"error": ("Background worker reported FAILED status for this podcast."),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.exception("[generate_podcast] Error: %s", error_message)
|
||||
return {
|
||||
"status": PodcastStatus.FAILED.value,
|
||||
"error": error_message,
|
||||
"title": podcast_title,
|
||||
"podcast_id": None,
|
||||
}
|
||||
|
||||
return generate_podcast
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,812 +0,0 @@
|
|||
"""
|
||||
Resume generation tool for the SurfSense agent.
|
||||
|
||||
Generates a structured resume as Typst source code using the rendercv package.
|
||||
The LLM outputs only the content body (= heading, sections, entries) while
|
||||
the template header (import + show rule) is hardcoded and prepended by the
|
||||
backend. This eliminates LLM errors in the complex configuration block.
|
||||
|
||||
Templates are stored in a registry so new designs can be added by defining
|
||||
a new entry in _TEMPLATES.
|
||||
|
||||
Uses the same short-lived session pattern as generate_report so no DB
|
||||
connection is held during the long LLM call.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import pypdf
|
||||
import typst
|
||||
from langchain_core.callbacks import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.db import Report, shielded_async_session
|
||||
from app.services.llm_service import get_document_summary_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── Template Registry ───────────────────────────────────────────────────────
|
||||
# Each template defines:
|
||||
# header - Typst import + show rule with {name}, {year}, {month}, {day} placeholders
|
||||
# component_reference - component docs shown to the LLM
|
||||
# rules - generation rules for the LLM
|
||||
|
||||
_TEMPLATES: dict[str, dict[str, str]] = {
|
||||
"classic": {
|
||||
"header": """\
|
||||
#import "@preview/rendercv:0.3.0": *
|
||||
|
||||
#show: rendercv.with(
|
||||
name: "{name}",
|
||||
title: "{name} - Resume",
|
||||
footer: context {{ [#emph[{name} -- #str(here().page())\\/#str(counter(page).final().first())]] }},
|
||||
top-note: [ #emph[Last updated in {month_name} {year}] ],
|
||||
locale-catalog-language: "en",
|
||||
text-direction: ltr,
|
||||
page-size: "us-letter",
|
||||
page-top-margin: 0.7in,
|
||||
page-bottom-margin: 0.7in,
|
||||
page-left-margin: 0.7in,
|
||||
page-right-margin: 0.7in,
|
||||
page-show-footer: false,
|
||||
page-show-top-note: true,
|
||||
colors-body: rgb(0, 0, 0),
|
||||
colors-name: rgb(0, 0, 0),
|
||||
colors-headline: rgb(0, 0, 0),
|
||||
colors-connections: rgb(0, 0, 0),
|
||||
colors-section-titles: rgb(0, 0, 0),
|
||||
colors-links: rgb(0, 0, 0),
|
||||
colors-footer: rgb(128, 128, 128),
|
||||
colors-top-note: rgb(128, 128, 128),
|
||||
typography-line-spacing: 0.6em,
|
||||
typography-alignment: "justified",
|
||||
typography-date-and-location-column-alignment: right,
|
||||
typography-font-family-body: "XCharter",
|
||||
typography-font-family-name: "XCharter",
|
||||
typography-font-family-headline: "XCharter",
|
||||
typography-font-family-connections: "XCharter",
|
||||
typography-font-family-section-titles: "XCharter",
|
||||
typography-font-size-body: 10pt,
|
||||
typography-font-size-name: 25pt,
|
||||
typography-font-size-headline: 10pt,
|
||||
typography-font-size-connections: 10pt,
|
||||
typography-font-size-section-titles: 1.2em,
|
||||
typography-small-caps-name: false,
|
||||
typography-small-caps-headline: false,
|
||||
typography-small-caps-connections: false,
|
||||
typography-small-caps-section-titles: false,
|
||||
typography-bold-name: false,
|
||||
typography-bold-headline: false,
|
||||
typography-bold-connections: false,
|
||||
typography-bold-section-titles: true,
|
||||
links-underline: true,
|
||||
links-show-external-link-icon: false,
|
||||
header-alignment: center,
|
||||
header-photo-width: 3.5cm,
|
||||
header-space-below-name: 0.7cm,
|
||||
header-space-below-headline: 0.7cm,
|
||||
header-space-below-connections: 0.7cm,
|
||||
header-connections-hyperlink: true,
|
||||
header-connections-show-icons: false,
|
||||
header-connections-display-urls-instead-of-usernames: true,
|
||||
header-connections-separator: "|",
|
||||
header-connections-space-between-connections: 0.5cm,
|
||||
section-titles-type: "with_full_line",
|
||||
section-titles-line-thickness: 0.5pt,
|
||||
section-titles-space-above: 0.5cm,
|
||||
section-titles-space-below: 0.3cm,
|
||||
sections-allow-page-break: true,
|
||||
sections-space-between-text-based-entries: 0.15cm,
|
||||
sections-space-between-regular-entries: 0.42cm,
|
||||
entries-date-and-location-width: 4.15cm,
|
||||
entries-side-space: 0cm,
|
||||
entries-space-between-columns: 0.1cm,
|
||||
entries-allow-page-break: false,
|
||||
entries-short-second-row: false,
|
||||
entries-degree-width: 1cm,
|
||||
entries-summary-space-left: 0cm,
|
||||
entries-summary-space-above: 0.08cm,
|
||||
entries-highlights-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
|
||||
entries-highlights-nested-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
|
||||
entries-highlights-space-left: 0cm,
|
||||
entries-highlights-space-above: 0.08cm,
|
||||
entries-highlights-space-between-items: 0.02cm,
|
||||
entries-highlights-space-between-bullet-and-text: 0.3em,
|
||||
date: datetime(
|
||||
year: {year},
|
||||
month: {month},
|
||||
day: {day},
|
||||
),
|
||||
)
|
||||
|
||||
""",
|
||||
"component_reference": """\
|
||||
Available components (use ONLY these):
|
||||
|
||||
= Full Name // Top-level heading — person's full name
|
||||
|
||||
#connections( // Contact info row (pipe-separated)
|
||||
[City, Country],
|
||||
[#link("mailto:email@example.com", icon: false, if-underline: false, if-color: false)[email\\@example.com]],
|
||||
[#link("https://linkedin.com/in/user", icon: false, if-underline: false, if-color: false)[linkedin.com\\/in\\/user]],
|
||||
[#link("https://github.com/user", icon: false, if-underline: false, if-color: false)[github.com\\/user]],
|
||||
)
|
||||
|
||||
== Section Title // Section heading (arbitrary name)
|
||||
|
||||
#regular-entry( // Work experience, projects, publications, etc.
|
||||
[
|
||||
#strong[Role/Title], Company Name -- Location
|
||||
],
|
||||
[
|
||||
Start -- End
|
||||
],
|
||||
main-column-second-row: [
|
||||
- Achievement or responsibility
|
||||
- Another bullet point
|
||||
],
|
||||
)
|
||||
|
||||
#education-entry( // Education entries
|
||||
[
|
||||
#strong[Institution], Degree in Field -- Location
|
||||
],
|
||||
[
|
||||
Start -- End
|
||||
],
|
||||
main-column-second-row: [
|
||||
- GPA, honours, relevant coursework
|
||||
],
|
||||
)
|
||||
|
||||
#summary([Short paragraph summary]) // Optional summary inside an entry
|
||||
#content-area([Free-form content]) // Freeform text block
|
||||
|
||||
For skills sections, use one bullet per category label:
|
||||
- #strong[Category:] item1, item2, item3
|
||||
|
||||
For simple list sections (e.g. Honors), use plain bullet points:
|
||||
- Item one
|
||||
- Item two
|
||||
""",
|
||||
"rules": """\
|
||||
RULES:
|
||||
- Do NOT include any #import or #show lines. Start directly with = Full Name.
|
||||
- Output ONLY valid Typst content. No explanatory text before or after.
|
||||
- Do NOT wrap output in ```typst code fences.
|
||||
- The = heading MUST use the person's COMPLETE full name exactly as provided. NEVER shorten or abbreviate.
|
||||
- Escape @ symbols inside link labels with a backslash: email\\@example.com
|
||||
- Escape forward slashes in link display text: linkedin.com\\/in\\/user
|
||||
- Every section MUST use == heading.
|
||||
- Use #regular-entry() for experience, projects, publications, certifications, and similar entries.
|
||||
- Use #education-entry() for education.
|
||||
- For skills sections, use one bullet line per category with a bold label.
|
||||
- Keep content professional, concise, and achievement-oriented.
|
||||
- Use action verbs for bullet points (Led, Built, Designed, Reduced, etc.).
|
||||
- This template works for ALL professions — adapt sections to the user's field.
|
||||
- Default behavior should prioritize concise one-page content.
|
||||
""",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_TEMPLATE = "classic"
|
||||
MIN_RESUME_PAGES = 1
|
||||
MAX_RESUME_PAGES = 5
|
||||
MAX_COMPRESSION_ATTEMPTS = 2
|
||||
|
||||
|
||||
# ─── Template Helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_template(template_id: str | None = None) -> dict[str, str]:
|
||||
"""Get a template by ID, falling back to default."""
|
||||
return _TEMPLATES.get(template_id or DEFAULT_TEMPLATE, _TEMPLATES[DEFAULT_TEMPLATE])
|
||||
|
||||
|
||||
_MONTH_NAMES = [
|
||||
"",
|
||||
"Jan",
|
||||
"Feb",
|
||||
"Mar",
|
||||
"Apr",
|
||||
"May",
|
||||
"Jun",
|
||||
"Jul",
|
||||
"Aug",
|
||||
"Sep",
|
||||
"Oct",
|
||||
"Nov",
|
||||
"Dec",
|
||||
]
|
||||
|
||||
|
||||
def _build_header(template: dict[str, str], name: str) -> str:
|
||||
"""Build the template header with the person's name and current date."""
|
||||
now = datetime.now(tz=UTC)
|
||||
return (
|
||||
template["header"]
|
||||
.replace("{name}", name)
|
||||
.replace("{year}", str(now.year))
|
||||
.replace("{month}", str(now.month))
|
||||
.replace("{day}", str(now.day))
|
||||
.replace("{month_name}", _MONTH_NAMES[now.month])
|
||||
)
|
||||
|
||||
|
||||
def _strip_header(full_source: str) -> str:
|
||||
"""Strip the import + show rule from stored source to get the body only.
|
||||
|
||||
Finds the closing parenthesis of the rendercv.with(...) block by tracking
|
||||
nesting depth, then returns everything after it.
|
||||
"""
|
||||
show_match = re.search(r"#show:\s*rendercv\.with\(", full_source)
|
||||
if not show_match:
|
||||
return full_source
|
||||
|
||||
start = show_match.end()
|
||||
depth = 1
|
||||
i = start
|
||||
while i < len(full_source) and depth > 0:
|
||||
if full_source[i] == "(":
|
||||
depth += 1
|
||||
elif full_source[i] == ")":
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return full_source[i:].lstrip("\n")
|
||||
|
||||
|
||||
def _extract_name(body: str) -> str | None:
|
||||
"""Extract the person's full name from the = heading in the body."""
|
||||
match = re.search(r"^=\s+(.+)$", body, re.MULTILINE)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
|
||||
def _strip_imports(body: str) -> str:
|
||||
"""Remove any #import or #show lines the LLM might accidentally include."""
|
||||
lines = body.split("\n")
|
||||
cleaned: list[str] = []
|
||||
skip_show = False
|
||||
depth = 0
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith("#import"):
|
||||
continue
|
||||
|
||||
if skip_show:
|
||||
depth += stripped.count("(") - stripped.count(")")
|
||||
if depth <= 0:
|
||||
skip_show = False
|
||||
continue
|
||||
|
||||
if stripped.startswith("#show:") and "rendercv" in stripped:
|
||||
depth = stripped.count("(") - stripped.count(")")
|
||||
if depth > 0:
|
||||
skip_show = True
|
||||
continue
|
||||
|
||||
cleaned.append(line)
|
||||
|
||||
result = "\n".join(cleaned).strip()
|
||||
return result
|
||||
|
||||
|
||||
def _build_llm_reference(template: dict[str, str]) -> str:
|
||||
"""Build the LLM prompt reference from a template."""
|
||||
return f"""\
|
||||
You MUST output valid Typst content for a resume.
|
||||
Do NOT include any #import or #show lines — those are handled automatically.
|
||||
Start directly with the = Full Name heading.
|
||||
|
||||
{template["component_reference"]}
|
||||
|
||||
{template["rules"]}"""
|
||||
|
||||
|
||||
# ─── Prompts ─────────────────────────────────────────────────────────────────
|
||||
|
||||
_RESUME_PROMPT = """\
|
||||
You are an expert resume writer. Generate professional resume content as Typst markup.
|
||||
|
||||
{llm_reference}
|
||||
|
||||
**User Information:**
|
||||
{user_info}
|
||||
|
||||
**Target Maximum Pages:** {max_pages}
|
||||
|
||||
{user_instructions_section}
|
||||
|
||||
Generate the resume content now (starting with = Full Name):
|
||||
"""
|
||||
|
||||
_REVISION_PROMPT = """\
|
||||
You are an expert resume editor. Modify the existing resume according to the instructions.
|
||||
Apply ONLY the requested changes — do NOT rewrite sections that are not affected.
|
||||
|
||||
{llm_reference}
|
||||
|
||||
**Target Maximum Pages:** {max_pages}
|
||||
|
||||
**Modification Instructions:** {user_instructions}
|
||||
|
||||
**EXISTING RESUME CONTENT:**
|
||||
|
||||
{previous_content}
|
||||
|
||||
---
|
||||
|
||||
Output the complete, updated resume content with the changes applied (starting with = Full Name):
|
||||
"""
|
||||
|
||||
_FIX_COMPILE_PROMPT = """\
|
||||
The resume content you generated failed to compile. Fix the error while preserving all content.
|
||||
|
||||
{llm_reference}
|
||||
|
||||
**Compilation Error:**
|
||||
{error}
|
||||
|
||||
**Full Typst Source (for context — error line numbers refer to this):**
|
||||
{full_source}
|
||||
|
||||
**Your content starts after the template header. Output ONLY the content portion \
|
||||
(starting with = Full Name), NOT the #import or #show rule:**
|
||||
"""
|
||||
|
||||
_COMPRESS_TO_PAGE_LIMIT_PROMPT = """\
|
||||
The resume compiles, but it exceeds the maximum allowed page count.
|
||||
Compress the resume while preserving high-impact accomplishments and role relevance.
|
||||
|
||||
{llm_reference}
|
||||
|
||||
**Target Maximum Pages:** {max_pages}
|
||||
**Current Page Count:** {actual_pages}
|
||||
**Compression Attempt:** {attempt_number}
|
||||
|
||||
Compression priorities (in this order):
|
||||
1) Keep recent, high-impact, role-relevant bullets.
|
||||
2) Remove low-impact or redundant bullets.
|
||||
3) Shorten verbose wording while preserving meaning.
|
||||
4) Trim older or less relevant details before recent ones.
|
||||
|
||||
Return the complete updated Typst content (starting with = Full Name), and keep it at or below the target pages.
|
||||
|
||||
**EXISTING RESUME CONTENT:**
|
||||
{previous_content}
|
||||
"""
|
||||
|
||||
|
||||
# ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _strip_typst_fences(text: str) -> str:
|
||||
"""Remove wrapping ```typst ... ``` fences that LLMs sometimes add."""
|
||||
stripped = text.strip()
|
||||
m = re.match(r"^(`{3,})(?:typst|typ)?\s*\n", stripped)
|
||||
if m:
|
||||
fence = m.group(1)
|
||||
if stripped.endswith(fence):
|
||||
stripped = stripped[m.end() :]
|
||||
stripped = stripped[: -len(fence)].rstrip()
|
||||
return stripped
|
||||
|
||||
|
||||
def _compile_typst(source: str) -> bytes:
|
||||
"""Compile Typst source to PDF bytes. Raises on failure."""
|
||||
return typst.compile(source.encode("utf-8"))
|
||||
|
||||
|
||||
def _count_pdf_pages(pdf_bytes: bytes) -> int:
|
||||
"""Count the number of pages in compiled PDF bytes."""
|
||||
with io.BytesIO(pdf_bytes) as pdf_stream:
|
||||
reader = pypdf.PdfReader(pdf_stream)
|
||||
return len(reader.pages)
|
||||
|
||||
|
||||
def _validate_max_pages(max_pages: int) -> int:
|
||||
"""Validate and normalize max_pages input."""
|
||||
if MIN_RESUME_PAGES <= max_pages <= MAX_RESUME_PAGES:
|
||||
return max_pages
|
||||
msg = (
|
||||
f"max_pages must be between {MIN_RESUME_PAGES} and "
|
||||
f"{MAX_RESUME_PAGES}. Received: {max_pages}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
# ─── Tool Factory ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_generate_resume_tool(
|
||||
search_space_id: int,
|
||||
thread_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_resume tool.
|
||||
|
||||
Generates a Typst-based resume, validates it via compilation,
|
||||
and stores the source in the Report table with content_type='typst'.
|
||||
The LLM generates only the content body; the template header is
|
||||
prepended by the backend.
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def generate_resume(
|
||||
user_info: str,
|
||||
user_instructions: str | None = None,
|
||||
parent_report_id: int | None = None,
|
||||
max_pages: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a professional resume as a Typst document.
|
||||
|
||||
Use this tool when the user asks to create, build, generate, write,
|
||||
or draft a resume or CV. Also use it when the user wants to modify,
|
||||
update, or revise an existing resume generated in this conversation.
|
||||
|
||||
Trigger phrases include:
|
||||
- "build me a resume", "create my resume", "generate a CV"
|
||||
- "update my resume", "change my title", "add my new job"
|
||||
- "make my resume more concise", "reformat my resume"
|
||||
|
||||
Do NOT use this tool for:
|
||||
- General questions about resumes or career advice
|
||||
- Reviewing or critiquing a resume without changes
|
||||
- Cover letters (use generate_report instead)
|
||||
|
||||
VERSIONING — parent_report_id:
|
||||
- Set parent_report_id when the user wants to MODIFY an existing
|
||||
resume that was already generated in this conversation.
|
||||
- Leave as None for new resumes.
|
||||
|
||||
Args:
|
||||
user_info: The user's resume content — work experience,
|
||||
education, skills, contact info, etc. Can be structured
|
||||
or unstructured text.
|
||||
user_instructions: Optional style or content preferences
|
||||
(e.g. "emphasize leadership", "keep it to one page",
|
||||
"use a modern style"). For revisions, describe what to change.
|
||||
parent_report_id: ID of a previous resume to revise (creates
|
||||
new version in the same version group).
|
||||
max_pages: Maximum number of pages for the generated resume.
|
||||
Defaults to 1. Allowed range: 1-5.
|
||||
|
||||
Returns:
|
||||
Dict with status, report_id, title, and content_type.
|
||||
"""
|
||||
report_group_id: int | None = None
|
||||
parent_content: str | None = None
|
||||
|
||||
template = _get_template()
|
||||
llm_reference = _build_llm_reference(template)
|
||||
|
||||
async def _save_failed_report(error_msg: str) -> int | None:
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
failed = Report(
|
||||
title="Resume",
|
||||
content=None,
|
||||
content_type="typst",
|
||||
report_metadata={
|
||||
"status": "failed",
|
||||
"error_message": error_msg,
|
||||
},
|
||||
report_style="resume",
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
report_group_id=report_group_id,
|
||||
)
|
||||
session.add(failed)
|
||||
await session.commit()
|
||||
await session.refresh(failed)
|
||||
if not failed.report_group_id:
|
||||
failed.report_group_id = failed.id
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"[generate_resume] Saved failed report {failed.id}: {error_msg}"
|
||||
)
|
||||
return failed.id
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[generate_resume] Could not persist failed report row"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
try:
|
||||
validated_max_pages = _validate_max_pages(max_pages)
|
||||
except ValueError as e:
|
||||
error_msg = str(e)
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
# ── Phase 1: READ ─────────────────────────────────────────────
|
||||
async with shielded_async_session() as read_session:
|
||||
if parent_report_id:
|
||||
parent_report = await read_session.get(Report, parent_report_id)
|
||||
if parent_report:
|
||||
report_group_id = parent_report.report_group_id
|
||||
parent_content = parent_report.content
|
||||
logger.info(
|
||||
f"[generate_resume] Revising from parent {parent_report_id} "
|
||||
f"(group {report_group_id})"
|
||||
)
|
||||
|
||||
llm = await get_document_summary_llm(read_session, search_space_id)
|
||||
|
||||
if not llm:
|
||||
error_msg = (
|
||||
"No LLM configured. Please configure a language model in Settings."
|
||||
)
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
# ── Phase 2: LLM GENERATION ───────────────────────────────────
|
||||
|
||||
user_instructions_section = ""
|
||||
if user_instructions:
|
||||
user_instructions_section = (
|
||||
f"**Additional Instructions:** {user_instructions}"
|
||||
)
|
||||
|
||||
if parent_content:
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{"phase": "writing", "message": "Updating your resume"},
|
||||
)
|
||||
parent_body = _strip_header(parent_content)
|
||||
prompt = _REVISION_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
max_pages=validated_max_pages,
|
||||
user_instructions=user_instructions
|
||||
or "Improve and refine the resume.",
|
||||
previous_content=parent_body,
|
||||
)
|
||||
else:
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{"phase": "writing", "message": "Building your resume"},
|
||||
)
|
||||
prompt = _RESUME_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
user_info=user_info,
|
||||
max_pages=validated_max_pages,
|
||||
user_instructions_section=user_instructions_section,
|
||||
)
|
||||
|
||||
response = await llm.ainvoke([HumanMessage(content=prompt)])
|
||||
body = response.content
|
||||
|
||||
if not body or not isinstance(body, str):
|
||||
error_msg = "LLM returned empty or invalid content"
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
body = _strip_typst_fences(body)
|
||||
body = _strip_imports(body)
|
||||
|
||||
# ── Phase 3: ASSEMBLE + COMPILE ───────────────────────────────
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{"phase": "compiling", "message": "Compiling resume..."},
|
||||
)
|
||||
|
||||
name = _extract_name(body) or "Resume"
|
||||
typst_source = ""
|
||||
actual_pages = 0
|
||||
compression_attempts = 0
|
||||
target_page_met = False
|
||||
|
||||
for compression_round in range(MAX_COMPRESSION_ATTEMPTS + 1):
|
||||
header = _build_header(template, name)
|
||||
typst_source = header + body
|
||||
compile_error: str | None = None
|
||||
pdf_bytes: bytes | None = None
|
||||
|
||||
for compile_attempt in range(2):
|
||||
try:
|
||||
pdf_bytes = _compile_typst(typst_source)
|
||||
compile_error = None
|
||||
break
|
||||
except Exception as e:
|
||||
compile_error = str(e)
|
||||
logger.warning(
|
||||
"[generate_resume] Compile attempt %s failed: %s",
|
||||
compile_attempt + 1,
|
||||
compile_error,
|
||||
)
|
||||
|
||||
if compile_attempt == 0:
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{
|
||||
"phase": "fixing",
|
||||
"message": "Fixing compilation issue...",
|
||||
},
|
||||
)
|
||||
fix_prompt = _FIX_COMPILE_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
error=compile_error,
|
||||
full_source=typst_source,
|
||||
)
|
||||
fix_response = await llm.ainvoke(
|
||||
[HumanMessage(content=fix_prompt)]
|
||||
)
|
||||
if fix_response.content and isinstance(
|
||||
fix_response.content, str
|
||||
):
|
||||
body = _strip_typst_fences(fix_response.content)
|
||||
body = _strip_imports(body)
|
||||
name = _extract_name(body) or name
|
||||
header = _build_header(template, name)
|
||||
typst_source = header + body
|
||||
|
||||
if compile_error or not pdf_bytes:
|
||||
error_msg = (
|
||||
"Typst compilation failed after 2 attempts: "
|
||||
f"{compile_error or 'Unknown compile error'}"
|
||||
)
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
actual_pages = _count_pdf_pages(pdf_bytes)
|
||||
if actual_pages <= validated_max_pages:
|
||||
target_page_met = True
|
||||
break
|
||||
|
||||
if compression_round >= MAX_COMPRESSION_ATTEMPTS:
|
||||
break
|
||||
|
||||
compression_attempts += 1
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{
|
||||
"phase": "compressing",
|
||||
"message": f"Condensing resume to {validated_max_pages} page(s)...",
|
||||
},
|
||||
)
|
||||
compress_prompt = _COMPRESS_TO_PAGE_LIMIT_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
max_pages=validated_max_pages,
|
||||
actual_pages=actual_pages,
|
||||
attempt_number=compression_attempts,
|
||||
previous_content=body,
|
||||
)
|
||||
compress_response = await llm.ainvoke(
|
||||
[HumanMessage(content=compress_prompt)]
|
||||
)
|
||||
if not compress_response.content or not isinstance(
|
||||
compress_response.content, str
|
||||
):
|
||||
error_msg = "LLM returned empty content while compressing resume"
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
body = _strip_typst_fences(compress_response.content)
|
||||
body = _strip_imports(body)
|
||||
name = _extract_name(body) or name
|
||||
|
||||
if actual_pages > MAX_RESUME_PAGES:
|
||||
error_msg = (
|
||||
"Resume exceeds hard page limit after compression retries. "
|
||||
f"Hard limit: <= {MAX_RESUME_PAGES} page(s), actual: {actual_pages}."
|
||||
)
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
# ── Phase 4: SAVE ─────────────────────────────────────────────
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{"phase": "saving", "message": "Saving your resume"},
|
||||
)
|
||||
|
||||
resume_title = f"{name} - Resume" if name != "Resume" else "Resume"
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"status": "ready",
|
||||
"word_count": len(typst_source.split()),
|
||||
"char_count": len(typst_source),
|
||||
"target_max_pages": validated_max_pages,
|
||||
"actual_page_count": actual_pages,
|
||||
"page_limit_enforced": True,
|
||||
"compression_attempts": compression_attempts,
|
||||
"target_page_met": target_page_met,
|
||||
}
|
||||
|
||||
async with shielded_async_session() as write_session:
|
||||
report = Report(
|
||||
title=resume_title,
|
||||
content=typst_source,
|
||||
content_type="typst",
|
||||
report_metadata=metadata,
|
||||
report_style="resume",
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
report_group_id=report_group_id,
|
||||
)
|
||||
write_session.add(report)
|
||||
await write_session.commit()
|
||||
await write_session.refresh(report)
|
||||
|
||||
if not report.report_group_id:
|
||||
report.report_group_id = report.id
|
||||
await write_session.commit()
|
||||
|
||||
saved_id = report.id
|
||||
|
||||
logger.info(f"[generate_resume] Created resume {saved_id}: {resume_title}")
|
||||
|
||||
return {
|
||||
"status": "ready",
|
||||
"report_id": saved_id,
|
||||
"title": resume_title,
|
||||
"content_type": "typst",
|
||||
"is_revision": bool(parent_content),
|
||||
"message": (
|
||||
f"Resume generated successfully: {resume_title}"
|
||||
if target_page_met
|
||||
else (
|
||||
f"Resume generated, but could not fit the target of <= {validated_max_pages} "
|
||||
f"page(s). Final length: {actual_pages} page(s)."
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.exception(f"[generate_resume] Error: {error_message}")
|
||||
report_id = await _save_failed_report(error_message)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_message,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
return generate_resume
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
"""
|
||||
Video presentation generation tool for the SurfSense agent.
|
||||
|
||||
This module provides a factory function for creating the generate_video_presentation
|
||||
tool that submits a Celery task for background video presentation generation. The
|
||||
tool then polls the row until it reaches a terminal status (READY/FAILED) and
|
||||
returns that status. The wait is bounded by the chat's HTTP / process lifetime;
|
||||
see app.agents.shared.deliverable_wait for details.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.shared.deliverable_wait import wait_for_deliverable
|
||||
from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_generate_video_presentation_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
thread_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_video_presentation tool with injected dependencies.
|
||||
|
||||
Pre-creates video presentation record with pending status so the ID is available
|
||||
immediately for frontend polling. The row is written via a fresh, tool-local
|
||||
session so parallel tool calls (e.g. video + podcast in the same agent step)
|
||||
don't share an ``AsyncSession`` (which is not concurrency-safe).
|
||||
"""
|
||||
del db_session # writes use a fresh tool-local session, see below
|
||||
|
||||
@tool
|
||||
async def generate_video_presentation(
|
||||
source_content: str,
|
||||
video_title: str = "SurfSense Presentation",
|
||||
user_prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate a video presentation from the provided content.
|
||||
|
||||
Use this tool when the user asks to create a video, presentation, slides, or slide deck.
|
||||
|
||||
Args:
|
||||
source_content: The text content to turn into a presentation.
|
||||
video_title: Title for the presentation (default: "SurfSense Presentation")
|
||||
user_prompt: Optional style/tone instructions.
|
||||
"""
|
||||
try:
|
||||
# See podcast.py for the rationale: parallel tool calls share the
|
||||
# streaming session, and AsyncSession is not concurrency-safe —
|
||||
# interleaved flushes produce "Session.add() during flush" and
|
||||
# poison the transaction for every concurrent tool.
|
||||
async with shielded_async_session() as session:
|
||||
video_pres = VideoPresentation(
|
||||
title=video_title,
|
||||
status=VideoPresentationStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
session.add(video_pres)
|
||||
await session.commit()
|
||||
await session.refresh(video_pres)
|
||||
video_pres_id = video_pres.id
|
||||
|
||||
from app.tasks.celery_tasks.video_presentation_tasks import (
|
||||
generate_video_presentation_task,
|
||||
)
|
||||
|
||||
task = generate_video_presentation_task.delay(
|
||||
video_presentation_id=video_pres_id,
|
||||
source_content=source_content,
|
||||
search_space_id=search_space_id,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[generate_video_presentation] Created video presentation %s, task: %s",
|
||||
video_pres_id,
|
||||
task.id,
|
||||
)
|
||||
|
||||
# Wait until the Celery worker flips the row to a terminal
|
||||
# state. No internal budget — see deliverable_wait module.
|
||||
terminal_status, _columns, elapsed = await wait_for_deliverable(
|
||||
model=VideoPresentation,
|
||||
row_id=video_pres_id,
|
||||
columns=[VideoPresentation.status],
|
||||
terminal_statuses={
|
||||
VideoPresentationStatus.READY,
|
||||
VideoPresentationStatus.FAILED,
|
||||
},
|
||||
)
|
||||
|
||||
if terminal_status == VideoPresentationStatus.READY:
|
||||
logger.info(
|
||||
"[generate_video_presentation] %s READY in %.2fs",
|
||||
video_pres_id,
|
||||
elapsed,
|
||||
)
|
||||
return {
|
||||
"status": VideoPresentationStatus.READY.value,
|
||||
"video_presentation_id": video_pres_id,
|
||||
"title": video_title,
|
||||
"message": "Video presentation generated and saved.",
|
||||
}
|
||||
|
||||
# Only other terminal state is FAILED.
|
||||
logger.warning(
|
||||
"[generate_video_presentation] %s FAILED in %.2fs",
|
||||
video_pres_id,
|
||||
elapsed,
|
||||
)
|
||||
return {
|
||||
"status": VideoPresentationStatus.FAILED.value,
|
||||
"video_presentation_id": video_pres_id,
|
||||
"title": video_title,
|
||||
"error": (
|
||||
"Background worker reported FAILED status for this "
|
||||
"video presentation."
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.exception("[generate_video_presentation] Error: %s", error_message)
|
||||
return {
|
||||
"status": VideoPresentationStatus.FAILED.value,
|
||||
"error": error_message,
|
||||
"title": video_title,
|
||||
"video_presentation_id": None,
|
||||
}
|
||||
|
||||
return generate_video_presentation
|
||||
|
|
@ -56,7 +56,7 @@ logger = logging.getLogger(__name__)
|
|||
# class-body init time. ``app.agents.shared.llm_config`` re-exports
|
||||
# this constant under the historical ``PROVIDER_MAP`` name; placing the
|
||||
# map there directly would re-introduce the
|
||||
# ``app.config -> ... -> app.agents.shared.tools.generate_image ->
|
||||
# ``app.config -> ... -> deliverables/tools/generate_image ->
|
||||
# app.config`` cycle that prompted the move.
|
||||
_PROVIDER_PREFIX_MAP: dict[str, str] = {
|
||||
"OPENAI": "openai",
|
||||
|
|
|
|||
|
|
@ -1,17 +1,58 @@
|
|||
"""Unit tests for resume page-limit helpers and enforcement flow."""
|
||||
"""Unit tests for resume page-limit helpers and enforcement flow.
|
||||
|
||||
Targets the live deliverables resume tool. The tool returns a
|
||||
``Command`` (payload JSON-encoded in ``update["messages"][0].content``
|
||||
plus a receipt), so flow tests invoke it via a ToolCall dict and unwrap
|
||||
the payload.
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pypdf
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
from app.agents.shared.tools import resume as resume_tool
|
||||
from app.agents.multi_agent_chat.subagents.builtins.deliverables.tools import (
|
||||
resume as resume_tool,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _silence_progress_events(monkeypatch):
|
||||
"""The live tool emits ``dispatch_custom_event`` progress updates that
|
||||
require a langgraph run context; neutralize them for direct unit calls."""
|
||||
monkeypatch.setattr(resume_tool, "dispatch_custom_event", lambda *a, **k: None)
|
||||
|
||||
|
||||
def _runtime(tool_call_id: str = "call-1") -> ToolRuntime:
|
||||
"""Minimal ToolRuntime; the resume tool only reads ``tool_call_id``."""
|
||||
return ToolRuntime(
|
||||
state={},
|
||||
context=None,
|
||||
config={},
|
||||
stream_writer=None,
|
||||
tool_call_id=tool_call_id,
|
||||
store=None,
|
||||
)
|
||||
|
||||
|
||||
async def _invoke(tool, args: dict) -> dict:
|
||||
"""Drive a Command-returning tool and return its decoded payload.
|
||||
|
||||
These tools take an injected ``ToolRuntime`` and return a
|
||||
``Command``; invoke the raw coroutine with a hand-built runtime
|
||||
(the repo's pattern for unit-testing such tools) and decode the
|
||||
ToolMessage payload.
|
||||
"""
|
||||
command = await tool.coroutine(runtime=_runtime(), **args)
|
||||
return json.loads(command.update["messages"][0].content)
|
||||
|
||||
|
||||
class _FakeReport:
|
||||
_next_id = 1000
|
||||
|
||||
|
|
@ -108,7 +149,7 @@ async def test_generate_resume_defaults_to_one_page_target(monkeypatch) -> None:
|
|||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: 1)
|
||||
|
||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||
result = await tool.ainvoke({"user_info": "Jane Doe experience"})
|
||||
result = await _invoke(tool, {"user_info": "Jane Doe experience"})
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert prompts
|
||||
|
|
@ -138,7 +179,7 @@ async def test_generate_resume_compresses_when_over_limit(monkeypatch) -> None:
|
|||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||
|
||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
||||
result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1})
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert write_session.added, "Expected successful report write"
|
||||
|
|
@ -173,7 +214,7 @@ async def test_generate_resume_returns_ready_when_target_not_met(monkeypatch) ->
|
|||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||
|
||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
||||
result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1})
|
||||
|
||||
assert result["status"] == "ready"
|
||||
assert "could not fit the target" in (result["message"] or "").lower()
|
||||
|
|
@ -206,7 +247,7 @@ async def test_generate_resume_fails_when_hard_limit_exceeded(monkeypatch) -> No
|
|||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||
|
||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
||||
result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1})
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert "hard page limit" in (result["error"] or "").lower()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from __future__ import annotations
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain.tools import ToolRuntime
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
|
@ -90,7 +91,9 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
|||
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
||||
"""Same defense at the agent tool entry point — both surfaces share
|
||||
the same OpenRouter config payloads."""
|
||||
from app.agents.shared.tools import generate_image as gi_module
|
||||
from app.agents.multi_agent_chat.subagents.builtins.deliverables.tools import (
|
||||
generate_image as gi_module,
|
||||
)
|
||||
|
||||
cfg = {
|
||||
"id": -20_001,
|
||||
|
|
@ -150,7 +153,19 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
|||
tool = gi_module.create_generate_image_tool(
|
||||
search_space_id=1, db_session=MagicMock()
|
||||
)
|
||||
await tool.ainvoke({"prompt": "a cat", "n": 1})
|
||||
# The live tool takes an injected ToolRuntime and returns a Command;
|
||||
# drive the raw coroutine with a minimal runtime (the tool only reads
|
||||
# ``tool_call_id``). We assert on what was forwarded to litellm, not
|
||||
# on the return value.
|
||||
runtime = ToolRuntime(
|
||||
state={},
|
||||
context=None,
|
||||
config={},
|
||||
stream_writer=None,
|
||||
tool_call_id="call-1",
|
||||
store=None,
|
||||
)
|
||||
await tool.coroutine(prompt="a cat", n=1, runtime=runtime)
|
||||
|
||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue