mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-07 14:52:39 +02:00
Rename package: multi_agent_chat
This commit is contained in:
parent
d675d4df3f
commit
972650909c
241 changed files with 114 additions and 114 deletions
|
|
@ -0,0 +1,19 @@
|
|||
"""Registry-backed subagent builders and helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .registry import (
|
||||
SUBAGENT_BUILDERS_BY_NAME,
|
||||
SubagentBuilder,
|
||||
build_subagents,
|
||||
get_subagents_to_exclude,
|
||||
main_prompt_registry_subagent_lines,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SUBAGENT_BUILDERS_BY_NAME",
|
||||
"SubagentBuilder",
|
||||
"build_subagents",
|
||||
"get_subagents_to_exclude",
|
||||
"main_prompt_registry_subagent_lines",
|
||||
]
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""`deliverables` route: ``SubAgent`` spec for deepagents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
merge_tools_permissions,
|
||||
middleware_gated_interrupt_on,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
from .tools.index import load_tools
|
||||
|
||||
NAME = "deliverables"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent:
|
||||
buckets = load_tools(dependencies=dependencies)
|
||||
merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket)
|
||||
tools = [
|
||||
row["tool"]
|
||||
for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"])
|
||||
if row.get("tool") is not None
|
||||
]
|
||||
interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket)
|
||||
description = read_md_file(__package__, "description").strip()
|
||||
if not description:
|
||||
description = "Handles deliverables tasks for this workspace."
|
||||
system_prompt = read_md_file(__package__, "system_prompt").strip()
|
||||
return pack_subagent(
|
||||
name=NAME,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
interrupt_on=interrupt_on,
|
||||
model=model,
|
||||
extra_middleware=extra_middleware,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
Use for deliverables and shareable artifacts: generated reports, podcasts, video presentations, resumes, and images—not for routine lookups or single small edits elsewhere.
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
You are the SurfSense deliverables operations sub-agent.
|
||||
You receive delegated instructions from a supervisor agent and return structured results for supervisor synthesis.
|
||||
|
||||
<goal>
|
||||
Produce **deliverables**: shareable **artifacts** the user keeps (reports, slide-style video presentations, podcasts, resumes, images). Use explicit constraints and reliable proof of what was generated.
|
||||
</goal>
|
||||
|
||||
<available_tools>
|
||||
- `generate_report`
|
||||
- `generate_podcast`
|
||||
- `generate_video_presentation`
|
||||
- `generate_resume`
|
||||
- `generate_image`
|
||||
</available_tools>
|
||||
|
||||
<tool_policy>
|
||||
- Use only tools in `<available_tools>`.
|
||||
- Require essential generation constraints (audience, format, tone, core content).
|
||||
- If critical constraints are missing, return `status=blocked` with `missing_fields`.
|
||||
- Never claim artifact generation success without tool confirmation.
|
||||
</tool_policy>
|
||||
|
||||
<out_of_scope>
|
||||
- Do not perform connector data mutations unrelated to artifact generation.
|
||||
</out_of_scope>
|
||||
|
||||
<safety>
|
||||
- Avoid generating artifacts with missing critical constraints.
|
||||
- Prefer one complete artifact over partial multi-artifact output.
|
||||
</safety>
|
||||
|
||||
<failure_policy>
|
||||
- On generation failure, return `status=error` with best retry guidance.
|
||||
- On missing constraints, return `status=blocked` with required fields.
|
||||
</failure_policy>
|
||||
|
||||
<output_contract>
|
||||
Return **only** one JSON object (no markdown/prose):
|
||||
{
|
||||
"status": "success" | "partial" | "blocked" | "error",
|
||||
"action_summary": string,
|
||||
"evidence": {
|
||||
"artifact_type": "report" | "podcast" | "video_presentation" | "resume" | "image" | null,
|
||||
"artifact_id": string | null,
|
||||
"artifact_location": string | null
|
||||
},
|
||||
"next_step": string | null,
|
||||
"missing_fields": string[] | null,
|
||||
"assumptions": string[] | null
|
||||
}
|
||||
Rules:
|
||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||
</output_contract>
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
"""Deliverable generators: reports, podcasts, video decks, resumes, images."""
|
||||
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .podcast import create_generate_podcast_tool
|
||||
from .report import create_generate_report_tool
|
||||
from .resume import create_generate_resume_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
|
||||
__all__ = [
|
||||
"create_generate_image_tool",
|
||||
"create_generate_podcast_tool",
|
||||
"create_generate_report_tool",
|
||||
"create_generate_resume_tool",
|
||||
"create_generate_video_presentation_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
"""Image generation via litellm; resolves model config from the search space and returns UI-ready payloads."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from litellm import aimage_generation
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ImageGeneration,
|
||||
ImageGenerationConfig,
|
||||
SearchSpace,
|
||||
shielded_async_session,
|
||||
)
|
||||
from app.services.image_gen_router_service import (
|
||||
IMAGE_GEN_AUTO_MODE_ID,
|
||||
ImageGenRouterService,
|
||||
is_image_gen_auto_mode,
|
||||
)
|
||||
from app.utils.signed_image_urls import generate_image_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Provider mapping (same as routes)
|
||||
_PROVIDER_MAP = {
|
||||
"OPENAI": "openai",
|
||||
"AZURE_OPENAI": "azure",
|
||||
"GOOGLE": "gemini",
|
||||
"VERTEX_AI": "vertex_ai",
|
||||
"BEDROCK": "bedrock",
|
||||
"RECRAFT": "recraft",
|
||||
"OPENROUTER": "openrouter",
|
||||
"XINFERENCE": "xinference",
|
||||
"NSCALE": "nscale",
|
||||
}
|
||||
|
||||
|
||||
def _build_model_string(
|
||||
provider: str, model_name: str, custom_provider: str | None
|
||||
) -> str:
|
||||
if custom_provider:
|
||||
return f"{custom_provider}/{model_name}"
|
||||
prefix = _PROVIDER_MAP.get(provider.upper(), provider.lower())
|
||||
return f"{prefix}/{model_name}"
|
||||
|
||||
|
||||
def _get_global_image_gen_config(config_id: int) -> dict | None:
|
||||
"""Get a global image gen config by negative ID."""
|
||||
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
|
||||
if cfg.get("id") == config_id:
|
||||
return cfg
|
||||
return None
|
||||
|
||||
|
||||
def create_generate_image_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
):
|
||||
"""Create ``generate_image`` with bound search space; DB work uses a per-call session."""
|
||||
del db_session # use a fresh per-call session, see below
|
||||
|
||||
@tool
|
||||
async def generate_image(
|
||||
prompt: str,
|
||||
n: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate an image from a text description using AI image models.
|
||||
|
||||
Use this tool when the user asks you to create, generate, draw, or make an image.
|
||||
The generated image will be displayed directly in the chat.
|
||||
|
||||
Args:
|
||||
prompt: A detailed text description of the image to generate.
|
||||
Be specific about subject, style, colors, composition, and mood.
|
||||
n: Number of images to generate (1-4). Default: 1
|
||||
|
||||
Returns:
|
||||
A dictionary containing the generated image(s) for display in the chat.
|
||||
"""
|
||||
try:
|
||||
# Use a per-call session so concurrent tool calls don't share an
|
||||
# AsyncSession (which is not concurrency-safe). The streaming
|
||||
# task's session is shared across every tool; without isolation,
|
||||
# autoflushes from a concurrent writer poison this tool too.
|
||||
async with shielded_async_session() as session:
|
||||
result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == search_space_id)
|
||||
)
|
||||
search_space = result.scalars().first()
|
||||
if not search_space:
|
||||
return {"error": "Search space not found"}
|
||||
|
||||
config_id = (
|
||||
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
|
||||
)
|
||||
|
||||
# Build generation kwargs
|
||||
# NOTE: size, quality, and style are intentionally NOT passed.
|
||||
# Different models support different values for these params
|
||||
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
|
||||
# gpt-image-1 wants "high"/"medium"/"low"; size options also
|
||||
# differ). Letting the model use its own defaults avoids errors.
|
||||
gen_kwargs: dict[str, Any] = {}
|
||||
if n is not None and n > 1:
|
||||
gen_kwargs["n"] = n
|
||||
|
||||
# Call litellm based on config type
|
||||
if is_image_gen_auto_mode(config_id):
|
||||
if not ImageGenRouterService.is_initialized():
|
||||
return {
|
||||
"error": "No image generation models configured. "
|
||||
"Please add an image model in Settings > Image Models."
|
||||
}
|
||||
response = await ImageGenRouterService.aimage_generation(
|
||||
prompt=prompt, model="auto", **gen_kwargs
|
||||
)
|
||||
elif config_id < 0:
|
||||
cfg = _get_global_image_gen_config(config_id)
|
||||
if not cfg:
|
||||
return {
|
||||
"error": f"Image generation config {config_id} not found"
|
||||
}
|
||||
|
||||
model_string = _build_model_string(
|
||||
cfg.get("provider", ""),
|
||||
cfg["model_name"],
|
||||
cfg.get("custom_provider"),
|
||||
)
|
||||
gen_kwargs["api_key"] = cfg.get("api_key")
|
||||
if cfg.get("api_base"):
|
||||
gen_kwargs["api_base"] = cfg["api_base"]
|
||||
if cfg.get("api_version"):
|
||||
gen_kwargs["api_version"] = cfg["api_version"]
|
||||
if cfg.get("litellm_params"):
|
||||
gen_kwargs.update(cfg["litellm_params"])
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
else:
|
||||
# Positive ID = user-created ImageGenerationConfig
|
||||
cfg_result = await session.execute(
|
||||
select(ImageGenerationConfig).filter(
|
||||
ImageGenerationConfig.id == config_id
|
||||
)
|
||||
)
|
||||
db_cfg = cfg_result.scalars().first()
|
||||
if not db_cfg:
|
||||
return {
|
||||
"error": f"Image generation config {config_id} not found"
|
||||
}
|
||||
|
||||
model_string = _build_model_string(
|
||||
db_cfg.provider.value,
|
||||
db_cfg.model_name,
|
||||
db_cfg.custom_provider,
|
||||
)
|
||||
gen_kwargs["api_key"] = db_cfg.api_key
|
||||
if db_cfg.api_base:
|
||||
gen_kwargs["api_base"] = db_cfg.api_base
|
||||
if db_cfg.api_version:
|
||||
gen_kwargs["api_version"] = db_cfg.api_version
|
||||
if db_cfg.litellm_params:
|
||||
gen_kwargs.update(db_cfg.litellm_params)
|
||||
|
||||
response = await aimage_generation(
|
||||
prompt=prompt, model=model_string, **gen_kwargs
|
||||
)
|
||||
|
||||
# Parse the response and store in DB
|
||||
response_dict = (
|
||||
response.model_dump()
|
||||
if hasattr(response, "model_dump")
|
||||
else dict(response)
|
||||
)
|
||||
|
||||
# Generate a random access token for this image
|
||||
access_token = generate_image_token()
|
||||
|
||||
# Save to image_generations table for history
|
||||
db_image_gen = ImageGeneration(
|
||||
prompt=prompt,
|
||||
model=getattr(response, "_hidden_params", {}).get("model"),
|
||||
n=n,
|
||||
image_generation_config_id=config_id,
|
||||
response_data=response_dict,
|
||||
search_space_id=search_space_id,
|
||||
access_token=access_token,
|
||||
)
|
||||
session.add(db_image_gen)
|
||||
await session.commit()
|
||||
await session.refresh(db_image_gen)
|
||||
db_image_gen_id = db_image_gen.id
|
||||
|
||||
# Extract image URLs from response
|
||||
images = response_dict.get("data", [])
|
||||
if not images:
|
||||
return {"error": "No images were generated"}
|
||||
|
||||
first_image = images[0]
|
||||
revised_prompt = first_image.get("revised_prompt", prompt)
|
||||
|
||||
# Resolve image URL:
|
||||
# - If the API returned a URL, use it directly.
|
||||
# - If the API returned b64_json (e.g. gpt-image-1), serve the
|
||||
# image through our backend endpoint to avoid bloating the
|
||||
# LLM context with megabytes of base64 data.
|
||||
if first_image.get("url"):
|
||||
image_url = first_image["url"]
|
||||
elif first_image.get("b64_json"):
|
||||
backend_url = config.BACKEND_URL or "http://localhost:8000"
|
||||
image_url = (
|
||||
f"{backend_url}/api/v1/image-generations/"
|
||||
f"{db_image_gen_id}/image?token={access_token}"
|
||||
)
|
||||
else:
|
||||
return {"error": "No displayable image data in the response"}
|
||||
|
||||
image_id = f"image-{hashlib.md5(image_url.encode()).hexdigest()[:12]}"
|
||||
|
||||
return {
|
||||
"id": image_id,
|
||||
"assetId": image_url,
|
||||
"src": image_url,
|
||||
"alt": revised_prompt or prompt,
|
||||
"title": "Generated Image",
|
||||
"description": revised_prompt if revised_prompt != prompt else None,
|
||||
"domain": "ai-generated",
|
||||
"ratio": "auto",
|
||||
"generated": True,
|
||||
"prompt": prompt,
|
||||
"image_count": len(images),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Image generation failed in tool")
|
||||
return {
|
||||
"error": f"Image generation failed: {e!s}",
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
return generate_image
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .generate_image import create_generate_image_tool
|
||||
from .podcast import create_generate_podcast_tool
|
||||
from .report import create_generate_report_tool
|
||||
from .resume import create_generate_resume_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
|
||||
|
||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
||||
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
||||
podcast = create_generate_podcast_tool(
|
||||
search_space_id=resolved_dependencies["search_space_id"],
|
||||
db_session=resolved_dependencies["db_session"],
|
||||
thread_id=resolved_dependencies["thread_id"],
|
||||
)
|
||||
video = create_generate_video_presentation_tool(
|
||||
search_space_id=resolved_dependencies["search_space_id"],
|
||||
db_session=resolved_dependencies["db_session"],
|
||||
thread_id=resolved_dependencies["thread_id"],
|
||||
)
|
||||
report = create_generate_report_tool(
|
||||
search_space_id=resolved_dependencies["search_space_id"],
|
||||
thread_id=resolved_dependencies["thread_id"],
|
||||
connector_service=resolved_dependencies.get("connector_service"),
|
||||
available_connectors=resolved_dependencies.get("available_connectors"),
|
||||
available_document_types=resolved_dependencies.get("available_document_types"),
|
||||
)
|
||||
resume = create_generate_resume_tool(
|
||||
search_space_id=resolved_dependencies["search_space_id"],
|
||||
thread_id=resolved_dependencies["thread_id"],
|
||||
)
|
||||
image = create_generate_image_tool(
|
||||
search_space_id=resolved_dependencies["search_space_id"],
|
||||
db_session=resolved_dependencies["db_session"],
|
||||
)
|
||||
return {
|
||||
"allow": [
|
||||
{"name": getattr(podcast, "name", "") or "", "tool": podcast},
|
||||
{"name": getattr(video, "name", "") or "", "tool": video},
|
||||
{"name": getattr(report, "name", "") or "", "tool": report},
|
||||
{"name": getattr(resume, "name", "") or "", "tool": resume},
|
||||
{"name": getattr(image, "name", "") or "", "tool": image},
|
||||
],
|
||||
"ask": [],
|
||||
}
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
"""Factory for a podcast-generation tool that queues background work and returns an ID for polling."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import Podcast, PodcastStatus, shielded_async_session
|
||||
|
||||
|
||||
def create_generate_podcast_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
thread_id: int | None = None,
|
||||
):
|
||||
"""Create ``generate_podcast`` with bound search space and thread; DB writes use a tool-local session."""
|
||||
del db_session # writes use a fresh tool-local session, see below
|
||||
|
||||
@tool
|
||||
async def generate_podcast(
|
||||
source_content: str,
|
||||
podcast_title: str = "SurfSense Podcast",
|
||||
user_prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a podcast from the provided content.
|
||||
|
||||
Use this tool when the user asks to create, generate, or make a podcast.
|
||||
Common triggers include phrases like:
|
||||
- "Give me a podcast about this"
|
||||
- "Create a podcast from this conversation"
|
||||
- "Generate a podcast summary"
|
||||
- "Make a podcast about..."
|
||||
- "Turn this into a podcast"
|
||||
|
||||
Args:
|
||||
source_content: The text content to convert into a podcast.
|
||||
podcast_title: Title for the podcast (default: "SurfSense Podcast")
|
||||
user_prompt: Optional instructions for podcast style, tone, or format.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- status: PodcastStatus value (pending, generating, or failed)
|
||||
- podcast_id: The podcast ID for polling (when status is pending or generating)
|
||||
- title: The podcast title
|
||||
- message: Status message (or "error" field if status is failed)
|
||||
"""
|
||||
try:
|
||||
# One DB session per tool call so parallel invocations never share an AsyncSession.
|
||||
async with shielded_async_session() as session:
|
||||
podcast = Podcast(
|
||||
title=podcast_title,
|
||||
status=PodcastStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
session.add(podcast)
|
||||
await session.commit()
|
||||
await session.refresh(podcast)
|
||||
podcast_id = podcast.id
|
||||
|
||||
from app.tasks.celery_tasks.podcast_tasks import (
|
||||
generate_content_podcast_task,
|
||||
)
|
||||
|
||||
task = generate_content_podcast_task.delay(
|
||||
podcast_id=podcast_id,
|
||||
source_content=source_content,
|
||||
search_space_id=search_space_id,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
print(f"[generate_podcast] Created podcast {podcast_id}, task: {task.id}")
|
||||
|
||||
return {
|
||||
"status": PodcastStatus.PENDING.value,
|
||||
"podcast_id": podcast_id,
|
||||
"title": podcast_title,
|
||||
"message": "Podcast generation started. This may take a few minutes.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
print(f"[generate_podcast] Error: {error_message}")
|
||||
return {
|
||||
"status": PodcastStatus.FAILED.value,
|
||||
"error": error_message,
|
||||
"title": podcast_title,
|
||||
"podcast_id": None,
|
||||
}
|
||||
|
||||
return generate_podcast
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,799 @@
|
|||
"""Resume as Typst: LLM fills the body; backend prepends a template from ``_TEMPLATES`` and compiles."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import pypdf
|
||||
import typst
|
||||
from langchain_core.callbacks import dispatch_custom_event
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.db import Report, shielded_async_session
|
||||
from app.services.llm_service import get_document_summary_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── Template Registry ───────────────────────────────────────────────────────
|
||||
# Each template defines:
|
||||
# header - Typst import + show rule with {name}, {year}, {month}, {day} placeholders
|
||||
# component_reference - component docs shown to the LLM
|
||||
# rules - generation rules for the LLM
|
||||
|
||||
_TEMPLATES: dict[str, dict[str, str]] = {
|
||||
"classic": {
|
||||
"header": """\
|
||||
#import "@preview/rendercv:0.3.0": *
|
||||
|
||||
#show: rendercv.with(
|
||||
name: "{name}",
|
||||
title: "{name} - Resume",
|
||||
footer: context {{ [#emph[{name} -- #str(here().page())\\/#str(counter(page).final().first())]] }},
|
||||
top-note: [ #emph[Last updated in {month_name} {year}] ],
|
||||
locale-catalog-language: "en",
|
||||
text-direction: ltr,
|
||||
page-size: "us-letter",
|
||||
page-top-margin: 0.7in,
|
||||
page-bottom-margin: 0.7in,
|
||||
page-left-margin: 0.7in,
|
||||
page-right-margin: 0.7in,
|
||||
page-show-footer: false,
|
||||
page-show-top-note: true,
|
||||
colors-body: rgb(0, 0, 0),
|
||||
colors-name: rgb(0, 0, 0),
|
||||
colors-headline: rgb(0, 0, 0),
|
||||
colors-connections: rgb(0, 0, 0),
|
||||
colors-section-titles: rgb(0, 0, 0),
|
||||
colors-links: rgb(0, 0, 0),
|
||||
colors-footer: rgb(128, 128, 128),
|
||||
colors-top-note: rgb(128, 128, 128),
|
||||
typography-line-spacing: 0.6em,
|
||||
typography-alignment: "justified",
|
||||
typography-date-and-location-column-alignment: right,
|
||||
typography-font-family-body: "XCharter",
|
||||
typography-font-family-name: "XCharter",
|
||||
typography-font-family-headline: "XCharter",
|
||||
typography-font-family-connections: "XCharter",
|
||||
typography-font-family-section-titles: "XCharter",
|
||||
typography-font-size-body: 10pt,
|
||||
typography-font-size-name: 25pt,
|
||||
typography-font-size-headline: 10pt,
|
||||
typography-font-size-connections: 10pt,
|
||||
typography-font-size-section-titles: 1.2em,
|
||||
typography-small-caps-name: false,
|
||||
typography-small-caps-headline: false,
|
||||
typography-small-caps-connections: false,
|
||||
typography-small-caps-section-titles: false,
|
||||
typography-bold-name: false,
|
||||
typography-bold-headline: false,
|
||||
typography-bold-connections: false,
|
||||
typography-bold-section-titles: true,
|
||||
links-underline: true,
|
||||
links-show-external-link-icon: false,
|
||||
header-alignment: center,
|
||||
header-photo-width: 3.5cm,
|
||||
header-space-below-name: 0.7cm,
|
||||
header-space-below-headline: 0.7cm,
|
||||
header-space-below-connections: 0.7cm,
|
||||
header-connections-hyperlink: true,
|
||||
header-connections-show-icons: false,
|
||||
header-connections-display-urls-instead-of-usernames: true,
|
||||
header-connections-separator: "|",
|
||||
header-connections-space-between-connections: 0.5cm,
|
||||
section-titles-type: "with_full_line",
|
||||
section-titles-line-thickness: 0.5pt,
|
||||
section-titles-space-above: 0.5cm,
|
||||
section-titles-space-below: 0.3cm,
|
||||
sections-allow-page-break: true,
|
||||
sections-space-between-text-based-entries: 0.15cm,
|
||||
sections-space-between-regular-entries: 0.42cm,
|
||||
entries-date-and-location-width: 4.15cm,
|
||||
entries-side-space: 0cm,
|
||||
entries-space-between-columns: 0.1cm,
|
||||
entries-allow-page-break: false,
|
||||
entries-short-second-row: false,
|
||||
entries-degree-width: 1cm,
|
||||
entries-summary-space-left: 0cm,
|
||||
entries-summary-space-above: 0.08cm,
|
||||
entries-highlights-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
|
||||
entries-highlights-nested-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
|
||||
entries-highlights-space-left: 0cm,
|
||||
entries-highlights-space-above: 0.08cm,
|
||||
entries-highlights-space-between-items: 0.02cm,
|
||||
entries-highlights-space-between-bullet-and-text: 0.3em,
|
||||
date: datetime(
|
||||
year: {year},
|
||||
month: {month},
|
||||
day: {day},
|
||||
),
|
||||
)
|
||||
|
||||
""",
|
||||
"component_reference": """\
|
||||
Available components (use ONLY these):
|
||||
|
||||
= Full Name // Top-level heading — person's full name
|
||||
|
||||
#connections( // Contact info row (pipe-separated)
|
||||
[City, Country],
|
||||
[#link("mailto:email@example.com", icon: false, if-underline: false, if-color: false)[email\\@example.com]],
|
||||
[#link("https://linkedin.com/in/user", icon: false, if-underline: false, if-color: false)[linkedin.com\\/in\\/user]],
|
||||
[#link("https://github.com/user", icon: false, if-underline: false, if-color: false)[github.com\\/user]],
|
||||
)
|
||||
|
||||
== Section Title // Section heading (arbitrary name)
|
||||
|
||||
#regular-entry( // Work experience, projects, publications, etc.
|
||||
[
|
||||
#strong[Role/Title], Company Name -- Location
|
||||
],
|
||||
[
|
||||
Start -- End
|
||||
],
|
||||
main-column-second-row: [
|
||||
- Achievement or responsibility
|
||||
- Another bullet point
|
||||
],
|
||||
)
|
||||
|
||||
#education-entry( // Education entries
|
||||
[
|
||||
#strong[Institution], Degree in Field -- Location
|
||||
],
|
||||
[
|
||||
Start -- End
|
||||
],
|
||||
main-column-second-row: [
|
||||
- GPA, honours, relevant coursework
|
||||
],
|
||||
)
|
||||
|
||||
#summary([Short paragraph summary]) // Optional summary inside an entry
|
||||
#content-area([Free-form content]) // Freeform text block
|
||||
|
||||
For skills sections, use one bullet per category label:
|
||||
- #strong[Category:] item1, item2, item3
|
||||
|
||||
For simple list sections (e.g. Honors), use plain bullet points:
|
||||
- Item one
|
||||
- Item two
|
||||
""",
|
||||
"rules": """\
|
||||
RULES:
|
||||
- Do NOT include any #import or #show lines. Start directly with = Full Name.
|
||||
- Output ONLY valid Typst content. No explanatory text before or after.
|
||||
- Do NOT wrap output in ```typst code fences.
|
||||
- The = heading MUST use the person's COMPLETE full name exactly as provided. NEVER shorten or abbreviate.
|
||||
- Escape @ symbols inside link labels with a backslash: email\\@example.com
|
||||
- Escape forward slashes in link display text: linkedin.com\\/in\\/user
|
||||
- Every section MUST use == heading.
|
||||
- Use #regular-entry() for experience, projects, publications, certifications, and similar entries.
|
||||
- Use #education-entry() for education.
|
||||
- For skills sections, use one bullet line per category with a bold label.
|
||||
- Keep content professional, concise, and achievement-oriented.
|
||||
- Use action verbs for bullet points (Led, Built, Designed, Reduced, etc.).
|
||||
- This template works for ALL professions — adapt sections to the user's field.
|
||||
- Default behavior should prioritize concise one-page content.
|
||||
""",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_TEMPLATE = "classic"
|
||||
MIN_RESUME_PAGES = 1
|
||||
MAX_RESUME_PAGES = 5
|
||||
MAX_COMPRESSION_ATTEMPTS = 2
|
||||
|
||||
|
||||
# ─── Template Helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_template(template_id: str | None = None) -> dict[str, str]:
|
||||
"""Get a template by ID, falling back to default."""
|
||||
return _TEMPLATES.get(template_id or DEFAULT_TEMPLATE, _TEMPLATES[DEFAULT_TEMPLATE])
|
||||
|
||||
|
||||
_MONTH_NAMES = [
|
||||
"",
|
||||
"Jan",
|
||||
"Feb",
|
||||
"Mar",
|
||||
"Apr",
|
||||
"May",
|
||||
"Jun",
|
||||
"Jul",
|
||||
"Aug",
|
||||
"Sep",
|
||||
"Oct",
|
||||
"Nov",
|
||||
"Dec",
|
||||
]
|
||||
|
||||
|
||||
def _build_header(template: dict[str, str], name: str) -> str:
|
||||
"""Build the template header with the person's name and current date."""
|
||||
now = datetime.now(tz=UTC)
|
||||
return (
|
||||
template["header"]
|
||||
.replace("{name}", name)
|
||||
.replace("{year}", str(now.year))
|
||||
.replace("{month}", str(now.month))
|
||||
.replace("{day}", str(now.day))
|
||||
.replace("{month_name}", _MONTH_NAMES[now.month])
|
||||
)
|
||||
|
||||
|
||||
def _strip_header(full_source: str) -> str:
|
||||
"""Strip the import + show rule from stored source to get the body only.
|
||||
|
||||
Finds the closing parenthesis of the rendercv.with(...) block by tracking
|
||||
nesting depth, then returns everything after it.
|
||||
"""
|
||||
show_match = re.search(r"#show:\s*rendercv\.with\(", full_source)
|
||||
if not show_match:
|
||||
return full_source
|
||||
|
||||
start = show_match.end()
|
||||
depth = 1
|
||||
i = start
|
||||
while i < len(full_source) and depth > 0:
|
||||
if full_source[i] == "(":
|
||||
depth += 1
|
||||
elif full_source[i] == ")":
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return full_source[i:].lstrip("\n")
|
||||
|
||||
|
||||
def _extract_name(body: str) -> str | None:
|
||||
"""Extract the person's full name from the = heading in the body."""
|
||||
match = re.search(r"^=\s+(.+)$", body, re.MULTILINE)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
|
||||
def _strip_imports(body: str) -> str:
|
||||
"""Remove any #import or #show lines the LLM might accidentally include."""
|
||||
lines = body.split("\n")
|
||||
cleaned: list[str] = []
|
||||
skip_show = False
|
||||
depth = 0
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith("#import"):
|
||||
continue
|
||||
|
||||
if skip_show:
|
||||
depth += stripped.count("(") - stripped.count(")")
|
||||
if depth <= 0:
|
||||
skip_show = False
|
||||
continue
|
||||
|
||||
if stripped.startswith("#show:") and "rendercv" in stripped:
|
||||
depth = stripped.count("(") - stripped.count(")")
|
||||
if depth > 0:
|
||||
skip_show = True
|
||||
continue
|
||||
|
||||
cleaned.append(line)
|
||||
|
||||
result = "\n".join(cleaned).strip()
|
||||
return result
|
||||
|
||||
|
||||
def _build_llm_reference(template: dict[str, str]) -> str:
|
||||
"""Build the LLM prompt reference from a template."""
|
||||
return f"""\
|
||||
You MUST output valid Typst content for a resume.
|
||||
Do NOT include any #import or #show lines — those are handled automatically.
|
||||
Start directly with the = Full Name heading.
|
||||
|
||||
{template["component_reference"]}
|
||||
|
||||
{template["rules"]}"""
|
||||
|
||||
|
||||
# ─── Prompts ─────────────────────────────────────────────────────────────────
|
||||
|
||||
_RESUME_PROMPT = """\
|
||||
You are an expert resume writer. Generate professional resume content as Typst markup.
|
||||
|
||||
{llm_reference}
|
||||
|
||||
**User Information:**
|
||||
{user_info}
|
||||
|
||||
**Target Maximum Pages:** {max_pages}
|
||||
|
||||
{user_instructions_section}
|
||||
|
||||
Generate the resume content now (starting with = Full Name):
|
||||
"""
|
||||
|
||||
_REVISION_PROMPT = """\
|
||||
You are an expert resume editor. Modify the existing resume according to the instructions.
|
||||
Apply ONLY the requested changes — do NOT rewrite sections that are not affected.
|
||||
|
||||
{llm_reference}
|
||||
|
||||
**Target Maximum Pages:** {max_pages}
|
||||
|
||||
**Modification Instructions:** {user_instructions}
|
||||
|
||||
**EXISTING RESUME CONTENT:**
|
||||
|
||||
{previous_content}
|
||||
|
||||
---
|
||||
|
||||
Output the complete, updated resume content with the changes applied (starting with = Full Name):
|
||||
"""
|
||||
|
||||
_FIX_COMPILE_PROMPT = """\
|
||||
The resume content you generated failed to compile. Fix the error while preserving all content.
|
||||
|
||||
{llm_reference}
|
||||
|
||||
**Compilation Error:**
|
||||
{error}
|
||||
|
||||
**Full Typst Source (for context — error line numbers refer to this):**
|
||||
{full_source}
|
||||
|
||||
**Your content starts after the template header. Output ONLY the content portion \
|
||||
(starting with = Full Name), NOT the #import or #show rule:**
|
||||
"""
|
||||
|
||||
_COMPRESS_TO_PAGE_LIMIT_PROMPT = """\
|
||||
The resume compiles, but it exceeds the maximum allowed page count.
|
||||
Compress the resume while preserving high-impact accomplishments and role relevance.
|
||||
|
||||
{llm_reference}
|
||||
|
||||
**Target Maximum Pages:** {max_pages}
|
||||
**Current Page Count:** {actual_pages}
|
||||
**Compression Attempt:** {attempt_number}
|
||||
|
||||
Compression priorities (in this order):
|
||||
1) Keep recent, high-impact, role-relevant bullets.
|
||||
2) Remove low-impact or redundant bullets.
|
||||
3) Shorten verbose wording while preserving meaning.
|
||||
4) Trim older or less relevant details before recent ones.
|
||||
|
||||
Return the complete updated Typst content (starting with = Full Name), and keep it at or below the target pages.
|
||||
|
||||
**EXISTING RESUME CONTENT:**
|
||||
{previous_content}
|
||||
"""
|
||||
|
||||
|
||||
# ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _strip_typst_fences(text: str) -> str:
|
||||
"""Remove wrapping ```typst ... ``` fences that LLMs sometimes add."""
|
||||
stripped = text.strip()
|
||||
m = re.match(r"^(`{3,})(?:typst|typ)?\s*\n", stripped)
|
||||
if m:
|
||||
fence = m.group(1)
|
||||
if stripped.endswith(fence):
|
||||
stripped = stripped[m.end() :]
|
||||
stripped = stripped[: -len(fence)].rstrip()
|
||||
return stripped
|
||||
|
||||
|
||||
def _compile_typst(source: str) -> bytes:
|
||||
"""Compile Typst source to PDF bytes. Raises on failure."""
|
||||
return typst.compile(source.encode("utf-8"))
|
||||
|
||||
|
||||
def _count_pdf_pages(pdf_bytes: bytes) -> int:
|
||||
"""Count the number of pages in compiled PDF bytes."""
|
||||
with io.BytesIO(pdf_bytes) as pdf_stream:
|
||||
reader = pypdf.PdfReader(pdf_stream)
|
||||
return len(reader.pages)
|
||||
|
||||
|
||||
def _validate_max_pages(max_pages: int) -> int:
|
||||
"""Validate and normalize max_pages input."""
|
||||
if MIN_RESUME_PAGES <= max_pages <= MAX_RESUME_PAGES:
|
||||
return max_pages
|
||||
msg = (
|
||||
f"max_pages must be between {MIN_RESUME_PAGES} and "
|
||||
f"{MAX_RESUME_PAGES}. Received: {max_pages}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
# ─── Tool Factory ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_generate_resume_tool(
|
||||
search_space_id: int,
|
||||
thread_id: int | None = None,
|
||||
):
|
||||
"""
|
||||
Factory function to create the generate_resume tool.
|
||||
|
||||
Generates a Typst-based resume, validates it via compilation,
|
||||
and stores the source in the Report table with content_type='typst'.
|
||||
The LLM generates only the content body; the template header is
|
||||
prepended by the backend.
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def generate_resume(
|
||||
user_info: str,
|
||||
user_instructions: str | None = None,
|
||||
parent_report_id: int | None = None,
|
||||
max_pages: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a professional resume as a Typst document.
|
||||
|
||||
Use this tool when the user asks to create, build, generate, write,
|
||||
or draft a resume or CV. Also use it when the user wants to modify,
|
||||
update, or revise an existing resume generated in this conversation.
|
||||
|
||||
Trigger phrases include:
|
||||
- "build me a resume", "create my resume", "generate a CV"
|
||||
- "update my resume", "change my title", "add my new job"
|
||||
- "make my resume more concise", "reformat my resume"
|
||||
|
||||
Do NOT use this tool for:
|
||||
- General questions about resumes or career advice
|
||||
- Reviewing or critiquing a resume without changes
|
||||
- Cover letters (use generate_report instead)
|
||||
|
||||
VERSIONING — parent_report_id:
|
||||
- Set parent_report_id when the user wants to MODIFY an existing
|
||||
resume that was already generated in this conversation.
|
||||
- Leave as None for new resumes.
|
||||
|
||||
Args:
|
||||
user_info: The user's resume content — work experience,
|
||||
education, skills, contact info, etc. Can be structured
|
||||
or unstructured text.
|
||||
user_instructions: Optional style or content preferences
|
||||
(e.g. "emphasize leadership", "keep it to one page",
|
||||
"use a modern style"). For revisions, describe what to change.
|
||||
parent_report_id: ID of a previous resume to revise (creates
|
||||
new version in the same version group).
|
||||
max_pages: Maximum number of pages for the generated resume.
|
||||
Defaults to 1. Allowed range: 1-5.
|
||||
|
||||
Returns:
|
||||
Dict with status, report_id, title, and content_type.
|
||||
"""
|
||||
report_group_id: int | None = None
|
||||
parent_content: str | None = None
|
||||
|
||||
template = _get_template()
|
||||
llm_reference = _build_llm_reference(template)
|
||||
|
||||
async def _save_failed_report(error_msg: str) -> int | None:
|
||||
try:
|
||||
async with shielded_async_session() as session:
|
||||
failed = Report(
|
||||
title="Resume",
|
||||
content=None,
|
||||
content_type="typst",
|
||||
report_metadata={
|
||||
"status": "failed",
|
||||
"error_message": error_msg,
|
||||
},
|
||||
report_style="resume",
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
report_group_id=report_group_id,
|
||||
)
|
||||
session.add(failed)
|
||||
await session.commit()
|
||||
await session.refresh(failed)
|
||||
if not failed.report_group_id:
|
||||
failed.report_group_id = failed.id
|
||||
await session.commit()
|
||||
logger.info(
|
||||
f"[generate_resume] Saved failed report {failed.id}: {error_msg}"
|
||||
)
|
||||
return failed.id
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[generate_resume] Could not persist failed report row"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
try:
|
||||
validated_max_pages = _validate_max_pages(max_pages)
|
||||
except ValueError as e:
|
||||
error_msg = str(e)
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
# ── Phase 1: READ ─────────────────────────────────────────────
|
||||
async with shielded_async_session() as read_session:
|
||||
if parent_report_id:
|
||||
parent_report = await read_session.get(Report, parent_report_id)
|
||||
if parent_report:
|
||||
report_group_id = parent_report.report_group_id
|
||||
parent_content = parent_report.content
|
||||
logger.info(
|
||||
f"[generate_resume] Revising from parent {parent_report_id} "
|
||||
f"(group {report_group_id})"
|
||||
)
|
||||
|
||||
llm = await get_document_summary_llm(read_session, search_space_id)
|
||||
|
||||
if not llm:
|
||||
error_msg = (
|
||||
"No LLM configured. Please configure a language model in Settings."
|
||||
)
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
# ── Phase 2: LLM GENERATION ───────────────────────────────────
|
||||
|
||||
user_instructions_section = ""
|
||||
if user_instructions:
|
||||
user_instructions_section = (
|
||||
f"**Additional Instructions:** {user_instructions}"
|
||||
)
|
||||
|
||||
if parent_content:
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{"phase": "writing", "message": "Updating your resume"},
|
||||
)
|
||||
parent_body = _strip_header(parent_content)
|
||||
prompt = _REVISION_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
max_pages=validated_max_pages,
|
||||
user_instructions=user_instructions
|
||||
or "Improve and refine the resume.",
|
||||
previous_content=parent_body,
|
||||
)
|
||||
else:
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{"phase": "writing", "message": "Building your resume"},
|
||||
)
|
||||
prompt = _RESUME_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
user_info=user_info,
|
||||
max_pages=validated_max_pages,
|
||||
user_instructions_section=user_instructions_section,
|
||||
)
|
||||
|
||||
response = await llm.ainvoke([HumanMessage(content=prompt)])
|
||||
body = response.content
|
||||
|
||||
if not body or not isinstance(body, str):
|
||||
error_msg = "LLM returned empty or invalid content"
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
body = _strip_typst_fences(body)
|
||||
body = _strip_imports(body)
|
||||
|
||||
# ── Phase 3: ASSEMBLE + COMPILE ───────────────────────────────
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{"phase": "compiling", "message": "Compiling resume..."},
|
||||
)
|
||||
|
||||
name = _extract_name(body) or "Resume"
|
||||
typst_source = ""
|
||||
actual_pages = 0
|
||||
compression_attempts = 0
|
||||
target_page_met = False
|
||||
|
||||
for compression_round in range(MAX_COMPRESSION_ATTEMPTS + 1):
|
||||
header = _build_header(template, name)
|
||||
typst_source = header + body
|
||||
compile_error: str | None = None
|
||||
pdf_bytes: bytes | None = None
|
||||
|
||||
for compile_attempt in range(2):
|
||||
try:
|
||||
pdf_bytes = _compile_typst(typst_source)
|
||||
compile_error = None
|
||||
break
|
||||
except Exception as e:
|
||||
compile_error = str(e)
|
||||
logger.warning(
|
||||
"[generate_resume] Compile attempt %s failed: %s",
|
||||
compile_attempt + 1,
|
||||
compile_error,
|
||||
)
|
||||
|
||||
if compile_attempt == 0:
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{
|
||||
"phase": "fixing",
|
||||
"message": "Fixing compilation issue...",
|
||||
},
|
||||
)
|
||||
fix_prompt = _FIX_COMPILE_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
error=compile_error,
|
||||
full_source=typst_source,
|
||||
)
|
||||
fix_response = await llm.ainvoke(
|
||||
[HumanMessage(content=fix_prompt)]
|
||||
)
|
||||
if fix_response.content and isinstance(
|
||||
fix_response.content, str
|
||||
):
|
||||
body = _strip_typst_fences(fix_response.content)
|
||||
body = _strip_imports(body)
|
||||
name = _extract_name(body) or name
|
||||
header = _build_header(template, name)
|
||||
typst_source = header + body
|
||||
|
||||
if compile_error or not pdf_bytes:
|
||||
error_msg = (
|
||||
"Typst compilation failed after 2 attempts: "
|
||||
f"{compile_error or 'Unknown compile error'}"
|
||||
)
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
actual_pages = _count_pdf_pages(pdf_bytes)
|
||||
if actual_pages <= validated_max_pages:
|
||||
target_page_met = True
|
||||
break
|
||||
|
||||
if compression_round >= MAX_COMPRESSION_ATTEMPTS:
|
||||
break
|
||||
|
||||
compression_attempts += 1
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{
|
||||
"phase": "compressing",
|
||||
"message": f"Condensing resume to {validated_max_pages} page(s)...",
|
||||
},
|
||||
)
|
||||
compress_prompt = _COMPRESS_TO_PAGE_LIMIT_PROMPT.format(
|
||||
llm_reference=llm_reference,
|
||||
max_pages=validated_max_pages,
|
||||
actual_pages=actual_pages,
|
||||
attempt_number=compression_attempts,
|
||||
previous_content=body,
|
||||
)
|
||||
compress_response = await llm.ainvoke(
|
||||
[HumanMessage(content=compress_prompt)]
|
||||
)
|
||||
if not compress_response.content or not isinstance(
|
||||
compress_response.content, str
|
||||
):
|
||||
error_msg = "LLM returned empty content while compressing resume"
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
body = _strip_typst_fences(compress_response.content)
|
||||
body = _strip_imports(body)
|
||||
name = _extract_name(body) or name
|
||||
|
||||
if actual_pages > MAX_RESUME_PAGES:
|
||||
error_msg = (
|
||||
"Resume exceeds hard page limit after compression retries. "
|
||||
f"Hard limit: <= {MAX_RESUME_PAGES} page(s), actual: {actual_pages}."
|
||||
)
|
||||
report_id = await _save_failed_report(error_msg)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
# ── Phase 4: SAVE ─────────────────────────────────────────────
|
||||
dispatch_custom_event(
|
||||
"report_progress",
|
||||
{"phase": "saving", "message": "Saving your resume"},
|
||||
)
|
||||
|
||||
resume_title = f"{name} - Resume" if name != "Resume" else "Resume"
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"status": "ready",
|
||||
"word_count": len(typst_source.split()),
|
||||
"char_count": len(typst_source),
|
||||
"target_max_pages": validated_max_pages,
|
||||
"actual_page_count": actual_pages,
|
||||
"page_limit_enforced": True,
|
||||
"compression_attempts": compression_attempts,
|
||||
"target_page_met": target_page_met,
|
||||
}
|
||||
|
||||
async with shielded_async_session() as write_session:
|
||||
report = Report(
|
||||
title=resume_title,
|
||||
content=typst_source,
|
||||
content_type="typst",
|
||||
report_metadata=metadata,
|
||||
report_style="resume",
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
report_group_id=report_group_id,
|
||||
)
|
||||
write_session.add(report)
|
||||
await write_session.commit()
|
||||
await write_session.refresh(report)
|
||||
|
||||
if not report.report_group_id:
|
||||
report.report_group_id = report.id
|
||||
await write_session.commit()
|
||||
|
||||
saved_id = report.id
|
||||
|
||||
logger.info(f"[generate_resume] Created resume {saved_id}: {resume_title}")
|
||||
|
||||
return {
|
||||
"status": "ready",
|
||||
"report_id": saved_id,
|
||||
"title": resume_title,
|
||||
"content_type": "typst",
|
||||
"is_revision": bool(parent_content),
|
||||
"message": (
|
||||
f"Resume generated successfully: {resume_title}"
|
||||
if target_page_met
|
||||
else (
|
||||
f"Resume generated, but could not fit the target of <= {validated_max_pages} "
|
||||
f"page(s). Final length: {actual_pages} page(s)."
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.exception(f"[generate_resume] Error: {error_message}")
|
||||
report_id = await _save_failed_report(error_message)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_message,
|
||||
"report_id": report_id,
|
||||
"title": "Resume",
|
||||
"content_type": "typst",
|
||||
}
|
||||
|
||||
return generate_resume
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
"""Factory for a video-presentation tool that queues background work and returns an ID for polling."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session
|
||||
|
||||
|
||||
def create_generate_video_presentation_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
thread_id: int | None = None,
|
||||
):
|
||||
"""Create ``generate_video_presentation`` with bound search space and thread; writes use a tool-local session."""
|
||||
del db_session # writes use a fresh tool-local session, see below
|
||||
|
||||
@tool
|
||||
async def generate_video_presentation(
|
||||
source_content: str,
|
||||
video_title: str = "SurfSense Presentation",
|
||||
user_prompt: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate a video presentation from the provided content.
|
||||
|
||||
Use this tool when the user asks to create a video, presentation, slides, or slide deck.
|
||||
|
||||
Args:
|
||||
source_content: The text content to turn into a presentation.
|
||||
video_title: Title for the presentation (default: "SurfSense Presentation")
|
||||
user_prompt: Optional style/tone instructions.
|
||||
"""
|
||||
try:
|
||||
# One DB session per tool call so parallel invocations never share an AsyncSession.
|
||||
async with shielded_async_session() as session:
|
||||
video_pres = VideoPresentation(
|
||||
title=video_title,
|
||||
status=VideoPresentationStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
session.add(video_pres)
|
||||
await session.commit()
|
||||
await session.refresh(video_pres)
|
||||
video_pres_id = video_pres.id
|
||||
|
||||
from app.tasks.celery_tasks.video_presentation_tasks import (
|
||||
generate_video_presentation_task,
|
||||
)
|
||||
|
||||
task = generate_video_presentation_task.delay(
|
||||
video_presentation_id=video_pres_id,
|
||||
source_content=source_content,
|
||||
search_space_id=search_space_id,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
print(
|
||||
f"[generate_video_presentation] Created video presentation {video_pres_id}, task: {task.id}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": VideoPresentationStatus.PENDING.value,
|
||||
"video_presentation_id": video_pres_id,
|
||||
"title": video_title,
|
||||
"message": "Video presentation generation started. This may take a few minutes.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
print(f"[generate_video_presentation] Error: {error_message}")
|
||||
return {
|
||||
"status": VideoPresentationStatus.FAILED.value,
|
||||
"error": error_message,
|
||||
"title": video_title,
|
||||
"video_presentation_id": None,
|
||||
}
|
||||
|
||||
return generate_video_presentation
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""`memory` route: ``SubAgent`` spec for deepagents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
merge_tools_permissions,
|
||||
middleware_gated_interrupt_on,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
from .tools.index import load_tools
|
||||
|
||||
NAME = "memory"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent:
|
||||
buckets = load_tools(dependencies=dependencies)
|
||||
merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket)
|
||||
tools = [
|
||||
row["tool"]
|
||||
for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"])
|
||||
if row.get("tool") is not None
|
||||
]
|
||||
interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket)
|
||||
description = read_md_file(__package__, "description").strip()
|
||||
if not description:
|
||||
description = "Handles memory tasks for this workspace."
|
||||
system_prompt = read_md_file(__package__, "system_prompt").strip()
|
||||
return pack_subagent(
|
||||
name=NAME,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
interrupt_on=interrupt_on,
|
||||
model=model,
|
||||
extra_middleware=extra_middleware,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
Use for storing durable user memory (private team variant selected at runtime).
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
You are the SurfSense memory operations sub-agent.
|
||||
You receive delegated instructions from a supervisor agent and return structured results for supervisor synthesis.
|
||||
|
||||
<goal>
|
||||
Persist durable preferences/facts/instructions with `update_memory` while avoiding transient or unsafe storage.
|
||||
</goal>
|
||||
|
||||
<visibility_scope>
|
||||
{{MEMORY_VISIBILITY_POLICY}}
|
||||
</visibility_scope>
|
||||
|
||||
<available_tools>
|
||||
- `update_memory`
|
||||
</available_tools>
|
||||
|
||||
<tool_policy>
|
||||
- Save only durable information with future value.
|
||||
- Do not store transient chatter.
|
||||
- Do not store secrets unless explicitly instructed.
|
||||
- If memory intent is unclear, return `status=blocked` with the missing intent signal.
|
||||
</tool_policy>
|
||||
|
||||
<out_of_scope>
|
||||
- Do not execute non-memory tool actions.
|
||||
- Do not store irrelevant, transient, or speculative information.
|
||||
</out_of_scope>
|
||||
|
||||
<safety>
|
||||
- Prefer minimal-memory writes over over-collection.
|
||||
- Never claim memory was updated unless `update_memory` succeeded.
|
||||
</safety>
|
||||
|
||||
<failure_policy>
|
||||
- On tool failure, return `status=error` with concise recovery steps.
|
||||
- When intent is ambiguous, return `status=blocked` with required disambiguation fields.
|
||||
</failure_policy>
|
||||
|
||||
<output_contract>
|
||||
Return **only** one JSON object (no markdown/prose):
|
||||
{
|
||||
"status": "success" | "partial" | "blocked" | "error",
|
||||
"action_summary": string,
|
||||
"evidence": {
|
||||
"memory_updated": boolean,
|
||||
"memory_category": "preference" | "fact" | "instruction" | null,
|
||||
"stored_summary": string | null
|
||||
},
|
||||
"next_step": string | null,
|
||||
"missing_fields": string[] | null,
|
||||
"assumptions": string[] | null
|
||||
}
|
||||
Rules:
|
||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||
</output_contract>
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
"""Memory tools: persist user or team markdown memory for later turns."""
|
||||
|
||||
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||
|
||||
__all__ = [
|
||||
"create_update_memory_tool",
|
||||
"create_update_team_memory_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
from app.db import ChatVisibility
|
||||
|
||||
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||
|
||||
|
||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
||||
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
||||
if resolved_dependencies.get("thread_visibility") == ChatVisibility.SEARCH_SPACE:
|
||||
mem = create_update_team_memory_tool(
|
||||
search_space_id=resolved_dependencies["search_space_id"],
|
||||
db_session=resolved_dependencies["db_session"],
|
||||
llm=resolved_dependencies.get("llm"),
|
||||
)
|
||||
return {"allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}], "ask": []}
|
||||
mem = create_update_memory_tool(
|
||||
user_id=resolved_dependencies["user_id"],
|
||||
db_session=resolved_dependencies["db_session"],
|
||||
llm=resolved_dependencies.get("llm"),
|
||||
)
|
||||
return {"allow": [{"name": getattr(mem, "name", "") or "", "tool": mem}], "ask": []}
|
||||
|
|
@ -0,0 +1,375 @@
|
|||
"""Overwrite one markdown memory document per user or team, with size and shrink guards."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SearchSpace, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_SOFT_LIMIT = 18_000
|
||||
MEMORY_HARD_LIMIT = 25_000
|
||||
|
||||
_SECTION_HEADING_RE = re.compile(r"^##\s+(.+)$", re.MULTILINE)
|
||||
_HEADING_NORMALIZE_RE = re.compile(r"\s+")
|
||||
|
||||
_MARKER_RE = re.compile(r"\[(fact|pref|instr)\]")
|
||||
_BULLET_FORMAT_RE = re.compile(r"^- \(\d{4}-\d{2}-\d{2}\) \[(fact|pref|instr)\] .+$")
|
||||
_PERSONAL_ONLY_MARKERS = {"pref", "instr"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Diff validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_headings(memory: str) -> set[str]:
|
||||
"""Return all ``## …`` heading texts (without the ``## `` prefix)."""
|
||||
return set(_SECTION_HEADING_RE.findall(memory))
|
||||
|
||||
|
||||
def _normalize_heading(heading: str) -> str:
|
||||
"""Normalize heading text for robust scope checks."""
|
||||
return _HEADING_NORMALIZE_RE.sub(" ", heading.strip().lower())
|
||||
|
||||
|
||||
def _validate_memory_scope(
|
||||
content: str, scope: Literal["user", "team"]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Reject personal-only markers ([pref], [instr]) in team memory."""
|
||||
if scope != "team":
|
||||
return None
|
||||
|
||||
markers = set(_MARKER_RE.findall(content))
|
||||
leaked = sorted(markers & _PERSONAL_ONLY_MARKERS)
|
||||
if leaked:
|
||||
tags = ", ".join(f"[{m}]" for m in leaked)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Team memory cannot include personal markers: {tags}. "
|
||||
"Use [fact] only in team memory."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _validate_bullet_format(content: str) -> list[str]:
|
||||
"""Return warnings for bullet lines that don't match the required format.
|
||||
|
||||
Expected: ``- (YYYY-MM-DD) [fact|pref|instr] text``
|
||||
"""
|
||||
warnings: list[str] = []
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped.startswith("- "):
|
||||
continue
|
||||
if not _BULLET_FORMAT_RE.match(stripped):
|
||||
short = stripped[:80] + ("..." if len(stripped) > 80 else "")
|
||||
warnings.append(f"Malformed bullet: {short}")
|
||||
return warnings
|
||||
|
||||
|
||||
def _validate_diff(old_memory: str | None, new_memory: str) -> list[str]:
|
||||
"""Return a list of warning strings about suspicious changes."""
|
||||
if not old_memory:
|
||||
return []
|
||||
|
||||
warnings: list[str] = []
|
||||
old_headings = _extract_headings(old_memory)
|
||||
new_headings = _extract_headings(new_memory)
|
||||
dropped = old_headings - new_headings
|
||||
if dropped:
|
||||
names = ", ".join(sorted(dropped))
|
||||
warnings.append(
|
||||
f"Sections removed: {names}. "
|
||||
"If unintentional, the user can restore from the settings page."
|
||||
)
|
||||
|
||||
old_len = len(old_memory)
|
||||
new_len = len(new_memory)
|
||||
if old_len > 0 and new_len < old_len * 0.4:
|
||||
warnings.append(
|
||||
f"Memory shrank significantly ({old_len:,} -> {new_len:,} chars). "
|
||||
"Possible data loss."
|
||||
)
|
||||
return warnings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Size validation & soft warning
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _validate_memory_size(content: str) -> dict[str, Any] | None:
|
||||
"""Return an error/warning dict if *content* is too large, else None."""
|
||||
length = len(content)
|
||||
if length > MEMORY_HARD_LIMIT:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Memory exceeds {MEMORY_HARD_LIMIT:,} character limit "
|
||||
f"({length:,} chars). Consolidate by merging related items, "
|
||||
"removing outdated entries, and shortening descriptions. "
|
||||
"Then call update_memory again."
|
||||
),
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _soft_warning(content: str) -> str | None:
|
||||
"""Return a warning string if content exceeds the soft limit."""
|
||||
length = len(content)
|
||||
if length > MEMORY_SOFT_LIMIT:
|
||||
return (
|
||||
f"Memory is at {length:,}/{MEMORY_HARD_LIMIT:,} characters. "
|
||||
"Consolidate by merging related items and removing less important "
|
||||
"entries on your next update."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forced rewrite when memory exceeds the hard limit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FORCED_REWRITE_PROMPT = """\
|
||||
You are a memory curator. The following memory document exceeds the character \
|
||||
limit and must be shortened.
|
||||
|
||||
RULES:
|
||||
1. Rewrite the document to be under {target} characters.
|
||||
2. Preserve existing ## headings. Every entry must remain under a heading. You may merge
|
||||
or rename headings to consolidate, but keep names personal and descriptive.
|
||||
3. Priority for keeping content: [instr] > [pref] > [fact].
|
||||
4. Merge duplicate entries, remove outdated entries, shorten verbose descriptions.
|
||||
5. Every bullet MUST have format: - (YYYY-MM-DD) [fact|pref|instr] text
|
||||
6. Preserve the user's first name in entries — do not replace it with "the user".
|
||||
7. Output ONLY the consolidated markdown — no explanations, no wrapping.
|
||||
|
||||
<memory_document>
|
||||
{content}
|
||||
</memory_document>"""
|
||||
|
||||
|
||||
async def _forced_rewrite(content: str, llm: Any) -> str | None:
|
||||
"""Use a focused LLM call to compress *content* under the hard limit.
|
||||
|
||||
Returns the rewritten string, or ``None`` if the call fails.
|
||||
"""
|
||||
try:
|
||||
prompt = _FORCED_REWRITE_PROMPT.format(
|
||||
target=MEMORY_HARD_LIMIT, content=content
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[HumanMessage(content=prompt)],
|
||||
config={"tags": ["surfsense:internal"]},
|
||||
)
|
||||
text = (
|
||||
response.content
|
||||
if isinstance(response.content, str)
|
||||
else str(response.content)
|
||||
)
|
||||
return text.strip()
|
||||
except Exception:
|
||||
logger.exception("Forced rewrite LLM call failed")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared save-and-respond logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _save_memory(
|
||||
*,
|
||||
updated_memory: str,
|
||||
old_memory: str | None,
|
||||
llm: Any | None,
|
||||
apply_fn,
|
||||
commit_fn,
|
||||
rollback_fn,
|
||||
label: str,
|
||||
scope: Literal["user", "team"],
|
||||
) -> dict[str, Any]:
|
||||
"""Validate, optionally force-rewrite if over the hard limit, save, and
|
||||
return a response dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
updated_memory : str
|
||||
The new document the agent submitted.
|
||||
old_memory : str | None
|
||||
The previously persisted document (for diff checks).
|
||||
llm : Any | None
|
||||
LLM instance for forced rewrite (may be ``None``).
|
||||
apply_fn : callable(str) -> None
|
||||
Callback that sets the new memory on the ORM object.
|
||||
commit_fn : coroutine
|
||||
``session.commit``.
|
||||
rollback_fn : coroutine
|
||||
``session.rollback``.
|
||||
label : str
|
||||
Human label for log messages (e.g. "user memory", "team memory").
|
||||
"""
|
||||
content = updated_memory
|
||||
|
||||
# --- forced rewrite if over the hard limit ---
|
||||
if len(content) > MEMORY_HARD_LIMIT and llm is not None:
|
||||
rewritten = await _forced_rewrite(content, llm)
|
||||
if rewritten is not None and len(rewritten) < len(content):
|
||||
content = rewritten
|
||||
|
||||
# --- hard-limit gate (reject if still too large after rewrite) ---
|
||||
size_err = _validate_memory_size(content)
|
||||
if size_err:
|
||||
return size_err
|
||||
|
||||
scope_err = _validate_memory_scope(content, scope)
|
||||
if scope_err:
|
||||
return scope_err
|
||||
|
||||
# --- persist ---
|
||||
try:
|
||||
apply_fn(content)
|
||||
await commit_fn()
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update %s: %s", label, e)
|
||||
await rollback_fn()
|
||||
return {"status": "error", "message": f"Failed to update {label}: {e}"}
|
||||
|
||||
# --- build response ---
|
||||
resp: dict[str, Any] = {
|
||||
"status": "saved",
|
||||
"message": f"{label.capitalize()} updated.",
|
||||
}
|
||||
|
||||
if content is not updated_memory:
|
||||
resp["notice"] = "Memory was automatically rewritten to fit within limits."
|
||||
|
||||
diff_warnings = _validate_diff(old_memory, content)
|
||||
if diff_warnings:
|
||||
resp["diff_warnings"] = diff_warnings
|
||||
|
||||
format_warnings = _validate_bullet_format(content)
|
||||
if format_warnings:
|
||||
resp["format_warnings"] = format_warnings
|
||||
|
||||
warning = _soft_warning(content)
|
||||
if warning:
|
||||
resp["warning"] = warning
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_update_memory_tool(
|
||||
user_id: str | UUID,
|
||||
db_session: AsyncSession,
|
||||
llm: Any | None = None,
|
||||
):
|
||||
uid = UUID(user_id) if isinstance(user_id, str) else user_id
|
||||
|
||||
@tool
|
||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||
"""Update the user's personal memory document.
|
||||
|
||||
Your current memory is shown in <user_memory> in the system prompt.
|
||||
When the user shares important long-term information (preferences,
|
||||
facts, instructions, context), rewrite the memory document to include
|
||||
the new information. Merge new facts with existing ones, update
|
||||
contradictions, remove outdated entries, and keep it concise.
|
||||
|
||||
Args:
|
||||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
"""
|
||||
try:
|
||||
result = await db_session.execute(select(User).where(User.id == uid))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return {"status": "error", "message": "User not found."}
|
||||
|
||||
old_memory = user.memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(user, "memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="memory",
|
||||
scope="user",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update user memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to update memory: {e}",
|
||||
}
|
||||
|
||||
return update_memory
|
||||
|
||||
|
||||
def create_update_team_memory_tool(
|
||||
search_space_id: int,
|
||||
db_session: AsyncSession,
|
||||
llm: Any | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_memory(updated_memory: str) -> dict[str, Any]:
|
||||
"""Update the team's shared memory document for this search space.
|
||||
|
||||
Your current team memory is shown in <team_memory> in the system
|
||||
prompt. When the team shares important long-term information
|
||||
(decisions, conventions, key facts, priorities), rewrite the memory
|
||||
document to include the new information. Merge new facts with
|
||||
existing ones, update contradictions, remove outdated entries, and
|
||||
keep it concise.
|
||||
|
||||
Args:
|
||||
updated_memory: The FULL updated markdown document (not a diff).
|
||||
"""
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSpace).where(SearchSpace.id == search_space_id)
|
||||
)
|
||||
space = result.scalars().first()
|
||||
if not space:
|
||||
return {"status": "error", "message": "Search space not found."}
|
||||
|
||||
old_memory = space.shared_memory_md
|
||||
|
||||
return await _save_memory(
|
||||
updated_memory=updated_memory,
|
||||
old_memory=old_memory,
|
||||
llm=llm,
|
||||
apply_fn=lambda content: setattr(space, "shared_memory_md", content),
|
||||
commit_fn=db_session.commit,
|
||||
rollback_fn=db_session.rollback,
|
||||
label="team memory",
|
||||
scope="team",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to update team memory: %s", e)
|
||||
await db_session.rollback()
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Failed to update team memory: {e}",
|
||||
}
|
||||
|
||||
return update_memory
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""`research` route: ``SubAgent`` spec for deepagents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
merge_tools_permissions,
|
||||
middleware_gated_interrupt_on,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
from .tools.index import load_tools
|
||||
|
||||
NAME = "research"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent:
|
||||
buckets = load_tools(dependencies=dependencies)
|
||||
merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket)
|
||||
tools = [
|
||||
row["tool"]
|
||||
for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"])
|
||||
if row.get("tool") is not None
|
||||
]
|
||||
interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket)
|
||||
description = read_md_file(__package__, "description").strip()
|
||||
if not description:
|
||||
description = "Handles research tasks for this workspace."
|
||||
system_prompt = read_md_file(__package__, "system_prompt").strip()
|
||||
return pack_subagent(
|
||||
name=NAME,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
interrupt_on=interrupt_on,
|
||||
model=model,
|
||||
extra_middleware=extra_middleware,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
Use for external research: find sources on the web, extract evidence, and answer documentation questions.
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
You are the SurfSense research operations sub-agent.
|
||||
You receive delegated instructions from a supervisor agent and return structured results for supervisor synthesis.
|
||||
|
||||
<goal>
|
||||
Gather and synthesize evidence using SurfSense research tools with clear citations and uncertainty reporting.
|
||||
</goal>
|
||||
|
||||
<available_tools>
|
||||
- `web_search`
|
||||
- `scrape_webpage`
|
||||
- `search_surfsense_docs`
|
||||
</available_tools>
|
||||
|
||||
<tool_policy>
|
||||
- Use only tools in `<available_tools>`.
|
||||
- Prefer primary and recent sources when recency matters.
|
||||
- If the delegated request is underspecified, return `status=blocked` with the missing research constraints.
|
||||
- Never fabricate facts, citations, URLs, or quote text.
|
||||
</tool_policy>
|
||||
|
||||
<out_of_scope>
|
||||
- Do not execute connector mutations (email/calendar/docs/chat writes) or deliverable generation.
|
||||
</out_of_scope>
|
||||
|
||||
<safety>
|
||||
- Report uncertainty explicitly when evidence is incomplete or conflicting.
|
||||
- Never present unverified claims as facts.
|
||||
</safety>
|
||||
|
||||
<failure_policy>
|
||||
- On tool failure, return `status=error` with a concise recovery `next_step`.
|
||||
- On no useful evidence, return `status=blocked` with recommended narrower filters.
|
||||
</failure_policy>
|
||||
|
||||
<output_contract>
|
||||
Return **only** one JSON object (no markdown/prose):
|
||||
{
|
||||
"status": "success" | "partial" | "blocked" | "error",
|
||||
"action_summary": string,
|
||||
"evidence": {
|
||||
"findings": string[],
|
||||
"sources": string[],
|
||||
"confidence": "high" | "medium" | "low"
|
||||
},
|
||||
"next_step": string | null,
|
||||
"missing_fields": string[] | null,
|
||||
"assumptions": string[] | null
|
||||
}
|
||||
Rules:
|
||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||
</output_contract>
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
"""Research-stage tools: web search, scrape, and in-product doc search."""
|
||||
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .web_search import create_web_search_tool
|
||||
|
||||
__all__ = [
|
||||
"create_scrape_webpage_tool",
|
||||
"create_search_surfsense_docs_tool",
|
||||
"create_web_search_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .web_search import create_web_search_tool
|
||||
|
||||
|
||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
||||
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
||||
web = create_web_search_tool(
|
||||
search_space_id=resolved_dependencies.get("search_space_id"),
|
||||
available_connectors=resolved_dependencies.get("available_connectors"),
|
||||
)
|
||||
scrape = create_scrape_webpage_tool(firecrawl_api_key=resolved_dependencies.get("firecrawl_api_key"))
|
||||
docs = create_search_surfsense_docs_tool(db_session=resolved_dependencies["db_session"])
|
||||
return {
|
||||
"allow": [
|
||||
{"name": getattr(web, "name", "") or "", "tool": web},
|
||||
{"name": getattr(scrape, "name", "") or "", "tool": scrape},
|
||||
{"name": getattr(docs, "name", "") or "", "tool": docs},
|
||||
],
|
||||
"ask": [],
|
||||
}
|
||||
|
|
@ -0,0 +1,300 @@
|
|||
"""Scrape pages via WebCrawlerConnector; YouTube URLs use the transcript API instead of HTML crawl."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
from fake_useragent import UserAgent
|
||||
from langchain_core.tools import tool
|
||||
from requests import Session
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
from app.connectors.webcrawler_connector import WebCrawlerConnector
|
||||
from app.tasks.document_processors.youtube_processor import get_youtube_video_id
|
||||
from app.utils.proxy_config import get_requests_proxies
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_domain(url: str) -> str:
|
||||
"""Extract the domain from a URL."""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc
|
||||
# Remove 'www.' prefix if present
|
||||
if domain.startswith("www."):
|
||||
domain = domain[4:]
|
||||
return domain
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def generate_scrape_id(url: str) -> str:
|
||||
"""Generate a unique ID for a scraped webpage."""
|
||||
hash_val = hashlib.md5(url.encode()).hexdigest()[:12]
|
||||
return f"scrape-{hash_val}"
|
||||
|
||||
|
||||
def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
|
||||
"""
|
||||
Truncate content to a maximum length.
|
||||
|
||||
Returns:
|
||||
Tuple of (truncated_content, was_truncated)
|
||||
"""
|
||||
if len(content) <= max_length:
|
||||
return content, False
|
||||
|
||||
# Try to truncate at a sentence boundary
|
||||
truncated = content[:max_length]
|
||||
last_period = truncated.rfind(".")
|
||||
last_newline = truncated.rfind("\n\n")
|
||||
|
||||
# Use the later of the two boundaries, or just truncate
|
||||
boundary = max(last_period, last_newline)
|
||||
if boundary > max_length * 0.8: # Only use boundary if it's not too far back
|
||||
truncated = content[: boundary + 1]
|
||||
|
||||
return truncated + "\n\n[Content truncated...]", True
|
||||
|
||||
|
||||
async def _scrape_youtube_video(
|
||||
url: str, video_id: str, max_length: int
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Fetch YouTube video metadata and transcript via the YouTubeTranscriptApi.
|
||||
|
||||
Returns a result dict in the same shape as the regular scrape_webpage output.
|
||||
"""
|
||||
scrape_id = generate_scrape_id(url)
|
||||
domain = "youtube.com"
|
||||
|
||||
# --- Video metadata via oEmbed ---
|
||||
residential_proxies = get_requests_proxies()
|
||||
|
||||
params = {
|
||||
"format": "json",
|
||||
"url": f"https://www.youtube.com/watch?v={video_id}",
|
||||
}
|
||||
oembed_url = "https://www.youtube.com/oembed"
|
||||
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as http_session,
|
||||
http_session.get(
|
||||
oembed_url,
|
||||
params=params,
|
||||
proxy=residential_proxies["http"] if residential_proxies else None,
|
||||
) as response,
|
||||
):
|
||||
video_data = await response.json()
|
||||
except Exception:
|
||||
video_data = {}
|
||||
|
||||
title = video_data.get("title", "YouTube Video")
|
||||
author = video_data.get("author_name", "Unknown")
|
||||
|
||||
# --- Transcript via YouTubeTranscriptApi ---
|
||||
try:
|
||||
ua = UserAgent()
|
||||
http_client = Session()
|
||||
http_client.headers.update({"User-Agent": ua.random})
|
||||
if residential_proxies:
|
||||
http_client.proxies.update(residential_proxies)
|
||||
ytt_api = YouTubeTranscriptApi(http_client=http_client)
|
||||
|
||||
# List all available transcripts and pick the first one
|
||||
# (the video's primary language) instead of defaulting to English
|
||||
transcript_list = ytt_api.list(video_id)
|
||||
transcript = next(iter(transcript_list))
|
||||
captions = transcript.fetch()
|
||||
|
||||
logger.info(
|
||||
f"[scrape_webpage] Fetched transcript for {video_id} "
|
||||
f"in {transcript.language} ({transcript.language_code})"
|
||||
)
|
||||
|
||||
transcript_segments = []
|
||||
for line in captions:
|
||||
start_time = line.start
|
||||
duration = line.duration
|
||||
text = line.text
|
||||
timestamp = f"[{start_time:.2f}s-{start_time + duration:.2f}s]"
|
||||
transcript_segments.append(f"{timestamp} {text}")
|
||||
transcript_text = "\n".join(transcript_segments)
|
||||
except Exception as e:
|
||||
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
|
||||
transcript_text = f"No captions available for this video. Error: {e!s}"
|
||||
|
||||
# Build combined content
|
||||
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
|
||||
|
||||
# Truncate if needed
|
||||
content, was_truncated = truncate_content(content, max_length)
|
||||
word_count = len(content.split())
|
||||
|
||||
description = f"YouTube video by {author}"
|
||||
|
||||
return {
|
||||
"id": scrape_id,
|
||||
"assetId": url,
|
||||
"kind": "article",
|
||||
"href": url,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"content": content,
|
||||
"domain": domain,
|
||||
"word_count": word_count,
|
||||
"was_truncated": was_truncated,
|
||||
"crawler_type": "youtube_transcript",
|
||||
"author": author,
|
||||
}
|
||||
|
||||
|
||||
def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
|
||||
"""
|
||||
Factory function to create the scrape_webpage tool.
|
||||
|
||||
Args:
|
||||
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
|
||||
Falls back to Chromium/Trafilatura if not provided.
|
||||
|
||||
Returns:
|
||||
A configured tool function for scraping webpages.
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def scrape_webpage(
|
||||
url: str,
|
||||
max_length: int = 50000,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Scrape and extract the main content from a webpage.
|
||||
|
||||
Use this tool when the user wants you to read, summarize, or answer
|
||||
questions about a specific webpage's content. This tool actually
|
||||
fetches and reads the full page content. For YouTube video URLs it
|
||||
fetches the transcript directly instead of crawling the page.
|
||||
|
||||
Common triggers:
|
||||
- "Read this article and summarize it"
|
||||
- "What does this page say about X?"
|
||||
- "Summarize this blog post for me"
|
||||
- "Tell me the key points from this article"
|
||||
- "What's in this webpage?"
|
||||
|
||||
Args:
|
||||
url: The URL of the webpage to scrape (must be HTTP/HTTPS)
|
||||
max_length: Maximum content length to return (default: 50000 chars)
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- id: Unique identifier for this scrape
|
||||
- assetId: The URL (for deduplication)
|
||||
- kind: "article" (type of content)
|
||||
- href: The URL to open when clicked
|
||||
- title: Page title
|
||||
- description: Brief description or excerpt
|
||||
- content: The extracted main content (markdown format)
|
||||
- domain: The domain name
|
||||
- word_count: Approximate word count
|
||||
- was_truncated: Whether content was truncated
|
||||
- error: Error message (if scraping failed)
|
||||
"""
|
||||
scrape_id = generate_scrape_id(url)
|
||||
domain = extract_domain(url)
|
||||
|
||||
# Validate and normalize URL
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = f"https://{url}"
|
||||
|
||||
try:
|
||||
# Check if this is a YouTube URL and use transcript API instead
|
||||
video_id = get_youtube_video_id(url)
|
||||
if video_id:
|
||||
return await _scrape_youtube_video(url, video_id, max_length)
|
||||
|
||||
# Create webcrawler connector
|
||||
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
|
||||
|
||||
# Crawl the URL
|
||||
result, error = await connector.crawl_url(url, formats=["markdown"])
|
||||
|
||||
if error:
|
||||
return {
|
||||
"id": scrape_id,
|
||||
"assetId": url,
|
||||
"kind": "article",
|
||||
"href": url,
|
||||
"title": domain or "Webpage",
|
||||
"domain": domain,
|
||||
"error": error,
|
||||
}
|
||||
|
||||
if not result:
|
||||
return {
|
||||
"id": scrape_id,
|
||||
"assetId": url,
|
||||
"kind": "article",
|
||||
"href": url,
|
||||
"title": domain or "Webpage",
|
||||
"domain": domain,
|
||||
"error": "No content returned from crawler",
|
||||
}
|
||||
|
||||
# Extract content and metadata
|
||||
content = result.get("content", "")
|
||||
metadata = result.get("metadata", {})
|
||||
|
||||
# Get title from metadata
|
||||
title = metadata.get("title", "")
|
||||
if not title:
|
||||
title = domain or url.split("/")[-1] or "Webpage"
|
||||
|
||||
# Get description from metadata
|
||||
description = metadata.get("description", "")
|
||||
if not description and content:
|
||||
# Use first paragraph as description
|
||||
first_para = content.split("\n\n")[0] if content else ""
|
||||
description = (
|
||||
first_para[:300] + "..." if len(first_para) > 300 else first_para
|
||||
)
|
||||
|
||||
# Truncate content if needed
|
||||
content, was_truncated = truncate_content(content, max_length)
|
||||
|
||||
# Calculate word count
|
||||
word_count = len(content.split())
|
||||
|
||||
return {
|
||||
"id": scrape_id,
|
||||
"assetId": url,
|
||||
"kind": "article",
|
||||
"href": url,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"content": content,
|
||||
"domain": domain,
|
||||
"word_count": word_count,
|
||||
"was_truncated": was_truncated,
|
||||
"crawler_type": result.get("crawler_type", "unknown"),
|
||||
"author": metadata.get("author"),
|
||||
"date": metadata.get("date"),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"[scrape_webpage] Error scraping {url}: {error_message}")
|
||||
return {
|
||||
"id": scrape_id,
|
||||
"assetId": url,
|
||||
"kind": "article",
|
||||
"href": url,
|
||||
"title": domain or "Webpage",
|
||||
"domain": domain,
|
||||
"error": f"Failed to scrape: {error_message[:100]}",
|
||||
}
|
||||
|
||||
return scrape_webpage
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
"""Semantic search over pre-indexed in-app documentation chunks for user how-to questions."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import SurfsenseDocsChunk, SurfsenseDocsDocument
|
||||
from app.utils.document_converters import embed_text
|
||||
|
||||
|
||||
def format_surfsense_docs_results(results: list[tuple]) -> str:
|
||||
"""Format (chunk, document) rows as XML with ``doc-`` chunk IDs for citations and UI routing."""
|
||||
if not results:
|
||||
return "No relevant Surfsense documentation found for your query."
|
||||
|
||||
# Group chunks by document
|
||||
grouped: dict[int, dict] = {}
|
||||
for chunk, doc in results:
|
||||
if doc.id not in grouped:
|
||||
grouped[doc.id] = {
|
||||
"document_id": f"doc-{doc.id}",
|
||||
"document_type": "SURFSENSE_DOCS",
|
||||
"title": doc.title,
|
||||
"url": doc.source,
|
||||
"metadata": {"source": doc.source},
|
||||
"chunks": [],
|
||||
}
|
||||
grouped[doc.id]["chunks"].append(
|
||||
{
|
||||
"chunk_id": f"doc-{chunk.id}",
|
||||
"content": chunk.content,
|
||||
}
|
||||
)
|
||||
|
||||
# Render XML matching format_documents_for_context structure
|
||||
parts: list[str] = []
|
||||
for g in grouped.values():
|
||||
metadata_json = json.dumps(g["metadata"], ensure_ascii=False)
|
||||
|
||||
parts.append("<document>")
|
||||
parts.append("<document_metadata>")
|
||||
parts.append(f" <document_id>{g['document_id']}</document_id>")
|
||||
parts.append(f" <document_type>{g['document_type']}</document_type>")
|
||||
parts.append(f" <title><![CDATA[{g['title']}]]></title>")
|
||||
parts.append(f" <url><![CDATA[{g['url']}]]></url>")
|
||||
parts.append(f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>")
|
||||
parts.append("</document_metadata>")
|
||||
parts.append("")
|
||||
parts.append("<document_content>")
|
||||
|
||||
for ch in g["chunks"]:
|
||||
parts.append(
|
||||
f" <chunk id='{ch['chunk_id']}'><![CDATA[{ch['content']}]]></chunk>"
|
||||
)
|
||||
|
||||
parts.append("</document_content>")
|
||||
parts.append("</document>")
|
||||
parts.append("")
|
||||
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
|
||||
async def search_surfsense_docs_async(
|
||||
query: str,
|
||||
db_session: AsyncSession,
|
||||
top_k: int = 10,
|
||||
) -> str:
|
||||
"""
|
||||
Search Surfsense documentation using vector similarity.
|
||||
|
||||
Args:
|
||||
query: The search query about Surfsense usage
|
||||
db_session: Database session for executing queries
|
||||
top_k: Number of results to return
|
||||
|
||||
Returns:
|
||||
Formatted string with relevant documentation content
|
||||
"""
|
||||
# Get embedding for the query
|
||||
query_embedding = await asyncio.to_thread(embed_text, query)
|
||||
|
||||
# Vector similarity search on chunks, joining with documents
|
||||
stmt = (
|
||||
select(SurfsenseDocsChunk, SurfsenseDocsDocument)
|
||||
.join(
|
||||
SurfsenseDocsDocument,
|
||||
SurfsenseDocsChunk.document_id == SurfsenseDocsDocument.id,
|
||||
)
|
||||
.order_by(SurfsenseDocsChunk.embedding.op("<=>")(query_embedding))
|
||||
.limit(top_k)
|
||||
)
|
||||
|
||||
result = await db_session.execute(stmt)
|
||||
rows = result.all()
|
||||
|
||||
return format_surfsense_docs_results(rows)
|
||||
|
||||
|
||||
def create_search_surfsense_docs_tool(db_session: AsyncSession):
|
||||
"""
|
||||
Factory function to create the search_surfsense_docs tool.
|
||||
|
||||
Args:
|
||||
db_session: Database session for executing queries
|
||||
|
||||
Returns:
|
||||
A configured tool function for searching Surfsense documentation
|
||||
"""
|
||||
|
||||
@tool
|
||||
async def search_surfsense_docs(query: str, top_k: int = 10) -> str:
|
||||
"""
|
||||
Search Surfsense documentation for help with using the application.
|
||||
|
||||
Use this tool when the user asks questions about:
|
||||
- How to use Surfsense features
|
||||
- Installation and setup instructions
|
||||
- Configuration options and settings
|
||||
- Troubleshooting common issues
|
||||
- Available connectors and integrations
|
||||
- Browser extension usage
|
||||
- API documentation
|
||||
|
||||
This searches the official Surfsense documentation that was indexed
|
||||
at deployment time. It does NOT search the user's personal knowledge base.
|
||||
|
||||
Args:
|
||||
query: The search query about Surfsense usage or features
|
||||
top_k: Number of documentation chunks to retrieve (default: 10)
|
||||
|
||||
Returns:
|
||||
Relevant documentation content formatted with chunk IDs for citations
|
||||
"""
|
||||
return await search_surfsense_docs_async(
|
||||
query=query,
|
||||
db_session=db_session,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
return search_surfsense_docs
|
||||
|
|
@ -0,0 +1,241 @@
|
|||
"""Real-time web search: SearXNG plus configured live-search connectors (Tavily, Linkup, Baidu, etc.)."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.db import shielded_async_session
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
||||
_LIVE_SEARCH_CONNECTORS: set[str] = {
|
||||
"TAVILY_API",
|
||||
"LINKUP_API",
|
||||
"BAIDU_SEARCH_API",
|
||||
}
|
||||
|
||||
_LIVE_CONNECTOR_SPECS: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
|
||||
"TAVILY_API": ("search_tavily", False, True, {}),
|
||||
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
|
||||
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
|
||||
}
|
||||
|
||||
_CONNECTOR_LABELS: dict[str, str] = {
|
||||
"TAVILY_API": "Tavily",
|
||||
"LINKUP_API": "Linkup",
|
||||
"BAIDU_SEARCH_API": "Baidu",
|
||||
}
|
||||
|
||||
|
||||
class WebSearchInput(BaseModel):
|
||||
"""Input schema for the web_search tool."""
|
||||
|
||||
query: str = Field(
|
||||
description="The search query to look up on the web. Use specific, descriptive terms.",
|
||||
)
|
||||
top_k: int = Field(
|
||||
default=10,
|
||||
description="Number of results to retrieve (default: 10, max: 50).",
|
||||
)
|
||||
|
||||
|
||||
def _format_web_results(
|
||||
documents: list[dict[str, Any]],
|
||||
*,
|
||||
max_chars: int = 50_000,
|
||||
) -> str:
|
||||
"""Format web search results into XML suitable for the LLM context."""
|
||||
if not documents:
|
||||
return "No web search results found."
|
||||
|
||||
parts: list[str] = []
|
||||
total_chars = 0
|
||||
|
||||
for doc in documents:
|
||||
doc_info = doc.get("document") or {}
|
||||
metadata = doc_info.get("metadata") or {}
|
||||
title = doc_info.get("title") or "Web Result"
|
||||
url = metadata.get("url") or ""
|
||||
content = (doc.get("content") or "").strip()
|
||||
source = metadata.get("document_type") or doc.get("source") or "WEB_SEARCH"
|
||||
if not content:
|
||||
continue
|
||||
|
||||
metadata_json = json.dumps(metadata, ensure_ascii=False)
|
||||
doc_xml = "\n".join(
|
||||
[
|
||||
"<document>",
|
||||
"<document_metadata>",
|
||||
f" <document_type>{source}</document_type>",
|
||||
f" <title><![CDATA[{title}]]></title>",
|
||||
f" <url><![CDATA[{url}]]></url>",
|
||||
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
|
||||
"</document_metadata>",
|
||||
"<document_content>",
|
||||
f" <chunk id='{url}'><![CDATA[{content}]]></chunk>",
|
||||
"</document_content>",
|
||||
"</document>",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
if total_chars + len(doc_xml) > max_chars:
|
||||
parts.append("<!-- Output truncated to fit context window -->")
|
||||
break
|
||||
|
||||
parts.append(doc_xml)
|
||||
total_chars += len(doc_xml)
|
||||
|
||||
return "\n".join(parts).strip() or "No web search results found."
|
||||
|
||||
|
||||
async def _search_live_connector(
|
||||
connector: str,
|
||||
query: str,
|
||||
search_space_id: int,
|
||||
top_k: int,
|
||||
semaphore: asyncio.Semaphore,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Dispatch a single live-search connector (Tavily / Linkup / Baidu)."""
|
||||
perf = get_perf_logger()
|
||||
spec = _LIVE_CONNECTOR_SPECS.get(connector)
|
||||
if spec is None:
|
||||
return []
|
||||
|
||||
method_name, _includes_date_range, includes_top_k, extra_kwargs = spec
|
||||
kwargs: dict[str, Any] = {
|
||||
"user_query": query,
|
||||
"search_space_id": search_space_id,
|
||||
**extra_kwargs,
|
||||
}
|
||||
if includes_top_k:
|
||||
kwargs["top_k"] = top_k
|
||||
|
||||
try:
|
||||
t0 = time.perf_counter()
|
||||
async with semaphore, shielded_async_session() as session:
|
||||
svc = ConnectorService(session, search_space_id)
|
||||
_, chunks = await getattr(svc, method_name)(**kwargs)
|
||||
perf.info(
|
||||
"[web_search] connector=%s results=%d in %.3fs",
|
||||
connector,
|
||||
len(chunks),
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
return chunks
|
||||
except Exception as e:
|
||||
perf.warning("[web_search] connector=%s FAILED: %s", connector, e)
|
||||
return []
|
||||
|
||||
|
||||
def create_web_search_tool(
|
||||
search_space_id: int | None = None,
|
||||
available_connectors: list[str] | None = None,
|
||||
) -> StructuredTool:
|
||||
"""Factory for the ``web_search`` tool.
|
||||
|
||||
Dispatches in parallel to the platform SearXNG instance and any
|
||||
user-configured live-search connectors (Tavily, Linkup, Baidu).
|
||||
"""
|
||||
active_live_connectors: list[str] = []
|
||||
if available_connectors:
|
||||
active_live_connectors = [
|
||||
c for c in available_connectors if c in _LIVE_SEARCH_CONNECTORS
|
||||
]
|
||||
|
||||
engine_names = ["SearXNG (platform default)"]
|
||||
engine_names.extend(_CONNECTOR_LABELS.get(c, c) for c in active_live_connectors)
|
||||
engines_summary = ", ".join(engine_names)
|
||||
|
||||
description = (
|
||||
"Search the web for real-time information. "
|
||||
"Use this for current events, news, prices, weather, public facts, or any "
|
||||
"question that requires up-to-date information from the internet.\n\n"
|
||||
f"Active search engines: {engines_summary}.\n"
|
||||
"All configured engines are queried in parallel and results are merged."
|
||||
)
|
||||
|
||||
_search_space_id = search_space_id
|
||||
_active_live = active_live_connectors
|
||||
|
||||
async def _web_search_impl(query: str, top_k: int = 10) -> str:
|
||||
from app.services import web_search_service
|
||||
|
||||
perf = get_perf_logger()
|
||||
t0 = time.perf_counter()
|
||||
clamped_top_k = min(max(1, top_k), 50)
|
||||
|
||||
semaphore = asyncio.Semaphore(4)
|
||||
tasks: list[asyncio.Task[list[dict[str, Any]]]] = []
|
||||
|
||||
if web_search_service.is_available():
|
||||
|
||||
async def _searxng() -> list[dict[str, Any]]:
|
||||
async with semaphore:
|
||||
_result_obj, docs = await web_search_service.search(
|
||||
query=query,
|
||||
top_k=clamped_top_k,
|
||||
)
|
||||
return docs
|
||||
|
||||
tasks.append(asyncio.ensure_future(_searxng()))
|
||||
|
||||
if _search_space_id is not None:
|
||||
for connector in _active_live:
|
||||
tasks.append(
|
||||
asyncio.ensure_future(
|
||||
_search_live_connector(
|
||||
connector=connector,
|
||||
query=query,
|
||||
search_space_id=_search_space_id,
|
||||
top_k=clamped_top_k,
|
||||
semaphore=semaphore,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not tasks:
|
||||
return "Web search is not available — no search engines are configured."
|
||||
|
||||
results_lists = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
all_documents: list[dict[str, Any]] = []
|
||||
for result in results_lists:
|
||||
if isinstance(result, BaseException):
|
||||
perf.warning("[web_search] a search engine failed: %s", result)
|
||||
continue
|
||||
all_documents.extend(result)
|
||||
|
||||
seen_urls: set[str] = set()
|
||||
deduplicated: list[dict[str, Any]] = []
|
||||
for doc in all_documents:
|
||||
url = ((doc.get("document") or {}).get("metadata") or {}).get("url", "")
|
||||
if url and url in seen_urls:
|
||||
continue
|
||||
if url:
|
||||
seen_urls.add(url)
|
||||
deduplicated.append(doc)
|
||||
|
||||
formatted = _format_web_results(deduplicated)
|
||||
|
||||
perf.info(
|
||||
"[web_search] query=%r engines=%d results=%d deduped=%d chars=%d in %.3fs",
|
||||
query[:60],
|
||||
len(tasks),
|
||||
len(all_documents),
|
||||
len(deduplicated),
|
||||
len(formatted),
|
||||
time.perf_counter() - t0,
|
||||
)
|
||||
return formatted
|
||||
|
||||
return StructuredTool(
|
||||
name="web_search",
|
||||
description=description,
|
||||
coroutine=_web_search_impl,
|
||||
args_schema=WebSearchInput,
|
||||
)
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""`airtable` route: ``SubAgent`` spec for deepagents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
merge_tools_permissions,
|
||||
middleware_gated_interrupt_on,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
from .tools.index import load_tools
|
||||
|
||||
NAME = "airtable"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent:
|
||||
buckets = load_tools(dependencies=dependencies)
|
||||
merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket)
|
||||
tools = [
|
||||
row["tool"]
|
||||
for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"])
|
||||
if row.get("tool") is not None
|
||||
]
|
||||
interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket)
|
||||
description = read_md_file(__package__, "description").strip()
|
||||
if not description:
|
||||
description = "Handles airtable tasks for this workspace."
|
||||
system_prompt = read_md_file(__package__, "system_prompt").strip()
|
||||
return pack_subagent(
|
||||
name=NAME,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
interrupt_on=interrupt_on,
|
||||
model=model,
|
||||
extra_middleware=extra_middleware,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
Use for Airtable structured data operations: locate bases/tables and create/read/update records.
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
You are the Airtable MCP operations sub-agent.
|
||||
You receive delegated instructions from a supervisor agent and return structured results for supervisor synthesis.
|
||||
|
||||
<goal>
|
||||
Execute Airtable MCP base/table/record operations accurately.
|
||||
</goal>
|
||||
|
||||
<available_tools>
|
||||
- Runtime-provided Airtable MCP tools for bases, tables, and records.
|
||||
</available_tools>
|
||||
|
||||
<tool_policy>
|
||||
- Resolve base and table targets before record-level actions.
|
||||
- Do not guess IDs or schema fields.
|
||||
- If targets are ambiguous, return `status=blocked` with candidate options.
|
||||
- Never claim mutation success without tool confirmation.
|
||||
</tool_policy>
|
||||
|
||||
<out_of_scope>
|
||||
- Do not execute non-Airtable tasks.
|
||||
</out_of_scope>
|
||||
|
||||
<safety>
|
||||
- Never claim record mutations succeeded without tool confirmation.
|
||||
</safety>
|
||||
|
||||
<failure_policy>
|
||||
- On tool failure, return `status=error` with concise recovery `next_step`.
|
||||
- On unresolved target/schema ambiguity, return `status=blocked` with required options.
|
||||
</failure_policy>
|
||||
|
||||
<output_contract>
|
||||
Return **only** one JSON object (no markdown/prose):
|
||||
{
|
||||
"status": "success" | "partial" | "blocked" | "error",
|
||||
"action_summary": string,
|
||||
"evidence": { "items": object | null },
|
||||
"next_step": string | null,
|
||||
"missing_fields": string[] | null,
|
||||
"assumptions": string[] | null
|
||||
}
|
||||
Rules:
|
||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||
</output_contract>
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""Airtable route: native tool factories are empty; MCP supplies tools when configured."""
|
||||
|
||||
__all__: list[str] = []
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
||||
|
||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
||||
_ = {**(dependencies or {}), **kwargs}
|
||||
return {"allow": [], "ask": []}
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""`calendar` route: ``SubAgent`` spec for deepagents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
merge_tools_permissions,
|
||||
middleware_gated_interrupt_on,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
from .tools.index import load_tools
|
||||
|
||||
NAME = "calendar"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent:
|
||||
buckets = load_tools(dependencies=dependencies)
|
||||
merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket)
|
||||
tools = [
|
||||
row["tool"]
|
||||
for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"])
|
||||
if row.get("tool") is not None
|
||||
]
|
||||
interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket)
|
||||
description = read_md_file(__package__, "description").strip()
|
||||
if not description:
|
||||
description = "Handles calendar tasks for this workspace."
|
||||
system_prompt = read_md_file(__package__, "system_prompt").strip()
|
||||
return pack_subagent(
|
||||
name=NAME,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
interrupt_on=interrupt_on,
|
||||
model=model,
|
||||
extra_middleware=extra_middleware,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
Use for calendar planning and scheduling: check availability, read event details, create events, and update events.
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
You are the Google Calendar operations sub-agent.
|
||||
You receive delegated instructions from a supervisor agent and return structured results for supervisor synthesis.
|
||||
|
||||
<goal>
|
||||
Execute calendar event operations (search, create, update, delete) accurately with timezone-safe scheduling.
|
||||
</goal>
|
||||
|
||||
<available_tools>
|
||||
- `search_calendar_events`
|
||||
- `create_calendar_event`
|
||||
- `update_calendar_event`
|
||||
- `delete_calendar_event`
|
||||
</available_tools>
|
||||
|
||||
<tool_policy>
|
||||
- Use only tools in `<available_tools>`.
|
||||
- Resolve relative dates against current runtime timestamp.
|
||||
- If required fields (date/time/timezone/target event) are missing or ambiguous, return `status=blocked` with `missing_fields` and supervisor `next_step`.
|
||||
- Never invent event IDs or mutation results.
|
||||
</tool_policy>
|
||||
|
||||
<out_of_scope>
|
||||
- Do not perform non-calendar tasks.
|
||||
</out_of_scope>
|
||||
|
||||
<safety>
|
||||
- Before update/delete, ensure event target is explicit.
|
||||
- Never claim event mutation success without tool confirmation.
|
||||
</safety>
|
||||
|
||||
<failure_policy>
|
||||
- On tool failure, return `status=error` with concise recovery `next_step`.
|
||||
- On ambiguity, return `status=blocked` with top event candidates.
|
||||
</failure_policy>
|
||||
|
||||
<output_contract>
|
||||
Return **only** one JSON object (no markdown/prose):
|
||||
{
|
||||
"status": "success" | "partial" | "blocked" | "error",
|
||||
"action_summary": string,
|
||||
"evidence": {
|
||||
"event_id": string | null,
|
||||
"title": string | null,
|
||||
"start_at": string (ISO 8601 with timezone) | null,
|
||||
"end_at": string (ISO 8601 with timezone) | null,
|
||||
"matched_candidates": [
|
||||
{
|
||||
"event_id": string,
|
||||
"title": string | null,
|
||||
"start_at": string (ISO 8601 with timezone) | null
|
||||
}
|
||||
] | null
|
||||
},
|
||||
"next_step": string | null,
|
||||
"missing_fields": string[] | null,
|
||||
"assumptions": string[] | null
|
||||
}
|
||||
Rules:
|
||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||
</output_contract>
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from app.agents.new_chat.tools.google_calendar.create_event import (
|
||||
create_create_calendar_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.delete_event import (
|
||||
create_delete_calendar_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.search_events import (
|
||||
create_search_calendar_events_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.update_event import (
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_calendar_event_tool",
|
||||
"create_delete_calendar_event_tool",
|
||||
"create_search_calendar_events_tool",
|
||||
"create_update_calendar_event_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,324 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_calendar_event(
|
||||
summary: str,
|
||||
start_datetime: str,
|
||||
end_datetime: str,
|
||||
description: str | None = None,
|
||||
location: str | None = None,
|
||||
attendees: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new event on Google Calendar.
|
||||
|
||||
Use when the user asks to schedule, create, or add a calendar event.
|
||||
Ask for event details if not provided.
|
||||
|
||||
Args:
|
||||
summary: The event title.
|
||||
start_datetime: Start time in ISO 8601 format (e.g. "2026-03-20T10:00:00").
|
||||
end_datetime: End time in ISO 8601 format (e.g. "2026-03-20T11:00:00").
|
||||
description: Optional event description.
|
||||
location: Optional event location.
|
||||
attendees: Optional list of attendee email addresses.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- html_link: URL to open the event (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
|
||||
Examples:
|
||||
- "Schedule a meeting with John tomorrow at 10am"
|
||||
- "Create a calendar event for the team standup"
|
||||
"""
|
||||
logger.info(
|
||||
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning(
|
||||
"All Google Calendar accounts have expired authentication"
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating calendar event: summary='{summary}'"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_creation",
|
||||
tool_name="create_calendar_event",
|
||||
params={
|
||||
"summary": summary,
|
||||
"start_datetime": start_datetime,
|
||||
"end_datetime": end_datetime,
|
||||
"description": description,
|
||||
"location": location,
|
||||
"attendees": attendees,
|
||||
"timezone": context.get("timezone"),
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_summary = result.params.get("summary", summary)
|
||||
final_start_datetime = result.params.get("start_datetime", start_datetime)
|
||||
final_end_datetime = result.params.get("end_datetime", end_datetime)
|
||||
final_description = result.params.get("description", description)
|
||||
final_location = result.params.get("location", location)
|
||||
final_attendees = result.params.get("attendees", attendees)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
if not final_summary or not final_summary.strip():
|
||||
return {"status": "error", "message": "Event summary cannot be empty."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
tz = context.get("timezone", "UTC")
|
||||
event_body: dict[str, Any] = {
|
||||
"summary": final_summary,
|
||||
"start": {"dateTime": final_start_datetime, "timeZone": tz},
|
||||
"end": {"dateTime": final_end_datetime, "timeZone": tz},
|
||||
}
|
||||
if final_description:
|
||||
event_body["description"] = final_description
|
||||
if final_location:
|
||||
event_body["location"] = final_location
|
||||
if final_attendees:
|
||||
event_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_attendees if e.strip()
|
||||
]
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.insert(calendarId="primary", body=event_body)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
event_id=created.get("id"),
|
||||
event_summary=final_summary,
|
||||
calendar_id="primary",
|
||||
start_time=final_start_datetime,
|
||||
end_time=final_end_datetime,
|
||||
location=final_location,
|
||||
html_link=created.get("htmlLink"),
|
||||
description=final_description,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": created.get("id"),
|
||||
"html_link": created.get("htmlLink"),
|
||||
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the event. Please try again.",
|
||||
}
|
||||
|
||||
return create_calendar_event
|
||||
|
|
@ -0,0 +1,304 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_calendar_event(
|
||||
event_title_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a Google Calendar event.
|
||||
|
||||
Use when the user asks to delete, remove, or cancel a calendar event.
|
||||
|
||||
Args:
|
||||
event_title_or_id: The exact title or event ID of the event to delete.
|
||||
delete_from_kb: Whether to also remove the event from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both Google Calendar and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the event name or check if it has been indexed.
|
||||
Examples:
|
||||
- "Delete the team standup event"
|
||||
- "Cancel my dentist appointment on Friday"
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch deletion context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Calendar account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_deletion",
|
||||
tool_name="delete_calendar_event",
|
||||
params={
|
||||
"event_id": event_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_event_id = result.params.get("event_id", event_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.delete(calendarId="primary", eventId=final_event_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event deleted: event_id={final_event_id}")
|
||||
|
||||
delete_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
delete_result["warning"] = (
|
||||
f"Event deleted, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
delete_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
delete_result["message"] = (
|
||||
f"{delete_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return delete_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error deleting calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the event. Please try again.",
|
||||
}
|
||||
|
||||
return delete_calendar_event
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .create_event import create_create_calendar_event_tool
|
||||
from .delete_event import create_delete_calendar_event_tool
|
||||
from .search_events import create_search_calendar_events_tool
|
||||
from .update_event import create_update_calendar_event_tool
|
||||
|
||||
|
||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
||||
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
||||
session_dependencies = {
|
||||
"db_session": resolved_dependencies["db_session"],
|
||||
"search_space_id": resolved_dependencies["search_space_id"],
|
||||
"user_id": resolved_dependencies["user_id"],
|
||||
}
|
||||
search = create_search_calendar_events_tool(**session_dependencies)
|
||||
create = create_create_calendar_event_tool(**session_dependencies)
|
||||
update = create_update_calendar_event_tool(**session_dependencies)
|
||||
delete = create_delete_calendar_event_tool(**session_dependencies)
|
||||
return {
|
||||
"allow": [{"name": getattr(search, "name", "") or "", "tool": search}],
|
||||
"ask": [
|
||||
{"name": getattr(create, "name", "") or "", "tool": create},
|
||||
{"name": getattr(update, "name", "") or "", "tool": update},
|
||||
{"name": getattr(delete, "name", "") or "", "tool": delete},
|
||||
],
|
||||
}
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CALENDAR_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
|
||||
def create_search_calendar_events_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def search_calendar_events(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
max_results: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Search Google Calendar events within a date range.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format (e.g. "2026-04-01").
|
||||
end_date: End date in YYYY-MM-DD format (e.g. "2026-04-30").
|
||||
max_results: Maximum number of events to return (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of events including
|
||||
event_id, summary, start, end, location, attendees.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Calendar tool not properly configured.",
|
||||
}
|
||||
|
||||
max_results = min(max_results, 50)
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||
|
||||
cal = GoogleCalendarConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
if error:
|
||||
if (
|
||||
"re-authenticate" in error.lower()
|
||||
or "authentication failed" in error.lower()
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error,
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
if "no events found" in error.lower():
|
||||
return {
|
||||
"status": "success",
|
||||
"events": [],
|
||||
"total": 0,
|
||||
"message": error,
|
||||
}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
events = []
|
||||
for ev in events_raw:
|
||||
start = ev.get("start", {})
|
||||
end = ev.get("end", {})
|
||||
attendees_raw = ev.get("attendees", [])
|
||||
events.append(
|
||||
{
|
||||
"event_id": ev.get("id"),
|
||||
"summary": ev.get("summary", "No Title"),
|
||||
"start": start.get("dateTime") or start.get("date", ""),
|
||||
"end": end.get("dateTime") or end.get("date", ""),
|
||||
"location": ev.get("location", ""),
|
||||
"description": ev.get("description", ""),
|
||||
"html_link": ev.get("htmlLink", ""),
|
||||
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
|
||||
"status": ev.get("status", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {"status": "success", "events": events, "total": len(events)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error searching calendar events: %s", e, exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Failed to search calendar events. Please try again.",
|
||||
}
|
||||
|
||||
return search_calendar_events
|
||||
|
|
@ -0,0 +1,356 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.google_calendar import GoogleCalendarToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_date_only(value: str) -> bool:
|
||||
"""Return True when *value* looks like a bare date (YYYY-MM-DD) with no time component."""
|
||||
return len(value) <= 10 and "T" not in value
|
||||
|
||||
|
||||
def _build_time_body(value: str, context: dict[str, Any] | Any) -> dict[str, str]:
|
||||
"""Build a Google Calendar start/end body using ``date`` for all-day
|
||||
events and ``dateTime`` for timed events."""
|
||||
if _is_date_only(value):
|
||||
return {"date": value}
|
||||
tz = context.get("timezone", "UTC") if isinstance(context, dict) else "UTC"
|
||||
return {"dateTime": value, "timeZone": tz}
|
||||
|
||||
|
||||
def create_update_calendar_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_calendar_event(
|
||||
event_title_or_id: str,
|
||||
new_summary: str | None = None,
|
||||
new_start_datetime: str | None = None,
|
||||
new_end_datetime: str | None = None,
|
||||
new_description: str | None = None,
|
||||
new_location: str | None = None,
|
||||
new_attendees: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Google Calendar event.
|
||||
|
||||
Use when the user asks to modify, reschedule, or change a calendar event.
|
||||
|
||||
Args:
|
||||
event_title_or_id: The exact title or event ID of the event to update.
|
||||
new_summary: New event title (if changing).
|
||||
new_start_datetime: New start time in ISO 8601 format (if rescheduling).
|
||||
new_end_datetime: New end time in ISO 8601 format (if rescheduling).
|
||||
new_description: New event description (if changing).
|
||||
new_location: New event location (if changing).
|
||||
new_attendees: New list of attendee email addresses (if changing).
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", "auth_error", or "error"
|
||||
- event_id: Google Calendar event ID (if success)
|
||||
- html_link: URL to open the event (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the event name or check if it has been indexed.
|
||||
Examples:
|
||||
- "Reschedule the team standup to 3pm"
|
||||
- "Change the location of my dentist appointment"
|
||||
"""
|
||||
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Calendar tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleCalendarToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, event_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Event not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if context.get("auth_expired"):
|
||||
logger.warning("Google Calendar account has expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_calendar",
|
||||
}
|
||||
|
||||
event = context["event"]
|
||||
event_id = event["event_id"]
|
||||
document_id = event.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not event_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_calendar_event_update",
|
||||
tool_name="update_calendar_event",
|
||||
params={
|
||||
"event_id": event_id,
|
||||
"document_id": document_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"new_summary": new_summary,
|
||||
"new_start_datetime": new_start_datetime,
|
||||
"new_end_datetime": new_end_datetime,
|
||||
"new_description": new_description,
|
||||
"new_location": new_location,
|
||||
"new_attendees": new_attendees,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_event_id = result.params.get("event_id", event_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_new_summary = result.params.get("new_summary", new_summary)
|
||||
final_new_start_datetime = result.params.get(
|
||||
"new_start_datetime", new_start_datetime
|
||||
)
|
||||
final_new_end_datetime = result.params.get(
|
||||
"new_end_datetime", new_end_datetime
|
||||
)
|
||||
final_new_description = result.params.get(
|
||||
"new_description", new_description
|
||||
)
|
||||
final_new_location = result.params.get("new_location", new_location)
|
||||
final_new_attendees = result.params.get("new_attendees", new_attendees)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this event.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_calendar_types = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_calendar_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this connector.",
|
||||
}
|
||||
else:
|
||||
config_data = dict(connector.config)
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and app_config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(app_config.SECRET_KEY)
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if config_data.get(key):
|
||||
config_data[key] = token_encryption.decrypt_token(
|
||||
config_data[key]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
service = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: build("calendar", "v3", credentials=creds)
|
||||
)
|
||||
|
||||
update_body: dict[str, Any] = {}
|
||||
if final_new_summary is not None:
|
||||
update_body["summary"] = final_new_summary
|
||||
if final_new_start_datetime is not None:
|
||||
update_body["start"] = _build_time_body(
|
||||
final_new_start_datetime, context
|
||||
)
|
||||
if final_new_end_datetime is not None:
|
||||
update_body["end"] = _build_time_body(final_new_end_datetime, context)
|
||||
if final_new_description is not None:
|
||||
update_body["description"] = final_new_description
|
||||
if final_new_location is not None:
|
||||
update_body["location"] = final_new_location
|
||||
if final_new_attendees is not None:
|
||||
update_body["attendees"] = [
|
||||
{"email": e.strip()} for e in final_new_attendees if e.strip()
|
||||
]
|
||||
|
||||
if not update_body:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No changes specified. Please provide at least one field to update.",
|
||||
}
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
service.events()
|
||||
.patch(
|
||||
calendarId="primary",
|
||||
eventId=final_event_id,
|
||||
body=update_body,
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Calendar event updated: event_id={final_event_id}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id is not None:
|
||||
try:
|
||||
from app.services.google_calendar import GoogleCalendarKBSyncService
|
||||
|
||||
kb_service = GoogleCalendarKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=document_id,
|
||||
event_id=final_event_id,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": final_event_id,
|
||||
"html_link": updated.get("htmlLink"),
|
||||
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error updating calendar event: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the event. Please try again.",
|
||||
}
|
||||
|
||||
return update_calendar_event
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""`clickup` route: ``SubAgent`` spec for deepagents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
merge_tools_permissions,
|
||||
middleware_gated_interrupt_on,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
from .tools.index import load_tools
|
||||
|
||||
NAME = "clickup"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent:
|
||||
buckets = load_tools(dependencies=dependencies)
|
||||
merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket)
|
||||
tools = [
|
||||
row["tool"]
|
||||
for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"])
|
||||
if row.get("tool") is not None
|
||||
]
|
||||
interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket)
|
||||
description = read_md_file(__package__, "description").strip()
|
||||
if not description:
|
||||
description = "Handles clickup tasks for this workspace."
|
||||
system_prompt = read_md_file(__package__, "system_prompt").strip()
|
||||
return pack_subagent(
|
||||
name=NAME,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
interrupt_on=interrupt_on,
|
||||
model=model,
|
||||
extra_middleware=extra_middleware,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
Use for ClickUp task management: find tasks/lists, update task fields, and track execution progress.
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
You are the ClickUp MCP operations sub-agent.
|
||||
You receive delegated instructions from a supervisor agent and return structured results for supervisor synthesis.
|
||||
|
||||
<goal>
|
||||
Execute ClickUp MCP operations accurately using only runtime-provided tools.
|
||||
</goal>
|
||||
|
||||
<available_tools>
|
||||
- Runtime-provided ClickUp MCP tools for task/workspace search and mutation.
|
||||
</available_tools>
|
||||
|
||||
<tool_policy>
|
||||
- Follow tool descriptions exactly.
|
||||
- If task/workspace target is ambiguous or missing, return `status=blocked` with required disambiguation fields.
|
||||
- Never claim mutation success without tool confirmation.
|
||||
</tool_policy>
|
||||
|
||||
<out_of_scope>
|
||||
- Do not execute non-ClickUp tasks.
|
||||
</out_of_scope>
|
||||
|
||||
<safety>
|
||||
- Never claim update/create success without tool confirmation.
|
||||
</safety>
|
||||
|
||||
<failure_policy>
|
||||
- On tool failure, return `status=error` with concise recovery `next_step`.
|
||||
- On unresolved ambiguity, return `status=blocked` with candidate options.
|
||||
</failure_policy>
|
||||
|
||||
<output_contract>
|
||||
Return **only** one JSON object (no markdown/prose):
|
||||
{
|
||||
"status": "success" | "partial" | "blocked" | "error",
|
||||
"action_summary": string,
|
||||
"evidence": { "items": object | null },
|
||||
"next_step": string | null,
|
||||
"missing_fields": string[] | null,
|
||||
"assumptions": string[] | null
|
||||
}
|
||||
Rules:
|
||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||
</output_contract>
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""ClickUp route: native tool factories are empty; MCP supplies tools when configured."""
|
||||
|
||||
__all__: list[str] = []
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
||||
|
||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
||||
_ = {**(dependencies or {}), **kwargs}
|
||||
return {"allow": [], "ask": []}
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""`confluence` route: ``SubAgent`` spec for deepagents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
merge_tools_permissions,
|
||||
middleware_gated_interrupt_on,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
from .tools.index import load_tools
|
||||
|
||||
NAME = "confluence"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent:
|
||||
buckets = load_tools(dependencies=dependencies)
|
||||
merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket)
|
||||
tools = [
|
||||
row["tool"]
|
||||
for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"])
|
||||
if row.get("tool") is not None
|
||||
]
|
||||
interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket)
|
||||
description = read_md_file(__package__, "description").strip()
|
||||
if not description:
|
||||
description = "Handles confluence tasks for this workspace."
|
||||
system_prompt = read_md_file(__package__, "system_prompt").strip()
|
||||
return pack_subagent(
|
||||
name=NAME,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
interrupt_on=interrupt_on,
|
||||
model=model,
|
||||
extra_middleware=extra_middleware,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
Use for Confluence knowledge pages: search/read existing pages, create new pages, and update page content.
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
You are the Confluence operations sub-agent.
|
||||
You receive delegated instructions from a supervisor agent and return structured results for supervisor synthesis.
|
||||
|
||||
<goal>
|
||||
Execute Confluence page operations accurately in the connected space.
|
||||
</goal>
|
||||
|
||||
<available_tools>
|
||||
- `create_confluence_page`
|
||||
- `update_confluence_page`
|
||||
- `delete_confluence_page`
|
||||
</available_tools>
|
||||
|
||||
<tool_policy>
|
||||
- Use only tools in `<available_tools>`.
|
||||
- Verify target page and intended mutation before update/delete.
|
||||
- If target page is ambiguous, return `status=blocked` with candidate options for supervisor disambiguation.
|
||||
- Never invent page IDs, titles, or mutation outcomes.
|
||||
</tool_policy>
|
||||
|
||||
<out_of_scope>
|
||||
- Do not perform non-Confluence tasks.
|
||||
</out_of_scope>
|
||||
|
||||
<safety>
|
||||
- Never claim page mutation success without tool confirmation.
|
||||
- If destructive action appears already completed in this session, do not repeat; return prior evidence with an `assumptions` note.
|
||||
</safety>
|
||||
|
||||
<failure_policy>
|
||||
- On tool failure, return `status=error` with concise retry/recovery `next_step`.
|
||||
- On unresolved page ambiguity, return `status=blocked` with candidates.
|
||||
</failure_policy>
|
||||
|
||||
<output_contract>
|
||||
Return **only** one JSON object (no markdown/prose):
|
||||
{
|
||||
"status": "success" | "partial" | "blocked" | "error",
|
||||
"action_summary": string,
|
||||
"evidence": {
|
||||
"page_id": string | null,
|
||||
"page_title": string | null,
|
||||
"matched_candidates": [
|
||||
{ "page_id": string, "page_title": string | null }
|
||||
] | null
|
||||
},
|
||||
"next_step": string | null,
|
||||
"missing_fields": string[] | null,
|
||||
"assumptions": string[] | null
|
||||
}
|
||||
Rules:
|
||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||
</output_contract>
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
"""Confluence tools for creating, updating, and deleting pages."""
|
||||
|
||||
from .create_page import create_create_confluence_page_tool
|
||||
from .delete_page import create_delete_confluence_page_tool
|
||||
from .update_page import create_update_confluence_page_tool
|
||||
|
||||
__all__ = [
|
||||
"create_create_confluence_page_tool",
|
||||
"create_delete_confluence_page_tool",
|
||||
"create_update_confluence_page_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,211 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_confluence_page(
|
||||
title: str,
|
||||
content: str | None = None,
|
||||
space_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new page in Confluence.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new Confluence page.
|
||||
|
||||
Args:
|
||||
title: Title of the page.
|
||||
content: Optional HTML/storage format content for the page body.
|
||||
space_id: Optional Confluence space ID to create the page in.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, page_id, and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(f"create_confluence_page called: title='{title}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Confluence accounts need re-authentication.",
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
|
||||
result = request_approval(
|
||||
action_type="confluence_page_creation",
|
||||
tool_name="create_confluence_page",
|
||||
params={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"space_id": space_id,
|
||||
"connector_id": connector_id,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_title = result.params.get("title", title)
|
||||
final_content = result.params.get("content", content) or ""
|
||||
final_space_id = result.params.get("space_id", space_id)
|
||||
final_connector_id = result.params.get("connector_id", connector_id)
|
||||
|
||||
if not final_title or not final_title.strip():
|
||||
return {"status": "error", "message": "Page title cannot be empty."}
|
||||
if not final_space_id:
|
||||
return {"status": "error", "message": "A space must be selected."}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
actual_connector_id = final_connector_id
|
||||
if actual_connector_id is None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Confluence connector found.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == actual_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=actual_connector_id
|
||||
)
|
||||
api_result = await client.create_page(
|
||||
space_id=final_space_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
_conn = connector
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_id = str(api_result.get("id", ""))
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
page_id=page_id,
|
||||
page_title=final_title,
|
||||
space_id=final_space_id,
|
||||
body_content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error creating Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the page.",
|
||||
}
|
||||
|
||||
return create_confluence_page
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_confluence_page(
|
||||
page_title_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a Confluence page.
|
||||
|
||||
Use this tool when the user asks to delete or remove a Confluence page.
|
||||
|
||||
Args:
|
||||
page_title_or_id: The page title or ID to identify the page.
|
||||
delete_from_kb: Whether to also remove from the knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message, and deleted_from_kb.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message to the user.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_deletion_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
page_title = page_data.get("page_title", "")
|
||||
document_id = page_data["document_id"]
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
result = request_approval(
|
||||
action_type="confluence_page_deletion",
|
||||
tool_name="delete_confluence_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
await client.delete_page(final_page_id)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
|
||||
message = f"Confluence page '{page_title}' deleted successfully."
|
||||
if deleted_from_kb:
|
||||
message += " Also removed from the knowledge base."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"deleted_from_kb": deleted_from_kb,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error deleting Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the page.",
|
||||
}
|
||||
|
||||
return delete_confluence_page
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .create_page import create_create_confluence_page_tool
|
||||
from .delete_page import create_delete_confluence_page_tool
|
||||
from .update_page import create_update_confluence_page_tool
|
||||
|
||||
|
||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
||||
resolved_dependencies = {**(dependencies or {}), **kwargs}
|
||||
session_dependencies = {
|
||||
"db_session": resolved_dependencies["db_session"],
|
||||
"search_space_id": resolved_dependencies["search_space_id"],
|
||||
"user_id": resolved_dependencies["user_id"],
|
||||
"connector_id": resolved_dependencies.get("connector_id"),
|
||||
}
|
||||
create = create_create_confluence_page_tool(**session_dependencies)
|
||||
update = create_update_confluence_page_tool(**session_dependencies)
|
||||
delete = create_delete_confluence_page_tool(**session_dependencies)
|
||||
return {
|
||||
"allow": [],
|
||||
"ask": [
|
||||
{"name": getattr(create, "name", "") or "", "tool": create},
|
||||
{"name": getattr(update, "name", "") or "", "tool": update},
|
||||
{"name": getattr(delete, "name", "") or "", "tool": delete},
|
||||
],
|
||||
}
|
||||
|
|
@ -0,0 +1,218 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.confluence_history import ConfluenceHistoryConnector
|
||||
from app.services.confluence import ConfluenceToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_confluence_page_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
connector_id: int | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_confluence_page(
|
||||
page_title_or_id: str,
|
||||
new_title: str | None = None,
|
||||
new_content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Confluence page.
|
||||
|
||||
Use this tool when the user asks to modify or edit a Confluence page.
|
||||
|
||||
Args:
|
||||
page_title_or_id: The page title or ID to identify the page.
|
||||
new_title: Optional new title for the page.
|
||||
new_content: Optional new HTML/storage format content.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and message.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", do NOT retry.
|
||||
- If status is "not_found", relay the message to the user.
|
||||
- If status is "insufficient_permissions", inform user to re-authenticate.
|
||||
"""
|
||||
logger.info(
|
||||
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Confluence tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = ConfluenceToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, page_title_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if context.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error_msg,
|
||||
"connector_id": context.get("connector_id"),
|
||||
"connector_type": "confluence",
|
||||
}
|
||||
if "not found" in error_msg.lower():
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
page_data = context["page"]
|
||||
page_id = page_data["page_id"]
|
||||
current_title = page_data["page_title"]
|
||||
current_body = page_data.get("body", "")
|
||||
current_version = page_data.get("version", 1)
|
||||
document_id = page_data.get("document_id")
|
||||
connector_id_from_context = context.get("account", {}).get("id")
|
||||
|
||||
result = request_approval(
|
||||
action_type="confluence_page_update",
|
||||
tool_name="update_confluence_page",
|
||||
params={
|
||||
"page_id": page_id,
|
||||
"document_id": document_id,
|
||||
"new_title": new_title,
|
||||
"new_content": new_content,
|
||||
"version": current_version,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_page_id = result.params.get("page_id", page_id)
|
||||
final_title = result.params.get("new_title", new_title) or current_title
|
||||
final_content = result.params.get("new_content", new_content)
|
||||
if final_content is None:
|
||||
final_content = current_body
|
||||
final_version = result.params.get("version", current_version)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_document_id = result.params.get("document_id", document_id)
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this page.",
|
||||
}
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Confluence connector is invalid.",
|
||||
}
|
||||
|
||||
try:
|
||||
client = ConfluenceHistoryConnector(
|
||||
session=db_session, connector_id=final_connector_id
|
||||
)
|
||||
api_result = await client.update_page(
|
||||
page_id=final_page_id,
|
||||
title=final_title,
|
||||
body=final_content,
|
||||
version_number=final_version + 1,
|
||||
)
|
||||
await client.close()
|
||||
except Exception as api_err:
|
||||
if (
|
||||
"http 403" in str(api_err).lower()
|
||||
or "status code 403" in str(api_err).lower()
|
||||
):
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": final_connector_id,
|
||||
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
page_links = (
|
||||
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
|
||||
)
|
||||
page_url = ""
|
||||
if page_links.get("base") and page_links.get("webui"):
|
||||
page_url = f"{page_links['base']}{page_links['webui']}"
|
||||
|
||||
kb_message_suffix = ""
|
||||
if final_document_id:
|
||||
try:
|
||||
from app.services.confluence import ConfluenceKBSyncService
|
||||
|
||||
kb_service = ConfluenceKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_update(
|
||||
document_id=final_document_id,
|
||||
page_id=final_page_id,
|
||||
user_id=user_id,
|
||||
search_space_id=search_space_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after update failed: {kb_err}")
|
||||
kb_message_suffix = (
|
||||
" The knowledge base will be updated in the next sync."
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"page_id": final_page_id,
|
||||
"page_url": page_url,
|
||||
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error updating Confluence page: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the page.",
|
||||
}
|
||||
|
||||
return update_confluence_page
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""`discord` route: ``SubAgent`` spec for deepagents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
merge_tools_permissions,
|
||||
middleware_gated_interrupt_on,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
from .tools.index import load_tools
|
||||
|
||||
NAME = "discord"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent:
|
||||
buckets = load_tools(dependencies=dependencies)
|
||||
merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket)
|
||||
tools = [
|
||||
row["tool"]
|
||||
for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"])
|
||||
if row.get("tool") is not None
|
||||
]
|
||||
interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket)
|
||||
description = read_md_file(__package__, "description").strip()
|
||||
if not description:
|
||||
description = "Handles discord tasks for this workspace."
|
||||
system_prompt = read_md_file(__package__, "system_prompt").strip()
|
||||
return pack_subagent(
|
||||
name=NAME,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
interrupt_on=interrupt_on,
|
||||
model=model,
|
||||
extra_middleware=extra_middleware,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
Use for Discord communication: read channel/thread messages, gather context, and send replies.
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
You are the Discord operations sub-agent.
|
||||
You receive delegated instructions from a supervisor agent and return structured results for supervisor synthesis.
|
||||
|
||||
<goal>
|
||||
Execute Discord reads and sends accurately in the connected server/workspace.
|
||||
</goal>
|
||||
|
||||
<available_tools>
|
||||
- `list_discord_channels`
|
||||
- `read_discord_messages`
|
||||
- `send_discord_message`
|
||||
</available_tools>
|
||||
|
||||
<tool_policy>
|
||||
- Use only tools in `<available_tools>`.
|
||||
- Resolve channel/thread targets before reads/sends.
|
||||
- If target is ambiguous, return `status=blocked` with candidate channels/threads.
|
||||
- Never invent message content, sender identity, timestamps, or delivery results.
|
||||
</tool_policy>
|
||||
|
||||
<out_of_scope>
|
||||
- Do not perform non-Discord tasks.
|
||||
</out_of_scope>
|
||||
|
||||
<safety>
|
||||
- Before send, verify destination and message intent match delegated instructions.
|
||||
- Never claim send success without tool confirmation.
|
||||
</safety>
|
||||
|
||||
<failure_policy>
|
||||
- On tool failure, return `status=error` with concise recovery `next_step`.
|
||||
- On unresolved destination ambiguity, return `status=blocked` with candidate options.
|
||||
</failure_policy>
|
||||
|
||||
<output_contract>
|
||||
Return **only** one JSON object (no markdown/prose):
|
||||
{
|
||||
"status": "success" | "partial" | "blocked" | "error",
|
||||
"action_summary": string,
|
||||
"evidence": {
|
||||
"channel_id": string | null,
|
||||
"thread_id": string | null,
|
||||
"message_id": string | null,
|
||||
"matched_candidates": [
|
||||
{ "channel_id": string, "thread_id": string | null, "label": string | null }
|
||||
] | null
|
||||
},
|
||||
"next_step": string | null,
|
||||
"missing_fields": string[] | null,
|
||||
"assumptions": string[] | null
|
||||
}
|
||||
Rules:
|
||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||
</output_contract>
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.discord.list_channels import (
|
||||
create_list_discord_channels_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.discord.read_messages import (
|
||||
create_read_discord_messages_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.discord.send_message import (
|
||||
create_send_discord_message_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_list_discord_channels_tool",
|
||||
"create_read_discord_messages_tool",
|
||||
"create_send_discord_message_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
"""Builds Discord REST API auth headers for connector-backed tools."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
DISCORD_API = "https://discord.com/api/v10"
|
||||
|
||||
|
||||
async def get_discord_connector(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> SearchSourceConnector | None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
def get_bot_token(connector: SearchSourceConnector) -> str:
|
||||
"""Extract and decrypt the bot token from connector config."""
|
||||
cfg = dict(connector.config)
|
||||
if cfg.get("_token_encrypted") and config.SECRET_KEY:
|
||||
enc = TokenEncryption(config.SECRET_KEY)
|
||||
if cfg.get("bot_token"):
|
||||
cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"])
|
||||
token = cfg.get("bot_token")
|
||||
if not token:
|
||||
raise ValueError("Discord bot token not found in connector config.")
|
||||
return token
|
||||
|
||||
|
||||
def get_guild_id(connector: SearchSourceConnector) -> str | None:
|
||||
return connector.config.get("guild_id")
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .list_channels import create_list_discord_channels_tool
|
||||
from .read_messages import create_read_discord_messages_tool
|
||||
from .send_message import create_send_discord_message_tool
|
||||
|
||||
|
||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
||||
d = {**(dependencies or {}), **kwargs}
|
||||
common = {
|
||||
"db_session": d["db_session"],
|
||||
"search_space_id": d["search_space_id"],
|
||||
"user_id": d["user_id"],
|
||||
}
|
||||
list_ch = create_list_discord_channels_tool(**common)
|
||||
read_msg = create_read_discord_messages_tool(**common)
|
||||
send = create_send_discord_message_tool(**common)
|
||||
return {
|
||||
"allow": [
|
||||
{"name": getattr(list_ch, "name", "") or "", "tool": list_ch},
|
||||
{"name": getattr(read_msg, "name", "") or "", "tool": read_msg},
|
||||
],
|
||||
"ask": [{"name": getattr(send, "name", "") or "", "tool": send}],
|
||||
}
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_list_discord_channels_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def list_discord_channels() -> dict[str, Any]:
|
||||
"""List text channels in the connected Discord server.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of channels (id, name).
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Discord tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
guild_id = get_guild_id(connector)
|
||||
if not guild_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No guild ID in Discord connector config.",
|
||||
}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Discord bot token is invalid.",
|
||||
"connector_type": "discord",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Discord API error: {resp.status_code}",
|
||||
}
|
||||
|
||||
# Type 0 = text channel
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch["name"]}
|
||||
for ch in resp.json()
|
||||
if ch.get("type") == 0
|
||||
]
|
||||
return {
|
||||
"status": "success",
|
||||
"guild_id": guild_id,
|
||||
"channels": channels,
|
||||
"total": len(channels),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error listing Discord channels: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to list Discord channels."}
|
||||
|
||||
return list_discord_channels
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_read_discord_messages_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_discord_messages(
|
||||
channel_id: str,
|
||||
limit: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Read recent messages from a Discord text channel.
|
||||
|
||||
Args:
|
||||
channel_id: The Discord channel ID (from list_discord_channels).
|
||||
limit: Number of messages to fetch (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of messages including
|
||||
id, author, content, timestamp.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Discord tool not properly configured.",
|
||||
}
|
||||
|
||||
limit = min(limit, 50)
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
params={"limit": limit},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Discord bot token is invalid.",
|
||||
"connector_type": "discord",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Bot lacks permission to read this channel.",
|
||||
}
|
||||
if resp.status_code != 200:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Discord API error: {resp.status_code}",
|
||||
}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"id": m["id"],
|
||||
"author": m.get("author", {}).get("username", "Unknown"),
|
||||
"content": m.get("content", ""),
|
||||
"timestamp": m.get("timestamp", ""),
|
||||
}
|
||||
for m in resp.json()
|
||||
]
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"channel_id": channel_id,
|
||||
"messages": messages,
|
||||
"total": len(messages),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Discord messages: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read Discord messages."}
|
||||
|
||||
return read_discord_messages
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_discord_message_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_discord_message(
|
||||
channel_id: str,
|
||||
content: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a message to a Discord text channel.
|
||||
|
||||
Args:
|
||||
channel_id: The Discord channel ID (from list_discord_channels).
|
||||
content: The message text (max 2000 characters).
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message_id on success.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Discord tool not properly configured.",
|
||||
}
|
||||
|
||||
if len(content) > 2000:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Message exceeds Discord's 2000-character limit.",
|
||||
}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(
|
||||
db_session, search_space_id, user_id
|
||||
)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="discord_send_message",
|
||||
tool_name="send_discord_message",
|
||||
params={"channel_id": channel_id, "content": content},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Message was not sent.",
|
||||
}
|
||||
|
||||
final_content = result.params.get("content", content)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{DISCORD_API}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bot {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"content": final_content},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Discord bot token is invalid.",
|
||||
"connector_type": "discord",
|
||||
}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Bot lacks permission to send messages in this channel.",
|
||||
}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Discord API error: {resp.status_code}",
|
||||
}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": f"Message sent to channel {final_channel}.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error sending Discord message: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to send Discord message."}
|
||||
|
||||
return send_discord_message
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""`dropbox` route: ``SubAgent`` spec for deepagents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
merge_tools_permissions,
|
||||
middleware_gated_interrupt_on,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
from .tools.index import load_tools
|
||||
|
||||
NAME = "dropbox"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent:
|
||||
buckets = load_tools(dependencies=dependencies)
|
||||
merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket)
|
||||
tools = [
|
||||
row["tool"]
|
||||
for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"])
|
||||
if row.get("tool") is not None
|
||||
]
|
||||
interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket)
|
||||
description = read_md_file(__package__, "description").strip()
|
||||
if not description:
|
||||
description = "Handles dropbox tasks for this workspace."
|
||||
system_prompt = read_md_file(__package__, "system_prompt").strip()
|
||||
return pack_subagent(
|
||||
name=NAME,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
interrupt_on=interrupt_on,
|
||||
model=model,
|
||||
extra_middleware=extra_middleware,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
Use for Dropbox file storage tasks: browse folders, read files, and manage Dropbox file content.
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
You are the Dropbox operations sub-agent.
|
||||
You receive delegated instructions from a supervisor agent and return structured results for supervisor synthesis.
|
||||
|
||||
<goal>
|
||||
Execute Dropbox file create/delete actions accurately in the connected account.
|
||||
</goal>
|
||||
|
||||
<available_tools>
|
||||
- `create_dropbox_file`
|
||||
- `delete_dropbox_file`
|
||||
</available_tools>
|
||||
|
||||
<tool_policy>
|
||||
- Use only tools in `<available_tools>`.
|
||||
- Ensure target path/file identity is explicit before mutate actions.
|
||||
- If target is ambiguous, return `status=blocked` with candidate paths.
|
||||
- Never invent file IDs/paths or mutation outcomes.
|
||||
</tool_policy>
|
||||
|
||||
<out_of_scope>
|
||||
- Do not perform non-Dropbox tasks.
|
||||
</out_of_scope>
|
||||
|
||||
<safety>
|
||||
- Never claim file mutation success without tool confirmation.
|
||||
</safety>
|
||||
|
||||
<failure_policy>
|
||||
- On tool failure, return `status=error` with concise recovery `next_step`.
|
||||
- On target ambiguity, return `status=blocked` with candidate paths.
|
||||
</failure_policy>
|
||||
|
||||
<output_contract>
|
||||
Return **only** one JSON object (no markdown/prose):
|
||||
{
|
||||
"status": "success" | "partial" | "blocked" | "error",
|
||||
"action_summary": string,
|
||||
"evidence": {
|
||||
"file_path": string | null,
|
||||
"file_id": string | null,
|
||||
"operation": "create" | "delete" | null,
|
||||
"matched_candidates": string[] | null
|
||||
},
|
||||
"next_step": string | null,
|
||||
"missing_fields": string[] | null,
|
||||
"assumptions": string[] | null
|
||||
}
|
||||
Rules:
|
||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||
</output_contract>
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
from app.agents.new_chat.tools.dropbox.create_file import (
|
||||
create_create_dropbox_file_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.dropbox.trash_file import (
|
||||
create_delete_dropbox_file_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_dropbox_file_tool",
|
||||
"create_delete_dropbox_file_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,275 @@
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.dropbox.client import DropboxClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DOCX_MIME = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
|
||||
_FILE_TYPE_LABELS = {
|
||||
"paper": "Dropbox Paper (.paper)",
|
||||
"docx": "Word Document (.docx)",
|
||||
}
|
||||
|
||||
_SUPPORTED_TYPES = [
|
||||
{"value": "paper", "label": "Dropbox Paper (.paper)"},
|
||||
{"value": "docx", "label": "Word Document (.docx)"},
|
||||
]
|
||||
|
||||
|
||||
def _ensure_extension(name: str, file_type: str) -> str:
|
||||
"""Strip any existing extension and append the correct one."""
|
||||
stem = Path(name).stem
|
||||
ext = ".paper" if file_type == "paper" else ".docx"
|
||||
return f"{stem}{ext}"
|
||||
|
||||
|
||||
def _markdown_to_docx(markdown_text: str) -> bytes:
|
||||
"""Convert a markdown string to DOCX bytes using pypandoc."""
|
||||
import pypandoc
|
||||
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=".docx")
|
||||
os.close(fd)
|
||||
try:
|
||||
pypandoc.convert_text(
|
||||
markdown_text,
|
||||
"docx",
|
||||
format="gfm",
|
||||
extra_args=["--standalone"],
|
||||
outputfile=tmp_path,
|
||||
)
|
||||
with open(tmp_path, "rb") as f:
|
||||
return f.read()
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
|
||||
def create_create_dropbox_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_dropbox_file(
|
||||
name: str,
|
||||
file_type: Literal["paper", "docx"] = "paper",
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new document in Dropbox.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new document
|
||||
in Dropbox. The user MUST specify a topic before you call this tool.
|
||||
|
||||
Args:
|
||||
name: The document title (without extension).
|
||||
file_type: Either "paper" (Dropbox Paper, default) or "docx" (Word document).
|
||||
content: Optional initial content as markdown.
|
||||
|
||||
Returns:
|
||||
Dictionary with status, file_id, name, web_url, and message.
|
||||
"""
|
||||
logger.info(
|
||||
f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Dropbox tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
|
||||
if not connectors:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
|
||||
}
|
||||
|
||||
accounts = []
|
||||
for c in connectors:
|
||||
cfg = c.config or {}
|
||||
accounts.append(
|
||||
{
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
"auth_expired": cfg.get("auth_expired", False),
|
||||
}
|
||||
)
|
||||
|
||||
if all(a.get("auth_expired") for a in accounts):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Dropbox accounts need re-authentication.",
|
||||
"connector_type": "dropbox",
|
||||
}
|
||||
|
||||
parent_folders: dict[int, list[dict[str, str]]] = {}
|
||||
for acc in accounts:
|
||||
cid = acc["id"]
|
||||
if acc.get("auth_expired"):
|
||||
parent_folders[cid] = []
|
||||
continue
|
||||
try:
|
||||
client = DropboxClient(session=db_session, connector_id=cid)
|
||||
items, err = await client.list_folder("")
|
||||
if err:
|
||||
logger.warning(
|
||||
"Failed to list folders for connector %s: %s", cid, err
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
else:
|
||||
parent_folders[cid] = [
|
||||
{
|
||||
"folder_path": item.get("path_lower", ""),
|
||||
"name": item["name"],
|
||||
}
|
||||
for item in items
|
||||
if item.get(".tag") == "folder" and item.get("name")
|
||||
]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error fetching folders for connector %s", cid, exc_info=True
|
||||
)
|
||||
parent_folders[cid] = []
|
||||
|
||||
context: dict[str, Any] = {
|
||||
"accounts": accounts,
|
||||
"parent_folders": parent_folders,
|
||||
"supported_types": _SUPPORTED_TYPES,
|
||||
}
|
||||
|
||||
result = request_approval(
|
||||
action_type="dropbox_file_creation",
|
||||
tool_name="create_dropbox_file",
|
||||
params={
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_path": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_name = result.params.get("name", name)
|
||||
final_file_type = result.params.get("file_type", file_type)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
final_parent_folder_path = result.params.get("parent_folder_path")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
final_name = _ensure_extension(final_name, final_file_type)
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
else:
|
||||
connector = connectors[0]
|
||||
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Dropbox connector is invalid.",
|
||||
}
|
||||
|
||||
client = DropboxClient(session=db_session, connector_id=connector.id)
|
||||
|
||||
parent_path = final_parent_folder_path or ""
|
||||
file_path = (
|
||||
f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
|
||||
)
|
||||
|
||||
if final_file_type == "paper":
|
||||
created = await client.create_paper_doc(file_path, final_content or "")
|
||||
file_id = created.get("file_id", "")
|
||||
web_url = created.get("url", "")
|
||||
else:
|
||||
docx_bytes = _markdown_to_docx(final_content or "")
|
||||
created = await client.upload_file(
|
||||
file_path, docx_bytes, mode="add", autorename=True
|
||||
)
|
||||
file_id = created.get("id", "")
|
||||
web_url = ""
|
||||
|
||||
logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.dropbox import DropboxKBSyncService
|
||||
|
||||
kb_service = DropboxKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=file_id,
|
||||
file_name=final_name,
|
||||
file_path=file_path,
|
||||
web_url=web_url,
|
||||
content=final_content,
|
||||
connector_id=connector.id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": file_id,
|
||||
"name": final_name,
|
||||
"web_url": web_url,
|
||||
"message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error creating Dropbox file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the file. Please try again.",
|
||||
}
|
||||
|
||||
return create_dropbox_file
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .create_file import create_create_dropbox_file_tool
|
||||
from .trash_file import create_delete_dropbox_file_tool
|
||||
|
||||
|
||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
||||
d = {**(dependencies or {}), **kwargs}
|
||||
common = {
|
||||
"db_session": d["db_session"],
|
||||
"search_space_id": d["search_space_id"],
|
||||
"user_id": d["user_id"],
|
||||
}
|
||||
create = create_create_dropbox_file_tool(**common)
|
||||
delete = create_delete_dropbox_file_tool(**common)
|
||||
return {
|
||||
"allow": [],
|
||||
"ask": [
|
||||
{"name": getattr(create, "name", "") or "", "tool": create},
|
||||
{"name": getattr(delete, "name", "") or "", "tool": delete},
|
||||
],
|
||||
}
|
||||
|
|
@ -0,0 +1,277 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy import String, and_, cast, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.dropbox.client import DropboxClient
|
||||
from app.db import (
|
||||
Document,
|
||||
DocumentType,
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_dropbox_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_dropbox_file(
|
||||
file_name: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a file from Dropbox.
|
||||
|
||||
Use this tool when the user explicitly asks to delete, remove, or trash
|
||||
a file in Dropbox.
|
||||
|
||||
Args:
|
||||
file_name: The exact name of the file to delete.
|
||||
delete_from_kb: Whether to also remove the file from the knowledge base.
|
||||
Default is False.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- file_id: Dropbox file ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the file name or check if it has been indexed.
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Dropbox tool not properly configured.",
|
||||
}
|
||||
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||
func.lower(Document.title) == func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if not document:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document)
|
||||
.join(
|
||||
SearchSourceConnector,
|
||||
Document.connector_id == SearchSourceConnector.id,
|
||||
)
|
||||
.filter(
|
||||
and_(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.DROPBOX_FILE,
|
||||
func.lower(
|
||||
cast(
|
||||
Document.document_metadata["dropbox_file_name"],
|
||||
String,
|
||||
)
|
||||
)
|
||||
== func.lower(file_name),
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.order_by(Document.updated_at.desc().nullslast())
|
||||
.limit(1)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
|
||||
if not document:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": (
|
||||
f"File '{file_name}' not found in your indexed Dropbox files. "
|
||||
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
|
||||
"or (3) the file name is different."
|
||||
),
|
||||
}
|
||||
|
||||
if not document.connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Document has no associated connector.",
|
||||
}
|
||||
|
||||
meta = document.document_metadata or {}
|
||||
file_path = meta.get("dropbox_path")
|
||||
file_id = meta.get("dropbox_file_id")
|
||||
document_id = document.id
|
||||
|
||||
if not file_path:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File path is missing. Please re-index the file.",
|
||||
}
|
||||
|
||||
conn_result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == document.connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
connector = conn_result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Dropbox connector not found or access denied.",
|
||||
}
|
||||
|
||||
cfg = connector.config or {}
|
||||
if cfg.get("auth_expired"):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "dropbox",
|
||||
}
|
||||
|
||||
context = {
|
||||
"file": {
|
||||
"file_id": file_id,
|
||||
"file_path": file_path,
|
||||
"name": file_name,
|
||||
"document_id": document_id,
|
||||
},
|
||||
"account": {
|
||||
"id": connector.id,
|
||||
"name": connector.name,
|
||||
"user_email": cfg.get("user_email"),
|
||||
},
|
||||
}
|
||||
|
||||
result = request_approval(
|
||||
action_type="dropbox_file_trash",
|
||||
tool_name="delete_dropbox_file",
|
||||
params={
|
||||
"file_path": file_path,
|
||||
"connector_id": connector.id,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. Do not retry or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_file_path = result.params.get("file_path", file_path)
|
||||
final_connector_id = result.params.get("connector_id", connector.id)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if final_connector_id != connector.id:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
and_(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
|
||||
)
|
||||
)
|
||||
)
|
||||
validated_connector = result.scalars().first()
|
||||
if not validated_connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Dropbox connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = validated_connector.id
|
||||
else:
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
client = DropboxClient(session=db_session, connector_id=actual_connector_id)
|
||||
await client.delete_file(final_file_path)
|
||||
|
||||
logger.info(f"Dropbox file deleted: path={final_file_path}")
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": file_id,
|
||||
"message": f"Successfully deleted '{file_name}' from Dropbox.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
doc = doc_result.scalars().first()
|
||||
if doc:
|
||||
await db_session.delete(doc)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File deleted, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error(f"Error deleting Dropbox file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while deleting the file. Please try again.",
|
||||
}
|
||||
|
||||
return delete_dropbox_file
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""`gmail` route: ``SubAgent`` spec for deepagents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
merge_tools_permissions,
|
||||
middleware_gated_interrupt_on,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
from .tools.index import load_tools
|
||||
|
||||
NAME = "gmail"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent:
|
||||
buckets = load_tools(dependencies=dependencies)
|
||||
merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket)
|
||||
tools = [
|
||||
row["tool"]
|
||||
for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"])
|
||||
if row.get("tool") is not None
|
||||
]
|
||||
interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket)
|
||||
description = read_md_file(__package__, "description").strip()
|
||||
if not description:
|
||||
description = "Handles gmail tasks for this workspace."
|
||||
system_prompt = read_md_file(__package__, "system_prompt").strip()
|
||||
return pack_subagent(
|
||||
name=NAME,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
interrupt_on=interrupt_on,
|
||||
model=model,
|
||||
extra_middleware=extra_middleware,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
Use for Gmail inbox actions: search/read emails, draft or update replies, send messages, and trash emails.
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
You are the Gmail operations sub-agent.
|
||||
You receive delegated instructions from a supervisor agent and return structured results for supervisor synthesis.
|
||||
|
||||
<goal>
|
||||
Execute Gmail operations accurately: search/read emails, prepare drafts, send, and trash.
|
||||
</goal>
|
||||
|
||||
<available_tools>
|
||||
- `search_gmail`: find candidate emails with query constraints.
|
||||
- `read_gmail_email`: read one message in full detail.
|
||||
- `create_gmail_draft`: create a new draft.
|
||||
- `update_gmail_draft`: modify an existing draft.
|
||||
- `send_gmail_email`: send an email.
|
||||
- `trash_gmail_email`: move an email to trash.
|
||||
</available_tools>
|
||||
|
||||
<tool_policy>
|
||||
- Use only tools in `<available_tools>`.
|
||||
- Build precise search queries using Gmail operators when possible (`from:`, `to:`, `subject:`, `after:`, `before:`, `has:attachment`, `is:unread`, `label:`).
|
||||
- Resolve relative dates against runtime timestamp; prefer narrower interpretation.
|
||||
- For reply requests, identify the target thread/email via search + read before drafting.
|
||||
- If required fields are missing or target selection is ambiguous, return `status=blocked` with `missing_fields` and disambiguation candidates.
|
||||
- Never invent IDs, recipients, timestamps, quoted text, or tool outcomes.
|
||||
</tool_policy>
|
||||
|
||||
<out_of_scope>
|
||||
- Do not perform non-Gmail work.
|
||||
- Filing operations not represented in `<available_tools>` (archive/label/mark-read/move-folder) are unsupported here.
|
||||
</out_of_scope>
|
||||
|
||||
<safety>
|
||||
- For send: verify draft `to`, `subject`, and `body` match delegated instructions.
|
||||
- If any send-critical field was inferred, do not send; return `status=blocked` with inferred values in `assumptions`.
|
||||
- For trash: ensure explicit target match before deletion.
|
||||
- If a destructive action appears already completed this session, do not repeat; return prior evidence.
|
||||
</safety>
|
||||
|
||||
<failure_policy>
|
||||
- On tool failure, return `status=error` with concise recovery `next_step`.
|
||||
- If search has no strong match, return `status=blocked` with suggested tighter filters.
|
||||
- If multiple strong candidates remain for risky actions, return `status=blocked` with top options.
|
||||
</failure_policy>
|
||||
|
||||
<output_contract>
|
||||
Return **only** one JSON object (no markdown/prose):
|
||||
{
|
||||
"status": "success" | "partial" | "blocked" | "error",
|
||||
"action_summary": string,
|
||||
"evidence": {
|
||||
"email_id": string | null,
|
||||
"thread_id": string | null,
|
||||
"subject": string | null,
|
||||
"sender": string | null,
|
||||
"recipients": string[] | null,
|
||||
"received_at": string (ISO 8601 with timezone) | null,
|
||||
"sent_message": {
|
||||
"id": string,
|
||||
"to": string[],
|
||||
"subject": string | null,
|
||||
"sent_at": string (ISO 8601 with timezone) | null
|
||||
} | null,
|
||||
"matched_candidates": [
|
||||
{
|
||||
"email_id": string,
|
||||
"subject": string | null,
|
||||
"sender": string | null,
|
||||
"received_at": string (ISO 8601 with timezone) | null
|
||||
}
|
||||
] | null
|
||||
},
|
||||
"next_step": string | null,
|
||||
"missing_fields": string[] | null,
|
||||
"assumptions": string[] | null
|
||||
}
|
||||
|
||||
Rules:
|
||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||
- For blocked ambiguity, include options in `evidence.matched_candidates`.
|
||||
- For trash actions, `evidence.email_id` is the trashed message.
|
||||
</output_contract>
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
from app.agents.new_chat.tools.gmail.create_draft import (
|
||||
create_create_gmail_draft_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.read_email import (
|
||||
create_read_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
create_search_gmail_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.send_email import (
|
||||
create_send_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.trash_email import (
|
||||
create_trash_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.update_draft import (
|
||||
create_update_gmail_draft_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_gmail_draft_tool",
|
||||
"create_read_gmail_email_tool",
|
||||
"create_search_gmail_tool",
|
||||
"create_send_gmail_email_tool",
|
||||
"create_trash_gmail_email_tool",
|
||||
"create_update_gmail_draft_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,313 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_gmail_draft_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_gmail_draft(
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a draft email in Gmail.
|
||||
|
||||
Use when the user asks to draft, compose, or prepare an email without
|
||||
sending it.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
subject: Email subject line.
|
||||
body: Email body content.
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- draft_id: Gmail draft ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Draft an email to alice@example.com about the meeting"
|
||||
- "Compose a reply to Bob about the project update"
|
||||
"""
|
||||
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_draft_creation",
|
||||
tool_name="create_gmail_draft",
|
||||
params={
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_to = result.params.get("to", to)
|
||||
final_subject = result.params.get("subject", subject)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
created = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.create(userId="me", body={"message": {"raw": raw}})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft created: id={created.get('id')}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
draft_message = created.get("message", {})
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=draft_message.get("id", ""),
|
||||
thread_id=draft_message.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
draft_id=created.get("id"),
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": created.get("id"),
|
||||
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating Gmail draft: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the draft. Please try again.",
|
||||
}
|
||||
|
||||
return create_gmail_draft
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .create_draft import create_create_gmail_draft_tool
|
||||
from .read_email import create_read_gmail_email_tool
|
||||
from .search_emails import create_search_gmail_tool
|
||||
from .send_email import create_send_gmail_email_tool
|
||||
from .trash_email import create_trash_gmail_email_tool
|
||||
from .update_draft import create_update_gmail_draft_tool
|
||||
|
||||
|
||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
||||
d = {**(dependencies or {}), **kwargs}
|
||||
common = {
|
||||
"db_session": d["db_session"],
|
||||
"search_space_id": d["search_space_id"],
|
||||
"user_id": d["user_id"],
|
||||
}
|
||||
search = create_search_gmail_tool(**common)
|
||||
read = create_read_gmail_email_tool(**common)
|
||||
draft = create_create_gmail_draft_tool(**common)
|
||||
send = create_send_gmail_email_tool(**common)
|
||||
trash = create_trash_gmail_email_tool(**common)
|
||||
updraft = create_update_gmail_draft_tool(**common)
|
||||
return {
|
||||
"allow": [
|
||||
{"name": getattr(search, "name", "") or "", "tool": search},
|
||||
{"name": getattr(read, "name", "") or "", "tool": read},
|
||||
],
|
||||
"ask": [
|
||||
{"name": getattr(draft, "name", "") or "", "tool": draft},
|
||||
{"name": getattr(send, "name", "") or "", "tool": send},
|
||||
{"name": getattr(trash, "name", "") or "", "tool": trash},
|
||||
{"name": getattr(updraft, "name", "") or "", "tool": updraft},
|
||||
],
|
||||
}
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GMAIL_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
|
||||
def create_read_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_gmail_email(message_id: str) -> dict[str, Any]:
|
||||
"""Read the full content of a specific Gmail email by its message ID.
|
||||
|
||||
Use after search_gmail to get the complete body of an email.
|
||||
|
||||
Args:
|
||||
message_id: The Gmail message ID (from search_gmail results).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and the full email content formatted as markdown.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
detail, error = await gmail.get_message_details(message_id)
|
||||
if error:
|
||||
if (
|
||||
"re-authenticate" in error.lower()
|
||||
or "authentication failed" in error.lower()
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error,
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not detail:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"message": f"Email with ID '{message_id}' not found.",
|
||||
}
|
||||
|
||||
content = gmail.format_message_to_markdown(detail)
|
||||
|
||||
return {"status": "success", "message_id": message_id, "content": content}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Gmail email: %s", e, exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Failed to read email. Please try again.",
|
||||
}
|
||||
|
||||
return read_gmail_email
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GMAIL_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
_token_encryption_cache: object | None = None
|
||||
|
||||
|
||||
def _get_token_encryption():
|
||||
global _token_encryption_cache
|
||||
if _token_encryption_cache is None:
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise RuntimeError("SECRET_KEY not configured for token decryption.")
|
||||
_token_encryption_cache = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption_cache
|
||||
|
||||
|
||||
def _build_credentials(connector: SearchSourceConnector):
|
||||
"""Build Google OAuth Credentials from a connector's stored config.
|
||||
|
||||
Handles both native OAuth connectors (with encrypted tokens) and
|
||||
Composio-backed connectors. Shared by Gmail and Calendar tools.
|
||||
"""
|
||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected account ID not found.")
|
||||
return build_composio_credentials(cca_id)
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
cfg = dict(connector.config)
|
||||
if cfg.get("_token_encrypted"):
|
||||
enc = _get_token_encryption()
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if cfg.get(key):
|
||||
cfg[key] = enc.decrypt_token(cfg[key])
|
||||
|
||||
exp = (cfg.get("expiry") or "").replace("Z", "")
|
||||
return Credentials(
|
||||
token=cfg.get("token"),
|
||||
refresh_token=cfg.get("refresh_token"),
|
||||
token_uri=cfg.get("token_uri"),
|
||||
client_id=cfg.get("client_id"),
|
||||
client_secret=cfg.get("client_secret"),
|
||||
scopes=cfg.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
|
||||
def create_search_gmail_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def search_gmail(
|
||||
query: str,
|
||||
max_results: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Search emails in the user's Gmail inbox using Gmail search syntax.
|
||||
|
||||
Args:
|
||||
query: Gmail search query, same syntax as the Gmail search bar.
|
||||
Examples: "from:alice@example.com", "subject:meeting",
|
||||
"is:unread", "after:2024/01/01 before:2024/02/01",
|
||||
"has:attachment", "in:sent".
|
||||
max_results: Number of emails to return (default 10, max 20).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of email summaries including
|
||||
message_id, subject, from, date, snippet.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 20)
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
messages_list, error = await gmail.get_messages_list(
|
||||
max_results=max_results, query=query
|
||||
)
|
||||
if error:
|
||||
if (
|
||||
"re-authenticate" in error.lower()
|
||||
or "authentication failed" in error.lower()
|
||||
):
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": error,
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not messages_list:
|
||||
return {
|
||||
"status": "success",
|
||||
"emails": [],
|
||||
"total": 0,
|
||||
"message": "No emails found.",
|
||||
}
|
||||
|
||||
emails = []
|
||||
for msg in messages_list:
|
||||
detail, err = await gmail.get_message_details(msg["id"])
|
||||
if err:
|
||||
continue
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in detail.get("payload", {}).get("headers", [])
|
||||
}
|
||||
emails.append(
|
||||
{
|
||||
"message_id": detail.get("id"),
|
||||
"thread_id": detail.get("threadId"),
|
||||
"subject": headers.get("subject", "No Subject"),
|
||||
"from": headers.get("from", "Unknown"),
|
||||
"to": headers.get("to", ""),
|
||||
"date": headers.get("date", ""),
|
||||
"snippet": detail.get("snippet", ""),
|
||||
"labels": detail.get("labelIds", []),
|
||||
}
|
||||
)
|
||||
|
||||
return {"status": "success", "emails": emails, "total": len(emails)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error searching Gmail: %s", e, exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Failed to search Gmail. Please try again.",
|
||||
}
|
||||
|
||||
return search_gmail
|
||||
|
|
@ -0,0 +1,315 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_gmail_email(
|
||||
to: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send an email via Gmail.
|
||||
|
||||
Use when the user explicitly asks to send an email. This sends the
|
||||
email immediately - it cannot be unsent.
|
||||
|
||||
Args:
|
||||
to: Recipient email address.
|
||||
subject: Email subject line.
|
||||
body: Email body content.
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- message_id: Gmail message ID (if success)
|
||||
- thread_id: Gmail thread ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Send an email to alice@example.com about the meeting"
|
||||
- "Email Bob the project update"
|
||||
"""
|
||||
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Gmail accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_email_send",
|
||||
tool_name="send_gmail_email",
|
||||
params={
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_to = result.params.get("to", to)
|
||||
final_subject = result.params.get("subject", subject)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
message = MIMEText(final_body)
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
sent = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.send(userId="me", body={"raw": raw})
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.gmail import GmailKBSyncService
|
||||
|
||||
kb_service = GmailKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
message_id=sent.get("id", ""),
|
||||
thread_id=sent.get("threadId", ""),
|
||||
subject=final_subject,
|
||||
sender="me",
|
||||
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
body_text=final_body,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after send failed: {kb_err}")
|
||||
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": sent.get("id"),
|
||||
"thread_id": sent.get("threadId"),
|
||||
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error sending Gmail email: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while sending the email. Please try again.",
|
||||
}
|
||||
|
||||
return send_gmail_email
|
||||
|
|
@ -0,0 +1,309 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_trash_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def trash_gmail_email(
|
||||
email_subject_or_id: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Move an email or draft to trash in Gmail.
|
||||
|
||||
Use when the user asks to delete, remove, or trash an email or draft.
|
||||
|
||||
Args:
|
||||
email_subject_or_id: The exact subject line or message ID of the
|
||||
email to trash (as it appears in the inbox).
|
||||
delete_from_kb: Whether to also remove the email from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both Gmail and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- message_id: Gmail message ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the email subject or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry this tool.
|
||||
Examples:
|
||||
- "Delete the email about 'Meeting Cancelled'"
|
||||
- "Trash the email from Bob about the project"
|
||||
"""
|
||||
logger.info(
|
||||
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_trash_context(
|
||||
search_space_id, user_id, email_subject_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Email not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not message_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_email_trash",
|
||||
tool_name="trash_gmail_email",
|
||||
params={
|
||||
"message_id": message_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_message_id = result.params.get("message_id", message_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this email.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
try:
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.messages()
|
||||
.trash(userId="me", id=final_message_id)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail email trashed: message_id={final_message_id}")
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"message_id": final_message_id,
|
||||
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"Email trashed, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error trashing Gmail email: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while trashing the email. Please try again.",
|
||||
}
|
||||
|
||||
return trash_gmail_email
|
||||
|
|
@ -0,0 +1,410 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.services.gmail import GmailToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_update_gmail_draft_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def update_gmail_draft(
|
||||
draft_subject_or_id: str,
|
||||
body: str,
|
||||
to: str | None = None,
|
||||
subject: str | None = None,
|
||||
cc: str | None = None,
|
||||
bcc: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing Gmail draft.
|
||||
|
||||
Use when the user asks to modify, edit, or add content to an existing
|
||||
email draft. This replaces the draft content with the new version.
|
||||
The user will be able to review and edit the content before it is applied.
|
||||
|
||||
If the user simply wants to "edit" a draft without specifying exact changes,
|
||||
generate the body yourself using your best understanding of the conversation
|
||||
context. The user will review and can freely edit the content in the approval
|
||||
card before confirming.
|
||||
|
||||
IMPORTANT: This tool is ONLY for modifying Gmail draft content, NOT for
|
||||
deleting/trashing drafts (use trash_gmail_email instead), Notion pages,
|
||||
calendar events, or any other content type.
|
||||
|
||||
Args:
|
||||
draft_subject_or_id: The exact subject line of the draft to update
|
||||
(as it appears in Gmail drafts).
|
||||
body: The full updated body content for the draft. Generate this
|
||||
yourself based on the user's request and conversation context.
|
||||
to: Optional new recipient email address (keeps original if omitted).
|
||||
subject: Optional new subject line (keeps original if omitted).
|
||||
cc: Optional CC recipient(s), comma-separated.
|
||||
bcc: Optional BCC recipient(s), comma-separated.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- draft_id: Gmail draft ID (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the draft subject or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Update the Kurseong Plan draft with the new itinerary details"
|
||||
- "Edit my draft about the project proposal and change the recipient"
|
||||
- "Let me edit the meeting notes draft" (call with current body content so user can edit in the approval card)
|
||||
"""
|
||||
logger.info(
|
||||
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Gmail tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GmailToolMetadataService(db_session)
|
||||
context = await metadata_service.get_update_context(
|
||||
search_space_id, user_id, draft_subject_or_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"Draft not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch update context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Gmail account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "gmail",
|
||||
}
|
||||
|
||||
email = context["email"]
|
||||
message_id = email["message_id"]
|
||||
document_id = email.get("document_id")
|
||||
connector_id_from_context = account["id"]
|
||||
draft_id_from_context = context.get("draft_id")
|
||||
|
||||
original_subject = email.get("subject", draft_subject_or_id)
|
||||
final_subject_default = subject if subject else original_subject
|
||||
final_to_default = to if to else ""
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for updating Gmail draft: '{original_subject}' "
|
||||
f"(message_id={message_id}, draft_id={draft_id_from_context})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="gmail_draft_update",
|
||||
tool_name="update_gmail_draft",
|
||||
params={
|
||||
"message_id": message_id,
|
||||
"draft_id": draft_id_from_context,
|
||||
"to": final_to_default,
|
||||
"subject": final_subject_default,
|
||||
"body": body,
|
||||
"cc": cc,
|
||||
"bcc": bcc,
|
||||
"connector_id": connector_id_from_context,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_to = result.params.get("to", final_to_default)
|
||||
final_subject = result.params.get("subject", final_subject_default)
|
||||
final_body = result.params.get("body", body)
|
||||
final_cc = result.params.get("cc", cc)
|
||||
final_bcc = result.params.get("bcc", bcc)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_draft_id = result.params.get("draft_id", draft_id_from_context)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this draft.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_gmail_types = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_gmail_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Gmail connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
creds = build_composio_credentials(cca_id)
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Composio connected account ID not found for this Gmail connector.",
|
||||
}
|
||||
else:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
config_data = dict(connector.config)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
|
||||
creds = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
gmail_service = build("gmail", "v1", credentials=creds)
|
||||
|
||||
# Resolve draft_id if not already available
|
||||
if not final_draft_id:
|
||||
logger.info(
|
||||
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
|
||||
)
|
||||
final_draft_id = await _find_draft_id_by_message(
|
||||
gmail_service, message_id
|
||||
)
|
||||
|
||||
if not final_draft_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
"Could not find this draft in Gmail. "
|
||||
"It may have already been sent or deleted."
|
||||
),
|
||||
}
|
||||
|
||||
message = MIMEText(final_body)
|
||||
if final_to:
|
||||
message["to"] = final_to
|
||||
message["subject"] = final_subject
|
||||
if final_cc:
|
||||
message["cc"] = final_cc
|
||||
if final_bcc:
|
||||
message["bcc"] = final_bcc
|
||||
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
|
||||
|
||||
try:
|
||||
updated = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: (
|
||||
gmail_service.users()
|
||||
.drafts()
|
||||
.update(
|
||||
userId="me",
|
||||
id=final_draft_id,
|
||||
body={"message": {"raw": raw}},
|
||||
)
|
||||
.execute()
|
||||
),
|
||||
)
|
||||
except Exception as api_err:
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {api_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(f"Gmail draft updated: id={updated.get('id')}")
|
||||
|
||||
kb_message_suffix = ""
|
||||
if document_id:
|
||||
try:
|
||||
from sqlalchemy.future import select as sa_select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
sa_select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
document.source_markdown = final_body
|
||||
document.title = final_subject
|
||||
meta = dict(document.document_metadata or {})
|
||||
meta["subject"] = final_subject
|
||||
meta["draft_id"] = updated.get("id", final_draft_id)
|
||||
updated_msg = updated.get("message", {})
|
||||
if updated_msg.get("id"):
|
||||
meta["message_id"] = updated_msg["id"]
|
||||
document.document_metadata = meta
|
||||
flag_modified(document, "document_metadata")
|
||||
await db_session.commit()
|
||||
kb_message_suffix = (
|
||||
" Your knowledge base has also been updated."
|
||||
)
|
||||
logger.info(
|
||||
f"KB document {document_id} updated for draft {final_draft_id}"
|
||||
)
|
||||
else:
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB update after draft edit failed: {kb_err}")
|
||||
await db_session.rollback()
|
||||
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"draft_id": updated.get("id"),
|
||||
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error updating Gmail draft: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while updating the draft. Please try again.",
|
||||
}
|
||||
|
||||
return update_gmail_draft
|
||||
|
||||
|
||||
async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str | None:
|
||||
"""Look up a draft's ID by its message ID via the Gmail API."""
|
||||
try:
|
||||
page_token = None
|
||||
while True:
|
||||
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
|
||||
if page_token:
|
||||
kwargs["pageToken"] = page_token
|
||||
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda kwargs=kwargs: (
|
||||
gmail_service.users().drafts().list(**kwargs).execute()
|
||||
),
|
||||
)
|
||||
|
||||
for draft in response.get("drafts", []):
|
||||
if draft.get("message", {}).get("id") == message_id:
|
||||
return draft["id"]
|
||||
|
||||
page_token = response.get("nextPageToken")
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to look up draft by message_id: {e}")
|
||||
return None
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
"""`google_drive` route: ``SubAgent`` spec for deepagents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from deepagents import SubAgent
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.md_file_reader import (
|
||||
read_md_file,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
merge_tools_permissions,
|
||||
middleware_gated_interrupt_on,
|
||||
)
|
||||
from app.agents.multi_agent_chat.subagents.shared.subagent_builder import (
|
||||
pack_subagent,
|
||||
)
|
||||
|
||||
from .tools.index import load_tools
|
||||
|
||||
NAME = "google_drive"
|
||||
|
||||
|
||||
def build_subagent(
|
||||
*,
|
||||
dependencies: dict[str, Any],
|
||||
model: BaseChatModel | None = None,
|
||||
extra_middleware: Sequence[Any] | None = None,
|
||||
extra_tools_bucket: ToolsPermissions | None = None,
|
||||
) -> SubAgent:
|
||||
buckets = load_tools(dependencies=dependencies)
|
||||
merged_tools_bucket = merge_tools_permissions(buckets, extra_tools_bucket)
|
||||
tools = [
|
||||
row["tool"]
|
||||
for row in (*merged_tools_bucket["allow"], *merged_tools_bucket["ask"])
|
||||
if row.get("tool") is not None
|
||||
]
|
||||
interrupt_on = middleware_gated_interrupt_on(merged_tools_bucket)
|
||||
description = read_md_file(__package__, "description").strip()
|
||||
if not description:
|
||||
description = "Handles google drive tasks for this workspace."
|
||||
system_prompt = read_md_file(__package__, "system_prompt").strip()
|
||||
return pack_subagent(
|
||||
name=NAME,
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
interrupt_on=interrupt_on,
|
||||
model=model,
|
||||
extra_middleware=extra_middleware,
|
||||
)
|
||||
|
|
@ -0,0 +1 @@
|
|||
Use for Google Drive document/file tasks: locate files, inspect content, and manage Drive files or folders.
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
You are the Google Drive operations sub-agent.
|
||||
You receive delegated instructions from a supervisor agent and return structured results for supervisor synthesis.
|
||||
|
||||
<goal>
|
||||
Execute Google Drive file operations accurately in the connected account.
|
||||
</goal>
|
||||
|
||||
<available_tools>
|
||||
- `create_google_drive_file`
|
||||
- `delete_google_drive_file`
|
||||
</available_tools>
|
||||
|
||||
<tool_policy>
|
||||
- Use only tools in `<available_tools>`.
|
||||
- Ensure target file identity/path is explicit before mutate actions.
|
||||
- If target is ambiguous, return `status=blocked` with candidate files.
|
||||
- Never invent file IDs/names or mutation outcomes.
|
||||
</tool_policy>
|
||||
|
||||
<out_of_scope>
|
||||
- Do not perform non-Google-Drive tasks.
|
||||
</out_of_scope>
|
||||
|
||||
<safety>
|
||||
- Never claim file mutation success without tool confirmation.
|
||||
</safety>
|
||||
|
||||
<failure_policy>
|
||||
- On tool failure, return `status=error` with concise recovery `next_step`.
|
||||
- On target ambiguity, return `status=blocked` with candidate files.
|
||||
</failure_policy>
|
||||
|
||||
<output_contract>
|
||||
Return **only** one JSON object (no markdown/prose):
|
||||
{
|
||||
"status": "success" | "partial" | "blocked" | "error",
|
||||
"action_summary": string,
|
||||
"evidence": {
|
||||
"file_id": string | null,
|
||||
"file_name": string | null,
|
||||
"operation": "create" | "delete" | null,
|
||||
"matched_candidates": [
|
||||
{ "file_id": string, "file_name": string | null }
|
||||
] | null
|
||||
},
|
||||
"next_step": string | null,
|
||||
"missing_fields": string[] | null,
|
||||
"assumptions": string[] | null
|
||||
}
|
||||
Rules:
|
||||
- `status=success` -> `next_step=null`, `missing_fields=null`.
|
||||
- `status=partial|blocked|error` -> `next_step` must be non-null.
|
||||
- `status=blocked` due to missing required inputs -> `missing_fields` must be non-null.
|
||||
</output_contract>
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
from app.agents.new_chat.tools.google_drive.create_file import (
|
||||
create_create_google_drive_file_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_drive.trash_file import (
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_google_drive_file_tool",
|
||||
"create_delete_google_drive_file_tool",
|
||||
]
|
||||
|
|
@ -0,0 +1,283 @@
|
|||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from googleapiclient.errors import HttpError
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
|
||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MIME_MAP: dict[str, str] = {
|
||||
"google_doc": GOOGLE_DOC,
|
||||
"google_sheet": GOOGLE_SHEET,
|
||||
}
|
||||
|
||||
|
||||
def create_create_google_drive_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_google_drive_file(
|
||||
name: str,
|
||||
file_type: Literal["google_doc", "google_sheet"],
|
||||
content: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new Google Doc or Google Sheet in Google Drive.
|
||||
|
||||
Use this tool when the user explicitly asks to create a new document
|
||||
or spreadsheet in Google Drive. The user MUST specify a topic before
|
||||
you call this tool. If the request does not contain a topic (e.g.
|
||||
"create a drive doc" or "make a Google Sheet"), ask what the file
|
||||
should be about. Never call this tool without a clear topic from the user.
|
||||
|
||||
Args:
|
||||
name: The file name (without extension).
|
||||
file_type: Either "google_doc" or "google_sheet".
|
||||
content: Optional initial content. Generate from the user's topic.
|
||||
For google_doc, provide markdown text. For google_sheet, provide CSV-formatted text.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", or "error"
|
||||
- file_id: Google Drive file ID (if success)
|
||||
- name: File name (if success)
|
||||
- web_view_link: URL to open the file (if success)
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined the action.
|
||||
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry the action.
|
||||
|
||||
Examples:
|
||||
- "Create a Google Doc with today's meeting notes"
|
||||
- "Create a spreadsheet for the 2026 budget"
|
||||
"""
|
||||
logger.info(
|
||||
f"create_google_drive_file called: name='{name}', type='{file_type}'"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
if file_type not in _MIME_MAP:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unsupported file type '{file_type}'. Use 'google_doc' or 'google_sheet'.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||
context = await metadata_service.get_creation_context(
|
||||
search_space_id, user_id
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
logger.error(f"Failed to fetch creation context: {context['error']}")
|
||||
return {"status": "error", "message": context["error"]}
|
||||
|
||||
accounts = context.get("accounts", [])
|
||||
if accounts and all(a.get("auth_expired") for a in accounts):
|
||||
logger.warning("All Google Drive accounts have expired authentication")
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_drive_file_creation",
|
||||
tool_name="create_google_drive_file",
|
||||
params={
|
||||
"name": name,
|
||||
"file_type": file_type,
|
||||
"content": content,
|
||||
"connector_id": None,
|
||||
"parent_folder_id": None,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_name = result.params.get("name", name)
|
||||
final_file_type = result.params.get("file_type", file_type)
|
||||
final_content = result.params.get("content", content)
|
||||
final_connector_id = result.params.get("connector_id")
|
||||
final_parent_folder_id = result.params.get("parent_folder_id")
|
||||
|
||||
if not final_name or not final_name.strip():
|
||||
return {"status": "error", "message": "File name cannot be empty."}
|
||||
|
||||
mime_type = _MIME_MAP.get(final_file_type)
|
||||
if not mime_type:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Unsupported file type '{final_file_type}'.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
if final_connector_id is not None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
else:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
|
||||
}
|
||||
actual_connector_id = connector.id
|
||||
|
||||
logger.info(
|
||||
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
|
||||
)
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=actual_connector_id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
created = await client.create_file(
|
||||
name=final_name,
|
||||
mime_type=mime_type,
|
||||
parent_folder_id=final_parent_folder_id,
|
||||
content=final_content,
|
||||
)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
_res = await db_session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == actual_connector_id
|
||||
)
|
||||
)
|
||||
_conn = _res.scalar_one_or_none()
|
||||
if _conn and not _conn.config.get("auth_expired"):
|
||||
_conn.config = {**_conn.config, "auth_expired": True}
|
||||
flag_modified(_conn, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
actual_connector_id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": actual_connector_id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
|
||||
)
|
||||
|
||||
kb_message_suffix = ""
|
||||
try:
|
||||
from app.services.google_drive import GoogleDriveKBSyncService
|
||||
|
||||
kb_service = GoogleDriveKBSyncService(db_session)
|
||||
kb_result = await kb_service.sync_after_create(
|
||||
file_id=created.get("id"),
|
||||
file_name=created.get("name", final_name),
|
||||
mime_type=mime_type,
|
||||
web_view_link=created.get("webViewLink"),
|
||||
content=final_content,
|
||||
connector_id=actual_connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if kb_result["status"] == "success":
|
||||
kb_message_suffix = " Your knowledge base has also been updated."
|
||||
else:
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
except Exception as kb_err:
|
||||
logger.warning(f"KB sync after create failed: {kb_err}")
|
||||
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"file_id": created.get("id"),
|
||||
"name": created.get("name"),
|
||||
"web_view_link": created.get("webViewLink"),
|
||||
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error creating Google Drive file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while creating the file. Please try again.",
|
||||
}
|
||||
|
||||
return create_google_drive_file
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.agents.multi_agent_chat.subagents.shared.permissions import (
|
||||
ToolsPermissions,
|
||||
)
|
||||
|
||||
from .create_file import create_create_google_drive_file_tool
|
||||
from .trash_file import create_delete_google_drive_file_tool
|
||||
|
||||
|
||||
def load_tools(*, dependencies: dict[str, Any] | None = None, **kwargs: Any) -> ToolsPermissions:
|
||||
d = {**(dependencies or {}), **kwargs}
|
||||
common = {
|
||||
"db_session": d["db_session"],
|
||||
"search_space_id": d["search_space_id"],
|
||||
"user_id": d["user_id"],
|
||||
}
|
||||
create = create_create_google_drive_file_tool(**common)
|
||||
delete = create_delete_google_drive_file_tool(**common)
|
||||
return {
|
||||
"allow": [],
|
||||
"ask": [
|
||||
{"name": getattr(create, "name", "") or "", "tool": create},
|
||||
{"name": getattr(delete, "name", "") or "", "tool": delete},
|
||||
],
|
||||
}
|
||||
|
|
@ -0,0 +1,262 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.errors import HttpError
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.connectors.google_drive.client import GoogleDriveClient
|
||||
from app.services.google_drive import GoogleDriveToolMetadataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_delete_google_drive_file_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def delete_google_drive_file(
|
||||
file_name: str,
|
||||
delete_from_kb: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Move a Google Drive file to trash.
|
||||
|
||||
Use this tool when the user explicitly asks to delete, remove, or trash
|
||||
a file in Google Drive.
|
||||
|
||||
Args:
|
||||
file_name: The exact name of the file to trash (as it appears in Drive).
|
||||
delete_from_kb: Whether to also remove the file from the knowledge base.
|
||||
Default is False.
|
||||
Set to True to remove from both Google Drive and knowledge base.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- status: "success", "rejected", "not_found", or "error"
|
||||
- file_id: Google Drive file ID (if success)
|
||||
- deleted_from_kb: whether the document was removed from the knowledge base
|
||||
- message: Result message
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Respond with a brief
|
||||
acknowledgment and do NOT retry or suggest alternatives.
|
||||
- If status is "not_found", relay the exact message to the user and ask them
|
||||
to verify the file name or check if it has been indexed.
|
||||
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
|
||||
Inform the user they need to re-authenticate and do NOT retry this tool.
|
||||
Examples:
|
||||
- "Delete the 'Meeting Notes' file from Google Drive"
|
||||
- "Trash the 'Old Budget' spreadsheet"
|
||||
"""
|
||||
logger.info(
|
||||
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
|
||||
)
|
||||
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Google Drive tool not properly configured. Please contact support.",
|
||||
}
|
||||
|
||||
try:
|
||||
metadata_service = GoogleDriveToolMetadataService(db_session)
|
||||
context = await metadata_service.get_trash_context(
|
||||
search_space_id, user_id, file_name
|
||||
)
|
||||
|
||||
if "error" in context:
|
||||
error_msg = context["error"]
|
||||
if "not found" in error_msg.lower():
|
||||
logger.warning(f"File not found: {error_msg}")
|
||||
return {"status": "not_found", "message": error_msg}
|
||||
logger.error(f"Failed to fetch trash context: {error_msg}")
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
account = context.get("account", {})
|
||||
if account.get("auth_expired"):
|
||||
logger.warning(
|
||||
"Google Drive account %s has expired authentication",
|
||||
account.get("id"),
|
||||
)
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
|
||||
"connector_type": "google_drive",
|
||||
}
|
||||
|
||||
file = context["file"]
|
||||
file_id = file["file_id"]
|
||||
document_id = file.get("document_id")
|
||||
connector_id_from_context = context["account"]["id"]
|
||||
|
||||
if not file_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
|
||||
)
|
||||
result = request_approval(
|
||||
action_type="google_drive_file_trash",
|
||||
tool_name="delete_google_drive_file",
|
||||
params={
|
||||
"file_id": file_id,
|
||||
"connector_id": connector_id_from_context,
|
||||
"delete_from_kb": delete_from_kb,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {
|
||||
"status": "rejected",
|
||||
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
|
||||
}
|
||||
|
||||
final_file_id = result.params.get("file_id", file_id)
|
||||
final_connector_id = result.params.get(
|
||||
"connector_id", connector_id_from_context
|
||||
)
|
||||
final_delete_from_kb = result.params.get("delete_from_kb", delete_from_kb)
|
||||
|
||||
if not final_connector_id:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No connector found for this file.",
|
||||
}
|
||||
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
_drive_types = [
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
|
||||
]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == final_connector_id,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_drive_types),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Selected Google Drive connector is invalid or has been disconnected.",
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
|
||||
)
|
||||
|
||||
pre_built_creds = None
|
||||
if (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if cca_id:
|
||||
pre_built_creds = build_composio_credentials(cca_id)
|
||||
|
||||
client = GoogleDriveClient(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
credentials=pre_built_creds,
|
||||
)
|
||||
try:
|
||||
await client.trash_file(file_id=final_file_id)
|
||||
except HttpError as http_err:
|
||||
if http_err.resp.status == 403:
|
||||
logger.warning(
|
||||
f"Insufficient permissions for connector {connector.id}: {http_err}"
|
||||
)
|
||||
try:
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
if not connector.config.get("auth_expired"):
|
||||
connector.config = {
|
||||
**connector.config,
|
||||
"auth_expired": True,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await db_session.commit()
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"connector_id": connector.id,
|
||||
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
|
||||
}
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
|
||||
)
|
||||
|
||||
trash_result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"file_id": final_file_id,
|
||||
"message": f"Successfully moved '{file['name']}' to trash.",
|
||||
}
|
||||
|
||||
deleted_from_kb = False
|
||||
if final_delete_from_kb and document_id:
|
||||
try:
|
||||
from app.db import Document
|
||||
|
||||
doc_result = await db_session.execute(
|
||||
select(Document).filter(Document.id == document_id)
|
||||
)
|
||||
document = doc_result.scalars().first()
|
||||
if document:
|
||||
await db_session.delete(document)
|
||||
await db_session.commit()
|
||||
deleted_from_kb = True
|
||||
logger.info(
|
||||
f"Deleted document {document_id} from knowledge base"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Document {document_id} not found in KB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete document from KB: {e}")
|
||||
await db_session.rollback()
|
||||
trash_result["warning"] = (
|
||||
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
|
||||
)
|
||||
|
||||
trash_result["deleted_from_kb"] = deleted_from_kb
|
||||
if deleted_from_kb:
|
||||
trash_result["message"] = (
|
||||
f"{trash_result.get('message', '')} (also removed from knowledge base)"
|
||||
)
|
||||
|
||||
return trash_result
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
|
||||
logger.error(f"Error deleting Google Drive file: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "Something went wrong while trashing the file. Please try again.",
|
||||
}
|
||||
|
||||
return delete_google_drive_file
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue