mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
refactor(agents): delete deliverable dead twins in shared/tools; fix live image api_base bug
The deliverables subagent runs its own generate_image/podcast/report/resume/ video_presentation (via tools/index.py); the shared/tools copies had zero production importers — classic dead twins. Removed them so deliverable tools live only in their vertical slice. While repointing the 2 stranded unit tests at the LIVE deliverables modules, found the OpenRouter empty-api_base defense (resolve_api_base) existed ONLY in the dead shared generate_image, never propagated to the live multi-agent copy. Ported the fix into deliverables/tools/generate_image.py (both the global-config and user-DB-config branches) so an empty api_base no longer falls through to LiteLLM's global api_base (Azure) and 404s. Tests now exercise the live Command/receipt-returning tools (invoke the raw coroutine with a hand-built ToolRuntime; resume progress events neutralized).
This commit is contained in:
parent
64512c604d
commit
8d0090c6a1
10 changed files with 104 additions and 2519 deletions
|
|
@ -25,6 +25,7 @@ from app.services.image_gen_router_service import (
|
||||||
ImageGenRouterService,
|
ImageGenRouterService,
|
||||||
is_image_gen_auto_mode,
|
is_image_gen_auto_mode,
|
||||||
)
|
)
|
||||||
|
from app.services.provider_api_base import resolve_api_base
|
||||||
from app.utils.signed_image_urls import generate_image_token
|
from app.utils.signed_image_urls import generate_image_token
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -43,13 +44,16 @@ _PROVIDER_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
||||||
|
if custom_provider:
|
||||||
|
return custom_provider
|
||||||
|
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||||
|
|
||||||
|
|
||||||
def _build_model_string(
|
def _build_model_string(
|
||||||
provider: str, model_name: str, custom_provider: str | None
|
provider: str, model_name: str, custom_provider: str | None
|
||||||
) -> str:
|
) -> str:
|
||||||
if custom_provider:
|
return f"{_resolve_provider_prefix(provider, custom_provider)}/{model_name}"
|
||||||
return f"{custom_provider}/{model_name}"
|
|
||||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
|
||||||
return f"{prefix}/{model_name}"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||||
|
|
@ -163,14 +167,20 @@ def create_generate_image_tool(
|
||||||
err = f"Image generation config {config_id} not found"
|
err = f"Image generation config {config_id} not found"
|
||||||
return _failed({"error": err}, error=err)
|
return _failed({"error": err}, error=err)
|
||||||
|
|
||||||
model_string = _build_model_string(
|
provider_prefix = _resolve_provider_prefix(
|
||||||
cfg.get("provider", ""),
|
cfg.get("provider", ""), cfg.get("custom_provider")
|
||||||
cfg["model_name"],
|
|
||||||
cfg.get("custom_provider"),
|
|
||||||
)
|
)
|
||||||
|
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
||||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||||
if cfg.get("api_base"):
|
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||||
gen_kwargs["api_base"] = cfg["api_base"]
|
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||||
|
api_base = resolve_api_base(
|
||||||
|
provider=cfg.get("provider"),
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=cfg.get("api_base"),
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
gen_kwargs["api_base"] = api_base
|
||||||
if cfg.get("api_version"):
|
if cfg.get("api_version"):
|
||||||
gen_kwargs["api_version"] = cfg["api_version"]
|
gen_kwargs["api_version"] = cfg["api_version"]
|
||||||
if cfg.get("litellm_params"):
|
if cfg.get("litellm_params"):
|
||||||
|
|
@ -191,14 +201,20 @@ def create_generate_image_tool(
|
||||||
err = f"Image generation config {config_id} not found"
|
err = f"Image generation config {config_id} not found"
|
||||||
return _failed({"error": err}, error=err)
|
return _failed({"error": err}, error=err)
|
||||||
|
|
||||||
model_string = _build_model_string(
|
provider_prefix = _resolve_provider_prefix(
|
||||||
db_cfg.provider.value,
|
db_cfg.provider.value, db_cfg.custom_provider
|
||||||
db_cfg.model_name,
|
|
||||||
db_cfg.custom_provider,
|
|
||||||
)
|
)
|
||||||
|
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
||||||
gen_kwargs["api_key"] = db_cfg.api_key
|
gen_kwargs["api_key"] = db_cfg.api_key
|
||||||
if db_cfg.api_base:
|
# Defense-in-depth: an empty ``api_base`` must not fall
|
||||||
gen_kwargs["api_base"] = db_cfg.api_base
|
# through to LiteLLM's global ``api_base`` (e.g. Azure).
|
||||||
|
api_base = resolve_api_base(
|
||||||
|
provider=db_cfg.provider.value,
|
||||||
|
provider_prefix=provider_prefix,
|
||||||
|
config_api_base=db_cfg.api_base,
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
gen_kwargs["api_base"] = api_base
|
||||||
if db_cfg.api_version:
|
if db_cfg.api_version:
|
||||||
gen_kwargs["api_version"] = db_cfg.api_version
|
gen_kwargs["api_version"] = db_cfg.api_version
|
||||||
if db_cfg.litellm_params:
|
if db_cfg.litellm_params:
|
||||||
|
|
|
||||||
|
|
@ -1,37 +1,24 @@
|
||||||
"""
|
"""Cross-agent shared tools and tool metadata.
|
||||||
Tools module for SurfSense deep agent.
|
|
||||||
|
|
||||||
This module contains all the tools available to the SurfSense agent.
|
Tool *implementations* live with the agents that own them (e.g. deliverable
|
||||||
To add a new tool, see the documentation in registry.py.
|
generators under ``subagents/builtins/deliverables/tools``). This package
|
||||||
|
holds only the genuinely shared pieces: the display-metadata catalog and the
|
||||||
Available tools:
|
knowledge-base helpers used across agents.
|
||||||
- generate_podcast: Generate audio podcasts from content
|
|
||||||
- generate_video_presentation: Generate video presentations with slides and narration
|
|
||||||
- generate_image: Generate images from text descriptions using AI models
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Registry exports
|
from .catalog import TOOL_CATALOG, ToolMetadata
|
||||||
# Tool factory exports (for direct use)
|
|
||||||
from .generate_image import create_generate_image_tool
|
|
||||||
from .knowledge_base import (
|
from .knowledge_base import (
|
||||||
CONNECTOR_DESCRIPTIONS,
|
CONNECTOR_DESCRIPTIONS,
|
||||||
format_documents_for_context,
|
format_documents_for_context,
|
||||||
search_knowledge_base_async,
|
search_knowledge_base_async,
|
||||||
)
|
)
|
||||||
from .catalog import TOOL_CATALOG, ToolMetadata
|
|
||||||
from .podcast import create_generate_podcast_tool
|
|
||||||
from .video_presentation import create_generate_video_presentation_tool
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Tool catalog (display metadata)
|
# Tool catalog (display metadata)
|
||||||
"TOOL_CATALOG",
|
"TOOL_CATALOG",
|
||||||
|
"ToolMetadata",
|
||||||
# Knowledge base utilities
|
# Knowledge base utilities
|
||||||
"CONNECTOR_DESCRIPTIONS",
|
"CONNECTOR_DESCRIPTIONS",
|
||||||
"ToolMetadata",
|
|
||||||
# Tool factories
|
|
||||||
"create_generate_image_tool",
|
|
||||||
"create_generate_podcast_tool",
|
|
||||||
"create_generate_video_presentation_tool",
|
|
||||||
"format_documents_for_context",
|
"format_documents_for_context",
|
||||||
"search_knowledge_base_async",
|
"search_knowledge_base_async",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,280 +0,0 @@
|
||||||
"""
|
|
||||||
Image generation tool for the SurfSense agent.
|
|
||||||
|
|
||||||
This module provides a tool that generates images using litellm.aimage_generation()
|
|
||||||
and returns the result directly in a format the frontend Image component can render.
|
|
||||||
|
|
||||||
Config resolution:
|
|
||||||
1. Uses the search space's image_generation_config_id preference
|
|
||||||
2. Falls back to Auto mode (router load balancing) if available
|
|
||||||
3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from litellm import aimage_generation
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.config import config
|
|
||||||
from app.db import (
|
|
||||||
ImageGeneration,
|
|
||||||
ImageGenerationConfig,
|
|
||||||
SearchSpace,
|
|
||||||
shielded_async_session,
|
|
||||||
)
|
|
||||||
from app.services.image_gen_router_service import (
|
|
||||||
IMAGE_GEN_AUTO_MODE_ID,
|
|
||||||
ImageGenRouterService,
|
|
||||||
is_image_gen_auto_mode,
|
|
||||||
)
|
|
||||||
from app.services.provider_api_base import resolve_api_base
|
|
||||||
from app.utils.signed_image_urls import generate_image_token
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Provider mapping (same as routes)
|
|
||||||
_PROVIDER_MAP = {
|
|
||||||
"OPENAI": "openai",
|
|
||||||
"AZURE_OPENAI": "azure",
|
|
||||||
"GOOGLE": "gemini",
|
|
||||||
"VERTEX_AI": "vertex_ai",
|
|
||||||
"BEDROCK": "bedrock",
|
|
||||||
"RECRAFT": "recraft",
|
|
||||||
"OPENROUTER": "openrouter",
|
|
||||||
"XINFERENCE": "xinference",
|
|
||||||
"NSCALE": "nscale",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
|
|
||||||
if custom_provider:
|
|
||||||
return custom_provider
|
|
||||||
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
|
||||||
|
|
||||||
|
|
||||||
def _build_model_string(
|
|
||||||
provider: str, model_name: str, custom_provider: str | None
|
|
||||||
) -> str:
|
|
||||||
prefix = _resolve_provider_prefix(provider, custom_provider)
|
|
||||||
return f"{prefix}/{model_name}"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
|
||||||
"""Get a global image gen config by negative ID."""
|
|
||||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
|
||||||
if cfg.get("id") == config_id:
|
|
||||||
return cfg
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_generate_image_tool(
|
|
||||||
search_space_id: int,
|
|
||||||
db_session: AsyncSession,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Factory function to create the generate_image tool.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
search_space_id: The search space ID (for config resolution)
|
|
||||||
db_session: Reserved for compatibility with the tool registry.
|
|
||||||
The streaming task's ``AsyncSession`` is shared by every tool;
|
|
||||||
because AsyncSession is not concurrency-safe, parallel tool calls
|
|
||||||
would interleave flushes (e.g. podcast + image in the same step)
|
|
||||||
and poison the transaction. This tool opens its own session.
|
|
||||||
"""
|
|
||||||
del db_session # use a fresh per-call session, see below
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def generate_image(
|
|
||||||
prompt: str,
|
|
||||||
n: int = 1,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Generate an image from a text description using AI image models.
|
|
||||||
|
|
||||||
Use this tool when the user asks you to create, generate, draw, or make an image.
|
|
||||||
The generated image will be displayed directly in the chat.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: A detailed text description of the image to generate.
|
|
||||||
Be specific about subject, style, colors, composition, and mood.
|
|
||||||
n: Number of images to generate (1-4). Default: 1
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary containing the generated image(s) for display in the chat.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Use a per-call session so concurrent tool calls don't share an
|
|
||||||
# AsyncSession (which is not concurrency-safe). The streaming
|
|
||||||
# task's session is shared across every tool; without isolation,
|
|
||||||
# autoflushes from a concurrent writer poison this tool too.
|
|
||||||
async with shielded_async_session() as session:
|
|
||||||
result = await session.execute(
|
|
||||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
|
||||||
)
|
|
||||||
search_space = result.scalars().first()
|
|
||||||
if not search_space:
|
|
||||||
return {"error": "Search space not found"}
|
|
||||||
|
|
||||||
config_id = (
|
|
||||||
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build generation kwargs
|
|
||||||
# NOTE: size, quality, and style are intentionally NOT passed.
|
|
||||||
# Different models support different values for these params
|
|
||||||
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
|
|
||||||
# gpt-image-1 wants "high"/"medium"/"low"; size options also
|
|
||||||
# differ). Letting the model use its own defaults avoids errors.
|
|
||||||
gen_kwargs: dict[str, Any] = {}
|
|
||||||
if n is not None and n > 1:
|
|
||||||
gen_kwargs["n"] = n
|
|
||||||
|
|
||||||
# Call litellm based on config type
|
|
||||||
if is_image_gen_auto_mode(config_id):
|
|
||||||
if not ImageGenRouterService.is_initialized():
|
|
||||||
return {
|
|
||||||
"error": "No image generation models configured. "
|
|
||||||
"Please add an image model in Settings > Image Models."
|
|
||||||
}
|
|
||||||
response = await ImageGenRouterService.aimage_generation(
|
|
||||||
prompt=prompt, model="auto", **gen_kwargs
|
|
||||||
)
|
|
||||||
elif config_id < 0:
|
|
||||||
cfg = _get_global_image_gen_config(config_id)
|
|
||||||
if not cfg:
|
|
||||||
return {
|
|
||||||
"error": f"Image generation config {config_id} not found"
|
|
||||||
}
|
|
||||||
|
|
||||||
provider_prefix = _resolve_provider_prefix(
|
|
||||||
cfg.get("provider", ""), cfg.get("custom_provider")
|
|
||||||
)
|
|
||||||
model_string = f"{provider_prefix}/{cfg['model_name']}"
|
|
||||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
|
||||||
api_base = resolve_api_base(
|
|
||||||
provider=cfg.get("provider"),
|
|
||||||
provider_prefix=provider_prefix,
|
|
||||||
config_api_base=cfg.get("api_base"),
|
|
||||||
)
|
|
||||||
if api_base:
|
|
||||||
gen_kwargs["api_base"] = api_base
|
|
||||||
if cfg.get("api_version"):
|
|
||||||
gen_kwargs["api_version"] = cfg["api_version"]
|
|
||||||
if cfg.get("litellm_params"):
|
|
||||||
gen_kwargs.update(cfg["litellm_params"])
|
|
||||||
|
|
||||||
response = await aimage_generation(
|
|
||||||
prompt=prompt, model=model_string, **gen_kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Positive ID = user-created ImageGenerationConfig
|
|
||||||
cfg_result = await session.execute(
|
|
||||||
select(ImageGenerationConfig).filter(
|
|
||||||
ImageGenerationConfig.id == config_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
db_cfg = cfg_result.scalars().first()
|
|
||||||
if not db_cfg:
|
|
||||||
return {
|
|
||||||
"error": f"Image generation config {config_id} not found"
|
|
||||||
}
|
|
||||||
|
|
||||||
provider_prefix = _resolve_provider_prefix(
|
|
||||||
db_cfg.provider.value, db_cfg.custom_provider
|
|
||||||
)
|
|
||||||
model_string = f"{provider_prefix}/{db_cfg.model_name}"
|
|
||||||
gen_kwargs["api_key"] = db_cfg.api_key
|
|
||||||
api_base = resolve_api_base(
|
|
||||||
provider=db_cfg.provider.value,
|
|
||||||
provider_prefix=provider_prefix,
|
|
||||||
config_api_base=db_cfg.api_base,
|
|
||||||
)
|
|
||||||
if api_base:
|
|
||||||
gen_kwargs["api_base"] = api_base
|
|
||||||
if db_cfg.api_version:
|
|
||||||
gen_kwargs["api_version"] = db_cfg.api_version
|
|
||||||
if db_cfg.litellm_params:
|
|
||||||
gen_kwargs.update(db_cfg.litellm_params)
|
|
||||||
|
|
||||||
response = await aimage_generation(
|
|
||||||
prompt=prompt, model=model_string, **gen_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse the response and store in DB
|
|
||||||
response_dict = (
|
|
||||||
response.model_dump()
|
|
||||||
if hasattr(response, "model_dump")
|
|
||||||
else dict(response)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate a random access token for this image
|
|
||||||
access_token = generate_image_token()
|
|
||||||
|
|
||||||
# Save to image_generations table for history
|
|
||||||
db_image_gen = ImageGeneration(
|
|
||||||
prompt=prompt,
|
|
||||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
|
||||||
n=n,
|
|
||||||
image_generation_config_id=config_id,
|
|
||||||
response_data=response_dict,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
session.add(db_image_gen)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(db_image_gen)
|
|
||||||
db_image_gen_id = db_image_gen.id
|
|
||||||
|
|
||||||
# Extract image URLs from response
|
|
||||||
images = response_dict.get("data", [])
|
|
||||||
if not images:
|
|
||||||
return {"error": "No images were generated"}
|
|
||||||
|
|
||||||
first_image = images[0]
|
|
||||||
revised_prompt = first_image.get("revised_prompt", prompt)
|
|
||||||
|
|
||||||
# Resolve image URL:
|
|
||||||
# - If the API returned a URL, use it directly.
|
|
||||||
# - If the API returned b64_json (e.g. gpt-image-1), serve the
|
|
||||||
# image through our backend endpoint to avoid bloating the
|
|
||||||
# LLM context with megabytes of base64 data.
|
|
||||||
if first_image.get("url"):
|
|
||||||
image_url = first_image["url"]
|
|
||||||
elif first_image.get("b64_json"):
|
|
||||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
|
||||||
image_url = (
|
|
||||||
f"{backend_url}/api/v1/image-generations/"
|
|
||||||
f"{db_image_gen_id}/image?token={access_token}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return {"error": "No displayable image data in the response"}
|
|
||||||
|
|
||||||
image_id = f"image-{hashlib.md5(image_url.encode()).hexdigest()[:12]}"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": image_id,
|
|
||||||
"assetId": image_url,
|
|
||||||
"src": image_url,
|
|
||||||
"alt": revised_prompt or prompt,
|
|
||||||
"title": "Generated Image",
|
|
||||||
"description": revised_prompt if revised_prompt != prompt else None,
|
|
||||||
"domain": "ai-generated",
|
|
||||||
"ratio": "auto",
|
|
||||||
"generated": True,
|
|
||||||
"prompt": prompt,
|
|
||||||
"image_count": len(images),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Image generation failed in tool")
|
|
||||||
return {
|
|
||||||
"error": f"Image generation failed: {e!s}",
|
|
||||||
"prompt": prompt,
|
|
||||||
}
|
|
||||||
|
|
||||||
return generate_image
|
|
||||||
|
|
@ -1,160 +0,0 @@
|
||||||
"""
|
|
||||||
Podcast generation tool for the SurfSense agent.
|
|
||||||
|
|
||||||
This module provides a factory function for creating the generate_podcast tool
|
|
||||||
that submits a Celery task for background podcast generation. The tool then
|
|
||||||
polls the podcast row until it reaches a terminal status (READY/FAILED) and
|
|
||||||
returns that status. The wait is bounded by the chat's HTTP / process
|
|
||||||
lifetime; see app.agents.shared.deliverable_wait for details.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.agents.shared.deliverable_wait import wait_for_deliverable
|
|
||||||
from app.db import Podcast, PodcastStatus, shielded_async_session
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def create_generate_podcast_tool(
|
|
||||||
search_space_id: int,
|
|
||||||
db_session: AsyncSession,
|
|
||||||
thread_id: int | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Factory function to create the generate_podcast tool with injected dependencies.
|
|
||||||
|
|
||||||
Pre-creates podcast record with pending status so podcast_id is available
|
|
||||||
immediately for frontend polling.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
search_space_id: The user's search space ID
|
|
||||||
db_session: Reserved for future read-side use; the row is written via a
|
|
||||||
fresh, tool-local session so parallel tool calls (e.g. podcast +
|
|
||||||
video presentation in the same agent step) don't share an
|
|
||||||
``AsyncSession`` (which is not concurrency-safe).
|
|
||||||
thread_id: The chat thread ID for associating the podcast
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A configured tool function for generating podcasts
|
|
||||||
"""
|
|
||||||
del db_session # writes use a fresh tool-local session, see below
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def generate_podcast(
|
|
||||||
source_content: str,
|
|
||||||
podcast_title: str = "SurfSense Podcast",
|
|
||||||
user_prompt: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Generate a podcast from the provided content.
|
|
||||||
|
|
||||||
Use this tool when the user asks to create, generate, or make a podcast.
|
|
||||||
Common triggers include phrases like:
|
|
||||||
- "Give me a podcast about this"
|
|
||||||
- "Create a podcast from this conversation"
|
|
||||||
- "Generate a podcast summary"
|
|
||||||
- "Make a podcast about..."
|
|
||||||
- "Turn this into a podcast"
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_content: The text content to convert into a podcast.
|
|
||||||
podcast_title: Title for the podcast (default: "SurfSense Podcast")
|
|
||||||
user_prompt: Optional instructions for podcast style, tone, or format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary containing:
|
|
||||||
- status: PodcastStatus value (pending, generating, or failed)
|
|
||||||
- podcast_id: The podcast ID for polling (when status is pending or generating)
|
|
||||||
- title: The podcast title
|
|
||||||
- message: Status message (or "error" field if status is failed)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Open a fresh session per call. The streaming task's session is
|
|
||||||
# shared between every tool, and ``AsyncSession`` is NOT safe for
|
|
||||||
# concurrent use: when the LLM emits parallel tool calls, two
|
|
||||||
# concurrent ``add()`` / ``commit()`` paths interleave and the
|
|
||||||
# second one hits "Session.add() during flush" → the transaction
|
|
||||||
# is poisoned for both tools.
|
|
||||||
async with shielded_async_session() as session:
|
|
||||||
podcast = Podcast(
|
|
||||||
title=podcast_title,
|
|
||||||
status=PodcastStatus.PENDING,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
)
|
|
||||||
session.add(podcast)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(podcast)
|
|
||||||
podcast_id = podcast.id
|
|
||||||
|
|
||||||
from app.tasks.celery_tasks.podcast_tasks import (
|
|
||||||
generate_content_podcast_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
task = generate_content_podcast_task.delay(
|
|
||||||
podcast_id=podcast_id,
|
|
||||||
source_content=source_content,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_prompt=user_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"[generate_podcast] Created podcast %s, task: %s",
|
|
||||||
podcast_id,
|
|
||||||
task.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait until the Celery worker flips the row to a terminal
|
|
||||||
# state. No internal budget — see deliverable_wait module.
|
|
||||||
terminal_status, columns, elapsed = await wait_for_deliverable(
|
|
||||||
model=Podcast,
|
|
||||||
row_id=podcast_id,
|
|
||||||
columns=[Podcast.status, Podcast.file_location],
|
|
||||||
terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED},
|
|
||||||
)
|
|
||||||
|
|
||||||
if terminal_status == PodcastStatus.READY:
|
|
||||||
file_location = columns[1] if columns else None
|
|
||||||
logger.info(
|
|
||||||
"[generate_podcast] Podcast %s READY in %.2fs (file=%s)",
|
|
||||||
podcast_id,
|
|
||||||
elapsed,
|
|
||||||
file_location,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": PodcastStatus.READY.value,
|
|
||||||
"podcast_id": podcast_id,
|
|
||||||
"title": podcast_title,
|
|
||||||
"file_location": file_location,
|
|
||||||
"message": ("Podcast generated and saved to your podcast panel."),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Only other terminal state is FAILED.
|
|
||||||
logger.warning(
|
|
||||||
"[generate_podcast] Podcast %s FAILED in %.2fs",
|
|
||||||
podcast_id,
|
|
||||||
elapsed,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": PodcastStatus.FAILED.value,
|
|
||||||
"podcast_id": podcast_id,
|
|
||||||
"title": podcast_title,
|
|
||||||
"error": ("Background worker reported FAILED status for this podcast."),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_message = str(e)
|
|
||||||
logger.exception("[generate_podcast] Error: %s", error_message)
|
|
||||||
return {
|
|
||||||
"status": PodcastStatus.FAILED.value,
|
|
||||||
"error": error_message,
|
|
||||||
"title": podcast_title,
|
|
||||||
"podcast_id": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
return generate_podcast
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,812 +0,0 @@
|
||||||
"""
|
|
||||||
Resume generation tool for the SurfSense agent.
|
|
||||||
|
|
||||||
Generates a structured resume as Typst source code using the rendercv package.
|
|
||||||
The LLM outputs only the content body (= heading, sections, entries) while
|
|
||||||
the template header (import + show rule) is hardcoded and prepended by the
|
|
||||||
backend. This eliminates LLM errors in the complex configuration block.
|
|
||||||
|
|
||||||
Templates are stored in a registry so new designs can be added by defining
|
|
||||||
a new entry in _TEMPLATES.
|
|
||||||
|
|
||||||
Uses the same short-lived session pattern as generate_report so no DB
|
|
||||||
connection is held during the long LLM call.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import pypdf
|
|
||||||
import typst
|
|
||||||
from langchain_core.callbacks import dispatch_custom_event
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.db import Report, shielded_async_session
|
|
||||||
from app.services.llm_service import get_document_summary_llm
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Template Registry ───────────────────────────────────────────────────────
|
|
||||||
# Each template defines:
|
|
||||||
# header - Typst import + show rule with {name}, {year}, {month}, {day} placeholders
|
|
||||||
# component_reference - component docs shown to the LLM
|
|
||||||
# rules - generation rules for the LLM
|
|
||||||
|
|
||||||
_TEMPLATES: dict[str, dict[str, str]] = {
|
|
||||||
"classic": {
|
|
||||||
"header": """\
|
|
||||||
#import "@preview/rendercv:0.3.0": *
|
|
||||||
|
|
||||||
#show: rendercv.with(
|
|
||||||
name: "{name}",
|
|
||||||
title: "{name} - Resume",
|
|
||||||
footer: context {{ [#emph[{name} -- #str(here().page())\\/#str(counter(page).final().first())]] }},
|
|
||||||
top-note: [ #emph[Last updated in {month_name} {year}] ],
|
|
||||||
locale-catalog-language: "en",
|
|
||||||
text-direction: ltr,
|
|
||||||
page-size: "us-letter",
|
|
||||||
page-top-margin: 0.7in,
|
|
||||||
page-bottom-margin: 0.7in,
|
|
||||||
page-left-margin: 0.7in,
|
|
||||||
page-right-margin: 0.7in,
|
|
||||||
page-show-footer: false,
|
|
||||||
page-show-top-note: true,
|
|
||||||
colors-body: rgb(0, 0, 0),
|
|
||||||
colors-name: rgb(0, 0, 0),
|
|
||||||
colors-headline: rgb(0, 0, 0),
|
|
||||||
colors-connections: rgb(0, 0, 0),
|
|
||||||
colors-section-titles: rgb(0, 0, 0),
|
|
||||||
colors-links: rgb(0, 0, 0),
|
|
||||||
colors-footer: rgb(128, 128, 128),
|
|
||||||
colors-top-note: rgb(128, 128, 128),
|
|
||||||
typography-line-spacing: 0.6em,
|
|
||||||
typography-alignment: "justified",
|
|
||||||
typography-date-and-location-column-alignment: right,
|
|
||||||
typography-font-family-body: "XCharter",
|
|
||||||
typography-font-family-name: "XCharter",
|
|
||||||
typography-font-family-headline: "XCharter",
|
|
||||||
typography-font-family-connections: "XCharter",
|
|
||||||
typography-font-family-section-titles: "XCharter",
|
|
||||||
typography-font-size-body: 10pt,
|
|
||||||
typography-font-size-name: 25pt,
|
|
||||||
typography-font-size-headline: 10pt,
|
|
||||||
typography-font-size-connections: 10pt,
|
|
||||||
typography-font-size-section-titles: 1.2em,
|
|
||||||
typography-small-caps-name: false,
|
|
||||||
typography-small-caps-headline: false,
|
|
||||||
typography-small-caps-connections: false,
|
|
||||||
typography-small-caps-section-titles: false,
|
|
||||||
typography-bold-name: false,
|
|
||||||
typography-bold-headline: false,
|
|
||||||
typography-bold-connections: false,
|
|
||||||
typography-bold-section-titles: true,
|
|
||||||
links-underline: true,
|
|
||||||
links-show-external-link-icon: false,
|
|
||||||
header-alignment: center,
|
|
||||||
header-photo-width: 3.5cm,
|
|
||||||
header-space-below-name: 0.7cm,
|
|
||||||
header-space-below-headline: 0.7cm,
|
|
||||||
header-space-below-connections: 0.7cm,
|
|
||||||
header-connections-hyperlink: true,
|
|
||||||
header-connections-show-icons: false,
|
|
||||||
header-connections-display-urls-instead-of-usernames: true,
|
|
||||||
header-connections-separator: "|",
|
|
||||||
header-connections-space-between-connections: 0.5cm,
|
|
||||||
section-titles-type: "with_full_line",
|
|
||||||
section-titles-line-thickness: 0.5pt,
|
|
||||||
section-titles-space-above: 0.5cm,
|
|
||||||
section-titles-space-below: 0.3cm,
|
|
||||||
sections-allow-page-break: true,
|
|
||||||
sections-space-between-text-based-entries: 0.15cm,
|
|
||||||
sections-space-between-regular-entries: 0.42cm,
|
|
||||||
entries-date-and-location-width: 4.15cm,
|
|
||||||
entries-side-space: 0cm,
|
|
||||||
entries-space-between-columns: 0.1cm,
|
|
||||||
entries-allow-page-break: false,
|
|
||||||
entries-short-second-row: false,
|
|
||||||
entries-degree-width: 1cm,
|
|
||||||
entries-summary-space-left: 0cm,
|
|
||||||
entries-summary-space-above: 0.08cm,
|
|
||||||
entries-highlights-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
|
|
||||||
entries-highlights-nested-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
|
|
||||||
entries-highlights-space-left: 0cm,
|
|
||||||
entries-highlights-space-above: 0.08cm,
|
|
||||||
entries-highlights-space-between-items: 0.02cm,
|
|
||||||
entries-highlights-space-between-bullet-and-text: 0.3em,
|
|
||||||
date: datetime(
|
|
||||||
year: {year},
|
|
||||||
month: {month},
|
|
||||||
day: {day},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
""",
|
|
||||||
"component_reference": """\
|
|
||||||
Available components (use ONLY these):
|
|
||||||
|
|
||||||
= Full Name // Top-level heading — person's full name
|
|
||||||
|
|
||||||
#connections( // Contact info row (pipe-separated)
|
|
||||||
[City, Country],
|
|
||||||
[#link("mailto:email@example.com", icon: false, if-underline: false, if-color: false)[email\\@example.com]],
|
|
||||||
[#link("https://linkedin.com/in/user", icon: false, if-underline: false, if-color: false)[linkedin.com\\/in\\/user]],
|
|
||||||
[#link("https://github.com/user", icon: false, if-underline: false, if-color: false)[github.com\\/user]],
|
|
||||||
)
|
|
||||||
|
|
||||||
== Section Title // Section heading (arbitrary name)
|
|
||||||
|
|
||||||
#regular-entry( // Work experience, projects, publications, etc.
|
|
||||||
[
|
|
||||||
#strong[Role/Title], Company Name -- Location
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Start -- End
|
|
||||||
],
|
|
||||||
main-column-second-row: [
|
|
||||||
- Achievement or responsibility
|
|
||||||
- Another bullet point
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
#education-entry( // Education entries
|
|
||||||
[
|
|
||||||
#strong[Institution], Degree in Field -- Location
|
|
||||||
],
|
|
||||||
[
|
|
||||||
Start -- End
|
|
||||||
],
|
|
||||||
main-column-second-row: [
|
|
||||||
- GPA, honours, relevant coursework
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
#summary([Short paragraph summary]) // Optional summary inside an entry
|
|
||||||
#content-area([Free-form content]) // Freeform text block
|
|
||||||
|
|
||||||
For skills sections, use one bullet per category label:
|
|
||||||
- #strong[Category:] item1, item2, item3
|
|
||||||
|
|
||||||
For simple list sections (e.g. Honors), use plain bullet points:
|
|
||||||
- Item one
|
|
||||||
- Item two
|
|
||||||
""",
|
|
||||||
"rules": """\
|
|
||||||
RULES:
|
|
||||||
- Do NOT include any #import or #show lines. Start directly with = Full Name.
|
|
||||||
- Output ONLY valid Typst content. No explanatory text before or after.
|
|
||||||
- Do NOT wrap output in ```typst code fences.
|
|
||||||
- The = heading MUST use the person's COMPLETE full name exactly as provided. NEVER shorten or abbreviate.
|
|
||||||
- Escape @ symbols inside link labels with a backslash: email\\@example.com
|
|
||||||
- Escape forward slashes in link display text: linkedin.com\\/in\\/user
|
|
||||||
- Every section MUST use == heading.
|
|
||||||
- Use #regular-entry() for experience, projects, publications, certifications, and similar entries.
|
|
||||||
- Use #education-entry() for education.
|
|
||||||
- For skills sections, use one bullet line per category with a bold label.
|
|
||||||
- Keep content professional, concise, and achievement-oriented.
|
|
||||||
- Use action verbs for bullet points (Led, Built, Designed, Reduced, etc.).
|
|
||||||
- This template works for ALL professions — adapt sections to the user's field.
|
|
||||||
- Default behavior should prioritize concise one-page content.
|
|
||||||
""",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
DEFAULT_TEMPLATE = "classic"
|
|
||||||
MIN_RESUME_PAGES = 1
|
|
||||||
MAX_RESUME_PAGES = 5
|
|
||||||
MAX_COMPRESSION_ATTEMPTS = 2
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Template Helpers ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _get_template(template_id: str | None = None) -> dict[str, str]:
|
|
||||||
"""Get a template by ID, falling back to default."""
|
|
||||||
return _TEMPLATES.get(template_id or DEFAULT_TEMPLATE, _TEMPLATES[DEFAULT_TEMPLATE])
|
|
||||||
|
|
||||||
|
|
||||||
_MONTH_NAMES = [
|
|
||||||
"",
|
|
||||||
"Jan",
|
|
||||||
"Feb",
|
|
||||||
"Mar",
|
|
||||||
"Apr",
|
|
||||||
"May",
|
|
||||||
"Jun",
|
|
||||||
"Jul",
|
|
||||||
"Aug",
|
|
||||||
"Sep",
|
|
||||||
"Oct",
|
|
||||||
"Nov",
|
|
||||||
"Dec",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _build_header(template: dict[str, str], name: str) -> str:
|
|
||||||
"""Build the template header with the person's name and current date."""
|
|
||||||
now = datetime.now(tz=UTC)
|
|
||||||
return (
|
|
||||||
template["header"]
|
|
||||||
.replace("{name}", name)
|
|
||||||
.replace("{year}", str(now.year))
|
|
||||||
.replace("{month}", str(now.month))
|
|
||||||
.replace("{day}", str(now.day))
|
|
||||||
.replace("{month_name}", _MONTH_NAMES[now.month])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_header(full_source: str) -> str:
|
|
||||||
"""Strip the import + show rule from stored source to get the body only.
|
|
||||||
|
|
||||||
Finds the closing parenthesis of the rendercv.with(...) block by tracking
|
|
||||||
nesting depth, then returns everything after it.
|
|
||||||
"""
|
|
||||||
show_match = re.search(r"#show:\s*rendercv\.with\(", full_source)
|
|
||||||
if not show_match:
|
|
||||||
return full_source
|
|
||||||
|
|
||||||
start = show_match.end()
|
|
||||||
depth = 1
|
|
||||||
i = start
|
|
||||||
while i < len(full_source) and depth > 0:
|
|
||||||
if full_source[i] == "(":
|
|
||||||
depth += 1
|
|
||||||
elif full_source[i] == ")":
|
|
||||||
depth -= 1
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return full_source[i:].lstrip("\n")
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_name(body: str) -> str | None:
|
|
||||||
"""Extract the person's full name from the = heading in the body."""
|
|
||||||
match = re.search(r"^=\s+(.+)$", body, re.MULTILINE)
|
|
||||||
return match.group(1).strip() if match else None
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_imports(body: str) -> str:
|
|
||||||
"""Remove any #import or #show lines the LLM might accidentally include."""
|
|
||||||
lines = body.split("\n")
|
|
||||||
cleaned: list[str] = []
|
|
||||||
skip_show = False
|
|
||||||
depth = 0
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
stripped = line.strip()
|
|
||||||
|
|
||||||
if stripped.startswith("#import"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if skip_show:
|
|
||||||
depth += stripped.count("(") - stripped.count(")")
|
|
||||||
if depth <= 0:
|
|
||||||
skip_show = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if stripped.startswith("#show:") and "rendercv" in stripped:
|
|
||||||
depth = stripped.count("(") - stripped.count(")")
|
|
||||||
if depth > 0:
|
|
||||||
skip_show = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
cleaned.append(line)
|
|
||||||
|
|
||||||
result = "\n".join(cleaned).strip()
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _build_llm_reference(template: dict[str, str]) -> str:
|
|
||||||
"""Build the LLM prompt reference from a template."""
|
|
||||||
return f"""\
|
|
||||||
You MUST output valid Typst content for a resume.
|
|
||||||
Do NOT include any #import or #show lines — those are handled automatically.
|
|
||||||
Start directly with the = Full Name heading.
|
|
||||||
|
|
||||||
{template["component_reference"]}
|
|
||||||
|
|
||||||
{template["rules"]}"""
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Prompts ─────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_RESUME_PROMPT = """\
|
|
||||||
You are an expert resume writer. Generate professional resume content as Typst markup.
|
|
||||||
|
|
||||||
{llm_reference}
|
|
||||||
|
|
||||||
**User Information:**
|
|
||||||
{user_info}
|
|
||||||
|
|
||||||
**Target Maximum Pages:** {max_pages}
|
|
||||||
|
|
||||||
{user_instructions_section}
|
|
||||||
|
|
||||||
Generate the resume content now (starting with = Full Name):
|
|
||||||
"""
|
|
||||||
|
|
||||||
_REVISION_PROMPT = """\
|
|
||||||
You are an expert resume editor. Modify the existing resume according to the instructions.
|
|
||||||
Apply ONLY the requested changes — do NOT rewrite sections that are not affected.
|
|
||||||
|
|
||||||
{llm_reference}
|
|
||||||
|
|
||||||
**Target Maximum Pages:** {max_pages}
|
|
||||||
|
|
||||||
**Modification Instructions:** {user_instructions}
|
|
||||||
|
|
||||||
**EXISTING RESUME CONTENT:**
|
|
||||||
|
|
||||||
{previous_content}
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Output the complete, updated resume content with the changes applied (starting with = Full Name):
|
|
||||||
"""
|
|
||||||
|
|
||||||
_FIX_COMPILE_PROMPT = """\
|
|
||||||
The resume content you generated failed to compile. Fix the error while preserving all content.
|
|
||||||
|
|
||||||
{llm_reference}
|
|
||||||
|
|
||||||
**Compilation Error:**
|
|
||||||
{error}
|
|
||||||
|
|
||||||
**Full Typst Source (for context — error line numbers refer to this):**
|
|
||||||
{full_source}
|
|
||||||
|
|
||||||
**Your content starts after the template header. Output ONLY the content portion \
|
|
||||||
(starting with = Full Name), NOT the #import or #show rule:**
|
|
||||||
"""
|
|
||||||
|
|
||||||
_COMPRESS_TO_PAGE_LIMIT_PROMPT = """\
|
|
||||||
The resume compiles, but it exceeds the maximum allowed page count.
|
|
||||||
Compress the resume while preserving high-impact accomplishments and role relevance.
|
|
||||||
|
|
||||||
{llm_reference}
|
|
||||||
|
|
||||||
**Target Maximum Pages:** {max_pages}
|
|
||||||
**Current Page Count:** {actual_pages}
|
|
||||||
**Compression Attempt:** {attempt_number}
|
|
||||||
|
|
||||||
Compression priorities (in this order):
|
|
||||||
1) Keep recent, high-impact, role-relevant bullets.
|
|
||||||
2) Remove low-impact or redundant bullets.
|
|
||||||
3) Shorten verbose wording while preserving meaning.
|
|
||||||
4) Trim older or less relevant details before recent ones.
|
|
||||||
|
|
||||||
Return the complete updated Typst content (starting with = Full Name), and keep it at or below the target pages.
|
|
||||||
|
|
||||||
**EXISTING RESUME CONTENT:**
|
|
||||||
{previous_content}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Helpers ─────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_typst_fences(text: str) -> str:
|
|
||||||
"""Remove wrapping ```typst ... ``` fences that LLMs sometimes add."""
|
|
||||||
stripped = text.strip()
|
|
||||||
m = re.match(r"^(`{3,})(?:typst|typ)?\s*\n", stripped)
|
|
||||||
if m:
|
|
||||||
fence = m.group(1)
|
|
||||||
if stripped.endswith(fence):
|
|
||||||
stripped = stripped[m.end() :]
|
|
||||||
stripped = stripped[: -len(fence)].rstrip()
|
|
||||||
return stripped
|
|
||||||
|
|
||||||
|
|
||||||
def _compile_typst(source: str) -> bytes:
|
|
||||||
"""Compile Typst source to PDF bytes. Raises on failure."""
|
|
||||||
return typst.compile(source.encode("utf-8"))
|
|
||||||
|
|
||||||
|
|
||||||
def _count_pdf_pages(pdf_bytes: bytes) -> int:
|
|
||||||
"""Count the number of pages in compiled PDF bytes."""
|
|
||||||
with io.BytesIO(pdf_bytes) as pdf_stream:
|
|
||||||
reader = pypdf.PdfReader(pdf_stream)
|
|
||||||
return len(reader.pages)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_max_pages(max_pages: int) -> int:
|
|
||||||
"""Validate and normalize max_pages input."""
|
|
||||||
if MIN_RESUME_PAGES <= max_pages <= MAX_RESUME_PAGES:
|
|
||||||
return max_pages
|
|
||||||
msg = (
|
|
||||||
f"max_pages must be between {MIN_RESUME_PAGES} and "
|
|
||||||
f"{MAX_RESUME_PAGES}. Received: {max_pages}"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Tool Factory ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def create_generate_resume_tool(
|
|
||||||
search_space_id: int,
|
|
||||||
thread_id: int | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Factory function to create the generate_resume tool.
|
|
||||||
|
|
||||||
Generates a Typst-based resume, validates it via compilation,
|
|
||||||
and stores the source in the Report table with content_type='typst'.
|
|
||||||
The LLM generates only the content body; the template header is
|
|
||||||
prepended by the backend.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def generate_resume(
|
|
||||||
user_info: str,
|
|
||||||
user_instructions: str | None = None,
|
|
||||||
parent_report_id: int | None = None,
|
|
||||||
max_pages: int = 1,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Generate a professional resume as a Typst document.
|
|
||||||
|
|
||||||
Use this tool when the user asks to create, build, generate, write,
|
|
||||||
or draft a resume or CV. Also use it when the user wants to modify,
|
|
||||||
update, or revise an existing resume generated in this conversation.
|
|
||||||
|
|
||||||
Trigger phrases include:
|
|
||||||
- "build me a resume", "create my resume", "generate a CV"
|
|
||||||
- "update my resume", "change my title", "add my new job"
|
|
||||||
- "make my resume more concise", "reformat my resume"
|
|
||||||
|
|
||||||
Do NOT use this tool for:
|
|
||||||
- General questions about resumes or career advice
|
|
||||||
- Reviewing or critiquing a resume without changes
|
|
||||||
- Cover letters (use generate_report instead)
|
|
||||||
|
|
||||||
VERSIONING — parent_report_id:
|
|
||||||
- Set parent_report_id when the user wants to MODIFY an existing
|
|
||||||
resume that was already generated in this conversation.
|
|
||||||
- Leave as None for new resumes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_info: The user's resume content — work experience,
|
|
||||||
education, skills, contact info, etc. Can be structured
|
|
||||||
or unstructured text.
|
|
||||||
user_instructions: Optional style or content preferences
|
|
||||||
(e.g. "emphasize leadership", "keep it to one page",
|
|
||||||
"use a modern style"). For revisions, describe what to change.
|
|
||||||
parent_report_id: ID of a previous resume to revise (creates
|
|
||||||
new version in the same version group).
|
|
||||||
max_pages: Maximum number of pages for the generated resume.
|
|
||||||
Defaults to 1. Allowed range: 1-5.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with status, report_id, title, and content_type.
|
|
||||||
"""
|
|
||||||
report_group_id: int | None = None
|
|
||||||
parent_content: str | None = None
|
|
||||||
|
|
||||||
template = _get_template()
|
|
||||||
llm_reference = _build_llm_reference(template)
|
|
||||||
|
|
||||||
async def _save_failed_report(error_msg: str) -> int | None:
|
|
||||||
try:
|
|
||||||
async with shielded_async_session() as session:
|
|
||||||
failed = Report(
|
|
||||||
title="Resume",
|
|
||||||
content=None,
|
|
||||||
content_type="typst",
|
|
||||||
report_metadata={
|
|
||||||
"status": "failed",
|
|
||||||
"error_message": error_msg,
|
|
||||||
},
|
|
||||||
report_style="resume",
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
report_group_id=report_group_id,
|
|
||||||
)
|
|
||||||
session.add(failed)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(failed)
|
|
||||||
if not failed.report_group_id:
|
|
||||||
failed.report_group_id = failed.id
|
|
||||||
await session.commit()
|
|
||||||
logger.info(
|
|
||||||
f"[generate_resume] Saved failed report {failed.id}: {error_msg}"
|
|
||||||
)
|
|
||||||
return failed.id
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"[generate_resume] Could not persist failed report row"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
try:
|
|
||||||
validated_max_pages = _validate_max_pages(max_pages)
|
|
||||||
except ValueError as e:
|
|
||||||
error_msg = str(e)
|
|
||||||
report_id = await _save_failed_report(error_msg)
|
|
||||||
return {
|
|
||||||
"status": "failed",
|
|
||||||
"error": error_msg,
|
|
||||||
"report_id": report_id,
|
|
||||||
"title": "Resume",
|
|
||||||
"content_type": "typst",
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── Phase 1: READ ─────────────────────────────────────────────
|
|
||||||
async with shielded_async_session() as read_session:
|
|
||||||
if parent_report_id:
|
|
||||||
parent_report = await read_session.get(Report, parent_report_id)
|
|
||||||
if parent_report:
|
|
||||||
report_group_id = parent_report.report_group_id
|
|
||||||
parent_content = parent_report.content
|
|
||||||
logger.info(
|
|
||||||
f"[generate_resume] Revising from parent {parent_report_id} "
|
|
||||||
f"(group {report_group_id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = await get_document_summary_llm(read_session, search_space_id)
|
|
||||||
|
|
||||||
if not llm:
|
|
||||||
error_msg = (
|
|
||||||
"No LLM configured. Please configure a language model in Settings."
|
|
||||||
)
|
|
||||||
report_id = await _save_failed_report(error_msg)
|
|
||||||
return {
|
|
||||||
"status": "failed",
|
|
||||||
"error": error_msg,
|
|
||||||
"report_id": report_id,
|
|
||||||
"title": "Resume",
|
|
||||||
"content_type": "typst",
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── Phase 2: LLM GENERATION ───────────────────────────────────
|
|
||||||
|
|
||||||
user_instructions_section = ""
|
|
||||||
if user_instructions:
|
|
||||||
user_instructions_section = (
|
|
||||||
f"**Additional Instructions:** {user_instructions}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if parent_content:
|
|
||||||
dispatch_custom_event(
|
|
||||||
"report_progress",
|
|
||||||
{"phase": "writing", "message": "Updating your resume"},
|
|
||||||
)
|
|
||||||
parent_body = _strip_header(parent_content)
|
|
||||||
prompt = _REVISION_PROMPT.format(
|
|
||||||
llm_reference=llm_reference,
|
|
||||||
max_pages=validated_max_pages,
|
|
||||||
user_instructions=user_instructions
|
|
||||||
or "Improve and refine the resume.",
|
|
||||||
previous_content=parent_body,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
dispatch_custom_event(
|
|
||||||
"report_progress",
|
|
||||||
{"phase": "writing", "message": "Building your resume"},
|
|
||||||
)
|
|
||||||
prompt = _RESUME_PROMPT.format(
|
|
||||||
llm_reference=llm_reference,
|
|
||||||
user_info=user_info,
|
|
||||||
max_pages=validated_max_pages,
|
|
||||||
user_instructions_section=user_instructions_section,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await llm.ainvoke([HumanMessage(content=prompt)])
|
|
||||||
body = response.content
|
|
||||||
|
|
||||||
if not body or not isinstance(body, str):
|
|
||||||
error_msg = "LLM returned empty or invalid content"
|
|
||||||
report_id = await _save_failed_report(error_msg)
|
|
||||||
return {
|
|
||||||
"status": "failed",
|
|
||||||
"error": error_msg,
|
|
||||||
"report_id": report_id,
|
|
||||||
"title": "Resume",
|
|
||||||
"content_type": "typst",
|
|
||||||
}
|
|
||||||
|
|
||||||
body = _strip_typst_fences(body)
|
|
||||||
body = _strip_imports(body)
|
|
||||||
|
|
||||||
# ── Phase 3: ASSEMBLE + COMPILE ───────────────────────────────
|
|
||||||
dispatch_custom_event(
|
|
||||||
"report_progress",
|
|
||||||
{"phase": "compiling", "message": "Compiling resume..."},
|
|
||||||
)
|
|
||||||
|
|
||||||
name = _extract_name(body) or "Resume"
|
|
||||||
typst_source = ""
|
|
||||||
actual_pages = 0
|
|
||||||
compression_attempts = 0
|
|
||||||
target_page_met = False
|
|
||||||
|
|
||||||
for compression_round in range(MAX_COMPRESSION_ATTEMPTS + 1):
|
|
||||||
header = _build_header(template, name)
|
|
||||||
typst_source = header + body
|
|
||||||
compile_error: str | None = None
|
|
||||||
pdf_bytes: bytes | None = None
|
|
||||||
|
|
||||||
for compile_attempt in range(2):
|
|
||||||
try:
|
|
||||||
pdf_bytes = _compile_typst(typst_source)
|
|
||||||
compile_error = None
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
compile_error = str(e)
|
|
||||||
logger.warning(
|
|
||||||
"[generate_resume] Compile attempt %s failed: %s",
|
|
||||||
compile_attempt + 1,
|
|
||||||
compile_error,
|
|
||||||
)
|
|
||||||
|
|
||||||
if compile_attempt == 0:
|
|
||||||
dispatch_custom_event(
|
|
||||||
"report_progress",
|
|
||||||
{
|
|
||||||
"phase": "fixing",
|
|
||||||
"message": "Fixing compilation issue...",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
fix_prompt = _FIX_COMPILE_PROMPT.format(
|
|
||||||
llm_reference=llm_reference,
|
|
||||||
error=compile_error,
|
|
||||||
full_source=typst_source,
|
|
||||||
)
|
|
||||||
fix_response = await llm.ainvoke(
|
|
||||||
[HumanMessage(content=fix_prompt)]
|
|
||||||
)
|
|
||||||
if fix_response.content and isinstance(
|
|
||||||
fix_response.content, str
|
|
||||||
):
|
|
||||||
body = _strip_typst_fences(fix_response.content)
|
|
||||||
body = _strip_imports(body)
|
|
||||||
name = _extract_name(body) or name
|
|
||||||
header = _build_header(template, name)
|
|
||||||
typst_source = header + body
|
|
||||||
|
|
||||||
if compile_error or not pdf_bytes:
|
|
||||||
error_msg = (
|
|
||||||
"Typst compilation failed after 2 attempts: "
|
|
||||||
f"{compile_error or 'Unknown compile error'}"
|
|
||||||
)
|
|
||||||
report_id = await _save_failed_report(error_msg)
|
|
||||||
return {
|
|
||||||
"status": "failed",
|
|
||||||
"error": error_msg,
|
|
||||||
"report_id": report_id,
|
|
||||||
"title": "Resume",
|
|
||||||
"content_type": "typst",
|
|
||||||
}
|
|
||||||
|
|
||||||
actual_pages = _count_pdf_pages(pdf_bytes)
|
|
||||||
if actual_pages <= validated_max_pages:
|
|
||||||
target_page_met = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if compression_round >= MAX_COMPRESSION_ATTEMPTS:
|
|
||||||
break
|
|
||||||
|
|
||||||
compression_attempts += 1
|
|
||||||
dispatch_custom_event(
|
|
||||||
"report_progress",
|
|
||||||
{
|
|
||||||
"phase": "compressing",
|
|
||||||
"message": f"Condensing resume to {validated_max_pages} page(s)...",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
compress_prompt = _COMPRESS_TO_PAGE_LIMIT_PROMPT.format(
|
|
||||||
llm_reference=llm_reference,
|
|
||||||
max_pages=validated_max_pages,
|
|
||||||
actual_pages=actual_pages,
|
|
||||||
attempt_number=compression_attempts,
|
|
||||||
previous_content=body,
|
|
||||||
)
|
|
||||||
compress_response = await llm.ainvoke(
|
|
||||||
[HumanMessage(content=compress_prompt)]
|
|
||||||
)
|
|
||||||
if not compress_response.content or not isinstance(
|
|
||||||
compress_response.content, str
|
|
||||||
):
|
|
||||||
error_msg = "LLM returned empty content while compressing resume"
|
|
||||||
report_id = await _save_failed_report(error_msg)
|
|
||||||
return {
|
|
||||||
"status": "failed",
|
|
||||||
"error": error_msg,
|
|
||||||
"report_id": report_id,
|
|
||||||
"title": "Resume",
|
|
||||||
"content_type": "typst",
|
|
||||||
}
|
|
||||||
|
|
||||||
body = _strip_typst_fences(compress_response.content)
|
|
||||||
body = _strip_imports(body)
|
|
||||||
name = _extract_name(body) or name
|
|
||||||
|
|
||||||
if actual_pages > MAX_RESUME_PAGES:
|
|
||||||
error_msg = (
|
|
||||||
"Resume exceeds hard page limit after compression retries. "
|
|
||||||
f"Hard limit: <= {MAX_RESUME_PAGES} page(s), actual: {actual_pages}."
|
|
||||||
)
|
|
||||||
report_id = await _save_failed_report(error_msg)
|
|
||||||
return {
|
|
||||||
"status": "failed",
|
|
||||||
"error": error_msg,
|
|
||||||
"report_id": report_id,
|
|
||||||
"title": "Resume",
|
|
||||||
"content_type": "typst",
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── Phase 4: SAVE ─────────────────────────────────────────────
|
|
||||||
dispatch_custom_event(
|
|
||||||
"report_progress",
|
|
||||||
{"phase": "saving", "message": "Saving your resume"},
|
|
||||||
)
|
|
||||||
|
|
||||||
resume_title = f"{name} - Resume" if name != "Resume" else "Resume"
|
|
||||||
|
|
||||||
metadata: dict[str, Any] = {
|
|
||||||
"status": "ready",
|
|
||||||
"word_count": len(typst_source.split()),
|
|
||||||
"char_count": len(typst_source),
|
|
||||||
"target_max_pages": validated_max_pages,
|
|
||||||
"actual_page_count": actual_pages,
|
|
||||||
"page_limit_enforced": True,
|
|
||||||
"compression_attempts": compression_attempts,
|
|
||||||
"target_page_met": target_page_met,
|
|
||||||
}
|
|
||||||
|
|
||||||
async with shielded_async_session() as write_session:
|
|
||||||
report = Report(
|
|
||||||
title=resume_title,
|
|
||||||
content=typst_source,
|
|
||||||
content_type="typst",
|
|
||||||
report_metadata=metadata,
|
|
||||||
report_style="resume",
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
report_group_id=report_group_id,
|
|
||||||
)
|
|
||||||
write_session.add(report)
|
|
||||||
await write_session.commit()
|
|
||||||
await write_session.refresh(report)
|
|
||||||
|
|
||||||
if not report.report_group_id:
|
|
||||||
report.report_group_id = report.id
|
|
||||||
await write_session.commit()
|
|
||||||
|
|
||||||
saved_id = report.id
|
|
||||||
|
|
||||||
logger.info(f"[generate_resume] Created resume {saved_id}: {resume_title}")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "ready",
|
|
||||||
"report_id": saved_id,
|
|
||||||
"title": resume_title,
|
|
||||||
"content_type": "typst",
|
|
||||||
"is_revision": bool(parent_content),
|
|
||||||
"message": (
|
|
||||||
f"Resume generated successfully: {resume_title}"
|
|
||||||
if target_page_met
|
|
||||||
else (
|
|
||||||
f"Resume generated, but could not fit the target of <= {validated_max_pages} "
|
|
||||||
f"page(s). Final length: {actual_pages} page(s)."
|
|
||||||
)
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_message = str(e)
|
|
||||||
logger.exception(f"[generate_resume] Error: {error_message}")
|
|
||||||
report_id = await _save_failed_report(error_message)
|
|
||||||
return {
|
|
||||||
"status": "failed",
|
|
||||||
"error": error_message,
|
|
||||||
"report_id": report_id,
|
|
||||||
"title": "Resume",
|
|
||||||
"content_type": "typst",
|
|
||||||
}
|
|
||||||
|
|
||||||
return generate_resume
|
|
||||||
|
|
@ -1,138 +0,0 @@
|
||||||
"""
|
|
||||||
Video presentation generation tool for the SurfSense agent.
|
|
||||||
|
|
||||||
This module provides a factory function for creating the generate_video_presentation
|
|
||||||
tool that submits a Celery task for background video presentation generation. The
|
|
||||||
tool then polls the row until it reaches a terminal status (READY/FAILED) and
|
|
||||||
returns that status. The wait is bounded by the chat's HTTP / process lifetime;
|
|
||||||
see app.agents.shared.deliverable_wait for details.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.agents.shared.deliverable_wait import wait_for_deliverable
|
|
||||||
from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def create_generate_video_presentation_tool(
|
|
||||||
search_space_id: int,
|
|
||||||
db_session: AsyncSession,
|
|
||||||
thread_id: int | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Factory function to create the generate_video_presentation tool with injected dependencies.
|
|
||||||
|
|
||||||
Pre-creates video presentation record with pending status so the ID is available
|
|
||||||
immediately for frontend polling. The row is written via a fresh, tool-local
|
|
||||||
session so parallel tool calls (e.g. video + podcast in the same agent step)
|
|
||||||
don't share an ``AsyncSession`` (which is not concurrency-safe).
|
|
||||||
"""
|
|
||||||
del db_session # writes use a fresh tool-local session, see below
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def generate_video_presentation(
|
|
||||||
source_content: str,
|
|
||||||
video_title: str = "SurfSense Presentation",
|
|
||||||
user_prompt: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Generate a video presentation from the provided content.
|
|
||||||
|
|
||||||
Use this tool when the user asks to create a video, presentation, slides, or slide deck.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_content: The text content to turn into a presentation.
|
|
||||||
video_title: Title for the presentation (default: "SurfSense Presentation")
|
|
||||||
user_prompt: Optional style/tone instructions.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# See podcast.py for the rationale: parallel tool calls share the
|
|
||||||
# streaming session, and AsyncSession is not concurrency-safe —
|
|
||||||
# interleaved flushes produce "Session.add() during flush" and
|
|
||||||
# poison the transaction for every concurrent tool.
|
|
||||||
async with shielded_async_session() as session:
|
|
||||||
video_pres = VideoPresentation(
|
|
||||||
title=video_title,
|
|
||||||
status=VideoPresentationStatus.PENDING,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
)
|
|
||||||
session.add(video_pres)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(video_pres)
|
|
||||||
video_pres_id = video_pres.id
|
|
||||||
|
|
||||||
from app.tasks.celery_tasks.video_presentation_tasks import (
|
|
||||||
generate_video_presentation_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
task = generate_video_presentation_task.delay(
|
|
||||||
video_presentation_id=video_pres_id,
|
|
||||||
source_content=source_content,
|
|
||||||
search_space_id=search_space_id,
|
|
||||||
user_prompt=user_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"[generate_video_presentation] Created video presentation %s, task: %s",
|
|
||||||
video_pres_id,
|
|
||||||
task.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait until the Celery worker flips the row to a terminal
|
|
||||||
# state. No internal budget — see deliverable_wait module.
|
|
||||||
terminal_status, _columns, elapsed = await wait_for_deliverable(
|
|
||||||
model=VideoPresentation,
|
|
||||||
row_id=video_pres_id,
|
|
||||||
columns=[VideoPresentation.status],
|
|
||||||
terminal_statuses={
|
|
||||||
VideoPresentationStatus.READY,
|
|
||||||
VideoPresentationStatus.FAILED,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if terminal_status == VideoPresentationStatus.READY:
|
|
||||||
logger.info(
|
|
||||||
"[generate_video_presentation] %s READY in %.2fs",
|
|
||||||
video_pres_id,
|
|
||||||
elapsed,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": VideoPresentationStatus.READY.value,
|
|
||||||
"video_presentation_id": video_pres_id,
|
|
||||||
"title": video_title,
|
|
||||||
"message": "Video presentation generated and saved.",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Only other terminal state is FAILED.
|
|
||||||
logger.warning(
|
|
||||||
"[generate_video_presentation] %s FAILED in %.2fs",
|
|
||||||
video_pres_id,
|
|
||||||
elapsed,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"status": VideoPresentationStatus.FAILED.value,
|
|
||||||
"video_presentation_id": video_pres_id,
|
|
||||||
"title": video_title,
|
|
||||||
"error": (
|
|
||||||
"Background worker reported FAILED status for this "
|
|
||||||
"video presentation."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_message = str(e)
|
|
||||||
logger.exception("[generate_video_presentation] Error: %s", error_message)
|
|
||||||
return {
|
|
||||||
"status": VideoPresentationStatus.FAILED.value,
|
|
||||||
"error": error_message,
|
|
||||||
"title": video_title,
|
|
||||||
"video_presentation_id": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
return generate_video_presentation
|
|
||||||
|
|
@ -56,7 +56,7 @@ logger = logging.getLogger(__name__)
|
||||||
# class-body init time. ``app.agents.shared.llm_config`` re-exports
|
# class-body init time. ``app.agents.shared.llm_config`` re-exports
|
||||||
# this constant under the historical ``PROVIDER_MAP`` name; placing the
|
# this constant under the historical ``PROVIDER_MAP`` name; placing the
|
||||||
# map there directly would re-introduce the
|
# map there directly would re-introduce the
|
||||||
# ``app.config -> ... -> app.agents.shared.tools.generate_image ->
|
# ``app.config -> ... -> deliverables/tools/generate_image ->
|
||||||
# app.config`` cycle that prompted the move.
|
# app.config`` cycle that prompted the move.
|
||||||
_PROVIDER_PREFIX_MAP: dict[str, str] = {
|
_PROVIDER_PREFIX_MAP: dict[str, str] = {
|
||||||
"OPENAI": "openai",
|
"OPENAI": "openai",
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,58 @@
|
||||||
"""Unit tests for resume page-limit helpers and enforcement flow."""
|
"""Unit tests for resume page-limit helpers and enforcement flow.
|
||||||
|
|
||||||
|
Targets the live deliverables resume tool. The tool returns a
|
||||||
|
``Command`` (payload JSON-encoded in ``update["messages"][0].content``
|
||||||
|
plus a receipt), so flow tests invoke it via a ToolCall dict and unwrap
|
||||||
|
the payload.
|
||||||
|
"""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pypdf
|
import pypdf
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
|
||||||
from app.agents.shared.tools import resume as resume_tool
|
from app.agents.multi_agent_chat.subagents.builtins.deliverables.tools import (
|
||||||
|
resume as resume_tool,
|
||||||
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _silence_progress_events(monkeypatch):
|
||||||
|
"""The live tool emits ``dispatch_custom_event`` progress updates that
|
||||||
|
require a langgraph run context; neutralize them for direct unit calls."""
|
||||||
|
monkeypatch.setattr(resume_tool, "dispatch_custom_event", lambda *a, **k: None)
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime(tool_call_id: str = "call-1") -> ToolRuntime:
|
||||||
|
"""Minimal ToolRuntime; the resume tool only reads ``tool_call_id``."""
|
||||||
|
return ToolRuntime(
|
||||||
|
state={},
|
||||||
|
context=None,
|
||||||
|
config={},
|
||||||
|
stream_writer=None,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
store=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _invoke(tool, args: dict) -> dict:
|
||||||
|
"""Drive a Command-returning tool and return its decoded payload.
|
||||||
|
|
||||||
|
These tools take an injected ``ToolRuntime`` and return a
|
||||||
|
``Command``; invoke the raw coroutine with a hand-built runtime
|
||||||
|
(the repo's pattern for unit-testing such tools) and decode the
|
||||||
|
ToolMessage payload.
|
||||||
|
"""
|
||||||
|
command = await tool.coroutine(runtime=_runtime(), **args)
|
||||||
|
return json.loads(command.update["messages"][0].content)
|
||||||
|
|
||||||
|
|
||||||
class _FakeReport:
|
class _FakeReport:
|
||||||
_next_id = 1000
|
_next_id = 1000
|
||||||
|
|
||||||
|
|
@ -108,7 +149,7 @@ async def test_generate_resume_defaults_to_one_page_target(monkeypatch) -> None:
|
||||||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: 1)
|
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: 1)
|
||||||
|
|
||||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||||
result = await tool.ainvoke({"user_info": "Jane Doe experience"})
|
result = await _invoke(tool, {"user_info": "Jane Doe experience"})
|
||||||
|
|
||||||
assert result["status"] == "ready"
|
assert result["status"] == "ready"
|
||||||
assert prompts
|
assert prompts
|
||||||
|
|
@ -138,7 +179,7 @@ async def test_generate_resume_compresses_when_over_limit(monkeypatch) -> None:
|
||||||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||||
|
|
||||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||||
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1})
|
||||||
|
|
||||||
assert result["status"] == "ready"
|
assert result["status"] == "ready"
|
||||||
assert write_session.added, "Expected successful report write"
|
assert write_session.added, "Expected successful report write"
|
||||||
|
|
@ -173,7 +214,7 @@ async def test_generate_resume_returns_ready_when_target_not_met(monkeypatch) ->
|
||||||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||||
|
|
||||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||||
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1})
|
||||||
|
|
||||||
assert result["status"] == "ready"
|
assert result["status"] == "ready"
|
||||||
assert "could not fit the target" in (result["message"] or "").lower()
|
assert "could not fit the target" in (result["message"] or "").lower()
|
||||||
|
|
@ -206,7 +247,7 @@ async def test_generate_resume_fails_when_hard_limit_exceeded(monkeypatch) -> No
|
||||||
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
monkeypatch.setattr(resume_tool, "_count_pdf_pages", lambda _pdf: next(page_counts))
|
||||||
|
|
||||||
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
tool = resume_tool.create_generate_resume_tool(search_space_id=1, thread_id=1)
|
||||||
result = await tool.ainvoke({"user_info": "Jane Doe experience", "max_pages": 1})
|
result = await _invoke(tool, {"user_info": "Jane Doe experience", "max_pages": 1})
|
||||||
|
|
||||||
assert result["status"] == "failed"
|
assert result["status"] == "failed"
|
||||||
assert "hard page limit" in (result["error"] or "").lower()
|
assert "hard page limit" in (result["error"] or "").lower()
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from __future__ import annotations
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from langchain.tools import ToolRuntime
|
||||||
|
|
||||||
pytestmark = pytest.mark.unit
|
pytestmark = pytest.mark.unit
|
||||||
|
|
||||||
|
|
@ -90,7 +91,9 @@ async def test_global_openrouter_image_gen_sets_api_base_when_config_empty():
|
||||||
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
||||||
"""Same defense at the agent tool entry point — both surfaces share
|
"""Same defense at the agent tool entry point — both surfaces share
|
||||||
the same OpenRouter config payloads."""
|
the same OpenRouter config payloads."""
|
||||||
from app.agents.shared.tools import generate_image as gi_module
|
from app.agents.multi_agent_chat.subagents.builtins.deliverables.tools import (
|
||||||
|
generate_image as gi_module,
|
||||||
|
)
|
||||||
|
|
||||||
cfg = {
|
cfg = {
|
||||||
"id": -20_001,
|
"id": -20_001,
|
||||||
|
|
@ -150,7 +153,19 @@ async def test_generate_image_tool_global_sets_api_base_when_config_empty():
|
||||||
tool = gi_module.create_generate_image_tool(
|
tool = gi_module.create_generate_image_tool(
|
||||||
search_space_id=1, db_session=MagicMock()
|
search_space_id=1, db_session=MagicMock()
|
||||||
)
|
)
|
||||||
await tool.ainvoke({"prompt": "a cat", "n": 1})
|
# The live tool takes an injected ToolRuntime and returns a Command;
|
||||||
|
# drive the raw coroutine with a minimal runtime (the tool only reads
|
||||||
|
# ``tool_call_id``). We assert on what was forwarded to litellm, not
|
||||||
|
# on the return value.
|
||||||
|
runtime = ToolRuntime(
|
||||||
|
state={},
|
||||||
|
context=None,
|
||||||
|
config={},
|
||||||
|
stream_writer=None,
|
||||||
|
tool_call_id="call-1",
|
||||||
|
store=None,
|
||||||
|
)
|
||||||
|
await tool.coroutine(prompt="a cat", n=1, runtime=runtime)
|
||||||
|
|
||||||
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
assert captured.get("api_base") == "https://openrouter.ai/api/v1"
|
||||||
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
assert captured["model"] == "openrouter/openai/gpt-image-1"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue