refactor(agents): move tools package to app/agents/shared (slice 6)

Relocate the entire new_chat/tools/ package (62 files incl. registry, hitl, MCP
cluster, and all connector subpackages: gmail/slack/discord/teams/drive/etc.)
to the shared kernel. The package turned out to be a clean cohesive cluster:
its only references to non-tools new_chat modules were comments, and its
middleware deps were already flipped to shared in slice 5c.

Flip 33 live importers (multi-agent, flows, routes, services, anonymous_agent,
tests). Re-export shims remain for the frozen single-agent stack: a package
__init__ mirroring the public surface (new_chat.__init__ imports it) plus
invalid_tool + registry submodule shims (chat_deepagent imports those).

Resolves slice 5c's two transient back-edges: shared/middleware/action_log
(TYPE_CHECKING ToolDefinition) and tool_call_repair (local INVALID_TOOL_NAME)
now point at app.agents.shared.tools.
This commit is contained in:
CREDO23 2026-06-04 13:11:56 +02:00
parent a7fde2a48e
commit aab95b9130
98 changed files with 1232 additions and 1152 deletions

View file

@ -3,7 +3,7 @@
Wraps every tool call via :meth:`AgentMiddleware.awrap_tool_call` and writes
a row to :class:`~app.db.AgentActionLog` after the tool returns. Tools opt
into reversibility by declaring a ``reverse`` callable on their
:class:`~app.agents.new_chat.tools.registry.ToolDefinition`; the rendered
:class:`~app.agents.shared.tools.registry.ToolDefinition`; the rendered
descriptor is persisted in ``reverse_descriptor`` for use by
``/api/threads/{thread_id}/revert/{action_id}``.
@ -42,7 +42,7 @@ if TYPE_CHECKING: # pragma: no cover - type-only
# Type-only import: keeping it lazy avoids a module-load cycle through the
# frozen single-agent package (new_chat.__init__ -> chat_deepagent ->
# middleware shim). Resolves to app.agents.shared.tools once tools migrate.
from app.agents.new_chat.tools.registry import ToolDefinition
from app.agents.shared.tools.registry import ToolDefinition
logger = logging.getLogger(__name__)

View file

@ -121,7 +121,7 @@ class ToolCallNameRepairMiddleware(
# Local import avoids a module-load cycle through the frozen single-agent
# package (new_chat.__init__ -> chat_deepagent -> middleware shim).
# Resolves to app.agents.shared.tools once tools migrate.
from app.agents.new_chat.tools.invalid_tool import INVALID_TOOL_NAME
from app.agents.shared.tools.invalid_tool import INVALID_TOOL_NAME
if INVALID_TOOL_NAME in registered:
original_args = call.get("args") or {}

View file

@ -0,0 +1,55 @@
"""
Tools module for SurfSense deep agent.
This module contains all the tools available to the SurfSense agent.
To add a new tool, see the documentation in registry.py.
Available tools:
- generate_podcast: Generate audio podcasts from content
- generate_video_presentation: Generate video presentations with slides and narration
- generate_image: Generate images from text descriptions using AI models
- scrape_webpage: Extract content from webpages
- update_memory: Update the user's / team's memory document
"""
# Registry exports
# Tool factory exports (for direct use)
from .generate_image import create_generate_image_tool
from .knowledge_base import (
CONNECTOR_DESCRIPTIONS,
format_documents_for_context,
search_knowledge_base_async,
)
from .podcast import create_generate_podcast_tool
from .registry import (
BUILTIN_TOOLS,
ToolDefinition,
build_tools,
get_all_tool_names,
get_default_enabled_tools,
get_tool_by_name,
)
from .scrape_webpage import create_scrape_webpage_tool
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
from .video_presentation import create_generate_video_presentation_tool
__all__ = [
# Registry
"BUILTIN_TOOLS",
# Knowledge base utilities
"CONNECTOR_DESCRIPTIONS",
"ToolDefinition",
"build_tools",
# Tool factories
"create_generate_image_tool",
"create_generate_podcast_tool",
"create_generate_video_presentation_tool",
"create_scrape_webpage_tool",
"create_update_memory_tool",
"create_update_team_memory_tool",
"format_documents_for_context",
"get_all_tool_names",
"get_default_enabled_tools",
"get_tool_by_name",
"search_knowledge_base_async",
]

View file

@ -0,0 +1,11 @@
"""Confluence tools for creating, updating, and deleting pages."""
from .create_page import create_create_confluence_page_tool
from .delete_page import create_delete_confluence_page_tool
from .update_page import create_update_confluence_page_tool
__all__ = [
"create_create_confluence_page_tool",
"create_delete_confluence_page_tool",
"create_update_confluence_page_tool",
]

View file

@ -0,0 +1,232 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.shared.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
def create_create_confluence_page_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the create_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool
async def create_confluence_page(
title: str,
content: str | None = None,
space_id: str | None = None,
) -> dict[str, Any]:
"""Create a new page in Confluence.
Use this tool when the user explicitly asks to create a new Confluence page.
Args:
title: Title of the page.
content: Optional HTML/storage format content for the page body.
space_id: Optional Confluence space ID to create the page in.
Returns:
Dictionary with status, page_id, and message.
IMPORTANT:
- If status is "rejected", do NOT retry.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(f"create_confluence_page called: title='{title}'")
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
async with async_session_maker() as db_session:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Confluence accounts need re-authentication.",
"connector_type": "confluence",
}
result = request_approval(
action_type="confluence_page_creation",
tool_name="create_confluence_page",
params={
"title": title,
"content": content,
"space_id": space_id,
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_title = result.params.get("title", title)
final_content = result.params.get("content", content) or ""
final_space_id = result.params.get("space_id", space_id)
final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
return {"status": "error", "message": "Page title cannot be empty."}
if not final_space_id:
return {"status": "error", "message": "A space must be selected."}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Confluence connector found.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=actual_connector_id
)
api_result = await client.create_page(
space_id=final_space_id,
title=final_title,
body=final_content,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
_conn = connector
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
page_id = str(api_result.get("id", ""))
page_links = (
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
)
page_url = ""
if page_links.get("base") and page_links.get("webui"):
page_url = f"{page_links['base']}{page_links['webui']}"
kb_message_suffix = ""
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
page_id=page_id,
page_title=final_title,
space_id=final_space_id,
body_content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"page_id": page_id,
"page_url": page_url,
"message": f"Confluence page '{final_title}' created successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating Confluence page: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the page.",
}
return create_confluence_page

View file

@ -0,0 +1,213 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.shared.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
def create_delete_confluence_page_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the delete_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_confluence_page(
page_title_or_id: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Delete a Confluence page.
Use this tool when the user asks to delete or remove a Confluence page.
Args:
page_title_or_id: The page title or ID to identify the page.
delete_from_kb: Whether to also remove from the knowledge base.
Returns:
Dictionary with status, message, and deleted_from_kb.
IMPORTANT:
- If status is "rejected", do NOT retry.
- If status is "not_found", relay the message to the user.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(
f"delete_confluence_page called: page_title_or_id='{page_title_or_id}'"
)
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
async with async_session_maker() as db_session:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, page_title_or_id
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"]
page_id = page_data["page_id"]
page_title = page_data.get("page_title", "")
document_id = page_data["document_id"]
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="confluence_page_deletion",
tool_name="delete_confluence_page",
params={
"page_id": page_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this page.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=final_connector_id
)
await client.delete_page(final_page_id)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
message = f"Confluence page '{page_title}' deleted successfully."
if deleted_from_kb:
message += " Also removed from the knowledge base."
return {
"status": "success",
"page_id": final_page_id,
"deleted_from_kb": deleted_from_kb,
"message": message,
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting Confluence page: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while deleting the page.",
}
return delete_confluence_page

View file

@ -0,0 +1,240 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.agents.shared.tools.hitl import request_approval
from app.connectors.confluence_history import ConfluenceHistoryConnector
from app.db import async_session_maker
from app.services.confluence import ConfluenceToolMetadataService
logger = logging.getLogger(__name__)
def create_update_confluence_page_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the update_confluence_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_confluence_page tool
"""
del db_session # per-call session — see docstring
@tool
async def update_confluence_page(
page_title_or_id: str,
new_title: str | None = None,
new_content: str | None = None,
) -> dict[str, Any]:
"""Update an existing Confluence page.
Use this tool when the user asks to modify or edit a Confluence page.
Args:
page_title_or_id: The page title or ID to identify the page.
new_title: Optional new title for the page.
new_content: Optional new HTML/storage format content.
Returns:
Dictionary with status and message.
IMPORTANT:
- If status is "rejected", do NOT retry.
- If status is "not_found", relay the message to the user.
- If status is "insufficient_permissions", inform user to re-authenticate.
"""
logger.info(
f"update_confluence_page called: page_title_or_id='{page_title_or_id}'"
)
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Confluence tool not properly configured.",
}
try:
async with async_session_maker() as db_session:
metadata_service = ConfluenceToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, page_title_or_id
)
if "error" in context:
error_msg = context["error"]
if context.get("auth_expired"):
return {
"status": "auth_error",
"message": error_msg,
"connector_id": context.get("connector_id"),
"connector_type": "confluence",
}
if "not found" in error_msg.lower():
return {"status": "not_found", "message": error_msg}
return {"status": "error", "message": error_msg}
page_data = context["page"]
page_id = page_data["page_id"]
current_title = page_data["page_title"]
current_body = page_data.get("body", "")
current_version = page_data.get("version", 1)
document_id = page_data.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
result = request_approval(
action_type="confluence_page_update",
tool_name="update_confluence_page",
params={
"page_id": page_id,
"document_id": document_id,
"new_title": new_title,
"new_content": new_content,
"version": current_version,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_title = result.params.get("new_title", new_title) or current_title
final_content = result.params.get("new_content", new_content)
if final_content is None:
final_content = current_body
final_version = result.params.get("version", current_version)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_document_id = result.params.get("document_id", document_id)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this page.",
}
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.CONFLUENCE_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Confluence connector is invalid.",
}
try:
client = ConfluenceHistoryConnector(
session=db_session, connector_id=final_connector_id
)
api_result = await client.update_page(
page_id=final_page_id,
title=final_title,
body=final_content,
version_number=final_version + 1,
)
await client.close()
except Exception as api_err:
if (
"http 403" in str(api_err).lower()
or "status code 403" in str(api_err).lower()
):
try:
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
pass
return {
"status": "insufficient_permissions",
"connector_id": final_connector_id,
"message": "This Confluence account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
page_links = (
api_result.get("_links", {}) if isinstance(api_result, dict) else {}
)
page_url = ""
if page_links.get("base") and page_links.get("webui"):
page_url = f"{page_links['base']}{page_links['webui']}"
kb_message_suffix = ""
if final_document_id:
try:
from app.services.confluence import ConfluenceKBSyncService
kb_service = ConfluenceKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=final_document_id,
page_id=final_page_id,
user_id=user_id,
search_space_id=search_space_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = (
" The knowledge base will be updated in the next sync."
)
return {
"status": "success",
"page_id": final_page_id,
"page_url": page_url,
"message": f"Confluence page '{final_title}' updated successfully.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error updating Confluence page: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while updating the page.",
}
return update_confluence_page

View file

@ -0,0 +1,135 @@
"""Connected-accounts discovery tool.
Lets the LLM discover which accounts are connected for a given service
(e.g. "jira", "linear", "slack") and retrieve the metadata it needs to
call action tools such as Jira's ``cloudId``.
The tool returns **only** non-sensitive fields explicitly listed in the
service's ``account_metadata_keys`` (see ``registry.py``), plus the
always-present ``display_name`` and ``connector_id``.
"""
import logging
from typing import Any
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
from app.services.mcp_oauth.registry import MCP_SERVICES
logger = logging.getLogger(__name__)
_SERVICE_KEY_BY_CONNECTOR_TYPE: dict[str, str] = {
cfg.connector_type: key for key, cfg in MCP_SERVICES.items()
}
class GetConnectedAccountsInput(BaseModel):
service: str = Field(
description=(
"Service key to look up connected accounts for. "
"Valid values: " + ", ".join(sorted(MCP_SERVICES.keys()))
),
)
def _extract_display_name(connector: SearchSourceConnector) -> str:
"""Best-effort human-readable label for a connector."""
cfg = connector.config or {}
if cfg.get("display_name"):
return cfg["display_name"]
if cfg.get("base_url"):
return f"{connector.name} ({cfg['base_url']})"
if cfg.get("organization_name"):
return f"{connector.name} ({cfg['organization_name']})"
return connector.name
def create_get_connected_accounts_tool(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> StructuredTool:
"""Factory function to create the get_connected_accounts tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to scope account discovery to.
user_id: User ID to scope account discovery to.
Returns:
Configured StructuredTool for connected-accounts discovery.
"""
del db_session # per-call session — see docstring
async def _run(service: str) -> list[dict[str, Any]]:
svc_cfg = MCP_SERVICES.get(service)
if not svc_cfg:
return [
{
"error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}"
}
]
try:
connector_type = SearchSourceConnectorType(svc_cfg.connector_type)
except ValueError:
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == connector_type,
)
)
connectors = result.scalars().all()
if not connectors:
return [
{
"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."
}
]
is_multi = len(connectors) > 1
accounts: list[dict[str, Any]] = []
for conn in connectors:
cfg = conn.config or {}
entry: dict[str, Any] = {
"connector_id": conn.id,
"display_name": _extract_display_name(conn),
"service": service,
}
if is_multi:
entry["tool_prefix"] = f"{service}_{conn.id}"
for key in svc_cfg.account_metadata_keys:
if key in cfg:
entry[key] = cfg[key]
accounts.append(entry)
return accounts
return StructuredTool(
name="get_connected_accounts",
description=(
"Discover which accounts are connected for a service (e.g. jira, linear, slack, clickup, airtable). "
"Returns display names and service-specific metadata the action tools need "
"(e.g. Jira's cloudId). Call this BEFORE using a service's action tools when "
"you need an account identifier or are unsure which account to use."
),
coroutine=_run,
args_schema=GetConnectedAccountsInput,
metadata={"hitl": False},
)

View file

@ -0,0 +1,15 @@
from app.agents.shared.tools.discord.list_channels import (
create_list_discord_channels_tool,
)
from app.agents.shared.tools.discord.read_messages import (
create_read_discord_messages_tool,
)
from app.agents.shared.tools.discord.send_message import (
create_send_discord_message_tool,
)
__all__ = [
"create_list_discord_channels_tool",
"create_read_discord_messages_tool",
"create_send_discord_message_tool",
]

View file

@ -0,0 +1,43 @@
"""Shared auth helper for Discord agent tools (REST API, not gateway bot)."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.utils.oauth_security import TokenEncryption
DISCORD_API = "https://discord.com/api/v10"
async def get_discord_connector(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> SearchSourceConnector | None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DISCORD_CONNECTOR,
)
)
return result.scalars().first()
def get_bot_token(connector: SearchSourceConnector) -> str:
"""Extract and decrypt the bot token from connector config."""
cfg = dict(connector.config)
if cfg.get("_token_encrypted") and config.SECRET_KEY:
enc = TokenEncryption(config.SECRET_KEY)
if cfg.get("bot_token"):
cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"])
token = cfg.get("bot_token")
if not token:
raise ValueError("Discord bot token not found in connector config.")
return token
def get_guild_id(connector: SearchSourceConnector) -> str | None:
return connector.config.get("guild_id")

View file

@ -0,0 +1,107 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
logger = logging.getLogger(__name__)
def create_list_discord_channels_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the list_discord_channels tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_discord_channels tool
"""
del db_session # per-call session — see docstring
@tool
async def list_discord_channels() -> dict[str, Any]:
"""List text channels in the connected Discord server.
Returns:
Dictionary with status and a list of channels (id, name).
"""
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
}
try:
async with async_session_maker() as db_session:
connector = await get_discord_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
guild_id = get_guild_id(connector)
if not guild_id:
return {
"status": "error",
"message": "No guild ID in Discord connector config.",
}
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/guilds/{guild_id}/channels",
headers={"Authorization": f"Bot {token}"},
timeout=15.0,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
# Type 0 = text channel
channels = [
{"id": ch["id"], "name": ch["name"]}
for ch in resp.json()
if ch.get("type") == 0
]
return {
"status": "success",
"guild_id": guild_id,
"channels": channels,
"total": len(channels),
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error listing Discord channels: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to list Discord channels."}
return list_discord_channels

View file

@ -0,0 +1,120 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
logger = logging.getLogger(__name__)
def create_read_discord_messages_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the read_discord_messages tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_discord_messages tool
"""
del db_session # per-call session — see docstring
@tool
async def read_discord_messages(
channel_id: str,
limit: int = 25,
) -> dict[str, Any]:
"""Read recent messages from a Discord text channel.
Args:
channel_id: The Discord channel ID (from list_discord_channels).
limit: Number of messages to fetch (default 25, max 50).
Returns:
Dictionary with status and a list of messages including
id, author, content, timestamp.
"""
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
}
limit = min(limit, 50)
try:
async with async_session_maker() as db_session:
connector = await get_discord_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/channels/{channel_id}/messages",
headers={"Authorization": f"Bot {token}"},
params={"limit": limit},
timeout=15.0,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Bot lacks permission to read this channel.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
messages = [
{
"id": m["id"],
"author": m.get("author", {}).get("username", "Unknown"),
"content": m.get("content", ""),
"timestamp": m.get("timestamp", ""),
}
for m in resp.json()
]
return {
"status": "success",
"channel_id": channel_id,
"messages": messages,
"total": len(messages),
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Discord messages: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to read Discord messages."}
return read_discord_messages

View file

@ -0,0 +1,136 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
logger = logging.getLogger(__name__)
def create_send_discord_message_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the send_discord_message tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_discord_message tool
"""
del db_session # per-call session — see docstring
@tool
async def send_discord_message(
channel_id: str,
content: str,
) -> dict[str, Any]:
"""Send a message to a Discord text channel.
Args:
channel_id: The Discord channel ID (from list_discord_channels).
content: The message text (max 2000 characters).
Returns:
Dictionary with status, message_id on success.
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Discord tool not properly configured.",
}
if len(content) > 2000:
return {
"status": "error",
"message": "Message exceeds Discord's 2000-character limit.",
}
try:
async with async_session_maker() as db_session:
connector = await get_discord_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
result = request_approval(
action_type="discord_send_message",
tool_name="send_discord_message",
params={"channel_id": channel_id, "content": content},
context={"connector_id": connector.id},
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Message was not sent.",
}
final_content = result.params.get("content", content)
final_channel = result.params.get("channel_id", channel_id)
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{DISCORD_API}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bot {token}",
"Content-Type": "application/json",
},
json={"content": final_content},
timeout=15.0,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Discord bot token is invalid.",
"connector_type": "discord",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Bot lacks permission to send messages in this channel.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Discord API error: {resp.status_code}",
}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": f"Message sent to channel {final_channel}.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error sending Discord message: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to send Discord message."}
return send_discord_message

View file

@ -0,0 +1,11 @@
from app.agents.shared.tools.dropbox.create_file import (
create_create_dropbox_file_tool,
)
from app.agents.shared.tools.dropbox.trash_file import (
create_delete_dropbox_file_tool,
)
__all__ = [
"create_create_dropbox_file_tool",
"create_delete_dropbox_file_tool",
]

View file

@ -0,0 +1,299 @@
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Literal
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.shared.tools.hitl import request_approval
from app.connectors.dropbox.client import DropboxClient
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
DOCX_MIME = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
_FILE_TYPE_LABELS = {
"paper": "Dropbox Paper (.paper)",
"docx": "Word Document (.docx)",
}
_SUPPORTED_TYPES = [
{"value": "paper", "label": "Dropbox Paper (.paper)"},
{"value": "docx", "label": "Word Document (.docx)"},
]
def _ensure_extension(name: str, file_type: str) -> str:
"""Strip any existing extension and append the correct one."""
stem = Path(name).stem
ext = ".paper" if file_type == "paper" else ".docx"
return f"{stem}{ext}"
def _markdown_to_docx(markdown_text: str) -> bytes:
"""Convert a markdown string to DOCX bytes using pypandoc."""
import pypandoc
fd, tmp_path = tempfile.mkstemp(suffix=".docx")
os.close(fd)
try:
pypandoc.convert_text(
markdown_text,
"docx",
format="gfm",
extra_args=["--standalone"],
outputfile=tmp_path,
)
with open(tmp_path, "rb") as f:
return f.read()
finally:
os.unlink(tmp_path)
def create_create_dropbox_file_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_dropbox_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_dropbox_file tool
"""
del db_session # per-call session — see docstring
@tool
async def create_dropbox_file(
name: str,
file_type: Literal["paper", "docx"] = "paper",
content: str | None = None,
) -> dict[str, Any]:
"""Create a new document in Dropbox.
Use this tool when the user explicitly asks to create a new document
in Dropbox. The user MUST specify a topic before you call this tool.
Args:
name: The document title (without extension).
file_type: Either "paper" (Dropbox Paper, default) or "docx" (Word document).
content: Optional initial content as markdown.
Returns:
Dictionary with status, file_id, name, web_url, and message.
"""
logger.info(
f"create_dropbox_file called: name='{name}', file_type='{file_type}'"
)
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Dropbox tool not properly configured.",
}
try:
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
connectors = result.scalars().all()
if not connectors:
return {
"status": "error",
"message": "No Dropbox connector found. Please connect Dropbox in your workspace settings.",
}
accounts = []
for c in connectors:
cfg = c.config or {}
accounts.append(
{
"id": c.id,
"name": c.name,
"user_email": cfg.get("user_email"),
"auth_expired": cfg.get("auth_expired", False),
}
)
if all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected Dropbox accounts need re-authentication.",
"connector_type": "dropbox",
}
parent_folders: dict[int, list[dict[str, str]]] = {}
for acc in accounts:
cid = acc["id"]
if acc.get("auth_expired"):
parent_folders[cid] = []
continue
try:
client = DropboxClient(session=db_session, connector_id=cid)
items, err = await client.list_folder("")
if err:
logger.warning(
"Failed to list folders for connector %s: %s", cid, err
)
parent_folders[cid] = []
else:
parent_folders[cid] = [
{
"folder_path": item.get("path_lower", ""),
"name": item["name"],
}
for item in items
if item.get(".tag") == "folder" and item.get("name")
]
except Exception:
logger.warning(
"Error fetching folders for connector %s",
cid,
exc_info=True,
)
parent_folders[cid] = []
context: dict[str, Any] = {
"accounts": accounts,
"parent_folders": parent_folders,
"supported_types": _SUPPORTED_TYPES,
}
result = request_approval(
action_type="dropbox_file_creation",
tool_name="create_dropbox_file",
params={
"name": name,
"file_type": file_type,
"content": content,
"connector_id": None,
"parent_folder_path": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_name = result.params.get("name", name)
final_file_type = result.params.get("file_type", file_type)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_path = result.params.get("parent_folder_path")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
final_name = _ensure_extension(final_name, final_file_type)
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
connector = result.scalars().first()
else:
connector = connectors[0]
if not connector:
return {
"status": "error",
"message": "Selected Dropbox connector is invalid.",
}
client = DropboxClient(session=db_session, connector_id=connector.id)
parent_path = final_parent_folder_path or ""
file_path = (
f"{parent_path}/{final_name}" if parent_path else f"/{final_name}"
)
if final_file_type == "paper":
created = await client.create_paper_doc(
file_path, final_content or ""
)
file_id = created.get("file_id", "")
web_url = created.get("url", "")
else:
docx_bytes = _markdown_to_docx(final_content or "")
created = await client.upload_file(
file_path, docx_bytes, mode="add", autorename=True
)
file_id = created.get("id", "")
web_url = ""
logger.info(f"Dropbox file created: id={file_id}, name={final_name}")
kb_message_suffix = ""
try:
from app.services.dropbox import DropboxKBSyncService
kb_service = DropboxKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=file_id,
file_name=final_name,
file_path=file_path,
web_url=web_url,
content=final_content,
connector_id=connector.id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"file_id": file_id,
"name": final_name,
"web_url": web_url,
"message": f"Successfully created '{final_name}' in Dropbox.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating Dropbox file: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the file. Please try again.",
}
return create_dropbox_file

View file

@ -0,0 +1,301 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy import String, and_, cast, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.shared.tools.hitl import request_approval
from app.connectors.dropbox.client import DropboxClient
from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
async_session_maker,
)
logger = logging.getLogger(__name__)
def create_delete_dropbox_file_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the delete_dropbox_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_dropbox_file tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_dropbox_file(
file_name: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Delete a file from Dropbox.
Use this tool when the user explicitly asks to delete, remove, or trash
a file in Dropbox.
Args:
file_name: The exact name of the file to delete.
delete_from_kb: Whether to also remove the file from the knowledge base.
Default is False.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", or "error"
- file_id: Dropbox file ID (if success)
- deleted_from_kb: whether the document was removed from the knowledge base
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined. Respond with a brief
acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the file name or check if it has been indexed.
"""
logger.info(
f"delete_dropbox_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Dropbox tool not properly configured.",
}
try:
async with async_session_maker() as db_session:
doc_result = await db_session.execute(
select(Document)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.DROPBOX_FILE,
func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
document = doc_result.scalars().first()
if not document:
doc_result = await db_session.execute(
select(Document)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.DROPBOX_FILE,
func.lower(
cast(
Document.document_metadata["dropbox_file_name"],
String,
)
)
== func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
document = doc_result.scalars().first()
if not document:
return {
"status": "not_found",
"message": (
f"File '{file_name}' not found in your indexed Dropbox files. "
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
"or (3) the file name is different."
),
}
if not document.connector_id:
return {
"status": "error",
"message": "Document has no associated connector.",
}
meta = document.document_metadata or {}
file_path = meta.get("dropbox_path")
file_id = meta.get("dropbox_file_id")
document_id = document.id
if not file_path:
return {
"status": "error",
"message": "File path is missing. Please re-index the file.",
}
conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
)
connector = conn_result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Dropbox connector not found or access denied.",
}
cfg = connector.config or {}
if cfg.get("auth_expired"):
return {
"status": "auth_error",
"message": "Dropbox account needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "dropbox",
}
context = {
"file": {
"file_id": file_id,
"file_path": file_path,
"name": file_name,
"document_id": document_id,
},
"account": {
"id": connector.id,
"name": connector.name,
"user_email": cfg.get("user_email"),
},
}
result = request_approval(
action_type="dropbox_file_trash",
tool_name="delete_dropbox_file",
params={
"file_path": file_path,
"connector_id": connector.id,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_file_path = result.params.get("file_path", file_path)
final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if final_connector_id != connector.id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id
== search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.DROPBOX_CONNECTOR,
)
)
)
validated_connector = result.scalars().first()
if not validated_connector:
return {
"status": "error",
"message": "Selected Dropbox connector is invalid or has been disconnected.",
}
actual_connector_id = validated_connector.id
else:
actual_connector_id = connector.id
logger.info(
f"Deleting Dropbox file: path='{final_file_path}', connector={actual_connector_id}"
)
client = DropboxClient(
session=db_session, connector_id=actual_connector_id
)
await client.delete_file(final_file_path)
logger.info(f"Dropbox file deleted: path={final_file_path}")
trash_result: dict[str, Any] = {
"status": "success",
"file_id": file_id,
"message": f"Successfully deleted '{file_name}' from Dropbox.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
doc = doc_result.scalars().first()
if doc:
await db_session.delete(doc)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File deleted, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting Dropbox file: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while deleting the file. Please try again.",
}
return delete_dropbox_file

View file

@ -0,0 +1,280 @@
"""
Image generation tool for the SurfSense agent.
This module provides a tool that generates images using litellm.aimage_generation()
and returns the result directly in a format the frontend Image component can render.
Config resolution:
1. Uses the search space's image_generation_config_id preference
2. Falls back to Auto mode (router load balancing) if available
3. Supports global YAML configs (negative IDs) and user DB configs (positive IDs)
"""
import hashlib
import logging
from typing import Any
from langchain_core.tools import tool
from litellm import aimage_generation
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import config
from app.db import (
ImageGeneration,
ImageGenerationConfig,
SearchSpace,
shielded_async_session,
)
from app.services.image_gen_router_service import (
IMAGE_GEN_AUTO_MODE_ID,
ImageGenRouterService,
is_image_gen_auto_mode,
)
from app.services.provider_api_base import resolve_api_base
from app.utils.signed_image_urls import generate_image_token
logger = logging.getLogger(__name__)
# Provider mapping (same as routes)
_PROVIDER_MAP = {
"OPENAI": "openai",
"AZURE_OPENAI": "azure",
"GOOGLE": "gemini",
"VERTEX_AI": "vertex_ai",
"BEDROCK": "bedrock",
"RECRAFT": "recraft",
"OPENROUTER": "openrouter",
"XINFERENCE": "xinference",
"NSCALE": "nscale",
}
def _resolve_provider_prefix(provider: str, custom_provider: str | None) -> str:
if custom_provider:
return custom_provider
return _PROVIDER_MAP.get(provider.upper(), provider.lower())
def _build_model_string(
provider: str, model_name: str, custom_provider: str | None
) -> str:
prefix = _resolve_provider_prefix(provider, custom_provider)
return f"{prefix}/{model_name}"
def _get_global_image_gen_config(config_id: int) -> dict | None:
"""Get a global image gen config by negative ID."""
for cfg in config.GLOBAL_IMAGE_GEN_CONFIGS:
if cfg.get("id") == config_id:
return cfg
return None
def create_generate_image_tool(
search_space_id: int,
db_session: AsyncSession,
):
"""
Factory function to create the generate_image tool.
Args:
search_space_id: The search space ID (for config resolution)
db_session: Reserved for compatibility with the tool registry.
The streaming task's ``AsyncSession`` is shared by every tool;
because AsyncSession is not concurrency-safe, parallel tool calls
would interleave flushes (e.g. podcast + image in the same step)
and poison the transaction. This tool opens its own session.
"""
del db_session # use a fresh per-call session, see below
@tool
async def generate_image(
prompt: str,
n: int = 1,
) -> dict[str, Any]:
"""
Generate an image from a text description using AI image models.
Use this tool when the user asks you to create, generate, draw, or make an image.
The generated image will be displayed directly in the chat.
Args:
prompt: A detailed text description of the image to generate.
Be specific about subject, style, colors, composition, and mood.
n: Number of images to generate (1-4). Default: 1
Returns:
A dictionary containing the generated image(s) for display in the chat.
"""
try:
# Use a per-call session so concurrent tool calls don't share an
# AsyncSession (which is not concurrency-safe). The streaming
# task's session is shared across every tool; without isolation,
# autoflushes from a concurrent writer poison this tool too.
async with shielded_async_session() as session:
result = await session.execute(
select(SearchSpace).filter(SearchSpace.id == search_space_id)
)
search_space = result.scalars().first()
if not search_space:
return {"error": "Search space not found"}
config_id = (
search_space.image_generation_config_id or IMAGE_GEN_AUTO_MODE_ID
)
# Build generation kwargs
# NOTE: size, quality, and style are intentionally NOT passed.
# Different models support different values for these params
# (e.g. DALL-E 3 wants "hd"/"standard" for quality while
# gpt-image-1 wants "high"/"medium"/"low"; size options also
# differ). Letting the model use its own defaults avoids errors.
gen_kwargs: dict[str, Any] = {}
if n is not None and n > 1:
gen_kwargs["n"] = n
# Call litellm based on config type
if is_image_gen_auto_mode(config_id):
if not ImageGenRouterService.is_initialized():
return {
"error": "No image generation models configured. "
"Please add an image model in Settings > Image Models."
}
response = await ImageGenRouterService.aimage_generation(
prompt=prompt, model="auto", **gen_kwargs
)
elif config_id < 0:
cfg = _get_global_image_gen_config(config_id)
if not cfg:
return {
"error": f"Image generation config {config_id} not found"
}
provider_prefix = _resolve_provider_prefix(
cfg.get("provider", ""), cfg.get("custom_provider")
)
model_string = f"{provider_prefix}/{cfg['model_name']}"
gen_kwargs["api_key"] = cfg.get("api_key")
api_base = resolve_api_base(
provider=cfg.get("provider"),
provider_prefix=provider_prefix,
config_api_base=cfg.get("api_base"),
)
if api_base:
gen_kwargs["api_base"] = api_base
if cfg.get("api_version"):
gen_kwargs["api_version"] = cfg["api_version"]
if cfg.get("litellm_params"):
gen_kwargs.update(cfg["litellm_params"])
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs
)
else:
# Positive ID = user-created ImageGenerationConfig
cfg_result = await session.execute(
select(ImageGenerationConfig).filter(
ImageGenerationConfig.id == config_id
)
)
db_cfg = cfg_result.scalars().first()
if not db_cfg:
return {
"error": f"Image generation config {config_id} not found"
}
provider_prefix = _resolve_provider_prefix(
db_cfg.provider.value, db_cfg.custom_provider
)
model_string = f"{provider_prefix}/{db_cfg.model_name}"
gen_kwargs["api_key"] = db_cfg.api_key
api_base = resolve_api_base(
provider=db_cfg.provider.value,
provider_prefix=provider_prefix,
config_api_base=db_cfg.api_base,
)
if api_base:
gen_kwargs["api_base"] = api_base
if db_cfg.api_version:
gen_kwargs["api_version"] = db_cfg.api_version
if db_cfg.litellm_params:
gen_kwargs.update(db_cfg.litellm_params)
response = await aimage_generation(
prompt=prompt, model=model_string, **gen_kwargs
)
# Parse the response and store in DB
response_dict = (
response.model_dump()
if hasattr(response, "model_dump")
else dict(response)
)
# Generate a random access token for this image
access_token = generate_image_token()
# Save to image_generations table for history
db_image_gen = ImageGeneration(
prompt=prompt,
model=getattr(response, "_hidden_params", {}).get("model"),
n=n,
image_generation_config_id=config_id,
response_data=response_dict,
search_space_id=search_space_id,
access_token=access_token,
)
session.add(db_image_gen)
await session.commit()
await session.refresh(db_image_gen)
db_image_gen_id = db_image_gen.id
# Extract image URLs from response
images = response_dict.get("data", [])
if not images:
return {"error": "No images were generated"}
first_image = images[0]
revised_prompt = first_image.get("revised_prompt", prompt)
# Resolve image URL:
# - If the API returned a URL, use it directly.
# - If the API returned b64_json (e.g. gpt-image-1), serve the
# image through our backend endpoint to avoid bloating the
# LLM context with megabytes of base64 data.
if first_image.get("url"):
image_url = first_image["url"]
elif first_image.get("b64_json"):
backend_url = config.BACKEND_URL or "http://localhost:8000"
image_url = (
f"{backend_url}/api/v1/image-generations/"
f"{db_image_gen_id}/image?token={access_token}"
)
else:
return {"error": "No displayable image data in the response"}
image_id = f"image-{hashlib.md5(image_url.encode()).hexdigest()[:12]}"
return {
"id": image_id,
"assetId": image_url,
"src": image_url,
"alt": revised_prompt or prompt,
"title": "Generated Image",
"description": revised_prompt if revised_prompt != prompt else None,
"domain": "ai-generated",
"ratio": "auto",
"generated": True,
"prompt": prompt,
"image_count": len(images),
}
except Exception as e:
logger.exception("Image generation failed in tool")
return {
"error": f"Image generation failed: {e!s}",
"prompt": prompt,
}
return generate_image

View file

@ -0,0 +1,27 @@
from app.agents.shared.tools.gmail.create_draft import (
create_create_gmail_draft_tool,
)
from app.agents.shared.tools.gmail.read_email import (
create_read_gmail_email_tool,
)
from app.agents.shared.tools.gmail.search_emails import (
create_search_gmail_tool,
)
from app.agents.shared.tools.gmail.send_email import (
create_send_gmail_email_tool,
)
from app.agents.shared.tools.gmail.trash_email import (
create_trash_gmail_email_tool,
)
from app.agents.shared.tools.gmail.update_draft import (
create_update_gmail_draft_tool,
)
__all__ = [
"create_create_gmail_draft_tool",
"create_read_gmail_email_tool",
"create_search_gmail_tool",
"create_send_gmail_email_tool",
"create_trash_gmail_email_tool",
"create_update_gmail_draft_tool",
]

View file

@ -0,0 +1,41 @@
from typing import Any
from app.db import SearchSourceConnector
from app.services.composio_service import ComposioService
def split_recipients(value: str | None) -> list[str]:
if not value:
return []
return [recipient.strip() for recipient in value.split(",") if recipient.strip()]
def unwrap_composio_data(data: Any) -> Any:
if isinstance(data, dict):
inner = data.get("data", data)
if isinstance(inner, dict):
return inner.get("response_data", inner)
return inner
return data
async def execute_composio_gmail_tool(
connector: SearchSourceConnector,
user_id: str,
tool_name: str,
params: dict[str, Any],
) -> tuple[Any, str | None]:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return None, "Composio connected account ID not found for this Gmail connector."
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name=tool_name,
params=params,
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
return None, result.get("error", "Unknown Composio Gmail error")
return unwrap_composio_data(result.get("data")), None

View file

@ -0,0 +1,361 @@
import asyncio
import base64
import logging
from datetime import datetime
from email.mime.text import MIMEText
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
def create_create_gmail_draft_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_gmail_draft tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_gmail_draft tool
"""
del db_session # per-call session — see docstring
@tool
async def create_gmail_draft(
to: str,
subject: str,
body: str,
cc: str | None = None,
bcc: str | None = None,
) -> dict[str, Any]:
"""Create a draft email in Gmail.
Use when the user asks to draft, compose, or prepare an email without
sending it.
Args:
to: Recipient email address.
subject: Email subject line.
body: Email body content.
cc: Optional CC recipient(s), comma-separated.
bcc: Optional BCC recipient(s), comma-separated.
Returns:
Dictionary with:
- status: "success", "rejected", or "error"
- draft_id: Gmail draft ID (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry the action.
Examples:
- "Draft an email to alice@example.com about the meeting"
- "Compose a reply to Bob about the project update"
"""
logger.info(f"create_gmail_draft called: to='{to}', subject='{subject}'")
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
logger.info(
f"Requesting approval for creating Gmail draft: to='{to}', subject='{subject}'"
)
result = request_approval(
action_type="gmail_draft_creation",
tool_name="create_gmail_draft",
params={
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The draft was not created. Do not ask again or suggest alternatives.",
}
final_to = result.params.get("to", to)
final_subject = result.params.get("subject", subject)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(
f"Creating Gmail draft: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
if is_composio_gmail:
from app.agents.shared.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
)
created, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_CREATE_EMAIL_DRAFT",
{
"user_id": "me",
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(created, dict):
created = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.create(userId="me", body={"message": {"raw": raw}})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Gmail draft created: id={created.get('id')}")
kb_message_suffix = ""
try:
from app.services.gmail import GmailKBSyncService
kb_service = GmailKBSyncService(db_session)
draft_message = created.get("message", {})
kb_result = await kb_service.sync_after_create(
message_id=draft_message.get("id", ""),
thread_id=draft_message.get("threadId", ""),
subject=final_subject,
sender="me",
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
body_text=final_body,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
draft_id=created.get("id"),
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This draft will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"draft_id": created.get("id"),
"message": f"Successfully created Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating Gmail draft: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the draft. Please try again.",
}
return create_gmail_draft

View file

@ -0,0 +1,172 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
_GMAIL_TYPES = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
def create_read_gmail_email_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the read_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool
async def read_gmail_email(message_id: str) -> dict[str, Any]:
"""Read the full content of a specific Gmail email by its message ID.
Use after search_gmail to get the complete body of an email.
Args:
message_id: The Gmail message ID (from search_gmail results).
Returns:
Dictionary with status and the full email content formatted as markdown.
"""
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
try:
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found.",
}
from app.agents.shared.tools.gmail.search_emails import (
_format_gmail_summary,
)
from app.services.composio_service import ComposioService
service = ComposioService()
detail, error = await service.get_gmail_message_detail(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
message_id=message_id,
)
if error:
return {"status": "error", "message": error}
if not detail:
return {
"status": "not_found",
"message": f"Email with ID '{message_id}' not found.",
}
summary = _format_gmail_summary(detail)
content = (
f"# {summary['subject']}\n\n"
f"**From:** {summary['from']}\n"
f"**To:** {summary['to']}\n"
f"**Date:** {summary['date']}\n\n"
f"## Message Content\n\n"
f"{detail.get('messageText') or detail.get('snippet') or ''}\n\n"
f"## Message Details\n\n"
f"- **Message ID:** {summary['message_id']}\n"
f"- **Thread ID:** {summary['thread_id']}\n"
)
return {
"status": "success",
"message_id": summary["message_id"] or message_id,
"content": content,
}
from app.agents.shared.tools.gmail.search_emails import (
_build_credentials,
)
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
detail, error = await gmail.get_message_details(message_id)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
return {
"status": "auth_error",
"message": error,
"connector_type": "gmail",
}
return {"status": "error", "message": error}
if not detail:
return {
"status": "not_found",
"message": f"Email with ID '{message_id}' not found.",
}
content = gmail.format_message_to_markdown(detail)
return {
"status": "success",
"message_id": message_id,
"content": content,
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Gmail email: %s", e, exc_info=True)
return {
"status": "error",
"message": "Failed to read email. Please try again.",
}
return read_gmail_email

View file

@ -0,0 +1,260 @@
import logging
from datetime import datetime
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
_GMAIL_TYPES = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
_token_encryption_cache: object | None = None
def _get_token_encryption():
global _token_encryption_cache
if _token_encryption_cache is None:
from app.config import config
from app.utils.oauth_security import TokenEncryption
if not config.SECRET_KEY:
raise RuntimeError("SECRET_KEY not configured for token decryption.")
_token_encryption_cache = TokenEncryption(config.SECRET_KEY)
return _token_encryption_cache
def _build_credentials(connector: SearchSourceConnector):
"""Build Google OAuth Credentials from a connector's stored config.
Handles both native OAuth connectors (with encrypted tokens) and
Composio-backed connectors. Shared by Gmail and Calendar tools.
"""
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
raise ValueError("Composio connectors must use Composio tool execution.")
from google.oauth2.credentials import Credentials
cfg = dict(connector.config)
if cfg.get("_token_encrypted"):
enc = _get_token_encryption()
for key in ("token", "refresh_token", "client_secret"):
if cfg.get(key):
cfg[key] = enc.decrypt_token(cfg[key])
exp = (cfg.get("expiry") or "").replace("Z", "")
return Credentials(
token=cfg.get("token"),
refresh_token=cfg.get("refresh_token"),
token_uri=cfg.get("token_uri"),
client_id=cfg.get("client_id"),
client_secret=cfg.get("client_secret"),
scopes=cfg.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
def _gmail_headers(message: dict[str, Any]) -> dict[str, str]:
headers = message.get("payload", {}).get("headers", [])
return {
header.get("name", "").lower(): header.get("value", "")
for header in headers
if isinstance(header, dict)
}
def _format_gmail_summary(message: dict[str, Any]) -> dict[str, Any]:
headers = _gmail_headers(message)
return {
"message_id": message.get("id") or message.get("messageId"),
"thread_id": message.get("threadId"),
"subject": message.get("subject") or headers.get("subject", "No Subject"),
"from": message.get("sender") or headers.get("from", "Unknown"),
"to": message.get("to") or headers.get("to", ""),
"date": message.get("messageTimestamp") or headers.get("date", ""),
"snippet": message.get("snippet") or message.get("messageText", "")[:300],
"labels": message.get("labelIds", []),
}
async def _search_composio_gmail(
connector: SearchSourceConnector,
user_id: str,
query: str,
max_results: int,
) -> dict[str, Any]:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found.",
}
from app.services.composio_service import ComposioService
service = ComposioService()
messages, _next_token, _estimate, error = await service.get_gmail_messages(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
query=query,
max_results=max_results,
)
if error:
return {"status": "error", "message": error}
emails = [_format_gmail_summary(message) for message in messages]
return {
"status": "success",
"emails": emails,
"total": len(emails),
"message": "No emails found." if not emails else None,
}
def create_search_gmail_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the search_gmail tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured search_gmail tool
"""
del db_session # per-call session — see docstring
@tool
async def search_gmail(
query: str,
max_results: int = 10,
) -> dict[str, Any]:
"""Search emails in the user's Gmail inbox using Gmail search syntax.
Args:
query: Gmail search query, same syntax as the Gmail search bar.
Examples: "from:alice@example.com", "subject:meeting",
"is:unread", "after:2024/01/01 before:2024/02/01",
"has:attachment", "in:sent".
max_results: Number of emails to return (default 10, max 20).
Returns:
Dictionary with status and a list of email summaries including
message_id, subject, from, date, snippet.
"""
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
max_results = min(max_results, 20)
try:
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
):
return await _search_composio_gmail(
connector, str(user_id), query, max_results
)
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
messages_list, error = await gmail.get_messages_list(
max_results=max_results, query=query
)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
return {
"status": "auth_error",
"message": error,
"connector_type": "gmail",
}
return {"status": "error", "message": error}
if not messages_list:
return {
"status": "success",
"emails": [],
"total": 0,
"message": "No emails found.",
}
emails = []
for msg in messages_list:
detail, err = await gmail.get_message_details(msg["id"])
if err:
continue
headers = {
h["name"].lower(): h["value"]
for h in detail.get("payload", {}).get("headers", [])
}
emails.append(
{
"message_id": detail.get("id"),
"thread_id": detail.get("threadId"),
"subject": headers.get("subject", "No Subject"),
"from": headers.get("from", "Unknown"),
"to": headers.get("to", ""),
"date": headers.get("date", ""),
"snippet": detail.get("snippet", ""),
"labels": detail.get("labelIds", []),
}
)
return {"status": "success", "emails": emails, "total": len(emails)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error searching Gmail: %s", e, exc_info=True)
return {
"status": "error",
"message": "Failed to search Gmail. Please try again.",
}
return search_gmail

View file

@ -0,0 +1,363 @@
import asyncio
import base64
import logging
from datetime import datetime
from email.mime.text import MIMEText
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
def create_send_gmail_email_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the send_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool
async def send_gmail_email(
to: str,
subject: str,
body: str,
cc: str | None = None,
bcc: str | None = None,
) -> dict[str, Any]:
"""Send an email via Gmail.
Use when the user explicitly asks to send an email. This sends the
email immediately - it cannot be unsent.
Args:
to: Recipient email address.
subject: Email subject line.
body: Email body content.
cc: Optional CC recipient(s), comma-separated.
bcc: Optional BCC recipient(s), comma-separated.
Returns:
Dictionary with:
- status: "success", "rejected", or "error"
- message_id: Gmail message ID (if success)
- thread_id: Gmail thread ID (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry the action.
Examples:
- "Send an email to alice@example.com about the meeting"
- "Email Bob the project update"
"""
logger.info(f"send_gmail_email called: to='{to}', subject='{subject}'")
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Gmail accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Gmail accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
logger.info(
f"Requesting approval for sending Gmail email: to='{to}', subject='{subject}'"
)
result = request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={
"to": to,
"subject": subject,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The email was not sent. Do not ask again or suggest alternatives.",
}
final_to = result.params.get("to", to)
final_subject = result.params.get("subject", subject)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get("connector_id")
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(
f"Sending Gmail email: to='{final_to}', subject='{final_subject}', connector={actual_connector_id}"
)
is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
message = MIMEText(final_body)
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
if is_composio_gmail:
from app.agents.shared.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
)
sent, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_SEND_EMAIL",
{
"user_id": "me",
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(sent, dict):
sent = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
sent = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.send(userId="me", body={"raw": raw})
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Gmail email sent: id={sent.get('id')}, threadId={sent.get('threadId')}"
)
kb_message_suffix = ""
try:
from app.services.gmail import GmailKBSyncService
kb_service = GmailKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
message_id=sent.get("id", ""),
thread_id=sent.get("threadId", ""),
subject=final_subject,
sender="me",
date_str=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
body_text=final_body,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after send failed: {kb_err}")
kb_message_suffix = " This email will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"message_id": sent.get("id"),
"thread_id": sent.get("threadId"),
"message": f"Successfully sent email to '{final_to}' with subject '{final_subject}'.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error sending Gmail email: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while sending the email. Please try again.",
}
return send_gmail_email

View file

@ -0,0 +1,344 @@
import asyncio
import logging
from datetime import datetime
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
def create_trash_gmail_email_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the trash_gmail_email tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured trash_gmail_email tool
"""
del db_session # per-call session — see docstring
@tool
async def trash_gmail_email(
email_subject_or_id: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Move an email or draft to trash in Gmail.
Use when the user asks to delete, remove, or trash an email or draft.
Args:
email_subject_or_id: The exact subject line or message ID of the
email to trash (as it appears in the inbox).
delete_from_kb: Whether to also remove the email from the knowledge base.
Default is False.
Set to True to remove from both Gmail and knowledge base.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", or "error"
- message_id: Gmail message ID (if success)
- deleted_from_kb: whether the document was removed from the knowledge base
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined. Respond with a brief
acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the email subject or check if it has been indexed.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry this tool.
Examples:
- "Delete the email about 'Meeting Cancelled'"
- "Trash the email from Bob about the project"
"""
logger.info(
f"trash_gmail_email called: email_subject_or_id='{email_subject_or_id}', delete_from_kb={delete_from_kb}"
)
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_trash_context(
search_space_id, user_id, email_subject_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Email not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Gmail account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Gmail account for this email needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
email = context["email"]
message_id = email["message_id"]
document_id = email.get("document_id")
connector_id_from_context = context["account"]["id"]
if not message_id:
return {
"status": "error",
"message": "Message ID is missing from the indexed document. Please re-index the email and try again.",
}
logger.info(
f"Requesting approval for trashing Gmail email: '{email_subject_or_id}' (message_id={message_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="gmail_email_trash",
tool_name="trash_gmail_email",
params={
"message_id": message_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The email was not trashed. Do not ask again or suggest alternatives.",
}
final_message_id = result.params.get("message_id", message_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this email.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
logger.info(
f"Trashing Gmail email: message_id='{final_message_id}', connector={final_connector_id}"
)
is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
try:
if is_composio_gmail:
from app.agents.shared.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
)
_trashed, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_MOVE_TO_TRASH",
{"user_id": "me", "message_id": final_message_id},
)
if error:
raise RuntimeError(error)
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.messages()
.trash(userId="me", id=final_message_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Gmail email trashed: message_id={final_message_id}")
trash_result: dict[str, Any] = {
"status": "success",
"message_id": final_message_id,
"message": f"Successfully moved email '{email.get('subject', email_subject_or_id)}' to trash.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"Email trashed, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error trashing Gmail email: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while trashing the email. Please try again.",
}
return trash_gmail_email

View file

@ -0,0 +1,495 @@
import asyncio
import base64
import logging
from datetime import datetime
from email.mime.text import MIMEText
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.gmail import GmailToolMetadataService
logger = logging.getLogger(__name__)
def create_update_gmail_draft_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the update_gmail_draft tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_gmail_draft tool
"""
del db_session # per-call session — see docstring
@tool
async def update_gmail_draft(
draft_subject_or_id: str,
body: str,
to: str | None = None,
subject: str | None = None,
cc: str | None = None,
bcc: str | None = None,
) -> dict[str, Any]:
"""Update an existing Gmail draft.
Use when the user asks to modify, edit, or add content to an existing
email draft. This replaces the draft content with the new version.
The user will be able to review and edit the content before it is applied.
If the user simply wants to "edit" a draft without specifying exact changes,
generate the body yourself using your best understanding of the conversation
context. The user will review and can freely edit the content in the approval
card before confirming.
IMPORTANT: This tool is ONLY for modifying Gmail draft content, NOT for
deleting/trashing drafts (use trash_gmail_email instead), Notion pages,
calendar events, or any other content type.
Args:
draft_subject_or_id: The exact subject line of the draft to update
(as it appears in Gmail drafts).
body: The full updated body content for the draft. Generate this
yourself based on the user's request and conversation context.
to: Optional new recipient email address (keeps original if omitted).
subject: Optional new subject line (keeps original if omitted).
cc: Optional CC recipient(s), comma-separated.
bcc: Optional BCC recipient(s), comma-separated.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", or "error"
- draft_id: Gmail draft ID (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the draft subject or check if it has been indexed.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry the action.
Examples:
- "Update the Kurseong Plan draft with the new itinerary details"
- "Edit my draft about the project proposal and change the recipient"
- "Let me edit the meeting notes draft" (call with current body content so user can edit in the approval card)
"""
logger.info(
f"update_gmail_draft called: draft_subject_or_id='{draft_subject_or_id}'"
)
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Gmail tool not properly configured. Please contact support.",
}
try:
async with async_session_maker() as db_session:
metadata_service = GmailToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, draft_subject_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Draft not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Gmail account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Gmail account for this draft needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "gmail",
}
email = context["email"]
message_id = email["message_id"]
document_id = email.get("document_id")
connector_id_from_context = account["id"]
draft_id_from_context = context.get("draft_id")
original_subject = email.get("subject", draft_subject_or_id)
final_subject_default = subject if subject else original_subject
final_to_default = to if to else ""
logger.info(
f"Requesting approval for updating Gmail draft: '{original_subject}' "
f"(message_id={message_id}, draft_id={draft_id_from_context})"
)
result = request_approval(
action_type="gmail_draft_update",
tool_name="update_gmail_draft",
params={
"message_id": message_id,
"draft_id": draft_id_from_context,
"to": final_to_default,
"subject": final_subject_default,
"body": body,
"cc": cc,
"bcc": bcc,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The draft was not updated. Do not ask again or suggest alternatives.",
}
final_to = result.params.get("to", final_to_default)
final_subject = result.params.get("subject", final_subject_default)
final_body = result.params.get("body", body)
final_cc = result.params.get("cc", cc)
final_bcc = result.params.get("bcc", bcc)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_draft_id = result.params.get("draft_id", draft_id_from_context)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this draft.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_gmail_types = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_gmail_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Gmail connector is invalid or has been disconnected.",
}
logger.info(
f"Updating Gmail draft: subject='{final_subject}', connector={final_connector_id}"
)
is_composio_gmail = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR
)
if is_composio_gmail:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Gmail connector.",
}
else:
from google.oauth2.credentials import Credentials
from app.config import config
from app.utils.oauth_security import TokenEncryption
config_data = dict(connector.config)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("token"):
config_data["token"] = token_encryption.decrypt_token(
config_data["token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = (
token_encryption.decrypt_token(
config_data["refresh_token"]
)
)
if config_data.get("client_secret"):
config_data["client_secret"] = (
token_encryption.decrypt_token(
config_data["client_secret"]
)
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
# Resolve draft_id if not already available
if not final_draft_id:
logger.info(
f"draft_id not in metadata, looking up via drafts.list for message_id={message_id}"
)
if is_composio_gmail:
final_draft_id = await _find_composio_draft_id_by_message(
connector, user_id, message_id
)
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
final_draft_id = await _find_draft_id_by_message(
gmail_service, message_id
)
if not final_draft_id:
return {
"status": "error",
"message": (
"Could not find this draft in Gmail. "
"It may have already been sent or deleted."
),
}
message = MIMEText(final_body)
if final_to:
message["to"] = final_to
message["subject"] = final_subject
if final_cc:
message["cc"] = final_cc
if final_bcc:
message["bcc"] = final_bcc
raw = base64.urlsafe_b64encode(message.as_bytes()).decode()
try:
if is_composio_gmail:
from app.agents.shared.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
split_recipients,
)
updated, error = await execute_composio_gmail_tool(
connector,
user_id,
"GMAIL_UPDATE_DRAFT",
{
"user_id": "me",
"draft_id": final_draft_id,
"recipient_email": final_to,
"subject": final_subject,
"body": final_body,
"cc": split_recipients(final_cc),
"bcc": split_recipients(final_bcc),
"is_html": False,
},
)
if error:
raise RuntimeError(error)
if not isinstance(updated, dict):
updated = {}
else:
from googleapiclient.discovery import build
gmail_service = build("gmail", "v1", credentials=creds)
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
gmail_service.users()
.drafts()
.update(
userId="me",
id=final_draft_id,
body={"message": {"raw": raw}},
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Gmail account needs additional permissions. Please re-authenticate in connector settings.",
}
if isinstance(api_err, HttpError) and api_err.resp.status == 404:
return {
"status": "error",
"message": "Draft no longer exists in Gmail. It may have been sent or deleted.",
}
raise
logger.info(f"Gmail draft updated: id={updated.get('id')}")
kb_message_suffix = ""
if document_id:
try:
from sqlalchemy.future import select as sa_select
from sqlalchemy.orm.attributes import flag_modified
from app.db import Document
doc_result = await db_session.execute(
sa_select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
document.source_markdown = final_body
document.title = final_subject
meta = dict(document.document_metadata or {})
meta["subject"] = final_subject
meta["draft_id"] = updated.get("id", final_draft_id)
updated_msg = updated.get("message", {})
if updated_msg.get("id"):
meta["message_id"] = updated_msg["id"]
document.document_metadata = meta
flag_modified(document, "document_metadata")
await db_session.commit()
kb_message_suffix = (
" Your knowledge base has also been updated."
)
logger.info(
f"KB document {document_id} updated for draft {final_draft_id}"
)
else:
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB update after draft edit failed: {kb_err}")
await db_session.rollback()
kb_message_suffix = " This draft will be fully updated in your knowledge base in the next scheduled sync."
return {
"status": "success",
"draft_id": updated.get("id"),
"message": f"Successfully updated Gmail draft with subject '{final_subject}'.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error updating Gmail draft: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while updating the draft. Please try again.",
}
return update_gmail_draft
async def _find_draft_id_by_message(gmail_service: Any, message_id: str) -> str | None:
"""Look up a draft's ID by its message ID via the Gmail API."""
try:
page_token = None
while True:
kwargs: dict[str, Any] = {"userId": "me", "maxResults": 100}
if page_token:
kwargs["pageToken"] = page_token
response = await asyncio.get_event_loop().run_in_executor(
None,
lambda kwargs=kwargs: (
gmail_service.users().drafts().list(**kwargs).execute()
),
)
for draft in response.get("drafts", []):
if draft.get("message", {}).get("id") == message_id:
return draft["id"]
page_token = response.get("nextPageToken")
if not page_token:
break
return None
except Exception as e:
logger.warning(f"Failed to look up draft by message_id: {e}")
return None
async def _find_composio_draft_id_by_message(
connector: Any, user_id: str, message_id: str
) -> str | None:
from app.agents.shared.tools.gmail.composio_helpers import (
execute_composio_gmail_tool,
)
page_token = ""
while True:
params: dict[str, Any] = {
"user_id": "me",
"max_results": 100,
"verbose": False,
}
if page_token:
params["page_token"] = page_token
data, error = await execute_composio_gmail_tool(
connector, user_id, "GMAIL_LIST_DRAFTS", params
)
if error or not isinstance(data, dict):
return None
for draft in data.get("drafts", []):
if draft.get("message", {}).get("id") == message_id:
return draft.get("id")
page_token = data.get("nextPageToken") or data.get("next_page_token") or ""
if not page_token:
return None

View file

@ -0,0 +1,19 @@
from app.agents.shared.tools.google_calendar.create_event import (
create_create_calendar_event_tool,
)
from app.agents.shared.tools.google_calendar.delete_event import (
create_delete_calendar_event_tool,
)
from app.agents.shared.tools.google_calendar.search_events import (
create_search_calendar_events_tool,
)
from app.agents.shared.tools.google_calendar.update_event import (
create_update_calendar_event_tool,
)
__all__ = [
"create_create_calendar_event_tool",
"create_delete_calendar_event_tool",
"create_search_calendar_events_tool",
"create_update_calendar_event_tool",
]

View file

@ -0,0 +1,382 @@
import asyncio
import logging
from datetime import datetime
from typing import Any
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
def create_create_calendar_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool
async def create_calendar_event(
summary: str,
start_datetime: str,
end_datetime: str,
description: str | None = None,
location: str | None = None,
attendees: list[str] | None = None,
) -> dict[str, Any]:
"""Create a new event on Google Calendar.
Use when the user asks to schedule, create, or add a calendar event.
Ask for event details if not provided.
Args:
summary: The event title.
start_datetime: Start time in ISO 8601 format (e.g. "2026-03-20T10:00:00").
end_datetime: End time in ISO 8601 format (e.g. "2026-03-20T11:00:00").
description: Optional event description.
location: Optional event location.
attendees: Optional list of attendee email addresses.
Returns:
Dictionary with:
- status: "success", "rejected", "auth_error", or "error"
- event_id: Google Calendar event ID (if success)
- html_link: URL to open the event (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
Examples:
- "Schedule a meeting with John tomorrow at 10am"
- "Create a calendar event for the team standup"
"""
logger.info(
f"create_calendar_event called: summary='{summary}', start='{start_datetime}', end='{end_datetime}'"
)
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
async with async_session_maker() as db_session:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning(
"All Google Calendar accounts have expired authentication"
)
return {
"status": "auth_error",
"message": "All connected Google Calendar accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
logger.info(
f"Requesting approval for creating calendar event: summary='{summary}'"
)
result = request_approval(
action_type="google_calendar_event_creation",
tool_name="create_calendar_event",
params={
"summary": summary,
"start_datetime": start_datetime,
"end_datetime": end_datetime,
"description": description,
"location": location,
"attendees": attendees,
"timezone": context.get("timezone"),
"connector_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The event was not created. Do not ask again or suggest alternatives.",
}
final_summary = result.params.get("summary", summary)
final_start_datetime = result.params.get(
"start_datetime", start_datetime
)
final_end_datetime = result.params.get("end_datetime", end_datetime)
final_description = result.params.get("description", description)
final_location = result.params.get("location", location)
final_attendees = result.params.get("attendees", attendees)
final_connector_id = result.params.get("connector_id")
if not final_summary or not final_summary.strip():
return {
"status": "error",
"message": "Event summary cannot be empty.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(
f"Creating calendar event: summary='{final_summary}', connector={actual_connector_id}"
)
is_composio_calendar = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
)
if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
tz = context.get("timezone", "UTC")
event_body: dict[str, Any] = {
"summary": final_summary,
"start": {"dateTime": final_start_datetime, "timeZone": tz},
"end": {"dateTime": final_end_datetime, "timeZone": tz},
}
if final_description:
event_body["description"] = final_description
if final_location:
event_body["location"] = final_location
if final_attendees:
event_body["attendees"] = [
{"email": e.strip()} for e in final_attendees if e.strip()
]
try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_params = {
"calendar_id": "primary",
"summary": final_summary,
"start_datetime": final_start_datetime,
"end_datetime": final_end_datetime,
"timezone": tz,
"attendees": final_attendees or [],
}
if final_description:
composio_params["description"] = final_description
if final_location:
composio_params["location"] = final_location
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_CREATE_EVENT",
params=composio_params,
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
created = composio_result.get("data", {})
if isinstance(created, dict):
created = created.get("data", created)
if isinstance(created, dict):
created = created.get("response_data", created)
else:
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
created = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.insert(calendarId="primary", body=event_body)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Calendar event created: id={created.get('id')}, summary={created.get('summary')}"
)
kb_message_suffix = ""
try:
from app.services.google_calendar import GoogleCalendarKBSyncService
kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
event_id=created.get("id"),
event_summary=final_summary,
calendar_id="primary",
start_time=final_start_datetime,
end_time=final_end_datetime,
location=final_location,
html_link=created.get("htmlLink"),
description=final_description,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This event will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"event_id": created.get("id"),
"html_link": created.get("htmlLink"),
"message": f"Successfully created '{final_summary}' on Google Calendar.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating calendar event: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the event. Please try again.",
}
return create_calendar_event

View file

@ -0,0 +1,340 @@
import asyncio
import logging
from datetime import datetime
from typing import Any
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
def create_delete_calendar_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the delete_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_calendar_event(
event_title_or_id: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Delete a Google Calendar event.
Use when the user asks to delete, remove, or cancel a calendar event.
Args:
event_title_or_id: The exact title or event ID of the event to delete.
delete_from_kb: Whether to also remove the event from the knowledge base.
Default is False.
Set to True to remove from both Google Calendar and knowledge base.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", "auth_error", or "error"
- event_id: Google Calendar event ID (if success)
- deleted_from_kb: whether the document was removed from the knowledge base
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined. Respond with a brief
acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the event name or check if it has been indexed.
Examples:
- "Delete the team standup event"
- "Cancel my dentist appointment on Friday"
"""
logger.info(
f"delete_calendar_event called: event_ref='{event_title_or_id}', delete_from_kb={delete_from_kb}"
)
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
async with async_session_maker() as db_session:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_deletion_context(
search_space_id, user_id, event_title_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch deletion context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Google Calendar account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
event = context["event"]
event_id = event["event_id"]
document_id = event.get("document_id")
connector_id_from_context = context["account"]["id"]
if not event_id:
return {
"status": "error",
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
logger.info(
f"Requesting approval for deleting calendar event: '{event_title_or_id}' (event_id={event_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="google_calendar_event_deletion",
tool_name="delete_calendar_event",
params={
"event_id": event_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The event was not deleted. Do not ask again or suggest alternatives.",
}
final_event_id = result.params.get("event_id", event_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
logger.info(
f"Deleting calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
is_composio_calendar = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
)
if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_DELETE_EVENT",
params={
"calendar_id": "primary",
"event_id": final_event_id,
},
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
else:
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.delete(calendarId="primary", eventId=final_event_id)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Calendar event deleted: event_id={final_event_id}")
delete_result: dict[str, Any] = {
"status": "success",
"event_id": final_event_id,
"message": f"Successfully deleted the calendar event '{event.get('summary', event_title_or_id)}'.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
delete_result["warning"] = (
f"Event deleted, but failed to remove from knowledge base: {e!s}"
)
delete_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
delete_result["message"] = (
f"{delete_result.get('message', '')} (also removed from knowledge base)"
)
return delete_result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting calendar event: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while deleting the event. Please try again.",
}
return delete_calendar_event

View file

@ -0,0 +1,187 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.shared.tools.gmail.search_emails import _build_credentials
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
_CALENDAR_TYPES = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
def _to_calendar_boundary(value: str, *, is_end: bool) -> str:
if "T" in value:
return value
time = "23:59:59" if is_end else "00:00:00"
return f"{value}T{time}Z"
def _format_calendar_events(events_raw: list[dict[str, Any]]) -> list[dict[str, Any]]:
events = []
for ev in events_raw:
start = ev.get("start", {})
end = ev.get("end", {})
attendees_raw = ev.get("attendees", [])
events.append(
{
"event_id": ev.get("id"),
"summary": ev.get("summary", "No Title"),
"start": start.get("dateTime") or start.get("date", ""),
"end": end.get("dateTime") or end.get("date", ""),
"location": ev.get("location", ""),
"description": ev.get("description", ""),
"html_link": ev.get("htmlLink", ""),
"attendees": [a.get("email", "") for a in attendees_raw[:10]],
"status": ev.get("status", ""),
}
)
return events
def create_search_calendar_events_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the search_calendar_events tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured search_calendar_events tool
"""
del db_session # per-call session — see docstring
@tool
async def search_calendar_events(
start_date: str,
end_date: str,
max_results: int = 25,
) -> dict[str, Any]:
"""Search Google Calendar events within a date range.
Args:
start_date: Start date in YYYY-MM-DD format (e.g. "2026-04-01").
end_date: End date in YYYY-MM-DD format (e.g. "2026-04-30").
max_results: Maximum number of events to return (default 25, max 50).
Returns:
Dictionary with status and a list of events including
event_id, summary, start, end, location, attendees.
"""
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Calendar tool not properly configured.",
}
max_results = min(max_results, 50)
try:
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
if (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
):
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
from app.services.composio_service import ComposioService
events_raw, error = await ComposioService().get_calendar_events(
connected_account_id=cca_id,
entity_id=f"surfsense_{user_id}",
time_min=_to_calendar_boundary(start_date, is_end=False),
time_max=_to_calendar_boundary(end_date, is_end=True),
max_results=max_results,
)
if not events_raw and not error:
error = "No events found in the specified date range."
else:
creds = _build_credentials(connector)
from app.connectors.google_calendar_connector import (
GoogleCalendarConnector,
)
cal = GoogleCalendarConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
events_raw, error = await cal.get_all_primary_calendar_events(
start_date=start_date,
end_date=end_date,
max_results=max_results,
)
if error:
if (
"re-authenticate" in error.lower()
or "authentication failed" in error.lower()
):
return {
"status": "auth_error",
"message": error,
"connector_type": "google_calendar",
}
if "no events found" in error.lower():
return {
"status": "success",
"events": [],
"total": 0,
"message": error,
}
return {"status": "error", "message": error}
events = _format_calendar_events(events_raw)
return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error searching calendar events: %s", e, exc_info=True)
return {
"status": "error",
"message": "Failed to search calendar events. Please try again.",
}
return search_calendar_events

View file

@ -0,0 +1,419 @@
import asyncio
import logging
from datetime import datetime
from typing import Any
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.db import async_session_maker
from app.services.google_calendar import GoogleCalendarToolMetadataService
logger = logging.getLogger(__name__)
def _is_date_only(value: str) -> bool:
"""Return True when *value* looks like a bare date (YYYY-MM-DD) with no time component."""
return len(value) <= 10 and "T" not in value
def _build_time_body(value: str, context: dict[str, Any] | Any) -> dict[str, str]:
"""Build a Google Calendar start/end body using ``date`` for all-day
events and ``dateTime`` for timed events."""
if _is_date_only(value):
return {"date": value}
tz = context.get("timezone", "UTC") if isinstance(context, dict) else "UTC"
return {"dateTime": value, "timeZone": tz}
def create_update_calendar_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the update_calendar_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured update_calendar_event tool
"""
del db_session # per-call session — see docstring
@tool
async def update_calendar_event(
event_title_or_id: str,
new_summary: str | None = None,
new_start_datetime: str | None = None,
new_end_datetime: str | None = None,
new_description: str | None = None,
new_location: str | None = None,
new_attendees: list[str] | None = None,
) -> dict[str, Any]:
"""Update an existing Google Calendar event.
Use when the user asks to modify, reschedule, or change a calendar event.
Args:
event_title_or_id: The exact title or event ID of the event to update.
new_summary: New event title (if changing).
new_start_datetime: New start time in ISO 8601 format (if rescheduling).
new_end_datetime: New end time in ISO 8601 format (if rescheduling).
new_description: New event description (if changing).
new_location: New event location (if changing).
new_attendees: New list of attendee email addresses (if changing).
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", "auth_error", or "error"
- event_id: Google Calendar event ID (if success)
- html_link: URL to open the event (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined. Respond with a brief
acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the event name or check if it has been indexed.
Examples:
- "Reschedule the team standup to 3pm"
- "Change the location of my dentist appointment"
"""
logger.info(f"update_calendar_event called: event_ref='{event_title_or_id}'")
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Calendar tool not properly configured. Please contact support.",
}
try:
async with async_session_maker() as db_session:
metadata_service = GoogleCalendarToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, event_title_or_id
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"Event not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch update context: {error_msg}")
return {"status": "error", "message": error_msg}
if context.get("auth_expired"):
logger.warning("Google Calendar account has expired authentication")
return {
"status": "auth_error",
"message": "The Google Calendar account for this event needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_calendar",
}
event = context["event"]
event_id = event["event_id"]
document_id = event.get("document_id")
connector_id_from_context = context["account"]["id"]
if not event_id:
return {
"status": "error",
"message": "Event ID is missing from the indexed document. Please re-index the event and try again.",
}
logger.info(
f"Requesting approval for updating calendar event: '{event_title_or_id}' (event_id={event_id})"
)
result = request_approval(
action_type="google_calendar_event_update",
tool_name="update_calendar_event",
params={
"event_id": event_id,
"document_id": document_id,
"connector_id": connector_id_from_context,
"new_summary": new_summary,
"new_start_datetime": new_start_datetime,
"new_end_datetime": new_end_datetime,
"new_description": new_description,
"new_location": new_location,
"new_attendees": new_attendees,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The event was not updated. Do not ask again or suggest alternatives.",
}
final_event_id = result.params.get("event_id", event_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_new_summary = result.params.get("new_summary", new_summary)
final_new_start_datetime = result.params.get(
"new_start_datetime", new_start_datetime
)
final_new_end_datetime = result.params.get(
"new_end_datetime", new_end_datetime
)
final_new_description = result.params.get(
"new_description", new_description
)
final_new_location = result.params.get("new_location", new_location)
final_new_attendees = result.params.get("new_attendees", new_attendees)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this event.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_calendar_types = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_calendar_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Calendar connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
logger.info(
f"Updating calendar event: event_id='{final_event_id}', connector={actual_connector_id}"
)
is_composio_calendar = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR
)
if is_composio_calendar:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this connector.",
}
else:
config_data = dict(connector.config)
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and app_config.SECRET_KEY:
token_encryption = TokenEncryption(app_config.SECRET_KEY)
for key in ("token", "refresh_token", "client_secret"):
if config_data.get(key):
config_data[key] = token_encryption.decrypt_token(
config_data[key]
)
exp = config_data.get("expiry", "")
if exp:
exp = exp.replace("Z", "")
creds = Credentials(
token=config_data.get("token"),
refresh_token=config_data.get("refresh_token"),
token_uri=config_data.get("token_uri"),
client_id=config_data.get("client_id"),
client_secret=config_data.get("client_secret"),
scopes=config_data.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
update_body: dict[str, Any] = {}
if final_new_summary is not None:
update_body["summary"] = final_new_summary
if final_new_start_datetime is not None:
update_body["start"] = _build_time_body(
final_new_start_datetime, context
)
if final_new_end_datetime is not None:
update_body["end"] = _build_time_body(
final_new_end_datetime, context
)
if final_new_description is not None:
update_body["description"] = final_new_description
if final_new_location is not None:
update_body["location"] = final_new_location
if final_new_attendees is not None:
update_body["attendees"] = [
{"email": e.strip()} for e in final_new_attendees if e.strip()
]
if not update_body:
return {
"status": "error",
"message": "No changes specified. Please provide at least one field to update.",
}
try:
if is_composio_calendar:
from app.services.composio_service import ComposioService
composio_params: dict[str, Any] = {
"calendar_id": "primary",
"event_id": final_event_id,
}
if final_new_summary is not None:
composio_params["summary"] = final_new_summary
if final_new_start_datetime is not None:
composio_params["start_time"] = final_new_start_datetime
if final_new_end_datetime is not None:
composio_params["end_time"] = final_new_end_datetime
if final_new_description is not None:
composio_params["description"] = final_new_description
if final_new_location is not None:
composio_params["location"] = final_new_location
if final_new_attendees is not None:
composio_params["attendees"] = [
e.strip() for e in final_new_attendees if e.strip()
]
if not _is_date_only(
final_new_start_datetime or final_new_end_datetime or ""
):
composio_params["timezone"] = context.get("timezone", "UTC")
composio_result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLECALENDAR_PATCH_EVENT",
params=composio_params,
entity_id=f"surfsense_{user_id}",
)
if not composio_result.get("success"):
raise RuntimeError(
composio_result.get(
"error", "Unknown Composio Calendar error"
)
)
updated = composio_result.get("data", {})
if isinstance(updated, dict):
updated = updated.get("data", updated)
if isinstance(updated, dict):
updated = updated.get("response_data", updated)
else:
service = await asyncio.get_event_loop().run_in_executor(
None, lambda: build("calendar", "v3", credentials=creds)
)
updated = await asyncio.get_event_loop().run_in_executor(
None,
lambda: (
service.events()
.patch(
calendarId="primary",
eventId=final_event_id,
body=update_body,
)
.execute()
),
)
except Exception as api_err:
from googleapiclient.errors import HttpError
if isinstance(api_err, HttpError) and api_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {api_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Calendar account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(f"Calendar event updated: event_id={final_event_id}")
kb_message_suffix = ""
if document_id is not None:
try:
from app.services.google_calendar import (
GoogleCalendarKBSyncService,
)
kb_service = GoogleCalendarKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=document_id,
event_id=final_event_id,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after update failed: {kb_err}")
kb_message_suffix = " The knowledge base will be updated in the next scheduled sync."
return {
"status": "success",
"event_id": final_event_id,
"html_link": updated.get("htmlLink"),
"message": f"Successfully updated the calendar event.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error updating calendar event: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while updating the event. Please try again.",
}
return update_calendar_event

View file

@ -0,0 +1,11 @@
from app.agents.shared.tools.google_drive.create_file import (
create_create_google_drive_file_tool,
)
from app.agents.shared.tools.google_drive.trash_file import (
create_delete_google_drive_file_tool,
)
__all__ = [
"create_create_google_drive_file_tool",
"create_delete_google_drive_file_tool",
]

View file

@ -0,0 +1,340 @@
import logging
from typing import Any, Literal
from googleapiclient.errors import HttpError
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient
from app.connectors.google_drive.file_types import GOOGLE_DOC, GOOGLE_SHEET
from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__)
_MIME_MAP: dict[str, str] = {
"google_doc": GOOGLE_DOC,
"google_sheet": GOOGLE_SHEET,
}
def create_create_google_drive_file_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_google_drive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Google Drive connector
user_id: User ID for fetching user-specific context
Returns:
Configured create_google_drive_file tool
"""
del db_session # per-call session — see docstring
@tool
async def create_google_drive_file(
name: str,
file_type: Literal["google_doc", "google_sheet"],
content: str | None = None,
) -> dict[str, Any]:
"""Create a new Google Doc or Google Sheet in Google Drive.
Use this tool when the user explicitly asks to create a new document
or spreadsheet in Google Drive. The user MUST specify a topic before
you call this tool. If the request does not contain a topic (e.g.
"create a drive doc" or "make a Google Sheet"), ask what the file
should be about. Never call this tool without a clear topic from the user.
Args:
name: The file name (without extension).
file_type: Either "google_doc" or "google_sheet".
content: Optional initial content. Generate from the user's topic.
For google_doc, provide markdown text. For google_sheet, provide CSV-formatted text.
Returns:
Dictionary with:
- status: "success", "rejected", or "error"
- file_id: Google Drive file ID (if success)
- name: File name (if success)
- web_view_link: URL to open the file (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment and do NOT retry or suggest alternatives.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry the action.
Examples:
- "Create a Google Doc with today's meeting notes"
- "Create a spreadsheet for the 2026 budget"
"""
logger.info(
f"create_google_drive_file called: name='{name}', type='{file_type}'"
)
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Drive tool not properly configured. Please contact support.",
}
if file_type not in _MIME_MAP:
return {
"status": "error",
"message": f"Unsupported file type '{file_type}'. Use 'google_doc' or 'google_sheet'.",
}
try:
async with async_session_maker() as db_session:
metadata_service = GoogleDriveToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {"status": "error", "message": context["error"]}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning(
"All Google Drive accounts have expired authentication"
)
return {
"status": "auth_error",
"message": "All connected Google Drive accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_drive",
}
logger.info(
f"Requesting approval for creating Google Drive file: name='{name}', type='{file_type}'"
)
result = request_approval(
action_type="google_drive_file_creation",
tool_name="create_google_drive_file",
params={
"name": name,
"file_type": file_type,
"content": content,
"connector_id": None,
"parent_folder_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The file was not created. Do not ask again or suggest alternatives.",
}
final_name = result.params.get("name", name)
final_file_type = result.params.get("file_type", file_type)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_id = result.params.get("parent_folder_id")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
mime_type = _MIME_MAP.get(final_file_type)
if not mime_type:
return {
"status": "error",
"message": f"Unsupported file type '{final_file_type}'.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
actual_connector_id = connector.id
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Drive connector found. Please connect Google Drive in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(
f"Creating Google Drive file: name='{final_name}', type='{final_file_type}', connector={actual_connector_id}"
)
is_composio_drive = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
)
if is_composio_drive:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Drive connector.",
}
client = GoogleDriveClient(
session=db_session,
connector_id=actual_connector_id,
)
try:
if is_composio_drive:
from app.services.composio_service import ComposioService
params: dict[str, Any] = {
"name": final_name,
"mimeType": mime_type,
"fields": "id,name,webViewLink,mimeType",
}
if final_parent_folder_id:
params["parents"] = [final_parent_folder_id]
if final_content:
params["description"] = final_content[:4096]
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLEDRIVE_CREATE_FILE",
params=params,
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
raise RuntimeError(
result.get("error", "Unknown Composio Drive error")
)
created = result.get("data", {})
if isinstance(created, dict):
created = created.get("data", created)
if isinstance(created, dict):
created = created.get("response_data", created)
if not isinstance(created, dict):
created = {}
else:
created = await client.create_file(
name=final_name,
mime_type=mime_type,
parent_folder_id=final_parent_folder_id,
content=final_content,
)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {actual_connector_id}: {http_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
_res = await db_session.execute(
select(SearchSourceConnector).where(
SearchSourceConnector.id == actual_connector_id
)
)
_conn = _res.scalar_one_or_none()
if _conn and not _conn.config.get("auth_expired"):
_conn.config = {**_conn.config, "auth_expired": True}
flag_modified(_conn, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
actual_connector_id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": actual_connector_id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Google Drive file created: id={created.get('id')}, name={created.get('name')}"
)
kb_message_suffix = ""
try:
from app.services.google_drive import GoogleDriveKBSyncService
kb_service = GoogleDriveKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=created.get("id"),
file_name=created.get("name", final_name),
mime_type=mime_type,
web_view_link=created.get("webViewLink"),
content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"file_id": created.get("id"),
"name": created.get("name"),
"web_view_link": created.get("webViewLink"),
"message": f"Successfully created '{created.get('name')}' in Google Drive.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating Google Drive file: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the file. Please try again.",
}
return create_google_drive_file

View file

@ -0,0 +1,299 @@
import logging
from typing import Any
from googleapiclient.errors import HttpError
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.connectors.google_drive.client import GoogleDriveClient
from app.db import async_session_maker
from app.services.google_drive import GoogleDriveToolMetadataService
logger = logging.getLogger(__name__)
def create_delete_google_drive_file_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the delete_google_drive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Google Drive connector
user_id: User ID for fetching user-specific context
Returns:
Configured delete_google_drive_file tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_google_drive_file(
file_name: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Move a Google Drive file to trash.
Use this tool when the user explicitly asks to delete, remove, or trash
a file in Google Drive.
Args:
file_name: The exact name of the file to trash (as it appears in Drive).
delete_from_kb: Whether to also remove the file from the knowledge base.
Default is False.
Set to True to remove from both Google Drive and knowledge base.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", or "error"
- file_id: Google Drive file ID (if success)
- deleted_from_kb: whether the document was removed from the knowledge base
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined. Respond with a brief
acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the file name or check if it has been indexed.
- If status is "insufficient_permissions", the connector lacks the required OAuth scope.
Inform the user they need to re-authenticate and do NOT retry this tool.
Examples:
- "Delete the 'Meeting Notes' file from Google Drive"
- "Trash the 'Old Budget' spreadsheet"
"""
logger.info(
f"delete_google_drive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "Google Drive tool not properly configured. Please contact support.",
}
try:
async with async_session_maker() as db_session:
metadata_service = GoogleDriveToolMetadataService(db_session)
context = await metadata_service.get_trash_context(
search_space_id, user_id, file_name
)
if "error" in context:
error_msg = context["error"]
if "not found" in error_msg.lower():
logger.warning(f"File not found: {error_msg}")
return {"status": "not_found", "message": error_msg}
logger.error(f"Failed to fetch trash context: {error_msg}")
return {"status": "error", "message": error_msg}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Google Drive account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Google Drive account for this file needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "google_drive",
}
file = context["file"]
file_id = file["file_id"]
document_id = file.get("document_id")
connector_id_from_context = context["account"]["id"]
if not file_id:
return {
"status": "error",
"message": "File ID is missing from the indexed document. Please re-index the file and try again.",
}
logger.info(
f"Requesting approval for deleting Google Drive file: '{file_name}' (file_id={file_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="google_drive_file_trash",
tool_name="delete_google_drive_file",
params={
"file_id": file_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. The file was not trashed. Do not ask again or suggest alternatives.",
}
final_file_id = result.params.get("file_id", file_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if not final_connector_id:
return {
"status": "error",
"message": "No connector found for this file.",
}
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
_drive_types = [
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR,
]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_drive_types),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "Selected Google Drive connector is invalid or has been disconnected.",
}
logger.info(
f"Deleting Google Drive file: file_id='{final_file_id}', connector={final_connector_id}"
)
is_composio_drive = (
connector.connector_type
== SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR
)
if is_composio_drive:
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
return {
"status": "error",
"message": "Composio connected account ID not found for this Drive connector.",
}
client = GoogleDriveClient(
session=db_session,
connector_id=connector.id,
)
try:
if is_composio_drive:
from app.services.composio_service import ComposioService
result = await ComposioService().execute_tool(
connected_account_id=cca_id,
tool_name="GOOGLEDRIVE_TRASH_FILE",
params={"file_id": final_file_id},
entity_id=f"surfsense_{user_id}",
)
if not result.get("success"):
raise RuntimeError(
result.get("error", "Unknown Composio Drive error")
)
else:
await client.trash_file(file_id=final_file_id)
except HttpError as http_err:
if http_err.resp.status == 403:
logger.warning(
f"Insufficient permissions for connector {connector.id}: {http_err}"
)
try:
from sqlalchemy.orm.attributes import flag_modified
if not connector.config.get("auth_expired"):
connector.config = {
**connector.config,
"auth_expired": True,
}
flag_modified(connector, "config")
await db_session.commit()
except Exception:
logger.warning(
"Failed to persist auth_expired for connector %s",
connector.id,
exc_info=True,
)
return {
"status": "insufficient_permissions",
"connector_id": connector.id,
"message": "This Google Drive account needs additional permissions. Please re-authenticate in connector settings.",
}
raise
logger.info(
f"Google Drive file deleted (moved to trash): file_id={final_file_id}"
)
trash_result: dict[str, Any] = {
"status": "success",
"file_id": final_file_id,
"message": f"Successfully moved '{file['name']}' to trash.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
from app.db import Document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File moved to trash, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting Google Drive file: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while trashing the file. Please try again.",
}
return delete_google_drive_file

View file

@ -0,0 +1,187 @@
"""Unified HITL (Human-in-the-Loop) approval utility.
Provides a single ``request_approval()`` function that encapsulates the
interrupt payload creation, decision parsing, and parameter merging logic
shared by every sensitive tool (native connectors and MCP tools alike).
Usage inside a tool::
from app.agents.shared.tools.hitl import request_approval
result = request_approval(
action_type="gmail_email_send",
tool_name="send_gmail_email",
params={"to": to, "subject": subject, "body": body},
context=context,
)
if result.rejected:
return {"status": "rejected", "message": "User declined."}
# result.params contains the final (possibly edited) parameters
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
from langgraph.types import interrupt
logger = logging.getLogger(__name__)
# Tools that mirror the safety profile of ``write_file`` against the
# SurfSense KB: each call creates ONE artifact in the user's own workspace
# with no external visibility (drafts aren't sent; new files aren't shared
# unless the user shares them later). These are auto-approved by default
# so the agent can compose drafts and seed scratch files without a popup
# on every call.
#
# Members of this set still call ``request_approval`` exactly as before;
# the function returns immediately with ``decision_type="auto_approved"``
# and the original params untouched. This preserves the call-site shape
# (logging, metadata fetching, account fallbacks) so the only behavior
# change is "no interrupt fires".
#
# To re-enable prompting, the future per-search-space rules table
# (``agent_permission_rules``) takes precedence — see the ``# (future)``
# layer-3 comment in :mod:`app.agents.new_chat.chat_deepagent`.
DEFAULT_AUTO_APPROVED_TOOLS: frozenset[str] = frozenset(
{
"create_gmail_draft",
"update_gmail_draft",
"create_calendar_event",
"create_notion_page",
"create_confluence_page",
"create_google_drive_file",
"create_dropbox_file",
"create_onedrive_file",
}
)
@dataclass(frozen=True, slots=True)
class HITLResult:
"""Outcome of a human-in-the-loop approval request."""
rejected: bool
decision_type: str
params: dict[str, Any] = field(default_factory=dict)
def _parse_decision(approval: Any) -> tuple[str, dict[str, Any]]:
"""Extract the first valid decision and its edited parameters.
Returns:
(decision_type, edited_params) where *decision_type* is one of
``"approve"``, ``"edit"``, or ``"reject"`` and *edited_params* is
the dict of user-modified arguments (empty when there are none).
Raises:
ValueError: when no usable decision dict can be found.
"""
decisions_raw = approval.get("decisions", []) if isinstance(approval, dict) else []
decisions = decisions_raw if isinstance(decisions_raw, list) else [decisions_raw]
decisions = [d for d in decisions if isinstance(d, dict)]
if not decisions:
raise ValueError("No approval decision received")
decision = decisions[0]
decision_type: str = (
decision.get("type") or decision.get("decision_type") or "approve"
)
edited_params: dict[str, Any] = {}
edited_action = decision.get("edited_action")
if isinstance(edited_action, dict):
edited_args = edited_action.get("args")
if isinstance(edited_args, dict):
edited_params = edited_args
elif isinstance(decision.get("args"), dict):
edited_params = decision["args"]
return decision_type, edited_params
def request_approval(
*,
action_type: str,
tool_name: str,
params: dict[str, Any],
context: dict[str, Any] | None = None,
trusted_tools: list[str] | None = None,
) -> HITLResult:
"""Pause the graph for user approval and return the decision.
This is a **synchronous** helper (not ``async``) because
``langgraph.types.interrupt`` is itself synchronous it raises a
``GraphInterrupt`` exception that the LangGraph runtime catches.
Parameters
----------
action_type:
A label that the frontend uses to select the correct approval card
(e.g. ``"gmail_email_send"``, ``"mcp_tool_call"``).
tool_name:
The registered LangChain tool name (e.g. ``"send_gmail_email"``).
params:
The original tool arguments. These are shown in the approval card
and used as defaults when the user does not edit anything.
context:
Rich metadata from a ``*ToolMetadataService`` (accounts, folders,
labels, etc.). For MCP tools this can hold the server name and
tool description.
trusted_tools:
An allow-list of tool names the user has previously marked as
"Always Allow". If *tool_name* appears in this list, HITL is
skipped and the tool executes immediately.
Returns
-------
HITLResult
``result.rejected`` is ``True`` when the user chose to deny the
action. Otherwise ``result.params`` contains the final parameter
dict either the originals or the user-edited version merged on
top.
"""
if trusted_tools and tool_name in trusted_tools:
logger.info("Tool '%s' is user-trusted — skipping HITL", tool_name)
return HITLResult(rejected=False, decision_type="trusted", params=dict(params))
if tool_name in DEFAULT_AUTO_APPROVED_TOOLS:
# Default policy: low-stakes creation tools (drafts + new-file
# creates) skip HITL because they're as recoverable as a local
# ``write_file`` against the SurfSense KB. The user can still
# delete the artifact in <30s if it's wrong.
logger.info(
"Tool '%s' is in DEFAULT_AUTO_APPROVED_TOOLS — skipping HITL",
tool_name,
)
return HITLResult(
rejected=False, decision_type="auto_approved", params=dict(params)
)
approval = interrupt(
{
"type": action_type,
"action": {"tool": tool_name, "params": params},
"context": context or {},
}
)
try:
decision_type, edited_params = _parse_decision(approval)
except ValueError:
logger.warning(
"No approval decision received for %s — rejecting for safety", tool_name
)
return HITLResult(rejected=True, decision_type="error", params=params)
logger.info("User decision for %s: %s", tool_name, decision_type)
if decision_type == "reject":
return HITLResult(rejected=True, decision_type="reject", params=params)
final_params = {**params, **edited_params} if edited_params else dict(params)
return HITLResult(rejected=False, decision_type=decision_type, params=final_params)

View file

@ -0,0 +1,53 @@
"""
The ``invalid`` fallback tool.
When the model emits a tool call whose name doesn't match any registered
tool, :class:`ToolCallNameRepairMiddleware` rewrites the call to ``invalid``
with the original name and a parser/validation error string. This tool's
execution then returns that error to the model so it can self-correct.
Ported from OpenCode's ``packages/opencode/src/tool/invalid.ts`` —
LangChain has no equivalent fallback path; the default behavior on an
unknown tool name is a hard ``ToolNotFoundError`` which kills the turn.
Critically, the :class:`ToolDefinition` for this tool is **excluded** from
the system-prompt tool list and from ``LLMToolSelectorMiddleware`` selection
(see ``ToolDefinition.always_include`` filtering in the registry) the
model never advertises ``invalid`` as a callable. It only ever shows up
in the tool registry so LangGraph can dispatch the rewritten call.
"""
from __future__ import annotations
from langchain_core.tools import tool
INVALID_TOOL_NAME = "invalid"
INVALID_TOOL_DESCRIPTION = "Do not use"
def _format_invalid_message(tool: str | None, error: str | None) -> str:
"""Return the user-visible error string. Mirrors ``invalid.ts``."""
name = tool or "<unknown>"
detail = error or "(no error message provided)"
return (
f"The arguments provided to the tool `{name}` are invalid: {detail}\n"
f"Read the tool's docstring carefully and try again with valid arguments."
)
@tool(name_or_callable=INVALID_TOOL_NAME, description=INVALID_TOOL_DESCRIPTION)
def invalid_tool(tool: str | None = None, error: str | None = None) -> str:
"""Return a human-readable explanation of a tool-call validation failure.
Activated only when :class:`ToolCallNameRepairMiddleware` rewrites a
failed tool call to ``invalid`` with the original tool name and the
error message produced during validation.
"""
return _format_invalid_message(tool, error)
__all__ = [
"INVALID_TOOL_DESCRIPTION",
"INVALID_TOOL_NAME",
"invalid_tool",
]

View file

@ -0,0 +1,817 @@
"""
Knowledge base search tool for the SurfSense agent.
This module provides:
- Connector constants and normalization
- Async knowledge base search across multiple connectors
- Document formatting for LLM context
"""
import asyncio
import contextlib
import json
import re
import time
from datetime import datetime
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import NATIVE_TO_LEGACY_DOCTYPE, shielded_async_session
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
# Connectors that call external live-search APIs. These are handled by the
# ``web_search`` tool and must be excluded from knowledge-base searches.
_LIVE_SEARCH_CONNECTORS: set[str] = {
"TAVILY_API",
"LINKUP_API",
"BAIDU_SEARCH_API",
}
# Patterns that indicate the query has no meaningful search signal.
# plainto_tsquery('english', '*') produces an empty tsquery and an embedding
# of '*' is random noise, so both keyword and semantic search degrade to
# arbitrary ordering — large documents (many chunks) dominate by chance.
_DEGENERATE_QUERY_RE = re.compile(
r"^[\s*?_.#@!\-/\\]+$" # only wildcards, punctuation, whitespace
)
# Max chunks per document when doing a recency-based browse instead of
# a real search. We want breadth (many docs) over depth (many chunks).
_BROWSE_MAX_CHUNKS_PER_DOC = 5
def _is_degenerate_query(query: str) -> bool:
"""Return True when the query carries no meaningful search signal.
Catches wildcard patterns (``*``, ``**``), empty / whitespace-only
strings, and single-character non-word tokens. These queries cause
both keyword search (empty tsquery) and semantic search (meaningless
embedding) to return effectively random results.
"""
stripped = query.strip()
if not stripped:
return True
return bool(_DEGENERATE_QUERY_RE.match(stripped))
async def _browse_recent_documents(
search_space_id: int,
document_type: str | list[str] | None,
top_k: int,
start_date: datetime | None,
end_date: datetime | None,
) -> list[dict[str, Any]]:
"""Return the most-recent documents (recency-ordered, no search ranking).
Used as a fallback when the search query is degenerate (e.g. ``*``) and
semantic / keyword search would produce arbitrary results. Returns
document-grouped dicts in the same shape as ``_combined_rrf_search``
so the rest of the pipeline works unchanged.
"""
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from app.db import Chunk, Document, DocumentType
perf = get_perf_logger()
t0 = time.perf_counter()
base_conditions = [Document.search_space_id == search_space_id]
if document_type is not None:
type_list = (
document_type if isinstance(document_type, list) else [document_type]
)
doc_type_enums = []
for dt in type_list:
if isinstance(dt, str):
with contextlib.suppress(KeyError):
doc_type_enums.append(DocumentType[dt])
else:
doc_type_enums.append(dt)
if not doc_type_enums:
return []
if len(doc_type_enums) == 1:
base_conditions.append(Document.document_type == doc_type_enums[0])
else:
base_conditions.append(Document.document_type.in_(doc_type_enums))
if start_date is not None:
base_conditions.append(Document.updated_at >= start_date)
if end_date is not None:
base_conditions.append(Document.updated_at <= end_date)
async with shielded_async_session() as session:
doc_query = (
select(Document)
.options(joinedload(Document.search_space))
.where(*base_conditions)
.order_by(Document.updated_at.desc())
.limit(top_k)
)
result = await session.execute(doc_query)
documents = result.scalars().unique().all()
if not documents:
return []
doc_ids = [d.id for d in documents]
chunk_query = (
select(Chunk)
.where(Chunk.document_id.in_(doc_ids))
.order_by(Chunk.document_id, Chunk.id)
)
chunk_result = await session.execute(chunk_query)
raw_chunks = chunk_result.scalars().all()
doc_chunk_counts: dict[int, int] = {}
doc_chunks: dict[int, list[dict]] = {d.id: [] for d in documents}
for chunk in raw_chunks:
did = chunk.document_id
count = doc_chunk_counts.get(did, 0)
if count < _BROWSE_MAX_CHUNKS_PER_DOC:
doc_chunks[did].append({"chunk_id": chunk.id, "content": chunk.content})
doc_chunk_counts[did] = count + 1
results: list[dict[str, Any]] = []
for doc in documents:
chunks_list = doc_chunks.get(doc.id, [])
results.append(
{
"document_id": doc.id,
"content": "\n\n".join(
c["content"] for c in chunks_list if c.get("content")
),
"score": 0.0,
"chunks": chunks_list,
"document": {
"id": doc.id,
"title": doc.title,
"document_type": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
"metadata": doc.document_metadata or {},
},
"source": doc.document_type.value
if getattr(doc, "document_type", None)
else None,
}
)
perf.info(
"[kb_browse] recency browse in %.3fs docs=%d space=%d type=%s",
time.perf_counter() - t0,
len(results),
search_space_id,
document_type,
)
return results
# =============================================================================
# Connector Constants and Normalization
# =============================================================================
# Canonical connector values used internally by ConnectorService
# Includes all document types and search source connectors
_ALL_CONNECTORS: list[str] = [
"EXTENSION",
"FILE",
"SLACK_CONNECTOR",
"TEAMS_CONNECTOR",
"NOTION_CONNECTOR",
"YOUTUBE_VIDEO",
"GITHUB_CONNECTOR",
"ELASTICSEARCH_CONNECTOR",
"LINEAR_CONNECTOR",
"JIRA_CONNECTOR",
"CONFLUENCE_CONNECTOR",
"CLICKUP_CONNECTOR",
"GOOGLE_CALENDAR_CONNECTOR",
"GOOGLE_GMAIL_CONNECTOR",
"GOOGLE_DRIVE_FILE",
"DISCORD_CONNECTOR",
"AIRTABLE_CONNECTOR",
"LUMA_CONNECTOR",
"NOTE",
"BOOKSTACK_CONNECTOR",
"CRAWLED_URL",
"CIRCLEBACK",
"OBSIDIAN_CONNECTOR",
"ONEDRIVE_FILE",
"DROPBOX_FILE",
]
# Human-readable descriptions for each connector type
# Used for generating dynamic docstrings and informing the LLM
CONNECTOR_DESCRIPTIONS: dict[str, str] = {
"EXTENSION": "Web content saved via SurfSense browser extension (personal browsing history)",
"FILE": "User-uploaded documents (PDFs, Word, etc.) (personal files)",
"NOTE": "SurfSense Notes (notes created inside SurfSense)",
"SLACK_CONNECTOR": "Slack conversations and shared content (personal workspace communications)",
"TEAMS_CONNECTOR": "Microsoft Teams messages and conversations (personal Teams communications)",
"NOTION_CONNECTOR": "Notion workspace pages and databases (personal knowledge management)",
"YOUTUBE_VIDEO": "YouTube video transcripts and metadata (personally saved videos)",
"GITHUB_CONNECTOR": "GitHub repository content and issues (personal repositories and interactions)",
"ELASTICSEARCH_CONNECTOR": "Elasticsearch indexed documents and data (personal Elasticsearch instances)",
"LINEAR_CONNECTOR": "Linear project issues and discussions (personal project management)",
"JIRA_CONNECTOR": "Jira project issues, tickets, and comments (personal project tracking)",
"CONFLUENCE_CONNECTOR": "Confluence pages and comments (personal project documentation)",
"CLICKUP_CONNECTOR": "ClickUp tasks and project data (personal task management)",
"GOOGLE_CALENDAR_CONNECTOR": "Google Calendar events, meetings, and schedules (personal calendar)",
"GOOGLE_GMAIL_CONNECTOR": "Google Gmail emails and conversations (personal emails)",
"GOOGLE_DRIVE_FILE": "Google Drive files and documents (personal cloud storage)",
"DISCORD_CONNECTOR": "Discord server conversations and shared content (personal community)",
"AIRTABLE_CONNECTOR": "Airtable records, tables, and database content (personal data)",
"LUMA_CONNECTOR": "Luma events and meetings",
"WEBCRAWLER_CONNECTOR": "Webpages indexed by SurfSense (personally selected websites)",
"CRAWLED_URL": "Webpages indexed by SurfSense (personally selected websites)",
"BOOKSTACK_CONNECTOR": "BookStack pages (personal documentation)",
"CIRCLEBACK": "Circleback meeting notes, transcripts, and action items",
"OBSIDIAN_CONNECTOR": "Obsidian vault notes and markdown files (personal notes)",
"ONEDRIVE_FILE": "Microsoft OneDrive files and documents (personal cloud storage)",
"DROPBOX_FILE": "Dropbox files and documents (cloud storage)",
}
def _normalize_connectors(
connectors_to_search: list[str] | None,
available_connectors: list[str] | None = None,
) -> list[str]:
"""
Normalize connectors provided by the model.
- Accepts user-facing enums like WEBCRAWLER_CONNECTOR and maps them to canonical
ConnectorService types.
- Drops unknown values.
- If available_connectors is provided, only includes connectors from that list.
- If connectors_to_search is None/empty, defaults to available_connectors or all.
Args:
connectors_to_search: List of connectors requested by the model
available_connectors: List of connectors actually available in the search space
Returns:
List of normalized connector strings to search
"""
# Determine the set of valid connectors to consider
valid_set = (
set(available_connectors) if available_connectors else set(_ALL_CONNECTORS)
)
valid_set -= _LIVE_SEARCH_CONNECTORS
if not connectors_to_search:
base = (
list(available_connectors)
if available_connectors
else list(_ALL_CONNECTORS)
)
return [c for c in base if c not in _LIVE_SEARCH_CONNECTORS]
normalized: list[str] = []
for raw in connectors_to_search:
c = (raw or "").strip().upper()
if not c:
continue
# Map user-facing aliases to canonical names
if c == "WEBCRAWLER_CONNECTOR":
c = "CRAWLED_URL"
normalized.append(c)
# de-dupe while preserving order + filter to valid connectors
seen: set[str] = set()
out: list[str] = []
for c in normalized:
if c in seen:
continue
# Only include if it's a known connector AND available
if c not in _ALL_CONNECTORS:
continue
if c not in valid_set:
continue
seen.add(c)
out.append(c)
# Fallback to all available if nothing matched
if not out:
base = (
list(available_connectors)
if available_connectors
else list(_ALL_CONNECTORS)
)
return [c for c in base if c not in _LIVE_SEARCH_CONNECTORS]
return out
# =============================================================================
# Document Formatting
# =============================================================================
# Fraction of the model's context window (in characters) that a single tool
# result is allowed to occupy. The remainder is reserved for system prompt,
# conversation history, and model output. With ~4 chars/token this gives a
# tool result ≈ 25 % of the context budget in tokens.
_TOOL_OUTPUT_CONTEXT_FRACTION = 0.25
_CHARS_PER_TOKEN = 4
# Hard-floor / ceiling so the budget is always sensible regardless of what
# the model reports.
_MIN_TOOL_OUTPUT_CHARS = 20_000 # ~5K tokens
_MAX_TOOL_OUTPUT_CHARS = 200_000 # ~50K tokens
_MAX_CHUNK_CHARS = 8_000
# Rank-adaptive per-document budget allocation.
# Top-ranked (most relevant) documents get a larger share of the budget so
# we pack as much high-quality context as possible.
#
# fraction(rank) = _TOP_DOC_BUDGET_FRACTION / (1 + rank * _RANK_DECAY)
#
# Examples (128K budget, 8K chunk cap):
# rank 0 → 40% → 6 chunks | rank 3 → 19% → 3 chunks
# rank 1 → 30% → 4 chunks | rank 10 → 10% → 3 chunks (floor)
# rank 2 → 24% → 3 chunks |
_TOP_DOC_BUDGET_FRACTION = 0.40
_RANK_DECAY = 0.35
_MIN_CHUNKS_PER_DOC = 3
def _compute_tool_output_budget(max_input_tokens: int | None) -> int:
"""Derive a character budget from the model's context window.
Uses ``litellm.get_model_info`` via the value already resolved by
``ChatLiteLLMRouter`` / ``ChatLiteLLM`` and passed through the dependency
chain as ``max_input_tokens``. Falls back to a conservative default when
the value is unavailable.
"""
if max_input_tokens is None or max_input_tokens <= 0:
return _MIN_TOOL_OUTPUT_CHARS # conservative fallback
budget = int(max_input_tokens * _CHARS_PER_TOKEN * _TOOL_OUTPUT_CONTEXT_FRACTION)
return max(_MIN_TOOL_OUTPUT_CHARS, min(budget, _MAX_TOOL_OUTPUT_CHARS))
_INTERNAL_METADATA_KEYS: frozenset[str] = frozenset(
{
"message_id",
"thread_id",
"event_id",
"calendar_id",
"google_drive_file_id",
"onedrive_file_id",
"dropbox_file_id",
"page_id",
"issue_id",
"connector_id",
}
)
def format_documents_for_context(
documents: list[dict[str, Any]],
*,
max_chars: int = _MAX_TOOL_OUTPUT_CHARS,
max_chunk_chars: int = _MAX_CHUNK_CHARS,
max_chunks_per_doc: int = 0,
) -> str:
"""
Format retrieved documents into a readable context string for the LLM.
Documents are added in order (highest relevance first) until the character
budget is reached. Individual chunks are capped at ``max_chunk_chars`` and
each document is limited to a dynamically computed chunk cap so a single
large document cannot monopolize the output while still maximising the use
of available context space.
Args:
documents: List of document dictionaries from connector search
max_chars: Approximate character budget for the entire output.
max_chunk_chars: Per-chunk character cap (content is tail-truncated).
max_chunks_per_doc: Maximum chunks per document. ``0`` (default) means
auto-compute per document using a rank-adaptive formula so
higher-ranked documents receive more chunks.
Returns:
Formatted string with document contents and metadata
"""
if not documents:
return ""
# Group chunks by document id (preferred) to produce the XML structure.
#
# IMPORTANT: ConnectorService returns **document-grouped** results of the form:
# {
# "document": {...},
# "chunks": [{"chunk_id": 123, "content": "..."}, ...],
# "source": "NOTION_CONNECTOR" | "FILE" | ...
# }
#
# We must preserve chunk_id so citations like [citation:123] are possible.
grouped: dict[str, dict[str, Any]] = {}
for doc in documents:
document_info = (doc.get("document") or {}) if isinstance(doc, dict) else {}
metadata = (
(document_info.get("metadata") or {})
if isinstance(document_info, dict)
else {}
)
if not metadata and isinstance(doc, dict):
# Some result shapes may place metadata at the top level.
metadata = doc.get("metadata") or {}
source = (
(doc.get("source") if isinstance(doc, dict) else None)
or document_info.get("document_type")
or metadata.get("document_type")
or "UNKNOWN"
)
# Document identity (prefer document_id; otherwise fall back to type+title+url)
document_id_val = document_info.get("id")
title = (
document_info.get("title") or metadata.get("title") or "Untitled Document"
)
url = (
metadata.get("url")
or metadata.get("source")
or metadata.get("page_url")
or ""
)
doc_key = (
str(document_id_val)
if document_id_val is not None
else f"{source}::{title}::{url}"
)
if doc_key not in grouped:
grouped[doc_key] = {
"document_id": document_id_val
if document_id_val is not None
else doc_key,
"document_type": metadata.get("document_type") or source,
"title": title,
"url": url,
"metadata": metadata,
"chunks": [],
}
# Prefer document-grouped chunks if available
chunks_list = doc.get("chunks") if isinstance(doc, dict) else None
if isinstance(chunks_list, list) and chunks_list:
for ch in chunks_list:
if not isinstance(ch, dict):
continue
chunk_id = ch.get("chunk_id") or ch.get("id")
content = (ch.get("content") or "").strip()
if not content:
continue
grouped[doc_key]["chunks"].append(
{"chunk_id": chunk_id, "content": content}
)
continue
# Fallback: treat this as a flat chunk-like object
if not isinstance(doc, dict):
continue
chunk_id = doc.get("chunk_id") or doc.get("id")
content = (doc.get("content") or "").strip()
if not content:
continue
grouped[doc_key]["chunks"].append({"chunk_id": chunk_id, "content": content})
# Live search connectors whose results should be cited by URL rather than
# a numeric chunk_id (the numeric IDs are meaningless auto-incremented counters).
live_search_connectors = {
"TAVILY_API",
"LINKUP_API",
"BAIDU_SEARCH_API",
}
# Render XML expected by citation instructions, respecting the char budget.
parts: list[str] = []
total_chars = 0
total_docs = len(grouped)
for doc_idx, g in enumerate(grouped.values()):
metadata_clean = {
k: v for k, v in g["metadata"].items() if k not in _INTERNAL_METADATA_KEYS
}
metadata_json = json.dumps(metadata_clean, ensure_ascii=False)
is_live_search = g["document_type"] in live_search_connectors
doc_lines: list[str] = [
"<document>",
"<document_metadata>",
f" <document_id>{g['document_id']}</document_id>",
f" <document_type>{g['document_type']}</document_type>",
f" <title><![CDATA[{g['title']}]]></title>",
f" <url><![CDATA[{g['url']}]]></url>",
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
"</document_metadata>",
"",
"<document_content>",
]
# Rank-adaptive per-document chunk cap: top results get more chunks.
if max_chunks_per_doc > 0:
chunks_allowed = max_chunks_per_doc
else:
doc_fraction = _TOP_DOC_BUDGET_FRACTION / (1 + doc_idx * _RANK_DECAY)
max_doc_chars = int(max_chars * doc_fraction)
xml_overhead = 500
chunks_allowed = max(
(max_doc_chars - xml_overhead) // max(max_chunk_chars, 1),
_MIN_CHUNKS_PER_DOC,
)
chunks = g["chunks"]
if len(chunks) > chunks_allowed:
chunks = chunks[:chunks_allowed]
for ch in chunks:
ch_content = ch["content"]
if max_chunk_chars and len(ch_content) > max_chunk_chars:
ch_content = ch_content[:max_chunk_chars] + "\n...(truncated)"
ch_id = g["url"] if (is_live_search and g["url"]) else ch["chunk_id"]
if ch_id is None:
doc_lines.append(f" <chunk><![CDATA[{ch_content}]]></chunk>")
else:
doc_lines.append(
f" <chunk id='{ch_id}'><![CDATA[{ch_content}]]></chunk>"
)
doc_lines.extend(["</document_content>", "</document>", ""])
doc_xml = "\n".join(doc_lines)
doc_len = len(doc_xml)
if total_chars + doc_len > max_chars:
remaining = total_docs - doc_idx
if doc_idx == 0:
parts.append(doc_xml)
total_chars += doc_len
parts.append(
f"<!-- Output truncated: {remaining} more document(s) omitted "
f"(budget {max_chars} chars). Refine your query or reduce top_k "
f"to retrieve different results. -->"
)
break
parts.append(doc_xml)
total_chars += doc_len
result = "\n".join(parts).strip()
# Hard safety net: if the result is still over budget (e.g. a single massive
# first document), forcibly truncate with a closing comment.
if len(result) > max_chars:
truncation_msg = "\n<!-- ...output forcibly truncated to fit context window -->"
result = result[: max_chars - len(truncation_msg)] + truncation_msg
return result
# =============================================================================
# Knowledge Base Search
# =============================================================================
async def search_knowledge_base_async(
query: str,
search_space_id: int,
db_session: AsyncSession,
connector_service: ConnectorService,
connectors_to_search: list[str] | None = None,
top_k: int = 10,
start_date: datetime | None = None,
end_date: datetime | None = None,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
max_input_tokens: int | None = None,
) -> str:
"""
Search the user's knowledge base for relevant documents.
This is the async implementation that searches across multiple connectors.
Args:
query: The search query
search_space_id: The user's search space ID
db_session: Database session
connector_service: Initialized connector service
connectors_to_search: Optional list of connector types to search. If omitted, searches all.
top_k: Number of results per connector
start_date: Optional start datetime (UTC) for filtering documents
end_date: Optional end datetime (UTC) for filtering documents
available_connectors: Optional list of connectors actually available in the search space.
If provided, only these connectors will be searched.
available_document_types: Optional list of document types that actually have indexed
data. When provided, local connectors whose document type is
absent are skipped entirely (no embedding / DB round-trip).
max_input_tokens: Model context window size (tokens). Used to dynamically
size the output so it fits within the model's limits.
Returns:
Formatted string with search results
"""
perf = get_perf_logger()
t0 = time.perf_counter()
deduplicated = await search_knowledge_base_raw_async(
query=query,
search_space_id=search_space_id,
db_session=db_session,
connector_service=connector_service,
connectors_to_search=connectors_to_search,
top_k=top_k,
start_date=start_date,
end_date=end_date,
available_connectors=available_connectors,
available_document_types=available_document_types,
)
if not deduplicated:
return "No documents found in the knowledge base. The search space has no indexed content yet."
# Use browse chunk cap for degenerate queries, otherwise adaptive chunking.
max_chunks_per_doc = (
_BROWSE_MAX_CHUNKS_PER_DOC if _is_degenerate_query(query) else 0
)
output_budget = _compute_tool_output_budget(max_input_tokens)
result = format_documents_for_context(
deduplicated,
max_chars=output_budget,
max_chunks_per_doc=max_chunks_per_doc,
)
if len(result) > output_budget:
perf.warning(
"[kb_search] output STILL exceeds budget after format (%d > %d), "
"hard truncation should have fired",
len(result),
output_budget,
)
perf.info(
"[kb_search] TOTAL in %.3fs total_docs=%d deduped=%d output_chars=%d "
"budget=%d max_input_tokens=%s space=%d",
time.perf_counter() - t0,
len(deduplicated),
len(deduplicated),
len(result),
output_budget,
max_input_tokens,
search_space_id,
)
return result
async def search_knowledge_base_raw_async(
query: str,
search_space_id: int,
db_session: AsyncSession,
connector_service: ConnectorService,
connectors_to_search: list[str] | None = None,
top_k: int = 10,
start_date: datetime | None = None,
end_date: datetime | None = None,
available_connectors: list[str] | None = None,
available_document_types: list[str] | None = None,
query_embedding: list[float] | None = None,
) -> list[dict[str, Any]]:
"""Search knowledge base and return raw document dicts (no XML formatting)."""
perf = get_perf_logger()
t0 = time.perf_counter()
all_documents: list[dict[str, Any]] = []
# Preserve the public signature for compatibility even if values are unused.
_ = (db_session, connector_service)
from app.agents.shared.utils import resolve_date_range
resolved_start_date, resolved_end_date = resolve_date_range(
start_date=start_date,
end_date=end_date,
)
connectors = _normalize_connectors(connectors_to_search, available_connectors)
if available_document_types:
doc_types_set = set(available_document_types)
connectors = [
c
for c in connectors
if c in doc_types_set
or NATIVE_TO_LEGACY_DOCTYPE.get(c, "") in doc_types_set
]
if not connectors:
return []
if _is_degenerate_query(query):
perf.info(
"[kb_search_raw] degenerate query %r detected - recency browse",
query,
)
browse_connectors = connectors if connectors else [None] # type: ignore[list-item]
expanded_browse = []
for connector in browse_connectors:
if connector is not None and connector in NATIVE_TO_LEGACY_DOCTYPE:
expanded_browse.append([connector, NATIVE_TO_LEGACY_DOCTYPE[connector]])
else:
expanded_browse.append(connector)
browse_results = await asyncio.gather(
*[
_browse_recent_documents(
search_space_id=search_space_id,
document_type=connector,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
)
for connector in expanded_browse
]
)
for docs in browse_results:
all_documents.extend(docs)
else:
if query_embedding is None:
from app.config import config as app_config
query_embedding = app_config.embedding_model_instance.embed(query)
max_parallel_searches = 4
semaphore = asyncio.Semaphore(max_parallel_searches)
async def _search_one_connector(connector: str) -> list[dict[str, Any]]:
try:
async with semaphore, shielded_async_session() as isolated_session:
svc = ConnectorService(isolated_session, search_space_id)
return await svc._combined_rrf_search(
query_text=query,
search_space_id=search_space_id,
document_type=connector,
top_k=top_k,
start_date=resolved_start_date,
end_date=resolved_end_date,
query_embedding=query_embedding,
)
except Exception as exc:
perf.warning("[kb_search_raw] connector=%s FAILED: %s", connector, exc)
return []
connector_results = await asyncio.gather(
*[_search_one_connector(connector) for connector in connectors]
)
for docs in connector_results:
all_documents.extend(docs)
seen_doc_ids: set[Any] = set()
seen_content_hashes: set[int] = set()
deduplicated: list[dict[str, Any]] = []
def _content_fingerprint(document: dict[str, Any]) -> int | None:
chunks = document.get("chunks")
if isinstance(chunks, list):
chunk_texts = []
for chunk in chunks:
if not isinstance(chunk, dict):
continue
chunk_content = (chunk.get("content") or "").strip()
if chunk_content:
chunk_texts.append(chunk_content)
if chunk_texts:
return hash("||".join(chunk_texts))
flat_content = (document.get("content") or "").strip()
if flat_content:
return hash(flat_content)
return None
for doc in all_documents:
doc_id = (doc.get("document", {}) or {}).get("id")
if doc_id is not None:
if doc_id in seen_doc_ids:
continue
seen_doc_ids.add(doc_id)
deduplicated.append(doc)
continue
content_hash = _content_fingerprint(doc)
if content_hash is not None and content_hash in seen_content_hashes:
continue
if content_hash is not None:
seen_content_hashes.add(content_hash)
deduplicated.append(doc)
deduplicated.sort(key=lambda doc: doc.get("score", 0), reverse=True)
perf.info(
"[kb_search_raw] done in %.3fs total=%d deduped=%d",
time.perf_counter() - t0,
len(all_documents),
len(deduplicated),
)
return deduplicated

View file

@ -0,0 +1,15 @@
from app.agents.shared.tools.luma.create_event import (
create_create_luma_event_tool,
)
from app.agents.shared.tools.luma.list_events import (
create_list_luma_events_tool,
)
from app.agents.shared.tools.luma.read_event import (
create_read_luma_event_tool,
)
__all__ = [
"create_create_luma_event_tool",
"create_list_luma_events_tool",
"create_read_luma_event_tool",
]

View file

@ -0,0 +1,39 @@
"""Shared auth helper for Luma agent tools."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
LUMA_API = "https://public-api.luma.com/v1"
async def get_luma_connector(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> SearchSourceConnector | None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.LUMA_CONNECTOR,
)
)
return result.scalars().first()
def get_api_key(connector: SearchSourceConnector) -> str:
"""Extract the API key from connector config (handles both key names)."""
key = connector.config.get("api_key") or connector.config.get("LUMA_API_KEY")
if not key:
raise ValueError("Luma API key not found in connector config.")
return key
def luma_headers(api_key: str) -> dict[str, str]:
return {
"Content-Type": "application/json",
"x-luma-api-key": api_key,
}

View file

@ -0,0 +1,150 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
def create_create_luma_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_luma_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_luma_event tool
"""
del db_session # per-call session — see docstring
@tool
async def create_luma_event(
name: str,
start_at: str,
end_at: str,
description: str | None = None,
timezone: str = "UTC",
) -> dict[str, Any]:
"""Create a new event on Luma.
Args:
name: The event title.
start_at: Start time in ISO 8601 format (e.g. "2026-05-01T18:00:00").
end_at: End time in ISO 8601 format (e.g. "2026-05-01T20:00:00").
description: Optional event description (markdown supported).
timezone: Timezone string (default "UTC", e.g. "America/New_York").
Returns:
Dictionary with status, event_id on success.
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
async with async_session_maker() as db_session:
connector = await get_luma_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
result = request_approval(
action_type="luma_create_event",
tool_name="create_luma_event",
params={
"name": name,
"start_at": start_at,
"end_at": end_at,
"description": description,
"timezone": timezone,
},
context={"connector_id": connector.id},
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Event was not created.",
}
final_name = result.params.get("name", name)
final_start = result.params.get("start_at", start_at)
final_end = result.params.get("end_at", end_at)
final_desc = result.params.get("description", description)
final_tz = result.params.get("timezone", timezone)
api_key = get_api_key(connector)
headers = luma_headers(api_key)
body: dict[str, Any] = {
"name": final_name,
"start_at": final_start,
"end_at": final_end,
"timezone": final_tz,
}
if final_desc:
body["description_md"] = final_desc
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{LUMA_API}/event/create",
headers=headers,
json=body,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Luma Plus subscription required to create events via API.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}{resp.text[:200]}",
}
data = resp.json()
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
return {
"status": "success",
"event_id": event_id,
"message": f"Event '{final_name}' created on Luma.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error creating Luma event: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to create Luma event."}
return create_luma_event

View file

@ -0,0 +1,133 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
def create_list_luma_events_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the list_luma_events tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_luma_events tool
"""
del db_session # per-call session — see docstring
@tool
async def list_luma_events(
max_results: int = 25,
) -> dict[str, Any]:
"""List upcoming and recent Luma events.
Args:
max_results: Maximum events to return (default 25, max 50).
Returns:
Dictionary with status and a list of events including
event_id, name, start_at, end_at, location, url.
"""
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
max_results = min(max_results, 50)
try:
async with async_session_maker() as db_session:
connector = await get_luma_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
api_key = get_api_key(connector)
headers = luma_headers(api_key)
all_entries: list[dict] = []
cursor = None
async with httpx.AsyncClient(timeout=20.0) as client:
while len(all_entries) < max_results:
params: dict[str, Any] = {
"limit": min(100, max_results - len(all_entries))
}
if cursor:
params["cursor"] = cursor
resp = await client.get(
f"{LUMA_API}/calendar/list-events",
headers=headers,
params=params,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}",
}
data = resp.json()
entries = data.get("entries", [])
if not entries:
break
all_entries.extend(entries)
next_cursor = data.get("next_cursor")
if not next_cursor:
break
cursor = next_cursor
events = []
for entry in all_entries[:max_results]:
ev = entry.get("event", {})
geo = ev.get("geo_info", {})
events.append(
{
"event_id": entry.get("api_id"),
"name": ev.get("name", "Untitled"),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location": geo.get("name", ""),
"url": ev.get("url", ""),
"visibility": ev.get("visibility", ""),
}
)
return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error listing Luma events: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to list Luma events."}
return list_luma_events

View file

@ -0,0 +1,114 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
def create_read_luma_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the read_luma_event tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_luma_event tool
"""
del db_session # per-call session — see docstring
@tool
async def read_luma_event(event_id: str) -> dict[str, Any]:
"""Read detailed information about a specific Luma event.
Args:
event_id: The Luma event API ID (from list_luma_events).
Returns:
Dictionary with status and full event details including
description, attendees count, meeting URL.
"""
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
async with async_session_maker() as db_session:
connector = await get_luma_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
api_key = get_api_key(connector)
headers = luma_headers(api_key)
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
f"{LUMA_API}/events/{event_id}",
headers=headers,
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Luma API key is invalid.",
"connector_type": "luma",
}
if resp.status_code == 404:
return {
"status": "not_found",
"message": f"Event '{event_id}' not found.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Luma API error: {resp.status_code}",
}
data = resp.json()
ev = data.get("event", data)
geo = ev.get("geo_info", {})
event_detail = {
"event_id": event_id,
"name": ev.get("name", ""),
"description": ev.get("description", ""),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location_name": geo.get("name", ""),
"address": geo.get("address", ""),
"url": ev.get("url", ""),
"meeting_url": ev.get("meeting_url", ""),
"visibility": ev.get("visibility", ""),
"cover_url": ev.get("cover_url", ""),
}
return {"status": "success", "event": event_detail}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Luma event: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to read Luma event."}
return read_luma_event

View file

@ -0,0 +1,326 @@
"""MCP Client Wrapper.
This module provides a client for communicating with MCP servers via stdio and HTTP transports.
It handles server lifecycle management, tool discovery, and tool execution.
"""
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from typing import Any
from mcp import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamablehttp_client
logger = logging.getLogger(__name__)
# Retry configuration
MAX_RETRIES = 3
RETRY_DELAY = 1.0 # seconds
RETRY_BACKOFF = 2.0 # exponential backoff multiplier
class MCPClient:
"""Client for communicating with an MCP server."""
def __init__(
self, command: str, args: list[str], env: dict[str, str] | None = None
):
"""Initialize MCP client.
Args:
command: Command to spawn the MCP server (e.g., "uvx", "node")
args: Arguments for the command (e.g., ["mcp-server-git"])
env: Optional environment variables for the server process
"""
self.command = command
self.args = args
self.env = env or {}
self.session: ClientSession | None = None
@asynccontextmanager
async def connect(self, max_retries: int = MAX_RETRIES):
"""Connect to the MCP server and manage its lifecycle.
Retries only apply to the **connection** phase (spawning the process,
initialising the session). Once the session is yielded to the caller,
any exception raised by the caller propagates normally -- the context
manager will NOT retry after ``yield``.
Previous implementation wrapped both connection AND yield inside the
retry loop. Because ``@asynccontextmanager`` only allows a single
``yield``, a failure after yield caused the generator to attempt a
second yield on retry, triggering
``RuntimeError("generator didn't stop after athrow()")`` and orphaning
the stdio subprocess.
Args:
max_retries: Maximum number of connection retry attempts
Yields:
ClientSession: Active MCP session for making requests
Raises:
RuntimeError: If all connection attempts fail
"""
last_error = None
delay = RETRY_DELAY
connected = False
for attempt in range(max_retries):
try:
server_env = os.environ.copy()
server_env.update(self.env)
server_params = StdioServerParameters(
command=self.command, args=self.args, env=server_env
)
async with stdio_client(server=server_params) as (read, write): # noqa: SIM117
async with ClientSession(read, write) as session:
await session.initialize()
self.session = session
connected = True
if attempt > 0:
logger.info(
"Connected to MCP server on attempt %d: %s %s",
attempt + 1,
self.command,
" ".join(self.args),
)
else:
logger.info(
"Connected to MCP server: %s %s",
self.command,
" ".join(self.args),
)
try:
yield session
finally:
self.session = None
return
except Exception as e:
self.session = None
if connected:
raise
last_error = e
if attempt < max_retries - 1:
logger.warning(
"MCP server connection failed (attempt %d/%d): %s. Retrying in %.1fs...",
attempt + 1,
max_retries,
e,
delay,
)
await asyncio.sleep(delay)
delay *= RETRY_BACKOFF
else:
logger.error(
"Failed to connect to MCP server after %d attempts: %s",
max_retries,
e,
exc_info=True,
)
error_msg = f"Failed to connect to MCP server '{self.command}' after {max_retries} attempts"
if last_error:
error_msg += f": {last_error}"
logger.error(error_msg)
raise RuntimeError(error_msg) from last_error
async def list_tools(self) -> list[dict[str, Any]]:
"""List all tools available from the MCP server.
Returns:
List of tool definitions with name, description, and input schema
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
# Call tools/list RPC method
response = await self.session.list_tools()
tools = []
for tool in response.tools:
tools.append(
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
)
logger.info("Listed %d tools from MCP server", len(tools))
return tools
except Exception as e:
logger.error("Failed to list tools from MCP server: %s", e, exc_info=True)
raise
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
timeout: float = 60.0,
) -> Any:
"""Call a tool on the MCP server.
Args:
tool_name: Name of the tool to call
arguments: Arguments to pass to the tool
timeout: Maximum seconds to wait for the tool to respond
Returns:
Tool execution result
Raises:
RuntimeError: If not connected to server
"""
if not self.session:
raise RuntimeError(
"Not connected to MCP server. Use 'async with client.connect():'"
)
try:
logger.info(
"Calling MCP tool '%s' with arguments: %s", tool_name, arguments
)
response = await asyncio.wait_for(
self.session.call_tool(tool_name, arguments=arguments),
timeout=timeout,
)
result = []
for content in response.content:
if hasattr(content, "text"):
result.append(content.text)
elif hasattr(content, "data"):
result.append(str(content.data))
else:
result.append(str(content))
result_str = "\n".join(result) if result else ""
logger.info("MCP tool '%s' succeeded: %s", tool_name, result_str[:200])
return result_str
except TimeoutError:
logger.error("MCP tool '%s' timed out after %.0fs", tool_name, timeout)
return f"Error: MCP tool '{tool_name}' timed out after {timeout:.0f}s"
except RuntimeError as e:
if "Invalid structured content" in str(e):
logger.warning(
"MCP server returned data not matching its schema, but continuing: %s",
e,
)
return "Operation completed (server returned unexpected format)"
raise
except (ValueError, TypeError, AttributeError, KeyError) as e:
logger.error(
"Failed to call MCP tool '%s': %s", tool_name, e, exc_info=True
)
return f"Error calling tool: {e!s}"
async def test_mcp_connection(
command: str, args: list[str], env: dict[str, str] | None = None
) -> dict[str, Any]:
"""Test connection to an MCP server via stdio and fetch available tools.
Args:
command: Command to spawn the MCP server
args: Arguments for the command
env: Optional environment variables
Returns:
Dict with connection status and available tools
"""
client = MCPClient(command, args, env)
try:
async with client.connect():
tools = await client.list_tools()
return {
"status": "success",
"message": f"Connected successfully. Found {len(tools)} tools.",
"tools": tools,
}
except (RuntimeError, ConnectionError, TimeoutError, OSError) as e:
return {
"status": "error",
"message": f"Failed to connect: {e!s}",
"tools": [],
}
async def test_mcp_http_connection(
url: str, headers: dict[str, str] | None = None, transport: str = "streamable-http"
) -> dict[str, Any]:
"""Test connection to an MCP server via HTTP and fetch available tools.
Args:
url: URL of the MCP server
headers: Optional HTTP headers for authentication
transport: Transport type ("streamable-http", "http", or "sse")
Returns:
Dict with connection status and available tools
"""
try:
logger.info(
"Testing HTTP MCP connection to: %s (transport: %s)", url, transport
)
# Use streamable HTTP client for all HTTP-based transports
async with (
streamablehttp_client(url, headers=headers or {}) as (read, write, _),
ClientSession(read, write) as session,
):
await session.initialize()
# List available tools
response = await session.list_tools()
tools = []
for tool in response.tools:
tools.append(
{
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema
if hasattr(tool, "inputSchema")
else {},
}
)
logger.info("HTTP MCP connection successful. Found %d tools.", len(tools))
return {
"status": "success",
"message": f"Connected successfully. Found {len(tools)} tools.",
"tools": tools,
}
except Exception as e:
logger.error("Failed to connect to HTTP MCP server: %s", e, exc_info=True)
return {
"status": "error",
"message": f"Failed to connect: {e!s}",
"tools": [],
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,145 @@
"""Persist MCP ``list_tools`` results in ``SearchSourceConnector.config.cached_tools``."""
from __future__ import annotations
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm.attributes import flag_modified
from app.db import SearchSourceConnector, async_session_maker
logger = logging.getLogger(__name__)
_pending_prefetch_tasks: set[asyncio.Task[None]] = set()
class CachedMCPToolDef(BaseModel):
name: str
description: str = ""
input_schema: dict[str, Any] = Field(default_factory=dict)
class CachedMCPTools(BaseModel):
discovered_at: datetime
server_version: str | None = None
server_name: str | None = None
transport: str | None = None
tools: list[CachedMCPToolDef]
def read_cached_tools(connector: SearchSourceConnector) -> CachedMCPTools | None:
"""Return parsed cached tools or ``None`` if missing / corrupt (caller falls back to live discovery)."""
cfg = connector.config or {}
raw = cfg.get("cached_tools")
if not raw or not isinstance(raw, dict):
return None
try:
return CachedMCPTools.model_validate(raw)
except ValidationError as exc:
logger.warning(
"MCP connector %d has corrupt cached_tools — falling back to live discovery: %s",
connector.id,
exc,
)
return None
async def write_cached_tools(
connector_id: int,
tool_definitions: list[dict[str, Any]],
*,
server_name: str | None = None,
server_version: str | None = None,
transport: str | None = None,
) -> None:
"""Best-effort persist; uses its own session so a write failure cannot poison the caller's transaction."""
payload = CachedMCPTools(
discovered_at=datetime.now(UTC),
server_version=server_version,
server_name=server_name,
transport=transport,
tools=[CachedMCPToolDef.model_validate(td) for td in tool_definitions],
)
try:
async with async_session_maker() as session:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
)
)
connector = result.scalars().first()
if connector is None:
return
cfg = dict(connector.config or {})
cfg["cached_tools"] = payload.model_dump(mode="json")
connector.config = cfg
flag_modified(connector, "config")
await session.commit()
logger.info(
"Persisted cached_tools for MCP connector %d (%d tools)",
connector_id,
len(payload.tools),
)
except Exception:
logger.warning(
"Failed to persist cached_tools for MCP connector %d",
connector_id,
exc_info=True,
)
def refresh_mcp_tools_cache_for_connector(
connector_id: int,
search_space_id: int,
) -> None:
"""Maintain the MCP tool cache after a single-connector lifecycle event.
Synchronously evicts the in-process LRU for the connector's search space
(LRU keys are per-space, so eviction cannot be scoped finer), then schedules
a background live discovery for this connector alone so its persisted
``cached_tools`` row is refreshed before the next user query.
Idempotent. Eviction is best-effort; prefetch is best-effort and only runs
when an event loop is available. Neither path raises.
"""
try:
from app.agents.shared.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(search_space_id)
except Exception:
logger.debug(
"MCP in-process cache eviction skipped for space %d",
search_space_id,
exc_info=True,
)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
task = loop.create_task(_run_connector_prefetch(connector_id))
_pending_prefetch_tasks.add(task)
task.add_done_callback(_pending_prefetch_tasks.discard)
async def _run_connector_prefetch(connector_id: int) -> None:
from app.agents.shared.tools.mcp_tool import discover_single_mcp_connector
try:
await discover_single_mcp_connector(connector_id)
except Exception:
logger.warning(
"MCP background prefetch failed for connector_id=%d",
connector_id,
exc_info=True,
)

View file

@ -0,0 +1,11 @@
"""Notion tools for creating, updating, and deleting pages."""
from .create_page import create_create_notion_page_tool
from .delete_page import create_delete_notion_page_tool
from .update_page import create_update_notion_page_tool
__all__ = [
"create_create_notion_page_tool",
"create_delete_notion_page_tool",
"create_update_notion_page_tool",
]

View file

@ -0,0 +1,258 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__)
def create_create_notion_page_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the create_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker`. This is critical for the compiled-agent
cache: the compiled graph (and therefore this closure) is reused
across HTTP requests, so capturing a per-request session here would
surface stale/closed sessions on cache hits. Per-call sessions also
keep the request's outer transaction free of long-running Notion API
blocking.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured create_notion_page tool
"""
del db_session # per-call session — see docstring
@tool
async def create_notion_page(
title: str,
content: str | None = None,
) -> dict[str, Any]:
"""Create a new page in Notion with the given title and content.
Use this tool when the user asks you to create, save, or publish
something to Notion. The page will be created in the user's
configured Notion workspace. The user MUST specify a topic before you
call this tool. If the request does not contain a topic (e.g. "create a
notion page"), ask what the page should be about. Never call this tool
without a clear topic from the user.
Args:
title: The title of the Notion page.
content: Optional markdown content for the page body (supports headings, lists, paragraphs).
Generate this yourself based on the user's topic.
Returns:
Dictionary with:
- status: "success", "rejected", or "error"
- page_id: Created page ID (if success)
- url: URL to the created page (if success)
- title: Page title (if success)
- message: Result message
IMPORTANT: If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment (e.g., "Understood, I didn't create the page.")
and move on. Do NOT troubleshoot or suggest alternatives.
Examples:
- "Create a Notion page about our Q2 roadmap"
- "Save a summary of today's discussion to Notion"
"""
logger.info(f"create_notion_page called: title='{title}'")
if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
return {
"status": "error",
"message": "Notion tool not properly configured. Please contact support.",
}
try:
async with async_session_maker() as db_session:
metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_creation_context(
search_space_id, user_id
)
if "error" in context:
logger.error(
f"Failed to fetch creation context: {context['error']}"
)
return {
"status": "error",
"message": context["error"],
}
accounts = context.get("accounts", [])
if accounts and all(a.get("auth_expired") for a in accounts):
logger.warning("All Notion accounts have expired authentication")
return {
"status": "auth_error",
"message": "All connected Notion accounts need re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "notion",
}
logger.info(f"Requesting approval for creating Notion page: '{title}'")
result = request_approval(
action_type="notion_page_creation",
tool_name="create_notion_page",
params={
"title": title,
"content": content,
"parent_page_id": None,
"connector_id": connector_id,
},
context=context,
)
if result.rejected:
logger.info("Notion page creation rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_title = result.params.get("title", title)
final_content = result.params.get("content", content)
final_parent_page_id = result.params.get("parent_page_id")
final_connector_id = result.params.get("connector_id", connector_id)
if not final_title or not final_title.strip():
logger.error("Title is empty or contains only whitespace")
return {
"status": "error",
"message": "Page title cannot be empty. Please provide a valid title.",
}
logger.info(
f"Creating Notion page with final params: title='{final_title}'"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
actual_connector_id = final_connector_id
if actual_connector_id is None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.warning(
f"No Notion connector found for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "No Notion connector found. Please connect Notion in your workspace settings.",
}
actual_connector_id = connector.id
logger.info(f"Found Notion connector: id={actual_connector_id}")
else:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == actual_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={actual_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
logger.info(f"Validated Notion connector: id={actual_connector_id}")
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
result = await notion_connector.create_page(
title=final_title,
content=final_content,
parent_page_id=final_parent_page_id,
)
logger.info(
f"create_page result: {result.get('status')} - {result.get('message', '')}"
)
if result.get("status") == "success":
kb_message_suffix = ""
try:
from app.services.notion import NotionKBSyncService
kb_service = NotionKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
page_id=result.get("page_id"),
page_title=result.get("title", final_title),
page_url=result.get("url"),
content=final_content,
connector_id=actual_connector_id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This page will be added to your knowledge base in the next scheduled sync."
result["message"] = result.get("message", "") + kb_message_suffix
return result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating Notion page: {e}", exc_info=True)
if isinstance(e, ValueError | NotionAPIError):
message = str(e)
else:
message = (
"Something went wrong while creating the page. Please try again."
)
return {"status": "error", "message": message}
return create_notion_page

View file

@ -0,0 +1,273 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion.tool_metadata_service import NotionToolMetadataService
logger = logging.getLogger(__name__)
def create_delete_notion_page_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the delete_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for finding the correct Notion connector
connector_id: Optional specific connector ID (if known)
Returns:
Configured delete_notion_page tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_notion_page(
page_title: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Delete (archive) a Notion page.
Use this tool when the user asks you to delete, remove, or archive
a Notion page. Note that Notion doesn't permanently delete pages,
it archives them (they can be restored from trash).
Args:
page_title: The title of the Notion page to delete.
delete_from_kb: Whether to also remove the page from the knowledge base.
Default is False.
Set to True to permanently remove from both Notion and knowledge base.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", or "error"
- page_id: Deleted page ID (if success)
- message: Success or error message
- deleted_from_kb: Whether the page was also removed from knowledge base (if success)
Examples:
- "Delete the 'Meeting Notes' Notion page"
- "Remove the 'Old Project Plan' Notion page"
- "Archive the 'Draft Ideas' Notion page"
"""
logger.info(
f"delete_notion_page called: page_title='{page_title}', delete_from_kb={delete_from_kb}"
)
if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
return {
"status": "error",
"message": "Notion tool not properly configured. Please contact support.",
}
try:
async with async_session_maker() as db_session:
# Get page context (page_id, account, title) from indexed data
metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_delete_context(
search_space_id, user_id, page_title
)
if "error" in context:
error_msg = context["error"]
# Check if it's a "not found" error (softer handling for LLM)
if "not found" in error_msg.lower():
logger.warning(f"Page not found: {error_msg}")
return {
"status": "not_found",
"message": error_msg,
}
else:
logger.error(f"Failed to fetch delete context: {error_msg}")
return {
"status": "error",
"message": error_msg,
}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Notion account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
page_id = context.get("page_id")
connector_id_from_context = account.get("id")
document_id = context.get("document_id")
logger.info(
f"Requesting approval for deleting Notion page: '{page_title}' (page_id={page_id}, delete_from_kb={delete_from_kb})"
)
result = request_approval(
action_type="notion_page_deletion",
tool_name="delete_notion_page",
params={
"page_id": page_id,
"connector_id": connector_id_from_context,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
logger.info("Notion page deletion rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
logger.info(
f"Deleting Notion page with final params: page_id={final_page_id}, connector_id={final_connector_id}, delete_from_kb={final_delete_from_kb}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
# Validate the connector
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
actual_connector_id = connector.id
logger.info(f"Validated Notion connector: id={actual_connector_id}")
else:
logger.error("No connector found for this page")
return {
"status": "error",
"message": "No connector found for this page.",
}
# Create connector instance
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
# Delete the page from Notion
result = await notion_connector.delete_page(page_id=final_page_id)
logger.info(
f"delete_page result: {result.get('status')} - {result.get('message', '')}"
)
# If deletion was successful and user wants to delete from KB
deleted_from_kb = False
if (
result.get("status") == "success"
and final_delete_from_kb
and document_id
):
try:
from sqlalchemy.future import select
from app.db import Document
# Get the document
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
document = doc_result.scalars().first()
if document:
await db_session.delete(document)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
result["warning"] = (
f"Page deleted from Notion, but failed to remove from knowledge base: {e!s}"
)
# Update result with KB deletion status
if result.get("status") == "success":
result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
result["message"] = (
f"{result.get('message', '')} (also removed from knowledge base)"
)
return result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting Notion page: {e}", exc_info=True)
error_str = str(e).lower()
if isinstance(e, NotionAPIError) and (
"401" in error_str or "unauthorized" in error_str
):
return {
"status": "auth_error",
"message": str(e),
"connector_id": connector_id_from_context
if "connector_id_from_context" in dir()
else None,
"connector_type": "notion",
}
if isinstance(e, ValueError | NotionAPIError):
message = str(e)
else:
message = (
"Something went wrong while deleting the page. Please try again."
)
return {"status": "error", "message": message}
return delete_notion_page

View file

@ -0,0 +1,276 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.connectors.notion_history import NotionAPIError, NotionHistoryConnector
from app.db import async_session_maker
from app.services.notion import NotionToolMetadataService
logger = logging.getLogger(__name__)
def create_update_notion_page_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
connector_id: int | None = None,
):
"""
Factory function to create the update_notion_page tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache (see
``create_create_notion_page_tool`` for the full rationale).
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
search_space_id: Search space ID to find the Notion connector
user_id: User ID for fetching user-specific context
connector_id: Optional specific connector ID (if known)
Returns:
Configured update_notion_page tool
"""
del db_session # per-call session — see docstring
@tool
async def update_notion_page(
page_title: str,
content: str | None = None,
) -> dict[str, Any]:
"""Update an existing Notion page by appending new content.
Use this tool when the user asks you to add content to, modify, or update
a Notion page. The new content will be appended to the existing page content.
The user MUST specify what to add before you call this tool. If the
request is vague, ask what content they want added.
Args:
page_title: The title of the Notion page to update.
content: Optional markdown content to append to the page body (supports headings, lists, paragraphs).
Generate this yourself based on the user's request.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", or "error"
- page_id: Updated page ID (if success)
- url: URL to the updated page (if success)
- title: Current page title (if success)
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined the action.
Respond with a brief acknowledgment (e.g., "Understood, I didn't update the page.")
and move on. Do NOT ask for alternatives or troubleshoot.
- If status is "not_found", inform the user conversationally using the exact message provided.
Example: "I couldn't find the page '[page_title]' in your indexed Notion pages. [message details]"
Do NOT treat this as an error. Do NOT invent information. Simply relay the message and
ask the user to verify the page title or check if it's been indexed.
Examples:
- "Add today's meeting notes to the 'Meeting Notes' Notion page"
- "Update the 'Project Plan' page with a status update on phase 1"
"""
logger.info(
f"update_notion_page called: page_title='{page_title}', content_length={len(content) if content else 0}"
)
if search_space_id is None or user_id is None:
logger.error(
"Notion tool not properly configured - missing required parameters"
)
return {
"status": "error",
"message": "Notion tool not properly configured. Please contact support.",
}
if not content or not content.strip():
logger.error(f"Empty content provided for page '{page_title}'")
return {
"status": "error",
"message": "Content is required to update the page. Please provide the actual content you want to add.",
}
try:
async with async_session_maker() as db_session:
metadata_service = NotionToolMetadataService(db_session)
context = await metadata_service.get_update_context(
search_space_id, user_id, page_title
)
if "error" in context:
error_msg = context["error"]
# Check if it's a "not found" error (softer handling for LLM)
if "not found" in error_msg.lower():
logger.warning(f"Page not found: {error_msg}")
return {
"status": "not_found",
"message": error_msg,
}
else:
logger.error(f"Failed to fetch update context: {error_msg}")
return {
"status": "error",
"message": error_msg,
}
account = context.get("account", {})
if account.get("auth_expired"):
logger.warning(
"Notion account %s has expired authentication",
account.get("id"),
)
return {
"status": "auth_error",
"message": "The Notion account for this page needs re-authentication. Please re-authenticate in your connector settings.",
}
page_id = context.get("page_id")
document_id = context.get("document_id")
connector_id_from_context = context.get("account", {}).get("id")
logger.info(
f"Requesting approval for updating Notion page: '{page_title}' (page_id={page_id})"
)
result = request_approval(
action_type="notion_page_update",
tool_name="update_notion_page",
params={
"page_id": page_id,
"content": content,
"connector_id": connector_id_from_context,
},
context=context,
)
if result.rejected:
logger.info("Notion page update rejected by user")
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_page_id = result.params.get("page_id", page_id)
final_content = result.params.get("content", content)
final_connector_id = result.params.get(
"connector_id", connector_id_from_context
)
logger.info(
f"Updating Notion page with final params: page_id={final_page_id}, has_content={final_content is not None}"
)
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
if final_connector_id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.NOTION_CONNECTOR,
)
)
connector = result.scalars().first()
if not connector:
logger.error(
f"Invalid connector_id={final_connector_id} for search_space_id={search_space_id}"
)
return {
"status": "error",
"message": "Selected Notion account is invalid or has been disconnected. Please select a valid account.",
}
actual_connector_id = connector.id
logger.info(f"Validated Notion connector: id={actual_connector_id}")
else:
logger.error("No connector found for this page")
return {
"status": "error",
"message": "No connector found for this page.",
}
notion_connector = NotionHistoryConnector(
session=db_session,
connector_id=actual_connector_id,
)
result = await notion_connector.update_page(
page_id=final_page_id,
content=final_content,
)
logger.info(
f"update_page result: {result.get('status')} - {result.get('message', '')}"
)
if result.get("status") == "success" and document_id is not None:
from app.services.notion import NotionKBSyncService
logger.info(
f"Updating knowledge base for document {document_id}..."
)
kb_service = NotionKBSyncService(db_session)
kb_result = await kb_service.sync_after_update(
document_id=document_id,
appended_content=final_content,
user_id=user_id,
search_space_id=search_space_id,
appended_block_ids=result.get("appended_block_ids"),
)
if kb_result["status"] == "success":
result["message"] = (
f"{result['message']}. Your knowledge base has also been updated."
)
logger.info(
f"Knowledge base successfully updated for page {final_page_id}"
)
elif kb_result["status"] == "not_indexed":
result["message"] = (
f"{result['message']}. This page will be added to your knowledge base in the next scheduled sync."
)
else:
result["message"] = (
f"{result['message']}. Your knowledge base will be updated in the next scheduled sync."
)
logger.warning(
f"KB update failed for page {final_page_id}: {kb_result['message']}"
)
return result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error updating Notion page: {e}", exc_info=True)
error_str = str(e).lower()
if isinstance(e, NotionAPIError) and (
"401" in error_str or "unauthorized" in error_str
):
return {
"status": "auth_error",
"message": str(e),
"connector_id": connector_id_from_context
if "connector_id_from_context" in dir()
else None,
"connector_type": "notion",
}
if isinstance(e, ValueError | NotionAPIError):
message = str(e)
else:
message = (
"Something went wrong while updating the page. Please try again."
)
return {"status": "error", "message": message}
return update_notion_page

View file

@ -0,0 +1,11 @@
from app.agents.shared.tools.onedrive.create_file import (
create_create_onedrive_file_tool,
)
from app.agents.shared.tools.onedrive.trash_file import (
create_delete_onedrive_file_tool,
)
__all__ = [
"create_create_onedrive_file_tool",
"create_delete_onedrive_file_tool",
]

View file

@ -0,0 +1,274 @@
import logging
import os
import tempfile
from pathlib import Path
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.shared.tools.hitl import request_approval
from app.connectors.onedrive.client import OneDriveClient
from app.db import SearchSourceConnector, SearchSourceConnectorType, async_session_maker
logger = logging.getLogger(__name__)
DOCX_MIME = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
def _ensure_docx_extension(name: str) -> str:
"""Strip any existing extension and append .docx."""
stem = Path(name).stem
return f"{stem}.docx"
def _markdown_to_docx(markdown_text: str) -> bytes:
"""Convert a markdown string to DOCX bytes using pypandoc."""
import pypandoc
fd, tmp_path = tempfile.mkstemp(suffix=".docx")
os.close(fd)
try:
pypandoc.convert_text(
markdown_text,
"docx",
format="gfm",
extra_args=["--standalone"],
outputfile=tmp_path,
)
with open(tmp_path, "rb") as f:
return f.read()
finally:
os.unlink(tmp_path)
def create_create_onedrive_file_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the create_onedrive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured create_onedrive_file tool
"""
del db_session # per-call session — see docstring
@tool
async def create_onedrive_file(
name: str,
content: str | None = None,
) -> dict[str, Any]:
"""Create a new Word document (.docx) in Microsoft OneDrive.
Use this tool when the user explicitly asks to create a new document
in OneDrive. The user MUST specify a topic before you call this tool.
The file is always saved as a .docx Word document. Provide content as
markdown and it will be automatically converted to a formatted Word file.
Args:
name: The document title (without extension). Extension will be set to .docx automatically.
content: Optional initial content as markdown. Will be converted to a formatted Word document.
Returns:
Dictionary with status, file_id, name, web_url, and message.
"""
logger.info(f"create_onedrive_file called: name='{name}'")
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "OneDrive tool not properly configured.",
}
try:
async with async_session_maker() as db_session:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
connectors = result.scalars().all()
if not connectors:
return {
"status": "error",
"message": "No OneDrive connector found. Please connect OneDrive in your workspace settings.",
}
accounts = []
for c in connectors:
cfg = c.config or {}
accounts.append(
{
"id": c.id,
"name": c.name,
"user_email": cfg.get("user_email"),
"auth_expired": cfg.get("auth_expired", False),
}
)
if all(a.get("auth_expired") for a in accounts):
return {
"status": "auth_error",
"message": "All connected OneDrive accounts need re-authentication.",
"connector_type": "onedrive",
}
parent_folders: dict[int, list[dict[str, str]]] = {}
for acc in accounts:
cid = acc["id"]
if acc.get("auth_expired"):
parent_folders[cid] = []
continue
try:
client = OneDriveClient(session=db_session, connector_id=cid)
items, err = await client.list_children("root")
if err:
logger.warning(
"Failed to list folders for connector %s: %s", cid, err
)
parent_folders[cid] = []
else:
parent_folders[cid] = [
{"folder_id": item["id"], "name": item["name"]}
for item in items
if item.get("folder") is not None
and item.get("id")
and item.get("name")
]
except Exception:
logger.warning(
"Error fetching folders for connector %s",
cid,
exc_info=True,
)
parent_folders[cid] = []
context: dict[str, Any] = {
"accounts": accounts,
"parent_folders": parent_folders,
}
result = request_approval(
action_type="onedrive_file_creation",
tool_name="create_onedrive_file",
params={
"name": name,
"content": content,
"connector_id": None,
"parent_folder_id": None,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_name = result.params.get("name", name)
final_content = result.params.get("content", content)
final_connector_id = result.params.get("connector_id")
final_parent_folder_id = result.params.get("parent_folder_id")
if not final_name or not final_name.strip():
return {"status": "error", "message": "File name cannot be empty."}
final_name = _ensure_docx_extension(final_name)
if final_connector_id is not None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
connector = result.scalars().first()
else:
connector = connectors[0]
if not connector:
return {
"status": "error",
"message": "Selected OneDrive connector is invalid.",
}
docx_bytes = _markdown_to_docx(final_content or "")
client = OneDriveClient(session=db_session, connector_id=connector.id)
created = await client.create_file(
name=final_name,
parent_id=final_parent_folder_id,
content=docx_bytes,
mime_type=DOCX_MIME,
)
logger.info(
f"OneDrive file created: id={created.get('id')}, name={created.get('name')}"
)
kb_message_suffix = ""
try:
from app.services.onedrive import OneDriveKBSyncService
kb_service = OneDriveKBSyncService(db_session)
kb_result = await kb_service.sync_after_create(
file_id=created.get("id"),
file_name=created.get("name", final_name),
mime_type=DOCX_MIME,
web_url=created.get("webUrl"),
content=final_content,
connector_id=connector.id,
search_space_id=search_space_id,
user_id=user_id,
)
if kb_result["status"] == "success":
kb_message_suffix = (
" Your knowledge base has also been updated."
)
else:
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
except Exception as kb_err:
logger.warning(f"KB sync after create failed: {kb_err}")
kb_message_suffix = " This file will be added to your knowledge base in the next scheduled sync."
return {
"status": "success",
"file_id": created.get("id"),
"name": created.get("name"),
"web_url": created.get("webUrl"),
"message": f"Successfully created '{created.get('name')}' in OneDrive.{kb_message_suffix}",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error creating OneDrive file: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while creating the file. Please try again.",
}
return create_onedrive_file

View file

@ -0,0 +1,305 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy import String, and_, cast, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.shared.tools.hitl import request_approval
from app.connectors.onedrive.client import OneDriveClient
from app.db import (
Document,
DocumentType,
SearchSourceConnector,
SearchSourceConnectorType,
async_session_maker,
)
logger = logging.getLogger(__name__)
def create_delete_onedrive_file_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the delete_onedrive_file tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured delete_onedrive_file tool
"""
del db_session # per-call session — see docstring
@tool
async def delete_onedrive_file(
file_name: str,
delete_from_kb: bool = False,
) -> dict[str, Any]:
"""Move a OneDrive file to the recycle bin.
Use this tool when the user explicitly asks to delete, remove, or trash
a file in OneDrive.
Args:
file_name: The exact name of the file to trash.
delete_from_kb: Whether to also remove the file from the knowledge base.
Default is False.
Set to True to remove from both OneDrive and knowledge base.
Returns:
Dictionary with:
- status: "success", "rejected", "not_found", or "error"
- file_id: OneDrive file ID (if success)
- deleted_from_kb: whether the document was removed from the knowledge base
- message: Result message
IMPORTANT:
- If status is "rejected", the user explicitly declined. Respond with a brief
acknowledgment and do NOT retry or suggest alternatives.
- If status is "not_found", relay the exact message to the user and ask them
to verify the file name or check if it has been indexed.
"""
logger.info(
f"delete_onedrive_file called: file_name='{file_name}', delete_from_kb={delete_from_kb}"
)
if search_space_id is None or user_id is None:
return {
"status": "error",
"message": "OneDrive tool not properly configured.",
}
try:
async with async_session_maker() as db_session:
doc_result = await db_session.execute(
select(Document)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.ONEDRIVE_FILE,
func.lower(Document.title) == func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
document = doc_result.scalars().first()
if not document:
doc_result = await db_session.execute(
select(Document)
.join(
SearchSourceConnector,
Document.connector_id == SearchSourceConnector.id,
)
.filter(
and_(
Document.search_space_id == search_space_id,
Document.document_type == DocumentType.ONEDRIVE_FILE,
func.lower(
cast(
Document.document_metadata[
"onedrive_file_name"
],
String,
)
)
== func.lower(file_name),
SearchSourceConnector.user_id == user_id,
)
)
.order_by(Document.updated_at.desc().nullslast())
.limit(1)
)
document = doc_result.scalars().first()
if not document:
return {
"status": "not_found",
"message": (
f"File '{file_name}' not found in your indexed OneDrive files. "
"This could mean: (1) the file doesn't exist, (2) it hasn't been indexed yet, "
"or (3) the file name is different."
),
}
if not document.connector_id:
return {
"status": "error",
"message": "Document has no associated connector.",
}
meta = document.document_metadata or {}
file_id = meta.get("onedrive_file_id")
document_id = document.id
if not file_id:
return {
"status": "error",
"message": "File ID is missing. Please re-index the file.",
}
conn_result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == document.connector_id,
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
)
connector = conn_result.scalars().first()
if not connector:
return {
"status": "error",
"message": "OneDrive connector not found or access denied.",
}
cfg = connector.config or {}
if cfg.get("auth_expired"):
return {
"status": "auth_error",
"message": "OneDrive account needs re-authentication. Please re-authenticate in your connector settings.",
"connector_type": "onedrive",
}
context = {
"file": {
"file_id": file_id,
"name": file_name,
"document_id": document_id,
"web_url": meta.get("web_url"),
},
"account": {
"id": connector.id,
"name": connector.name,
"user_email": cfg.get("user_email"),
},
}
result = request_approval(
action_type="onedrive_file_trash",
tool_name="delete_onedrive_file",
params={
"file_id": file_id,
"connector_id": connector.id,
"delete_from_kb": delete_from_kb,
},
context=context,
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Do not retry or suggest alternatives.",
}
final_file_id = result.params.get("file_id", file_id)
final_connector_id = result.params.get("connector_id", connector.id)
final_delete_from_kb = result.params.get(
"delete_from_kb", delete_from_kb
)
if final_connector_id != connector.id:
result = await db_session.execute(
select(SearchSourceConnector).filter(
and_(
SearchSourceConnector.id == final_connector_id,
SearchSourceConnector.search_space_id
== search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
)
)
)
validated_connector = result.scalars().first()
if not validated_connector:
return {
"status": "error",
"message": "Selected OneDrive connector is invalid or has been disconnected.",
}
actual_connector_id = validated_connector.id
else:
actual_connector_id = connector.id
logger.info(
f"Deleting OneDrive file: file_id='{final_file_id}', connector={actual_connector_id}"
)
client = OneDriveClient(
session=db_session, connector_id=actual_connector_id
)
await client.trash_file(final_file_id)
logger.info(
f"OneDrive file deleted (moved to recycle bin): file_id={final_file_id}"
)
trash_result: dict[str, Any] = {
"status": "success",
"file_id": final_file_id,
"message": f"Successfully moved '{file_name}' to the recycle bin.",
}
deleted_from_kb = False
if final_delete_from_kb and document_id:
try:
doc_result = await db_session.execute(
select(Document).filter(Document.id == document_id)
)
doc = doc_result.scalars().first()
if doc:
await db_session.delete(doc)
await db_session.commit()
deleted_from_kb = True
logger.info(
f"Deleted document {document_id} from knowledge base"
)
else:
logger.warning(f"Document {document_id} not found in KB")
except Exception as e:
logger.error(f"Failed to delete document from KB: {e}")
await db_session.rollback()
trash_result["warning"] = (
f"File moved to recycle bin, but failed to remove from knowledge base: {e!s}"
)
trash_result["deleted_from_kb"] = deleted_from_kb
if deleted_from_kb:
trash_result["message"] = (
f"{trash_result.get('message', '')} (also removed from knowledge base)"
)
return trash_result
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error(f"Error deleting OneDrive file: {e}", exc_info=True)
return {
"status": "error",
"message": "Something went wrong while trashing the file. Please try again.",
}
return delete_onedrive_file

View file

@ -0,0 +1,160 @@
"""
Podcast generation tool for the SurfSense agent.
This module provides a factory function for creating the generate_podcast tool
that submits a Celery task for background podcast generation. The tool then
polls the podcast row until it reaches a terminal status (READY/FAILED) and
returns that status. The wait is bounded by the chat's HTTP / process
lifetime; see app.agents.shared.deliverable_wait for details.
"""
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.deliverable_wait import wait_for_deliverable
from app.db import Podcast, PodcastStatus, shielded_async_session
logger = logging.getLogger(__name__)
def create_generate_podcast_tool(
search_space_id: int,
db_session: AsyncSession,
thread_id: int | None = None,
):
"""
Factory function to create the generate_podcast tool with injected dependencies.
Pre-creates podcast record with pending status so podcast_id is available
immediately for frontend polling.
Args:
search_space_id: The user's search space ID
db_session: Reserved for future read-side use; the row is written via a
fresh, tool-local session so parallel tool calls (e.g. podcast +
video presentation in the same agent step) don't share an
``AsyncSession`` (which is not concurrency-safe).
thread_id: The chat thread ID for associating the podcast
Returns:
A configured tool function for generating podcasts
"""
del db_session # writes use a fresh tool-local session, see below
@tool
async def generate_podcast(
source_content: str,
podcast_title: str = "SurfSense Podcast",
user_prompt: str | None = None,
) -> dict[str, Any]:
"""
Generate a podcast from the provided content.
Use this tool when the user asks to create, generate, or make a podcast.
Common triggers include phrases like:
- "Give me a podcast about this"
- "Create a podcast from this conversation"
- "Generate a podcast summary"
- "Make a podcast about..."
- "Turn this into a podcast"
Args:
source_content: The text content to convert into a podcast.
podcast_title: Title for the podcast (default: "SurfSense Podcast")
user_prompt: Optional instructions for podcast style, tone, or format.
Returns:
A dictionary containing:
- status: PodcastStatus value (pending, generating, or failed)
- podcast_id: The podcast ID for polling (when status is pending or generating)
- title: The podcast title
- message: Status message (or "error" field if status is failed)
"""
try:
# Open a fresh session per call. The streaming task's session is
# shared between every tool, and ``AsyncSession`` is NOT safe for
# concurrent use: when the LLM emits parallel tool calls, two
# concurrent ``add()`` / ``commit()`` paths interleave and the
# second one hits "Session.add() during flush" → the transaction
# is poisoned for both tools.
async with shielded_async_session() as session:
podcast = Podcast(
title=podcast_title,
status=PodcastStatus.PENDING,
search_space_id=search_space_id,
thread_id=thread_id,
)
session.add(podcast)
await session.commit()
await session.refresh(podcast)
podcast_id = podcast.id
from app.tasks.celery_tasks.podcast_tasks import (
generate_content_podcast_task,
)
task = generate_content_podcast_task.delay(
podcast_id=podcast_id,
source_content=source_content,
search_space_id=search_space_id,
user_prompt=user_prompt,
)
logger.info(
"[generate_podcast] Created podcast %s, task: %s",
podcast_id,
task.id,
)
# Wait until the Celery worker flips the row to a terminal
# state. No internal budget — see deliverable_wait module.
terminal_status, columns, elapsed = await wait_for_deliverable(
model=Podcast,
row_id=podcast_id,
columns=[Podcast.status, Podcast.file_location],
terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED},
)
if terminal_status == PodcastStatus.READY:
file_location = columns[1] if columns else None
logger.info(
"[generate_podcast] Podcast %s READY in %.2fs (file=%s)",
podcast_id,
elapsed,
file_location,
)
return {
"status": PodcastStatus.READY.value,
"podcast_id": podcast_id,
"title": podcast_title,
"file_location": file_location,
"message": ("Podcast generated and saved to your podcast panel."),
}
# Only other terminal state is FAILED.
logger.warning(
"[generate_podcast] Podcast %s FAILED in %.2fs",
podcast_id,
elapsed,
)
return {
"status": PodcastStatus.FAILED.value,
"podcast_id": podcast_id,
"title": podcast_title,
"error": ("Background worker reported FAILED status for this podcast."),
}
except Exception as e:
error_message = str(e)
logger.exception("[generate_podcast] Error: %s", error_message)
return {
"status": PodcastStatus.FAILED.value,
"error": error_message,
"title": podcast_title,
"podcast_id": None,
}
return generate_podcast

View file

@ -0,0 +1,962 @@
"""Tools registry for SurfSense deep agent.
This module provides a registry pattern for managing tools in the SurfSense agent.
It makes it easy for OSS contributors to add new tools by:
1. Creating a tool factory function in a new file in this directory
2. Registering the tool in the BUILTIN_TOOLS list below
Example of adding a new tool:
------------------------------
1. Create your tool file (e.g., `tools/my_tool.py`):
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
def create_my_tool(search_space_id: int, db_session: AsyncSession):
@tool
async def my_tool(param: str) -> dict:
'''My tool description.'''
# Your implementation
return {"result": "success"}
return my_tool
2. Import and register in this file:
from .my_tool import create_my_tool
# Add to BUILTIN_TOOLS list:
ToolDefinition(
name="my_tool",
description="Description of what your tool does",
factory=lambda deps: create_my_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
),
requires=["search_space_id", "db_session"],
),
"""
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
from langchain_core.tools import BaseTool
from app.agents.shared.middleware.dedup_tool_calls import (
wrap_dedup_key_by_arg_name,
)
from app.db import ChatVisibility
from .confluence import (
create_create_confluence_page_tool,
create_delete_confluence_page_tool,
create_update_confluence_page_tool,
)
from .connected_accounts import create_get_connected_accounts_tool
from .discord import (
create_list_discord_channels_tool,
create_read_discord_messages_tool,
create_send_discord_message_tool,
)
from .dropbox import (
create_create_dropbox_file_tool,
create_delete_dropbox_file_tool,
)
from .generate_image import create_generate_image_tool
from .gmail import (
create_create_gmail_draft_tool,
create_read_gmail_email_tool,
create_search_gmail_tool,
create_send_gmail_email_tool,
create_trash_gmail_email_tool,
create_update_gmail_draft_tool,
)
from .google_calendar import (
create_create_calendar_event_tool,
create_delete_calendar_event_tool,
create_search_calendar_events_tool,
create_update_calendar_event_tool,
)
from .google_drive import (
create_create_google_drive_file_tool,
create_delete_google_drive_file_tool,
)
from .luma import (
create_create_luma_event_tool,
create_list_luma_events_tool,
create_read_luma_event_tool,
)
from .mcp_tool import load_mcp_tools
from .notion import (
create_create_notion_page_tool,
create_delete_notion_page_tool,
create_update_notion_page_tool,
)
from .onedrive import (
create_create_onedrive_file_tool,
create_delete_onedrive_file_tool,
)
from .podcast import create_generate_podcast_tool
from .report import create_generate_report_tool
from .resume import create_generate_resume_tool
from .scrape_webpage import create_scrape_webpage_tool
from .teams import (
create_list_teams_channels_tool,
create_read_teams_messages_tool,
create_send_teams_message_tool,
)
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
from .video_presentation import create_generate_video_presentation_tool
from .web_search import create_web_search_tool
logger = logging.getLogger(__name__)
# =============================================================================
# Tool Definition
# =============================================================================
@dataclass
class ToolDefinition:
"""Definition of a tool that can be added to the agent.
Attributes:
name: Unique identifier for the tool
description: Human-readable description of what the tool does
factory: Callable that creates the tool. Receives a dict of dependencies.
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
enabled_by_default: Whether the tool is enabled when no explicit config is provided
required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``)
that must be in ``available_connectors`` for the tool to be enabled.
dedup_key: Optional callable that maps a tool's ``args`` dict to a
string signature used by :class:`DedupHITLToolCallsMiddleware`
to drop duplicate calls within a single LLM response.
reverse: Optional callable that, given the tool's ``(args, result)``,
returns a ``ReverseDescriptor`` describing the inverse tool
invocation. Consumed by the snapshot/revert pipeline.
"""
name: str
description: str
factory: Callable[[dict[str, Any]], BaseTool]
requires: list[str] = field(default_factory=list)
enabled_by_default: bool = True
hidden: bool = False
required_connector: str | None = None
dedup_key: Callable[[dict[str, Any]], str] | None = None
reverse: Callable[[dict[str, Any], Any], dict[str, Any]] | None = None
# =============================================================================
# Deferred-import factories
# =============================================================================
# Used for tools whose impls live under ``multi_agent_chat``. Importing those
# at module-load time would cycle (``multi_agent_chat`` middleware imports
# this registry). The import inside the factory runs only when
# ``build_tools`` is called, by which point ``multi_agent_chat`` is fully
# initialised.
def _build_create_automation_tool(deps: dict[str, Any]) -> BaseTool:
from app.agents.multi_agent_chat.main_agent.tools.automation import (
create_create_automation_tool,
)
return create_create_automation_tool(
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
llm=deps["llm"],
)
# =============================================================================
# Built-in Tools Registry
# =============================================================================
# Registry of all built-in tools
# Contributors: Add your new tools here!
BUILTIN_TOOLS: list[ToolDefinition] = [
# Podcast generation tool
ToolDefinition(
name="generate_podcast",
description="Generate an audio podcast from provided content",
factory=lambda deps: create_generate_podcast_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
thread_id=deps["thread_id"],
),
requires=["search_space_id", "db_session", "thread_id"],
),
# Video presentation generation tool
ToolDefinition(
name="generate_video_presentation",
description="Generate a video presentation with slides and narration from provided content",
factory=lambda deps: create_generate_video_presentation_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
thread_id=deps["thread_id"],
),
requires=["search_space_id", "db_session", "thread_id"],
),
# Report generation tool (inline, short-lived sessions for DB ops)
# Supports internal KB search via source_strategy so the agent does not
# need a separate search step before generating.
ToolDefinition(
name="generate_report",
description="Generate a structured report from provided content and export it",
factory=lambda deps: create_generate_report_tool(
search_space_id=deps["search_space_id"],
thread_id=deps["thread_id"],
connector_service=deps.get("connector_service"),
available_connectors=deps.get("available_connectors"),
available_document_types=deps.get("available_document_types"),
),
requires=["search_space_id", "thread_id"],
# connector_service, available_connectors, and available_document_types
# are optional — when missing, source_strategy="kb_search" degrades
# gracefully to "provided"
),
# Resume generation tool (Typst-based, uses rendercv package)
ToolDefinition(
name="generate_resume",
description="Generate a professional resume as a Typst document",
factory=lambda deps: create_generate_resume_tool(
search_space_id=deps["search_space_id"],
thread_id=deps["thread_id"],
),
requires=["search_space_id", "thread_id"],
),
# Generate image tool - creates images using AI models (DALL-E, GPT Image, etc.)
ToolDefinition(
name="generate_image",
description="Generate images from text descriptions using AI image models",
factory=lambda deps: create_generate_image_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
),
requires=["search_space_id", "db_session"],
),
# Web scraping tool - extracts content from webpages
ToolDefinition(
name="scrape_webpage",
description="Scrape and extract the main content from a webpage",
factory=lambda deps: create_scrape_webpage_tool(
firecrawl_api_key=deps.get("firecrawl_api_key"),
),
requires=[], # firecrawl_api_key is optional
),
# Web search tool — real-time web search via SearXNG + user-configured engines
ToolDefinition(
name="web_search",
description="Search the web for real-time information using configured search engines",
factory=lambda deps: create_web_search_tool(
search_space_id=deps.get("search_space_id"),
available_connectors=deps.get("available_connectors"),
),
requires=[],
),
# =========================================================================
# SERVICE ACCOUNT DISCOVERY
# Generic tool for the LLM to discover connected accounts and resolve
# service-specific identifiers (e.g. Jira cloudId, Slack team, etc.)
# =========================================================================
ToolDefinition(
name="get_connected_accounts",
description="Discover connected accounts for a service and their metadata",
factory=lambda deps: create_get_connected_accounts_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
# =========================================================================
# AUTOMATION AUTHORING - single HITL tool. The tool takes an NL ``intent``
# from the main agent, drafts the full AutomationCreate JSON via a focused
# sub-LLM, surfaces it on an approval card, and persists on approval. The
# factory defers its import because the impl lives under ``multi_agent_chat``
# and that package transitively pulls this registry via middleware;
# deferring to ``build_tools`` call-time breaks the cycle without a
# parallel registry.
# =========================================================================
ToolDefinition(
name="create_automation",
description="Draft an automation from an NL intent; user approves the card; tool saves",
factory=_build_create_automation_tool,
requires=["search_space_id", "user_id", "llm"],
),
# =========================================================================
# MEMORY TOOL - single update_memory, private or team by thread_visibility
# =========================================================================
ToolDefinition(
name="update_memory",
description="Save important long-term facts, preferences, and instructions to the (personal or team) memory",
factory=lambda deps: (
create_update_team_memory_tool(
search_space_id=deps["search_space_id"],
db_session=deps["db_session"],
llm=deps.get("llm"),
)
if deps["thread_visibility"] == ChatVisibility.SEARCH_SPACE
else create_update_memory_tool(
user_id=deps["user_id"],
db_session=deps["db_session"],
llm=deps.get("llm"),
)
),
requires=[
"user_id",
"search_space_id",
"db_session",
"thread_visibility",
"llm",
],
),
# =========================================================================
# NOTION TOOLS - create, update, delete pages
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
# =========================================================================
ToolDefinition(
name="create_notion_page",
description="Create a new page in the user's Notion workspace",
factory=lambda deps: create_create_notion_page_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="NOTION_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("title"),
),
ToolDefinition(
name="update_notion_page",
description="Append new content to an existing Notion page",
factory=lambda deps: create_update_notion_page_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="NOTION_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("page_title"),
),
ToolDefinition(
name="delete_notion_page",
description="Delete an existing Notion page",
factory=lambda deps: create_delete_notion_page_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="NOTION_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("page_title"),
),
# =========================================================================
# GOOGLE DRIVE TOOLS - create files, delete files
# Auto-disabled when no Google Drive connector is configured (see chat_deepagent.py)
# =========================================================================
ToolDefinition(
name="create_google_drive_file",
description="Create a new Google Doc or Google Sheet in Google Drive",
factory=lambda deps: create_create_google_drive_file_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_DRIVE_FILE",
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
),
ToolDefinition(
name="delete_google_drive_file",
description="Move an indexed Google Drive file to trash",
factory=lambda deps: create_delete_google_drive_file_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_DRIVE_FILE",
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
),
# =========================================================================
# DROPBOX TOOLS - create and trash files
# Auto-disabled when no Dropbox connector is configured (see chat_deepagent.py)
# =========================================================================
ToolDefinition(
name="create_dropbox_file",
description="Create a new file in Dropbox",
factory=lambda deps: create_create_dropbox_file_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DROPBOX_FILE",
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
),
ToolDefinition(
name="delete_dropbox_file",
description="Delete a file from Dropbox",
factory=lambda deps: create_delete_dropbox_file_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DROPBOX_FILE",
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
),
# =========================================================================
# ONEDRIVE TOOLS - create and trash files
# Auto-disabled when no OneDrive connector is configured (see chat_deepagent.py)
# =========================================================================
ToolDefinition(
name="create_onedrive_file",
description="Create a new file in Microsoft OneDrive",
factory=lambda deps: create_create_onedrive_file_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="ONEDRIVE_FILE",
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
),
ToolDefinition(
name="delete_onedrive_file",
description="Move a OneDrive file to the recycle bin",
factory=lambda deps: create_delete_onedrive_file_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="ONEDRIVE_FILE",
dedup_key=wrap_dedup_key_by_arg_name("file_name"),
),
# =========================================================================
# GOOGLE CALENDAR TOOLS - search, create, update, delete events
# Auto-disabled when no Google Calendar connector is configured
# =========================================================================
ToolDefinition(
name="search_calendar_events",
description="Search Google Calendar events within a date range",
factory=lambda deps: create_search_calendar_events_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
),
ToolDefinition(
name="create_calendar_event",
description="Create a new event on Google Calendar",
factory=lambda deps: create_create_calendar_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("title"),
),
ToolDefinition(
name="update_calendar_event",
description="Update an existing indexed Google Calendar event",
factory=lambda deps: create_update_calendar_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("event_title_or_id"),
),
ToolDefinition(
name="delete_calendar_event",
description="Delete an existing indexed Google Calendar event",
factory=lambda deps: create_delete_calendar_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("event_title_or_id"),
),
# =========================================================================
# GMAIL TOOLS - search, read, create drafts, update drafts, send, trash
# Auto-disabled when no Gmail connector is configured
# =========================================================================
ToolDefinition(
name="search_gmail",
description="Search emails in Gmail using Gmail search syntax",
factory=lambda deps: create_search_gmail_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="read_gmail_email",
description="Read the full content of a specific Gmail email",
factory=lambda deps: create_read_gmail_email_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="create_gmail_draft",
description="Create a draft email in Gmail",
factory=lambda deps: create_create_gmail_draft_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("subject"),
),
ToolDefinition(
name="send_gmail_email",
description="Send an email via Gmail",
factory=lambda deps: create_send_gmail_email_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("subject"),
),
ToolDefinition(
name="trash_gmail_email",
description="Move an indexed email to trash in Gmail",
factory=lambda deps: create_trash_gmail_email_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("email_subject_or_id"),
),
ToolDefinition(
name="update_gmail_draft",
description="Update an existing Gmail draft",
factory=lambda deps: create_update_gmail_draft_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("draft_subject_or_id"),
),
# =========================================================================
# CONFLUENCE TOOLS - create, update, delete pages
# Auto-disabled when no Confluence connector is configured (see chat_deepagent.py)
# =========================================================================
ToolDefinition(
name="create_confluence_page",
description="Create a new page in the user's Confluence space",
factory=lambda deps: create_create_confluence_page_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="CONFLUENCE_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("title"),
),
ToolDefinition(
name="update_confluence_page",
description="Update an existing indexed Confluence page",
factory=lambda deps: create_update_confluence_page_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="CONFLUENCE_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("page_title_or_id"),
),
ToolDefinition(
name="delete_confluence_page",
description="Delete an existing indexed Confluence page",
factory=lambda deps: create_delete_confluence_page_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="CONFLUENCE_CONNECTOR",
dedup_key=wrap_dedup_key_by_arg_name("page_title_or_id"),
),
# =========================================================================
# DISCORD TOOLS - list channels, read messages, send messages
# Auto-disabled when no Discord connector is configured
# =========================================================================
ToolDefinition(
name="list_discord_channels",
description="List text channels in the connected Discord server",
factory=lambda deps: create_list_discord_channels_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DISCORD_CONNECTOR",
),
ToolDefinition(
name="read_discord_messages",
description="Read recent messages from a Discord text channel",
factory=lambda deps: create_read_discord_messages_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DISCORD_CONNECTOR",
),
ToolDefinition(
name="send_discord_message",
description="Send a message to a Discord text channel",
factory=lambda deps: create_send_discord_message_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DISCORD_CONNECTOR",
),
# =========================================================================
# TEAMS TOOLS - list channels, read messages, send messages
# Auto-disabled when no Teams connector is configured
# =========================================================================
ToolDefinition(
name="list_teams_channels",
description="List Microsoft Teams and their channels",
factory=lambda deps: create_list_teams_channels_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="TEAMS_CONNECTOR",
),
ToolDefinition(
name="read_teams_messages",
description="Read recent messages from a Microsoft Teams channel",
factory=lambda deps: create_read_teams_messages_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="TEAMS_CONNECTOR",
),
ToolDefinition(
name="send_teams_message",
description="Send a message to a Microsoft Teams channel",
factory=lambda deps: create_send_teams_message_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="TEAMS_CONNECTOR",
),
# =========================================================================
# LUMA TOOLS - list events, read event details, create events
# Auto-disabled when no Luma connector is configured
# =========================================================================
ToolDefinition(
name="list_luma_events",
description="List upcoming and recent Luma events",
factory=lambda deps: create_list_luma_events_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="LUMA_CONNECTOR",
),
ToolDefinition(
name="read_luma_event",
description="Read detailed information about a specific Luma event",
factory=lambda deps: create_read_luma_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="LUMA_CONNECTOR",
),
ToolDefinition(
name="create_luma_event",
description="Create a new event on Luma",
factory=lambda deps: create_create_luma_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="LUMA_CONNECTOR",
),
]
# =============================================================================
# Registry Functions
# =============================================================================
def get_tool_by_name(name: str) -> ToolDefinition | None:
"""Get a tool definition by its name."""
for tool_def in BUILTIN_TOOLS:
if tool_def.name == name:
return tool_def
return None
def get_connector_gated_tools(
available_connectors: list[str] | None,
) -> list[str]:
"""Return tool names to disable"""
available = set() if available_connectors is None else set(available_connectors)
disabled: list[str] = []
for tool_def in BUILTIN_TOOLS:
if tool_def.required_connector and tool_def.required_connector not in available:
disabled.append(tool_def.name)
return disabled
def get_all_tool_names() -> list[str]:
"""Get names of all registered tools."""
return [tool_def.name for tool_def in BUILTIN_TOOLS]
def get_default_enabled_tools() -> list[str]:
"""Get names of tools that are enabled by default (excludes hidden tools)."""
return [tool_def.name for tool_def in BUILTIN_TOOLS if tool_def.enabled_by_default]
def build_tools(
dependencies: dict[str, Any],
enabled_tools: list[str] | None = None,
disabled_tools: list[str] | None = None,
additional_tools: list[BaseTool] | None = None,
) -> list[BaseTool]:
"""Build the list of tools for the agent.
Args:
dependencies: Dict containing all possible dependencies:
- search_space_id: The search space ID
- db_session: Database session
- connector_service: Connector service instance
- firecrawl_api_key: Optional Firecrawl API key
enabled_tools: Explicit list of tool names to enable. If None, uses defaults.
disabled_tools: List of tool names to disable (applied after enabled_tools).
additional_tools: Extra tools to add (e.g., custom tools not in registry).
Returns:
List of configured tool instances ready for the agent.
Example:
# Use all default tools
tools = build_tools(deps)
# Use only specific tools
tools = build_tools(deps, enabled_tools=["generate_report"])
# Use defaults but disable podcast
tools = build_tools(deps, disabled_tools=["generate_podcast"])
# Add custom tools
tools = build_tools(deps, additional_tools=[my_custom_tool])
"""
# Determine which tools to enable
if enabled_tools is not None:
tool_names_to_use = set(enabled_tools)
else:
tool_names_to_use = set(get_default_enabled_tools())
# Apply disabled list
if disabled_tools:
tool_names_to_use -= set(disabled_tools)
# Build the tools (skip hidden/WIP tools unconditionally)
tools: list[BaseTool] = []
for tool_def in BUILTIN_TOOLS:
if tool_def.hidden or tool_def.name not in tool_names_to_use:
continue
# Check that all required dependencies are provided
missing_deps = [dep for dep in tool_def.requires if dep not in dependencies]
if missing_deps:
msg = f"Tool '{tool_def.name}' requires dependencies: {missing_deps}"
raise ValueError(
msg,
)
# Create the tool
tool = tool_def.factory(dependencies)
# Propagate the registry-level metadata so middleware (e.g.
# ``DedupHITLToolCallsMiddleware``) and the action-log/revert
# pipeline can pick the resolvers up via ``tool.metadata`` without
# re-importing :data:`BUILTIN_TOOLS`.
if tool_def.dedup_key is not None or tool_def.reverse is not None:
existing_meta = getattr(tool, "metadata", None) or {}
merged_meta = dict(existing_meta)
if tool_def.dedup_key is not None:
merged_meta.setdefault("dedup_key", tool_def.dedup_key)
if tool_def.reverse is not None:
merged_meta.setdefault("reverse", tool_def.reverse)
try:
tool.metadata = merged_meta
except Exception:
logger.debug(
"Tool %s rejected metadata mutation; relying on registry lookup",
tool_def.name,
)
tools.append(tool)
# Add any additional custom tools
if additional_tools:
tools.extend(additional_tools)
return tools
async def build_tools_async(
dependencies: dict[str, Any],
enabled_tools: list[str] | None = None,
disabled_tools: list[str] | None = None,
additional_tools: list[BaseTool] | None = None,
include_mcp_tools: bool = True,
) -> list[BaseTool]:
"""Async version of build_tools that also loads MCP tools from database.
Design Note:
This function exists because MCP tools require database queries to load
user configs, while built-in tools are created synchronously from static
code.
Alternative: We could make build_tools() itself async and always query
the database, but that would force async everywhere even when only using
built-in tools. The current design keeps the simple case (static tools
only) synchronous while supporting dynamic database-loaded tools through
this async wrapper.
Phase 1.3: built-in tool construction (CPU; runs in a thread pool to
avoid event-loop stalls) and MCP tool loading (HTTP/DB I/O; runs on
the event loop) are kicked off concurrently. Cold-path savings are
bounded by the slower of the two typically MCP at ~200ms-1.7s
so the parallelization recovers the ~50-200ms previously spent
serially on built-in construction.
Args:
dependencies: Dict containing all possible dependencies
enabled_tools: Explicit list of tool names to enable. If None, uses defaults.
disabled_tools: List of tool names to disable (applied after enabled_tools).
additional_tools: Extra tools to add (e.g., custom tools not in registry).
include_mcp_tools: Whether to load user's MCP tools from database.
Returns:
List of configured tool instances ready for the agent, including MCP tools.
"""
import asyncio
import time
_perf_log = logging.getLogger("surfsense.perf")
_perf_log.setLevel(logging.DEBUG)
can_load_mcp = (
include_mcp_tools
and "db_session" in dependencies
and "search_space_id" in dependencies
)
# Built-in tool construction is synchronous + CPU-only. Off-loop it so
# MCP's HTTP/DB I/O can fire concurrently. ``build_tools`` is pure
# function over its inputs — safe to thread-shift.
_t0 = time.perf_counter()
builtin_task = asyncio.create_task(
asyncio.to_thread(
build_tools, dependencies, enabled_tools, disabled_tools, additional_tools
)
)
mcp_task: asyncio.Task | None = None
if can_load_mcp:
mcp_task = asyncio.create_task(
load_mcp_tools(
dependencies["db_session"],
dependencies["search_space_id"],
)
)
# Surface failures from each task independently so a flaky MCP
# endpoint never poisons built-in tool registration. ``return_exceptions``
# gives us per-task exceptions instead of dropping the second result
# when the first raises.
if mcp_task is not None:
builtin_result, mcp_result = await asyncio.gather(
builtin_task, mcp_task, return_exceptions=True
)
else:
builtin_result = await builtin_task
mcp_result = None
if isinstance(builtin_result, BaseException):
raise builtin_result # built-in registration failure is non-recoverable
tools: list[BaseTool] = builtin_result
_perf_log.info(
"[build_tools_async] Built-in tools in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0,
len(tools),
)
if mcp_task is not None:
if isinstance(mcp_result, BaseException):
# ``return_exceptions=True`` captures the exception out-of-band,
# so ``sys.exc_info()`` is empty here. Pass the captured
# exception via ``exc_info=`` to get a real traceback.
logging.error(
"Failed to load MCP tools: %s", mcp_result, exc_info=mcp_result
)
else:
mcp_tools = mcp_result or []
_perf_log.info(
"[build_tools_async] MCP tools loaded in %.3fs (%d tools, parallel)",
time.perf_counter() - _t0,
len(mcp_tools),
)
tools.extend(mcp_tools)
logging.info(
"Registered %d MCP tools: %s",
len(mcp_tools),
[t.name for t in mcp_tools],
)
logging.info(
"Total tools for agent: %d%s",
len(tools),
[t.name for t in tools],
)
return tools

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,812 @@
"""
Resume generation tool for the SurfSense agent.
Generates a structured resume as Typst source code using the rendercv package.
The LLM outputs only the content body (= heading, sections, entries) while
the template header (import + show rule) is hardcoded and prepended by the
backend. This eliminates LLM errors in the complex configuration block.
Templates are stored in a registry so new designs can be added by defining
a new entry in _TEMPLATES.
Uses the same short-lived session pattern as generate_report so no DB
connection is held during the long LLM call.
"""
import io
import logging
import re
from datetime import UTC, datetime
from typing import Any
import pypdf
import typst
from langchain_core.callbacks import dispatch_custom_event
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from app.db import Report, shielded_async_session
from app.services.llm_service import get_document_summary_llm
logger = logging.getLogger(__name__)
# ─── Template Registry ───────────────────────────────────────────────────────
# Each template defines:
# header - Typst import + show rule with {name}, {year}, {month}, {day} placeholders
# component_reference - component docs shown to the LLM
# rules - generation rules for the LLM
_TEMPLATES: dict[str, dict[str, str]] = {
"classic": {
"header": """\
#import "@preview/rendercv:0.3.0": *
#show: rendercv.with(
name: "{name}",
title: "{name} - Resume",
footer: context {{ [#emph[{name} -- #str(here().page())\\/#str(counter(page).final().first())]] }},
top-note: [ #emph[Last updated in {month_name} {year}] ],
locale-catalog-language: "en",
text-direction: ltr,
page-size: "us-letter",
page-top-margin: 0.7in,
page-bottom-margin: 0.7in,
page-left-margin: 0.7in,
page-right-margin: 0.7in,
page-show-footer: false,
page-show-top-note: true,
colors-body: rgb(0, 0, 0),
colors-name: rgb(0, 0, 0),
colors-headline: rgb(0, 0, 0),
colors-connections: rgb(0, 0, 0),
colors-section-titles: rgb(0, 0, 0),
colors-links: rgb(0, 0, 0),
colors-footer: rgb(128, 128, 128),
colors-top-note: rgb(128, 128, 128),
typography-line-spacing: 0.6em,
typography-alignment: "justified",
typography-date-and-location-column-alignment: right,
typography-font-family-body: "XCharter",
typography-font-family-name: "XCharter",
typography-font-family-headline: "XCharter",
typography-font-family-connections: "XCharter",
typography-font-family-section-titles: "XCharter",
typography-font-size-body: 10pt,
typography-font-size-name: 25pt,
typography-font-size-headline: 10pt,
typography-font-size-connections: 10pt,
typography-font-size-section-titles: 1.2em,
typography-small-caps-name: false,
typography-small-caps-headline: false,
typography-small-caps-connections: false,
typography-small-caps-section-titles: false,
typography-bold-name: false,
typography-bold-headline: false,
typography-bold-connections: false,
typography-bold-section-titles: true,
links-underline: true,
links-show-external-link-icon: false,
header-alignment: center,
header-photo-width: 3.5cm,
header-space-below-name: 0.7cm,
header-space-below-headline: 0.7cm,
header-space-below-connections: 0.7cm,
header-connections-hyperlink: true,
header-connections-show-icons: false,
header-connections-display-urls-instead-of-usernames: true,
header-connections-separator: "|",
header-connections-space-between-connections: 0.5cm,
section-titles-type: "with_full_line",
section-titles-line-thickness: 0.5pt,
section-titles-space-above: 0.5cm,
section-titles-space-below: 0.3cm,
sections-allow-page-break: true,
sections-space-between-text-based-entries: 0.15cm,
sections-space-between-regular-entries: 0.42cm,
entries-date-and-location-width: 4.15cm,
entries-side-space: 0cm,
entries-space-between-columns: 0.1cm,
entries-allow-page-break: false,
entries-short-second-row: false,
entries-degree-width: 1cm,
entries-summary-space-left: 0cm,
entries-summary-space-above: 0.08cm,
entries-highlights-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
entries-highlights-nested-bullet: text(13pt, [\\u{2022}], baseline: -0.6pt),
entries-highlights-space-left: 0cm,
entries-highlights-space-above: 0.08cm,
entries-highlights-space-between-items: 0.02cm,
entries-highlights-space-between-bullet-and-text: 0.3em,
date: datetime(
year: {year},
month: {month},
day: {day},
),
)
""",
"component_reference": """\
Available components (use ONLY these):
= Full Name // Top-level heading person's full name
#connections( // Contact info row (pipe-separated)
[City, Country],
[#link("mailto:email@example.com", icon: false, if-underline: false, if-color: false)[email\\@example.com]],
[#link("https://linkedin.com/in/user", icon: false, if-underline: false, if-color: false)[linkedin.com\\/in\\/user]],
[#link("https://github.com/user", icon: false, if-underline: false, if-color: false)[github.com\\/user]],
)
== Section Title // Section heading (arbitrary name)
#regular-entry( // Work experience, projects, publications, etc.
[
#strong[Role/Title], Company Name -- Location
],
[
Start -- End
],
main-column-second-row: [
- Achievement or responsibility
- Another bullet point
],
)
#education-entry( // Education entries
[
#strong[Institution], Degree in Field -- Location
],
[
Start -- End
],
main-column-second-row: [
- GPA, honours, relevant coursework
],
)
#summary([Short paragraph summary]) // Optional summary inside an entry
#content-area([Free-form content]) // Freeform text block
For skills sections, use one bullet per category label:
- #strong[Category:] item1, item2, item3
For simple list sections (e.g. Honors), use plain bullet points:
- Item one
- Item two
""",
"rules": """\
RULES:
- Do NOT include any #import or #show lines. Start directly with = Full Name.
- Output ONLY valid Typst content. No explanatory text before or after.
- Do NOT wrap output in ```typst code fences.
- The = heading MUST use the person's COMPLETE full name exactly as provided. NEVER shorten or abbreviate.
- Escape @ symbols inside link labels with a backslash: email\\@example.com
- Escape forward slashes in link display text: linkedin.com\\/in\\/user
- Every section MUST use == heading.
- Use #regular-entry() for experience, projects, publications, certifications, and similar entries.
- Use #education-entry() for education.
- For skills sections, use one bullet line per category with a bold label.
- Keep content professional, concise, and achievement-oriented.
- Use action verbs for bullet points (Led, Built, Designed, Reduced, etc.).
- This template works for ALL professions adapt sections to the user's field.
- Default behavior should prioritize concise one-page content.
""",
},
}
DEFAULT_TEMPLATE = "classic"
MIN_RESUME_PAGES = 1
MAX_RESUME_PAGES = 5
MAX_COMPRESSION_ATTEMPTS = 2
# ─── Template Helpers ─────────────────────────────────────────────────────────
def _get_template(template_id: str | None = None) -> dict[str, str]:
"""Get a template by ID, falling back to default."""
return _TEMPLATES.get(template_id or DEFAULT_TEMPLATE, _TEMPLATES[DEFAULT_TEMPLATE])
_MONTH_NAMES = [
"",
"Jan",
"Feb",
"Mar",
"Apr",
"May",
"Jun",
"Jul",
"Aug",
"Sep",
"Oct",
"Nov",
"Dec",
]
def _build_header(template: dict[str, str], name: str) -> str:
"""Build the template header with the person's name and current date."""
now = datetime.now(tz=UTC)
return (
template["header"]
.replace("{name}", name)
.replace("{year}", str(now.year))
.replace("{month}", str(now.month))
.replace("{day}", str(now.day))
.replace("{month_name}", _MONTH_NAMES[now.month])
)
def _strip_header(full_source: str) -> str:
"""Strip the import + show rule from stored source to get the body only.
Finds the closing parenthesis of the rendercv.with(...) block by tracking
nesting depth, then returns everything after it.
"""
show_match = re.search(r"#show:\s*rendercv\.with\(", full_source)
if not show_match:
return full_source
start = show_match.end()
depth = 1
i = start
while i < len(full_source) and depth > 0:
if full_source[i] == "(":
depth += 1
elif full_source[i] == ")":
depth -= 1
i += 1
return full_source[i:].lstrip("\n")
def _extract_name(body: str) -> str | None:
"""Extract the person's full name from the = heading in the body."""
match = re.search(r"^=\s+(.+)$", body, re.MULTILINE)
return match.group(1).strip() if match else None
def _strip_imports(body: str) -> str:
"""Remove any #import or #show lines the LLM might accidentally include."""
lines = body.split("\n")
cleaned: list[str] = []
skip_show = False
depth = 0
for line in lines:
stripped = line.strip()
if stripped.startswith("#import"):
continue
if skip_show:
depth += stripped.count("(") - stripped.count(")")
if depth <= 0:
skip_show = False
continue
if stripped.startswith("#show:") and "rendercv" in stripped:
depth = stripped.count("(") - stripped.count(")")
if depth > 0:
skip_show = True
continue
cleaned.append(line)
result = "\n".join(cleaned).strip()
return result
def _build_llm_reference(template: dict[str, str]) -> str:
"""Build the LLM prompt reference from a template."""
return f"""\
You MUST output valid Typst content for a resume.
Do NOT include any #import or #show lines — those are handled automatically.
Start directly with the = Full Name heading.
{template["component_reference"]}
{template["rules"]}"""
# ─── Prompts ─────────────────────────────────────────────────────────────────
_RESUME_PROMPT = """\
You are an expert resume writer. Generate professional resume content as Typst markup.
{llm_reference}
**User Information:**
{user_info}
**Target Maximum Pages:** {max_pages}
{user_instructions_section}
Generate the resume content now (starting with = Full Name):
"""
_REVISION_PROMPT = """\
You are an expert resume editor. Modify the existing resume according to the instructions.
Apply ONLY the requested changes do NOT rewrite sections that are not affected.
{llm_reference}
**Target Maximum Pages:** {max_pages}
**Modification Instructions:** {user_instructions}
**EXISTING RESUME CONTENT:**
{previous_content}
---
Output the complete, updated resume content with the changes applied (starting with = Full Name):
"""
_FIX_COMPILE_PROMPT = """\
The resume content you generated failed to compile. Fix the error while preserving all content.
{llm_reference}
**Compilation Error:**
{error}
**Full Typst Source (for context error line numbers refer to this):**
{full_source}
**Your content starts after the template header. Output ONLY the content portion \
(starting with = Full Name), NOT the #import or #show rule:**
"""
_COMPRESS_TO_PAGE_LIMIT_PROMPT = """\
The resume compiles, but it exceeds the maximum allowed page count.
Compress the resume while preserving high-impact accomplishments and role relevance.
{llm_reference}
**Target Maximum Pages:** {max_pages}
**Current Page Count:** {actual_pages}
**Compression Attempt:** {attempt_number}
Compression priorities (in this order):
1) Keep recent, high-impact, role-relevant bullets.
2) Remove low-impact or redundant bullets.
3) Shorten verbose wording while preserving meaning.
4) Trim older or less relevant details before recent ones.
Return the complete updated Typst content (starting with = Full Name), and keep it at or below the target pages.
**EXISTING RESUME CONTENT:**
{previous_content}
"""
# ─── Helpers ─────────────────────────────────────────────────────────────────
def _strip_typst_fences(text: str) -> str:
"""Remove wrapping ```typst ... ``` fences that LLMs sometimes add."""
stripped = text.strip()
m = re.match(r"^(`{3,})(?:typst|typ)?\s*\n", stripped)
if m:
fence = m.group(1)
if stripped.endswith(fence):
stripped = stripped[m.end() :]
stripped = stripped[: -len(fence)].rstrip()
return stripped
def _compile_typst(source: str) -> bytes:
"""Compile Typst source to PDF bytes. Raises on failure."""
return typst.compile(source.encode("utf-8"))
def _count_pdf_pages(pdf_bytes: bytes) -> int:
"""Count the number of pages in compiled PDF bytes."""
with io.BytesIO(pdf_bytes) as pdf_stream:
reader = pypdf.PdfReader(pdf_stream)
return len(reader.pages)
def _validate_max_pages(max_pages: int) -> int:
"""Validate and normalize max_pages input."""
if MIN_RESUME_PAGES <= max_pages <= MAX_RESUME_PAGES:
return max_pages
msg = (
f"max_pages must be between {MIN_RESUME_PAGES} and "
f"{MAX_RESUME_PAGES}. Received: {max_pages}"
)
raise ValueError(msg)
# ─── Tool Factory ───────────────────────────────────────────────────────────
def create_generate_resume_tool(
search_space_id: int,
thread_id: int | None = None,
):
"""
Factory function to create the generate_resume tool.
Generates a Typst-based resume, validates it via compilation,
and stores the source in the Report table with content_type='typst'.
The LLM generates only the content body; the template header is
prepended by the backend.
"""
@tool
async def generate_resume(
user_info: str,
user_instructions: str | None = None,
parent_report_id: int | None = None,
max_pages: int = 1,
) -> dict[str, Any]:
"""
Generate a professional resume as a Typst document.
Use this tool when the user asks to create, build, generate, write,
or draft a resume or CV. Also use it when the user wants to modify,
update, or revise an existing resume generated in this conversation.
Trigger phrases include:
- "build me a resume", "create my resume", "generate a CV"
- "update my resume", "change my title", "add my new job"
- "make my resume more concise", "reformat my resume"
Do NOT use this tool for:
- General questions about resumes or career advice
- Reviewing or critiquing a resume without changes
- Cover letters (use generate_report instead)
VERSIONING parent_report_id:
- Set parent_report_id when the user wants to MODIFY an existing
resume that was already generated in this conversation.
- Leave as None for new resumes.
Args:
user_info: The user's resume content — work experience,
education, skills, contact info, etc. Can be structured
or unstructured text.
user_instructions: Optional style or content preferences
(e.g. "emphasize leadership", "keep it to one page",
"use a modern style"). For revisions, describe what to change.
parent_report_id: ID of a previous resume to revise (creates
new version in the same version group).
max_pages: Maximum number of pages for the generated resume.
Defaults to 1. Allowed range: 1-5.
Returns:
Dict with status, report_id, title, and content_type.
"""
report_group_id: int | None = None
parent_content: str | None = None
template = _get_template()
llm_reference = _build_llm_reference(template)
async def _save_failed_report(error_msg: str) -> int | None:
try:
async with shielded_async_session() as session:
failed = Report(
title="Resume",
content=None,
content_type="typst",
report_metadata={
"status": "failed",
"error_message": error_msg,
},
report_style="resume",
search_space_id=search_space_id,
thread_id=thread_id,
report_group_id=report_group_id,
)
session.add(failed)
await session.commit()
await session.refresh(failed)
if not failed.report_group_id:
failed.report_group_id = failed.id
await session.commit()
logger.info(
f"[generate_resume] Saved failed report {failed.id}: {error_msg}"
)
return failed.id
except Exception:
logger.exception(
"[generate_resume] Could not persist failed report row"
)
return None
try:
try:
validated_max_pages = _validate_max_pages(max_pages)
except ValueError as e:
error_msg = str(e)
report_id = await _save_failed_report(error_msg)
return {
"status": "failed",
"error": error_msg,
"report_id": report_id,
"title": "Resume",
"content_type": "typst",
}
# ── Phase 1: READ ─────────────────────────────────────────────
async with shielded_async_session() as read_session:
if parent_report_id:
parent_report = await read_session.get(Report, parent_report_id)
if parent_report:
report_group_id = parent_report.report_group_id
parent_content = parent_report.content
logger.info(
f"[generate_resume] Revising from parent {parent_report_id} "
f"(group {report_group_id})"
)
llm = await get_document_summary_llm(read_session, search_space_id)
if not llm:
error_msg = (
"No LLM configured. Please configure a language model in Settings."
)
report_id = await _save_failed_report(error_msg)
return {
"status": "failed",
"error": error_msg,
"report_id": report_id,
"title": "Resume",
"content_type": "typst",
}
# ── Phase 2: LLM GENERATION ───────────────────────────────────
user_instructions_section = ""
if user_instructions:
user_instructions_section = (
f"**Additional Instructions:** {user_instructions}"
)
if parent_content:
dispatch_custom_event(
"report_progress",
{"phase": "writing", "message": "Updating your resume"},
)
parent_body = _strip_header(parent_content)
prompt = _REVISION_PROMPT.format(
llm_reference=llm_reference,
max_pages=validated_max_pages,
user_instructions=user_instructions
or "Improve and refine the resume.",
previous_content=parent_body,
)
else:
dispatch_custom_event(
"report_progress",
{"phase": "writing", "message": "Building your resume"},
)
prompt = _RESUME_PROMPT.format(
llm_reference=llm_reference,
user_info=user_info,
max_pages=validated_max_pages,
user_instructions_section=user_instructions_section,
)
response = await llm.ainvoke([HumanMessage(content=prompt)])
body = response.content
if not body or not isinstance(body, str):
error_msg = "LLM returned empty or invalid content"
report_id = await _save_failed_report(error_msg)
return {
"status": "failed",
"error": error_msg,
"report_id": report_id,
"title": "Resume",
"content_type": "typst",
}
body = _strip_typst_fences(body)
body = _strip_imports(body)
# ── Phase 3: ASSEMBLE + COMPILE ───────────────────────────────
dispatch_custom_event(
"report_progress",
{"phase": "compiling", "message": "Compiling resume..."},
)
name = _extract_name(body) or "Resume"
typst_source = ""
actual_pages = 0
compression_attempts = 0
target_page_met = False
for compression_round in range(MAX_COMPRESSION_ATTEMPTS + 1):
header = _build_header(template, name)
typst_source = header + body
compile_error: str | None = None
pdf_bytes: bytes | None = None
for compile_attempt in range(2):
try:
pdf_bytes = _compile_typst(typst_source)
compile_error = None
break
except Exception as e:
compile_error = str(e)
logger.warning(
"[generate_resume] Compile attempt %s failed: %s",
compile_attempt + 1,
compile_error,
)
if compile_attempt == 0:
dispatch_custom_event(
"report_progress",
{
"phase": "fixing",
"message": "Fixing compilation issue...",
},
)
fix_prompt = _FIX_COMPILE_PROMPT.format(
llm_reference=llm_reference,
error=compile_error,
full_source=typst_source,
)
fix_response = await llm.ainvoke(
[HumanMessage(content=fix_prompt)]
)
if fix_response.content and isinstance(
fix_response.content, str
):
body = _strip_typst_fences(fix_response.content)
body = _strip_imports(body)
name = _extract_name(body) or name
header = _build_header(template, name)
typst_source = header + body
if compile_error or not pdf_bytes:
error_msg = (
"Typst compilation failed after 2 attempts: "
f"{compile_error or 'Unknown compile error'}"
)
report_id = await _save_failed_report(error_msg)
return {
"status": "failed",
"error": error_msg,
"report_id": report_id,
"title": "Resume",
"content_type": "typst",
}
actual_pages = _count_pdf_pages(pdf_bytes)
if actual_pages <= validated_max_pages:
target_page_met = True
break
if compression_round >= MAX_COMPRESSION_ATTEMPTS:
break
compression_attempts += 1
dispatch_custom_event(
"report_progress",
{
"phase": "compressing",
"message": f"Condensing resume to {validated_max_pages} page(s)...",
},
)
compress_prompt = _COMPRESS_TO_PAGE_LIMIT_PROMPT.format(
llm_reference=llm_reference,
max_pages=validated_max_pages,
actual_pages=actual_pages,
attempt_number=compression_attempts,
previous_content=body,
)
compress_response = await llm.ainvoke(
[HumanMessage(content=compress_prompt)]
)
if not compress_response.content or not isinstance(
compress_response.content, str
):
error_msg = "LLM returned empty content while compressing resume"
report_id = await _save_failed_report(error_msg)
return {
"status": "failed",
"error": error_msg,
"report_id": report_id,
"title": "Resume",
"content_type": "typst",
}
body = _strip_typst_fences(compress_response.content)
body = _strip_imports(body)
name = _extract_name(body) or name
if actual_pages > MAX_RESUME_PAGES:
error_msg = (
"Resume exceeds hard page limit after compression retries. "
f"Hard limit: <= {MAX_RESUME_PAGES} page(s), actual: {actual_pages}."
)
report_id = await _save_failed_report(error_msg)
return {
"status": "failed",
"error": error_msg,
"report_id": report_id,
"title": "Resume",
"content_type": "typst",
}
# ── Phase 4: SAVE ─────────────────────────────────────────────
dispatch_custom_event(
"report_progress",
{"phase": "saving", "message": "Saving your resume"},
)
resume_title = f"{name} - Resume" if name != "Resume" else "Resume"
metadata: dict[str, Any] = {
"status": "ready",
"word_count": len(typst_source.split()),
"char_count": len(typst_source),
"target_max_pages": validated_max_pages,
"actual_page_count": actual_pages,
"page_limit_enforced": True,
"compression_attempts": compression_attempts,
"target_page_met": target_page_met,
}
async with shielded_async_session() as write_session:
report = Report(
title=resume_title,
content=typst_source,
content_type="typst",
report_metadata=metadata,
report_style="resume",
search_space_id=search_space_id,
thread_id=thread_id,
report_group_id=report_group_id,
)
write_session.add(report)
await write_session.commit()
await write_session.refresh(report)
if not report.report_group_id:
report.report_group_id = report.id
await write_session.commit()
saved_id = report.id
logger.info(f"[generate_resume] Created resume {saved_id}: {resume_title}")
return {
"status": "ready",
"report_id": saved_id,
"title": resume_title,
"content_type": "typst",
"is_revision": bool(parent_content),
"message": (
f"Resume generated successfully: {resume_title}"
if target_page_met
else (
f"Resume generated, but could not fit the target of <= {validated_max_pages} "
f"page(s). Final length: {actual_pages} page(s)."
)
),
}
except Exception as e:
error_message = str(e)
logger.exception(f"[generate_resume] Error: {error_message}")
report_id = await _save_failed_report(error_message)
return {
"status": "failed",
"error": error_message,
"report_id": report_id,
"title": "Resume",
"content_type": "typst",
}
return generate_resume

View file

@ -0,0 +1,306 @@
"""
Web scraping tool for the SurfSense agent.
This module provides a tool for scraping and extracting content from webpages
using the existing WebCrawlerConnector. For YouTube URLs, it fetches the
transcript directly via the YouTubeTranscriptApi instead of crawling the page.
"""
import hashlib
import logging
from typing import Any
from urllib.parse import urlparse
import aiohttp
from fake_useragent import UserAgent
from langchain_core.tools import tool
from requests import Session
from youtube_transcript_api import YouTubeTranscriptApi
from app.connectors.webcrawler_connector import WebCrawlerConnector
from app.tasks.document_processors.youtube_processor import get_youtube_video_id
from app.utils.proxy_config import get_requests_proxies
logger = logging.getLogger(__name__)
def extract_domain(url: str) -> str:
"""Extract the domain from a URL."""
try:
parsed = urlparse(url)
domain = parsed.netloc
# Remove 'www.' prefix if present
if domain.startswith("www."):
domain = domain[4:]
return domain
except Exception:
return ""
def generate_scrape_id(url: str) -> str:
"""Generate a unique ID for a scraped webpage."""
hash_val = hashlib.md5(url.encode()).hexdigest()[:12]
return f"scrape-{hash_val}"
def truncate_content(content: str, max_length: int = 50000) -> tuple[str, bool]:
"""
Truncate content to a maximum length.
Returns:
Tuple of (truncated_content, was_truncated)
"""
if len(content) <= max_length:
return content, False
# Try to truncate at a sentence boundary
truncated = content[:max_length]
last_period = truncated.rfind(".")
last_newline = truncated.rfind("\n\n")
# Use the later of the two boundaries, or just truncate
boundary = max(last_period, last_newline)
if boundary > max_length * 0.8: # Only use boundary if it's not too far back
truncated = content[: boundary + 1]
return truncated + "\n\n[Content truncated...]", True
async def _scrape_youtube_video(
url: str, video_id: str, max_length: int
) -> dict[str, Any]:
"""
Fetch YouTube video metadata and transcript via the YouTubeTranscriptApi.
Returns a result dict in the same shape as the regular scrape_webpage output.
"""
scrape_id = generate_scrape_id(url)
domain = "youtube.com"
# --- Video metadata via oEmbed ---
residential_proxies = get_requests_proxies()
params = {
"format": "json",
"url": f"https://www.youtube.com/watch?v={video_id}",
}
oembed_url = "https://www.youtube.com/oembed"
try:
async with (
aiohttp.ClientSession() as http_session,
http_session.get(
oembed_url,
params=params,
proxy=residential_proxies["http"] if residential_proxies else None,
) as response,
):
video_data = await response.json()
except Exception:
video_data = {}
title = video_data.get("title", "YouTube Video")
author = video_data.get("author_name", "Unknown")
# --- Transcript via YouTubeTranscriptApi ---
try:
ua = UserAgent()
http_client = Session()
http_client.headers.update({"User-Agent": ua.random})
if residential_proxies:
http_client.proxies.update(residential_proxies)
ytt_api = YouTubeTranscriptApi(http_client=http_client)
# List all available transcripts and pick the first one
# (the video's primary language) instead of defaulting to English
transcript_list = ytt_api.list(video_id)
transcript = next(iter(transcript_list))
captions = transcript.fetch()
logger.info(
f"[scrape_webpage] Fetched transcript for {video_id} "
f"in {transcript.language} ({transcript.language_code})"
)
transcript_segments = []
for line in captions:
start_time = line.start
duration = line.duration
text = line.text
timestamp = f"[{start_time:.2f}s-{start_time + duration:.2f}s]"
transcript_segments.append(f"{timestamp} {text}")
transcript_text = "\n".join(transcript_segments)
except Exception as e:
logger.warning(f"[scrape_webpage] No transcript for video {video_id}: {e}")
transcript_text = f"No captions available for this video. Error: {e!s}"
# Build combined content
content = f"# {title}\n\n**Author:** {author}\n**Video ID:** {video_id}\n\n## Transcript\n\n{transcript_text}"
# Truncate if needed
content, was_truncated = truncate_content(content, max_length)
word_count = len(content.split())
description = f"YouTube video by {author}"
return {
"id": scrape_id,
"assetId": url,
"kind": "article",
"href": url,
"title": title,
"description": description,
"content": content,
"domain": domain,
"word_count": word_count,
"was_truncated": was_truncated,
"crawler_type": "youtube_transcript",
"author": author,
}
def create_scrape_webpage_tool(firecrawl_api_key: str | None = None):
"""
Factory function to create the scrape_webpage tool.
Args:
firecrawl_api_key: Optional Firecrawl API key for premium web scraping.
Falls back to Chromium/Trafilatura if not provided.
Returns:
A configured tool function for scraping webpages.
"""
@tool
async def scrape_webpage(
url: str,
max_length: int = 50000,
) -> dict[str, Any]:
"""
Scrape and extract the main content from a webpage.
Use this tool when the user wants you to read, summarize, or answer
questions about a specific webpage's content. This tool actually
fetches and reads the full page content. For YouTube video URLs it
fetches the transcript directly instead of crawling the page.
Common triggers:
- "Read this article and summarize it"
- "What does this page say about X?"
- "Summarize this blog post for me"
- "Tell me the key points from this article"
- "What's in this webpage?"
Args:
url: The URL of the webpage to scrape (must be HTTP/HTTPS)
max_length: Maximum content length to return (default: 50000 chars)
Returns:
A dictionary containing:
- id: Unique identifier for this scrape
- assetId: The URL (for deduplication)
- kind: "article" (type of content)
- href: The URL to open when clicked
- title: Page title
- description: Brief description or excerpt
- content: The extracted main content (markdown format)
- domain: The domain name
- word_count: Approximate word count
- was_truncated: Whether content was truncated
- error: Error message (if scraping failed)
"""
scrape_id = generate_scrape_id(url)
domain = extract_domain(url)
# Validate and normalize URL
if not url.startswith(("http://", "https://")):
url = f"https://{url}"
try:
# Check if this is a YouTube URL and use transcript API instead
video_id = get_youtube_video_id(url)
if video_id:
return await _scrape_youtube_video(url, video_id, max_length)
# Create webcrawler connector
connector = WebCrawlerConnector(firecrawl_api_key=firecrawl_api_key)
# Crawl the URL
result, error = await connector.crawl_url(url, formats=["markdown"])
if error:
return {
"id": scrape_id,
"assetId": url,
"kind": "article",
"href": url,
"title": domain or "Webpage",
"domain": domain,
"error": error,
}
if not result:
return {
"id": scrape_id,
"assetId": url,
"kind": "article",
"href": url,
"title": domain or "Webpage",
"domain": domain,
"error": "No content returned from crawler",
}
# Extract content and metadata
content = result.get("content", "")
metadata = result.get("metadata", {})
# Get title from metadata
title = metadata.get("title", "")
if not title:
title = domain or url.split("/")[-1] or "Webpage"
# Get description from metadata
description = metadata.get("description", "")
if not description and content:
# Use first paragraph as description
first_para = content.split("\n\n")[0] if content else ""
description = (
first_para[:300] + "..." if len(first_para) > 300 else first_para
)
# Truncate content if needed
content, was_truncated = truncate_content(content, max_length)
# Calculate word count
word_count = len(content.split())
return {
"id": scrape_id,
"assetId": url,
"kind": "article",
"href": url,
"title": title,
"description": description,
"content": content,
"domain": domain,
"word_count": word_count,
"was_truncated": was_truncated,
"crawler_type": result.get("crawler_type", "unknown"),
"author": metadata.get("author"),
"date": metadata.get("date"),
}
except Exception as e:
error_message = str(e)
logger.error(f"[scrape_webpage] Error scraping {url}: {error_message}")
return {
"id": scrape_id,
"assetId": url,
"kind": "article",
"href": url,
"title": domain or "Webpage",
"domain": domain,
"error": f"Failed to scrape: {error_message[:100]}",
}
return scrape_webpage

View file

@ -0,0 +1,15 @@
from app.agents.shared.tools.teams.list_channels import (
create_list_teams_channels_tool,
)
from app.agents.shared.tools.teams.read_messages import (
create_read_teams_messages_tool,
)
from app.agents.shared.tools.teams.send_message import (
create_send_teams_message_tool,
)
__all__ = [
"create_list_teams_channels_tool",
"create_read_teams_messages_tool",
"create_send_teams_message_tool",
]

View file

@ -0,0 +1,38 @@
"""Shared auth helper for Teams agent tools (Microsoft Graph REST API)."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
GRAPH_API = "https://graph.microsoft.com/v1.0"
async def get_teams_connector(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> SearchSourceConnector | None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type
== SearchSourceConnectorType.TEAMS_CONNECTOR,
)
)
return result.scalars().first()
async def get_access_token(
db_session: AsyncSession,
connector: SearchSourceConnector,
) -> str:
"""Get a valid Microsoft Graph access token, refreshing if expired."""
from app.connectors.teams_connector import TeamsConnector
tc = TeamsConnector(
session=db_session,
connector_id=connector.id,
)
return await tc._get_valid_token()

View file

@ -0,0 +1,114 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
def create_list_teams_channels_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the list_teams_channels tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured list_teams_channels tool
"""
del db_session # per-call session — see docstring
@tool
async def list_teams_channels() -> dict[str, Any]:
"""List all Microsoft Teams and their channels the user has access to.
Returns:
Dictionary with status and a list of teams, each containing
team_id, team_name, and a list of channels (id, name).
"""
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
async with async_session_maker() as db_session:
connector = await get_teams_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
token = await get_access_token(db_session, connector)
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=20.0) as client:
teams_resp = await client.get(
f"{GRAPH_API}/me/joinedTeams", headers=headers
)
if teams_resp.status_code == 401:
return {
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if teams_resp.status_code != 200:
return {
"status": "error",
"message": f"Graph API error: {teams_resp.status_code}",
}
teams_data = teams_resp.json().get("value", [])
result_teams = []
async with httpx.AsyncClient(timeout=20.0) as client:
for team in teams_data:
team_id = team["id"]
ch_resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels",
headers=headers,
)
channels = []
if ch_resp.status_code == 200:
channels = [
{"id": ch["id"], "name": ch.get("displayName", "")}
for ch in ch_resp.json().get("value", [])
]
result_teams.append(
{
"team_id": team_id,
"team_name": team.get("displayName", ""),
"channels": channels,
}
)
return {
"status": "success",
"teams": result_teams,
"total_teams": len(result_teams),
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error listing Teams channels: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to list Teams channels."}
return list_teams_channels

View file

@ -0,0 +1,125 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
def create_read_teams_messages_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the read_teams_messages tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured read_teams_messages tool
"""
del db_session # per-call session — see docstring
@tool
async def read_teams_messages(
team_id: str,
channel_id: str,
limit: int = 25,
) -> dict[str, Any]:
"""Read recent messages from a Microsoft Teams channel.
Args:
team_id: The team ID (from list_teams_channels).
channel_id: The channel ID (from list_teams_channels).
limit: Number of messages to fetch (default 25, max 50).
Returns:
Dictionary with status and a list of messages including
id, sender, content, timestamp.
"""
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
limit = min(limit, 50)
try:
async with async_session_maker() as db_session:
connector = await get_teams_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
token = await get_access_token(db_session, connector)
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
headers={"Authorization": f"Bearer {token}"},
params={"$top": limit},
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if resp.status_code == 403:
return {
"status": "error",
"message": "Insufficient permissions to read this channel.",
}
if resp.status_code != 200:
return {
"status": "error",
"message": f"Graph API error: {resp.status_code}",
}
raw_msgs = resp.json().get("value", [])
messages = []
for m in raw_msgs:
sender = m.get("from", {})
user_info = sender.get("user", {}) if sender else {}
body = m.get("body", {})
messages.append(
{
"id": m.get("id"),
"sender": user_info.get("displayName", "Unknown"),
"content": body.get("content", ""),
"content_type": body.get("contentType", "text"),
"timestamp": m.get("createdDateTime", ""),
}
)
return {
"status": "success",
"team_id": team_id,
"channel_id": channel_id,
"messages": messages,
"total": len(messages),
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Teams messages: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to read Teams messages."}
return read_teams_messages

View file

@ -0,0 +1,136 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.tools.hitl import request_approval
from app.db import async_session_maker
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
def create_send_teams_message_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
"""
Factory function to create the send_teams_message tool.
The tool acquires its own short-lived ``AsyncSession`` per call via
:data:`async_session_maker` so the closure is safe to share across
HTTP requests by the compiled-agent cache. Capturing a per-request
session here would surface stale/closed sessions on cache hits.
Args:
db_session: Reserved for registry compatibility. Per-call sessions
are opened via :data:`async_session_maker` inside the tool body.
Returns:
Configured send_teams_message tool
"""
del db_session # per-call session — see docstring
@tool
async def send_teams_message(
team_id: str,
channel_id: str,
content: str,
) -> dict[str, Any]:
"""Send a message to a Microsoft Teams channel.
Requires the ChannelMessage.Send OAuth scope. If the user gets a
permission error, they may need to re-authenticate with updated scopes.
Args:
team_id: The team ID (from list_teams_channels).
channel_id: The channel ID (from list_teams_channels).
content: The message text (HTML supported).
Returns:
Dictionary with status, message_id on success.
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
async with async_session_maker() as db_session:
connector = await get_teams_connector(
db_session, search_space_id, user_id
)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
result = request_approval(
action_type="teams_send_message",
tool_name="send_teams_message",
params={
"team_id": team_id,
"channel_id": channel_id,
"content": content,
},
context={"connector_id": connector.id},
)
if result.rejected:
return {
"status": "rejected",
"message": "User declined. Message was not sent.",
}
final_content = result.params.get("content", content)
final_team = result.params.get("team_id", team_id)
final_channel = result.params.get("channel_id", channel_id)
token = await get_access_token(db_session, connector)
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json={"body": {"content": final_content}},
)
if resp.status_code == 401:
return {
"status": "auth_error",
"message": "Teams token expired. Please re-authenticate.",
"connector_type": "teams",
}
if resp.status_code == 403:
return {
"status": "insufficient_permissions",
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
}
if resp.status_code not in (200, 201):
return {
"status": "error",
"message": f"Graph API error: {resp.status_code}{resp.text[:200]}",
}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": "Message sent to Teams channel.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error sending Teams message: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to send Teams message."}
return send_teams_message

View file

@ -0,0 +1,93 @@
"""Memory update tools backed by the canonical memory service."""
from __future__ import annotations
import logging
from typing import Any
from uuid import UUID
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import async_session_maker
from app.services.memory import MemoryScope, save_memory
logger = logging.getLogger(__name__)
def create_update_memory_tool(
user_id: str | UUID,
db_session: AsyncSession,
llm: Any | None = None,
):
"""Factory for the user-memory update tool.
Uses a fresh short-lived session per call so compiled-agent caches never
retain a stale request-scoped session.
"""
del db_session
uid = UUID(user_id) if isinstance(user_id, str) else user_id
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the user's personal memory document.
The current memory is shown in <user_memory>. Pass the FULL updated
markdown document, not a diff.
"""
try:
async with async_session_maker() as db_session:
result = await save_memory(
scope=MemoryScope.USER,
target_id=uid,
content=updated_memory,
session=db_session,
llm=llm,
)
return result.to_dict()
except Exception as e:
logger.exception("Failed to update user memory: %s", e)
return {"status": "error", "message": f"Failed to update memory: {e}"}
return update_memory
def create_update_team_memory_tool(
search_space_id: int,
db_session: AsyncSession,
llm: Any | None = None,
):
"""Factory for the team-memory update tool."""
del db_session
@tool
async def update_memory(updated_memory: str) -> dict[str, Any]:
"""Update the team's shared memory document for this search space.
The current team memory is shown in <team_memory>. Pass the FULL updated
markdown document, not a diff.
"""
try:
async with async_session_maker() as db_session:
result = await save_memory(
scope=MemoryScope.TEAM,
target_id=search_space_id,
content=updated_memory,
session=db_session,
llm=llm,
)
return result.to_dict()
except Exception as e:
logger.exception("Failed to update team memory: %s", e)
return {
"status": "error",
"message": f"Failed to update team memory: {e}",
}
return update_memory
__all__ = [
"create_update_memory_tool",
"create_update_team_memory_tool",
]

View file

@ -0,0 +1,138 @@
"""
Video presentation generation tool for the SurfSense agent.
This module provides a factory function for creating the generate_video_presentation
tool that submits a Celery task for background video presentation generation. The
tool then polls the row until it reaches a terminal status (READY/FAILED) and
returns that status. The wait is bounded by the chat's HTTP / process lifetime;
see app.agents.shared.deliverable_wait for details.
"""
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.shared.deliverable_wait import wait_for_deliverable
from app.db import VideoPresentation, VideoPresentationStatus, shielded_async_session
logger = logging.getLogger(__name__)
def create_generate_video_presentation_tool(
search_space_id: int,
db_session: AsyncSession,
thread_id: int | None = None,
):
"""
Factory function to create the generate_video_presentation tool with injected dependencies.
Pre-creates video presentation record with pending status so the ID is available
immediately for frontend polling. The row is written via a fresh, tool-local
session so parallel tool calls (e.g. video + podcast in the same agent step)
don't share an ``AsyncSession`` (which is not concurrency-safe).
"""
del db_session # writes use a fresh tool-local session, see below
@tool
async def generate_video_presentation(
source_content: str,
video_title: str = "SurfSense Presentation",
user_prompt: str | None = None,
) -> dict[str, Any]:
"""Generate a video presentation from the provided content.
Use this tool when the user asks to create a video, presentation, slides, or slide deck.
Args:
source_content: The text content to turn into a presentation.
video_title: Title for the presentation (default: "SurfSense Presentation")
user_prompt: Optional style/tone instructions.
"""
try:
# See podcast.py for the rationale: parallel tool calls share the
# streaming session, and AsyncSession is not concurrency-safe —
# interleaved flushes produce "Session.add() during flush" and
# poison the transaction for every concurrent tool.
async with shielded_async_session() as session:
video_pres = VideoPresentation(
title=video_title,
status=VideoPresentationStatus.PENDING,
search_space_id=search_space_id,
thread_id=thread_id,
)
session.add(video_pres)
await session.commit()
await session.refresh(video_pres)
video_pres_id = video_pres.id
from app.tasks.celery_tasks.video_presentation_tasks import (
generate_video_presentation_task,
)
task = generate_video_presentation_task.delay(
video_presentation_id=video_pres_id,
source_content=source_content,
search_space_id=search_space_id,
user_prompt=user_prompt,
)
logger.info(
"[generate_video_presentation] Created video presentation %s, task: %s",
video_pres_id,
task.id,
)
# Wait until the Celery worker flips the row to a terminal
# state. No internal budget — see deliverable_wait module.
terminal_status, _columns, elapsed = await wait_for_deliverable(
model=VideoPresentation,
row_id=video_pres_id,
columns=[VideoPresentation.status],
terminal_statuses={
VideoPresentationStatus.READY,
VideoPresentationStatus.FAILED,
},
)
if terminal_status == VideoPresentationStatus.READY:
logger.info(
"[generate_video_presentation] %s READY in %.2fs",
video_pres_id,
elapsed,
)
return {
"status": VideoPresentationStatus.READY.value,
"video_presentation_id": video_pres_id,
"title": video_title,
"message": "Video presentation generated and saved.",
}
# Only other terminal state is FAILED.
logger.warning(
"[generate_video_presentation] %s FAILED in %.2fs",
video_pres_id,
elapsed,
)
return {
"status": VideoPresentationStatus.FAILED.value,
"video_presentation_id": video_pres_id,
"title": video_title,
"error": (
"Background worker reported FAILED status for this "
"video presentation."
),
}
except Exception as e:
error_message = str(e)
logger.exception("[generate_video_presentation] Error: %s", error_message)
return {
"status": VideoPresentationStatus.FAILED.value,
"error": error_message,
"title": video_title,
"video_presentation_id": None,
}
return generate_video_presentation

View file

@ -0,0 +1,247 @@
"""
Web search tool for the SurfSense agent.
Provides a unified tool for real-time web searches that dispatches to all
configured search engines: the platform SearXNG instance (always available)
plus any user-configured live-search connectors (Tavily, Linkup, Baidu).
"""
import asyncio
import json
import time
from typing import Any
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from app.db import shielded_async_session
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
_LIVE_SEARCH_CONNECTORS: set[str] = {
"TAVILY_API",
"LINKUP_API",
"BAIDU_SEARCH_API",
}
_LIVE_CONNECTOR_SPECS: dict[str, tuple[str, bool, bool, dict[str, Any]]] = {
"TAVILY_API": ("search_tavily", False, True, {}),
"LINKUP_API": ("search_linkup", False, False, {"mode": "standard"}),
"BAIDU_SEARCH_API": ("search_baidu", False, True, {}),
}
_CONNECTOR_LABELS: dict[str, str] = {
"TAVILY_API": "Tavily",
"LINKUP_API": "Linkup",
"BAIDU_SEARCH_API": "Baidu",
}
class WebSearchInput(BaseModel):
"""Input schema for the web_search tool."""
query: str = Field(
description="The search query to look up on the web. Use specific, descriptive terms.",
)
top_k: int = Field(
default=10,
description="Number of results to retrieve (default: 10, max: 50).",
)
def _format_web_results(
documents: list[dict[str, Any]],
*,
max_chars: int = 50_000,
) -> str:
"""Format web search results into XML suitable for the LLM context."""
if not documents:
return "No web search results found."
parts: list[str] = []
total_chars = 0
for doc in documents:
doc_info = doc.get("document") or {}
metadata = doc_info.get("metadata") or {}
title = doc_info.get("title") or "Web Result"
url = metadata.get("url") or ""
content = (doc.get("content") or "").strip()
source = metadata.get("document_type") or doc.get("source") or "WEB_SEARCH"
if not content:
continue
metadata_json = json.dumps(metadata, ensure_ascii=False)
doc_xml = "\n".join(
[
"<document>",
"<document_metadata>",
f" <document_type>{source}</document_type>",
f" <title><![CDATA[{title}]]></title>",
f" <url><![CDATA[{url}]]></url>",
f" <metadata_json><![CDATA[{metadata_json}]]></metadata_json>",
"</document_metadata>",
"<document_content>",
f" <chunk id='{url}'><![CDATA[{content}]]></chunk>",
"</document_content>",
"</document>",
"",
]
)
if total_chars + len(doc_xml) > max_chars:
parts.append("<!-- Output truncated to fit context window -->")
break
parts.append(doc_xml)
total_chars += len(doc_xml)
return "\n".join(parts).strip() or "No web search results found."
async def _search_live_connector(
connector: str,
query: str,
search_space_id: int,
top_k: int,
semaphore: asyncio.Semaphore,
) -> list[dict[str, Any]]:
"""Dispatch a single live-search connector (Tavily / Linkup / Baidu)."""
perf = get_perf_logger()
spec = _LIVE_CONNECTOR_SPECS.get(connector)
if spec is None:
return []
method_name, _includes_date_range, includes_top_k, extra_kwargs = spec
kwargs: dict[str, Any] = {
"user_query": query,
"search_space_id": search_space_id,
**extra_kwargs,
}
if includes_top_k:
kwargs["top_k"] = top_k
try:
t0 = time.perf_counter()
async with semaphore, shielded_async_session() as session:
svc = ConnectorService(session, search_space_id)
_, chunks = await getattr(svc, method_name)(**kwargs)
perf.info(
"[web_search] connector=%s results=%d in %.3fs",
connector,
len(chunks),
time.perf_counter() - t0,
)
return chunks
except Exception as e:
perf.warning("[web_search] connector=%s FAILED: %s", connector, e)
return []
def create_web_search_tool(
search_space_id: int | None = None,
available_connectors: list[str] | None = None,
) -> StructuredTool:
"""Factory for the ``web_search`` tool.
Dispatches in parallel to the platform SearXNG instance and any
user-configured live-search connectors (Tavily, Linkup, Baidu).
"""
active_live_connectors: list[str] = []
if available_connectors:
active_live_connectors = [
c for c in available_connectors if c in _LIVE_SEARCH_CONNECTORS
]
engine_names = ["SearXNG (platform default)"]
engine_names.extend(_CONNECTOR_LABELS.get(c, c) for c in active_live_connectors)
engines_summary = ", ".join(engine_names)
description = (
"Search the web for real-time information. "
"Use this for current events, news, prices, weather, public facts, or any "
"question that requires up-to-date information from the internet.\n\n"
f"Active search engines: {engines_summary}.\n"
"All configured engines are queried in parallel and results are merged."
)
_search_space_id = search_space_id
_active_live = active_live_connectors
async def _web_search_impl(query: str, top_k: int = 10) -> str:
from app.services import web_search_service
perf = get_perf_logger()
t0 = time.perf_counter()
clamped_top_k = min(max(1, top_k), 50)
semaphore = asyncio.Semaphore(4)
tasks: list[asyncio.Task[list[dict[str, Any]]]] = []
if web_search_service.is_available():
async def _searxng() -> list[dict[str, Any]]:
async with semaphore:
_result_obj, docs = await web_search_service.search(
query=query,
top_k=clamped_top_k,
)
return docs
tasks.append(asyncio.ensure_future(_searxng()))
if _search_space_id is not None:
for connector in _active_live:
tasks.append(
asyncio.ensure_future(
_search_live_connector(
connector=connector,
query=query,
search_space_id=_search_space_id,
top_k=clamped_top_k,
semaphore=semaphore,
)
)
)
if not tasks:
return "Web search is not available — no search engines are configured."
results_lists = await asyncio.gather(*tasks, return_exceptions=True)
all_documents: list[dict[str, Any]] = []
for result in results_lists:
if isinstance(result, BaseException):
perf.warning("[web_search] a search engine failed: %s", result)
continue
all_documents.extend(result)
seen_urls: set[str] = set()
deduplicated: list[dict[str, Any]] = []
for doc in all_documents:
url = ((doc.get("document") or {}).get("metadata") or {}).get("url", "")
if url and url in seen_urls:
continue
if url:
seen_urls.add(url)
deduplicated.append(doc)
formatted = _format_web_results(deduplicated)
perf.info(
"[web_search] query=%r engines=%d results=%d deduped=%d chars=%d in %.3fs",
query[:60],
len(tasks),
len(all_documents),
len(deduplicated),
len(formatted),
time.perf_counter() - t0,
)
return formatted
return StructuredTool(
name="web_search",
description=description,
coroutine=_web_search_impl,
args_schema=WebSearchInput,
)