mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-21 18:55:16 +02:00
Merge pull request #1294 from CREDO23/feature/mcp-migration
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
[FEAT] Live connector tools via MCP OAuth and native APIs
This commit is contained in:
commit
7245ab4046
65 changed files with 4175 additions and 1463 deletions
|
|
@ -0,0 +1,109 @@
|
|||
"""Connected-accounts discovery tool.
|
||||
|
||||
Lets the LLM discover which accounts are connected for a given service
|
||||
(e.g. "jira", "linear", "slack") and retrieve the metadata it needs to
|
||||
call action tools — such as Jira's ``cloudId``.
|
||||
|
||||
The tool returns **only** non-sensitive fields explicitly listed in the
|
||||
service's ``account_metadata_keys`` (see ``registry.py``), plus the
|
||||
always-present ``display_name`` and ``connector_id``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.services.mcp_oauth.registry import MCP_SERVICES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SERVICE_KEY_BY_CONNECTOR_TYPE: dict[str, str] = {
|
||||
cfg.connector_type: key for key, cfg in MCP_SERVICES.items()
|
||||
}
|
||||
|
||||
|
||||
class GetConnectedAccountsInput(BaseModel):
|
||||
service: str = Field(
|
||||
description=(
|
||||
"Service key to look up connected accounts for. "
|
||||
"Valid values: " + ", ".join(sorted(MCP_SERVICES.keys()))
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _extract_display_name(connector: SearchSourceConnector) -> str:
|
||||
"""Best-effort human-readable label for a connector."""
|
||||
cfg = connector.config or {}
|
||||
if cfg.get("display_name"):
|
||||
return cfg["display_name"]
|
||||
if cfg.get("base_url"):
|
||||
return f"{connector.name} ({cfg['base_url']})"
|
||||
if cfg.get("organization_name"):
|
||||
return f"{connector.name} ({cfg['organization_name']})"
|
||||
return connector.name
|
||||
|
||||
|
||||
def create_get_connected_accounts_tool(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> StructuredTool:
|
||||
|
||||
async def _run(service: str) -> list[dict[str, Any]]:
|
||||
svc_cfg = MCP_SERVICES.get(service)
|
||||
if not svc_cfg:
|
||||
return [{"error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}"}]
|
||||
|
||||
try:
|
||||
connector_type = SearchSourceConnectorType(svc_cfg.connector_type)
|
||||
except ValueError:
|
||||
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == connector_type,
|
||||
)
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
|
||||
if not connectors:
|
||||
return [{"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."}]
|
||||
|
||||
is_multi = len(connectors) > 1
|
||||
|
||||
accounts: list[dict[str, Any]] = []
|
||||
for conn in connectors:
|
||||
cfg = conn.config or {}
|
||||
entry: dict[str, Any] = {
|
||||
"connector_id": conn.id,
|
||||
"display_name": _extract_display_name(conn),
|
||||
"service": service,
|
||||
}
|
||||
if is_multi:
|
||||
entry["tool_prefix"] = f"{service}_{conn.id}"
|
||||
for key in svc_cfg.account_metadata_keys:
|
||||
if key in cfg:
|
||||
entry[key] = cfg[key]
|
||||
accounts.append(entry)
|
||||
|
||||
return accounts
|
||||
|
||||
return StructuredTool(
|
||||
name="get_connected_accounts",
|
||||
description=(
|
||||
"Discover which accounts are connected for a service (e.g. jira, linear, slack, clickup, airtable). "
|
||||
"Returns display names and service-specific metadata the action tools need "
|
||||
"(e.g. Jira's cloudId). Call this BEFORE using a service's action tools when "
|
||||
"you need an account identifier or are unsure which account to use."
|
||||
),
|
||||
coroutine=_run,
|
||||
args_schema=GetConnectedAccountsInput,
|
||||
metadata={"hitl": False},
|
||||
)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.discord.list_channels import (
|
||||
create_list_discord_channels_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.discord.read_messages import (
|
||||
create_read_discord_messages_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.discord.send_message import (
|
||||
create_send_discord_message_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_list_discord_channels_tool",
|
||||
"create_read_discord_messages_tool",
|
||||
"create_send_discord_message_tool",
|
||||
]
|
||||
42
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
42
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Shared auth helper for Discord agent tools (REST API, not gateway bot)."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
DISCORD_API = "https://discord.com/api/v10"
|
||||
|
||||
|
||||
async def get_discord_connector(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> SearchSourceConnector | None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
def get_bot_token(connector: SearchSourceConnector) -> str:
|
||||
"""Extract and decrypt the bot token from connector config."""
|
||||
cfg = dict(connector.config)
|
||||
if cfg.get("_token_encrypted") and config.SECRET_KEY:
|
||||
enc = TokenEncryption(config.SECRET_KEY)
|
||||
if cfg.get("bot_token"):
|
||||
cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"])
|
||||
token = cfg.get("bot_token")
|
||||
if not token:
|
||||
raise ValueError("Discord bot token not found in connector config.")
|
||||
return token
|
||||
|
||||
|
||||
def get_guild_id(connector: SearchSourceConnector) -> str | None:
|
||||
return connector.config.get("guild_id")
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_list_discord_channels_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def list_discord_channels() -> dict[str, Any]:
|
||||
"""List text channels in the connected Discord server.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of channels (id, name).
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
guild_id = get_guild_id(connector)
|
||||
if not guild_id:
|
||||
return {"status": "error", "message": "No guild ID in Discord connector config."}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
# Type 0 = text channel
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch["name"]}
|
||||
for ch in resp.json()
|
||||
if ch.get("type") == 0
|
||||
]
|
||||
return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error listing Discord channels: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to list Discord channels."}
|
||||
|
||||
return list_discord_channels
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_read_discord_messages_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_discord_messages(
|
||||
channel_id: str,
|
||||
limit: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Read recent messages from a Discord text channel.
|
||||
|
||||
Args:
|
||||
channel_id: The Discord channel ID (from list_discord_channels).
|
||||
limit: Number of messages to fetch (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of messages including
|
||||
id, author, content, timestamp.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
limit = min(limit, 50)
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
params={"limit": limit},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Bot lacks permission to read this channel."}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"id": m["id"],
|
||||
"author": m.get("author", {}).get("username", "Unknown"),
|
||||
"content": m.get("content", ""),
|
||||
"timestamp": m.get("timestamp", ""),
|
||||
}
|
||||
for m in resp.json()
|
||||
]
|
||||
|
||||
return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Discord messages: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read Discord messages."}
|
||||
|
||||
return read_discord_messages
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_discord_message_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_discord_message(
|
||||
channel_id: str,
|
||||
content: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a message to a Discord text channel.
|
||||
|
||||
Args:
|
||||
channel_id: The Discord channel ID (from list_discord_channels).
|
||||
content: The message text (max 2000 characters).
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message_id on success.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
if len(content) > 2000:
|
||||
return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="discord_send_message",
|
||||
tool_name="send_discord_message",
|
||||
params={"channel_id": channel_id, "content": content},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined. Message was not sent."}
|
||||
|
||||
final_content = result.params.get("content", content)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{DISCORD_API}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bot {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"content": final_content},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Bot lacks permission to send messages in this channel."}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": f"Message sent to channel {final_channel}.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error sending Discord message: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to send Discord message."}
|
||||
|
||||
return send_discord_message
|
||||
|
|
@ -1,6 +1,12 @@
|
|||
from app.agents.new_chat.tools.gmail.create_draft import (
|
||||
create_create_gmail_draft_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.read_email import (
|
||||
create_read_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
create_search_gmail_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.send_email import (
|
||||
create_send_gmail_email_tool,
|
||||
)
|
||||
|
|
@ -13,6 +19,8 @@ from app.agents.new_chat.tools.gmail.update_draft import (
|
|||
|
||||
__all__ = [
|
||||
"create_create_gmail_draft_tool",
|
||||
"create_read_gmail_email_tool",
|
||||
"create_search_gmail_tool",
|
||||
"create_send_gmail_email_tool",
|
||||
"create_trash_gmail_email_tool",
|
||||
"create_update_gmail_draft_tool",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GMAIL_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
|
||||
def create_read_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_gmail_email(message_id: str) -> dict[str, Any]:
|
||||
"""Read the full content of a specific Gmail email by its message ID.
|
||||
|
||||
Use after search_gmail to get the complete body of an email.
|
||||
|
||||
Args:
|
||||
message_id: The Gmail message ID (from search_gmail results).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and the full email content formatted as markdown.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
detail, error = await gmail.get_message_details(message_id)
|
||||
if error:
|
||||
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not detail:
|
||||
return {"status": "not_found", "message": f"Email with ID '{message_id}' not found."}
|
||||
|
||||
content = gmail.format_message_to_markdown(detail)
|
||||
|
||||
return {"status": "success", "message_id": message_id, "content": content}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Gmail email: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read email. Please try again."}
|
||||
|
||||
return read_gmail_email
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GMAIL_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
_token_encryption_cache: object | None = None
|
||||
|
||||
|
||||
def _get_token_encryption():
|
||||
global _token_encryption_cache
|
||||
if _token_encryption_cache is None:
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise RuntimeError("SECRET_KEY not configured for token decryption.")
|
||||
_token_encryption_cache = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption_cache
|
||||
|
||||
|
||||
def _build_credentials(connector: SearchSourceConnector):
|
||||
"""Build Google OAuth Credentials from a connector's stored config.
|
||||
|
||||
Handles both native OAuth connectors (with encrypted tokens) and
|
||||
Composio-backed connectors. Shared by Gmail and Calendar tools.
|
||||
"""
|
||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected account ID not found.")
|
||||
return build_composio_credentials(cca_id)
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
cfg = dict(connector.config)
|
||||
if cfg.get("_token_encrypted"):
|
||||
enc = _get_token_encryption()
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if cfg.get(key):
|
||||
cfg[key] = enc.decrypt_token(cfg[key])
|
||||
|
||||
exp = (cfg.get("expiry") or "").replace("Z", "")
|
||||
return Credentials(
|
||||
token=cfg.get("token"),
|
||||
refresh_token=cfg.get("refresh_token"),
|
||||
token_uri=cfg.get("token_uri"),
|
||||
client_id=cfg.get("client_id"),
|
||||
client_secret=cfg.get("client_secret"),
|
||||
scopes=cfg.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
|
||||
def create_search_gmail_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def search_gmail(
|
||||
query: str,
|
||||
max_results: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Search emails in the user's Gmail inbox using Gmail search syntax.
|
||||
|
||||
Args:
|
||||
query: Gmail search query, same syntax as the Gmail search bar.
|
||||
Examples: "from:alice@example.com", "subject:meeting",
|
||||
"is:unread", "after:2024/01/01 before:2024/02/01",
|
||||
"has:attachment", "in:sent".
|
||||
max_results: Number of emails to return (default 10, max 20).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of email summaries including
|
||||
message_id, subject, from, date, snippet.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 20)
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
messages_list, error = await gmail.get_messages_list(
|
||||
max_results=max_results, query=query
|
||||
)
|
||||
if error:
|
||||
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not messages_list:
|
||||
return {"status": "success", "emails": [], "total": 0, "message": "No emails found."}
|
||||
|
||||
emails = []
|
||||
for msg in messages_list:
|
||||
detail, err = await gmail.get_message_details(msg["id"])
|
||||
if err:
|
||||
continue
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in detail.get("payload", {}).get("headers", [])
|
||||
}
|
||||
emails.append({
|
||||
"message_id": detail.get("id"),
|
||||
"thread_id": detail.get("threadId"),
|
||||
"subject": headers.get("subject", "No Subject"),
|
||||
"from": headers.get("from", "Unknown"),
|
||||
"to": headers.get("to", ""),
|
||||
"date": headers.get("date", ""),
|
||||
"snippet": detail.get("snippet", ""),
|
||||
"labels": detail.get("labelIds", []),
|
||||
})
|
||||
|
||||
return {"status": "success", "emails": emails, "total": len(emails)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error searching Gmail: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to search Gmail. Please try again."}
|
||||
|
||||
return search_gmail
|
||||
|
|
@ -4,6 +4,9 @@ from app.agents.new_chat.tools.google_calendar.create_event import (
|
|||
from app.agents.new_chat.tools.google_calendar.delete_event import (
|
||||
create_delete_calendar_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.search_events import (
|
||||
create_search_calendar_events_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.update_event import (
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
|
|
@ -11,5 +14,6 @@ from app.agents.new_chat.tools.google_calendar.update_event import (
|
|||
__all__ = [
|
||||
"create_create_calendar_event_tool",
|
||||
"create_delete_calendar_event_tool",
|
||||
"create_search_calendar_events_tool",
|
||||
"create_update_calendar_event_tool",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,114 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CALENDAR_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
|
||||
def create_search_calendar_events_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def search_calendar_events(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
max_results: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Search Google Calendar events within a date range.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format (e.g. "2026-04-01").
|
||||
end_date: End date in YYYY-MM-DD format (e.g. "2026-04-30").
|
||||
max_results: Maximum number of events to return (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of events including
|
||||
event_id, summary, start, end, location, attendees.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Calendar tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 50)
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||
|
||||
cal = GoogleCalendarConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
if error:
|
||||
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||
return {"status": "auth_error", "message": error, "connector_type": "google_calendar"}
|
||||
if "no events found" in error.lower():
|
||||
return {"status": "success", "events": [], "total": 0, "message": error}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
events = []
|
||||
for ev in events_raw:
|
||||
start = ev.get("start", {})
|
||||
end = ev.get("end", {})
|
||||
attendees_raw = ev.get("attendees", [])
|
||||
events.append({
|
||||
"event_id": ev.get("id"),
|
||||
"summary": ev.get("summary", "No Title"),
|
||||
"start": start.get("dateTime") or start.get("date", ""),
|
||||
"end": end.get("dateTime") or end.get("date", ""),
|
||||
"location": ev.get("location", ""),
|
||||
"description": ev.get("description", ""),
|
||||
"html_link": ev.get("htmlLink", ""),
|
||||
"attendees": [
|
||||
a.get("email", "") for a in attendees_raw[:10]
|
||||
],
|
||||
"status": ev.get("status", ""),
|
||||
})
|
||||
|
||||
return {"status": "success", "events": events, "total": len(events)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error searching calendar events: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to search calendar events. Please try again."}
|
||||
|
||||
return search_calendar_events
|
||||
|
|
@ -130,8 +130,8 @@ def request_approval(
|
|||
try:
|
||||
decision_type, edited_params = _parse_decision(approval)
|
||||
except ValueError:
|
||||
logger.warning("No approval decision received for %s", tool_name)
|
||||
return HITLResult(rejected=False, decision_type="error", params=params)
|
||||
logger.warning("No approval decision received for %s — rejecting for safety", tool_name)
|
||||
return HITLResult(rejected=True, decision_type="error", params=params)
|
||||
|
||||
logger.info("User decision for %s: %s", tool_name, decision_type)
|
||||
|
||||
|
|
|
|||
15
surfsense_backend/app/agents/new_chat/tools/luma/__init__.py
Normal file
15
surfsense_backend/app/agents/new_chat/tools/luma/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.luma.create_event import (
|
||||
create_create_luma_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.luma.list_events import (
|
||||
create_list_luma_events_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.luma.read_event import (
|
||||
create_read_luma_event_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_luma_event_tool",
|
||||
"create_list_luma_events_tool",
|
||||
"create_read_luma_event_tool",
|
||||
]
|
||||
38
surfsense_backend/app/agents/new_chat/tools/luma/_auth.py
Normal file
38
surfsense_backend/app/agents/new_chat/tools/luma/_auth.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
"""Shared auth helper for Luma agent tools."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
LUMA_API = "https://public-api.luma.com/v1"
|
||||
|
||||
|
||||
async def get_luma_connector(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> SearchSourceConnector | None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
def get_api_key(connector: SearchSourceConnector) -> str:
|
||||
"""Extract the API key from connector config (handles both key names)."""
|
||||
key = connector.config.get("api_key") or connector.config.get("LUMA_API_KEY")
|
||||
if not key:
|
||||
raise ValueError("Luma API key not found in connector config.")
|
||||
return key
|
||||
|
||||
|
||||
def luma_headers(api_key: str) -> dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"x-luma-api-key": api_key,
|
||||
}
|
||||
116
surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
Normal file
116
surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_luma_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_luma_event(
|
||||
name: str,
|
||||
start_at: str,
|
||||
end_at: str,
|
||||
description: str | None = None,
|
||||
timezone: str = "UTC",
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new event on Luma.
|
||||
|
||||
Args:
|
||||
name: The event title.
|
||||
start_at: Start time in ISO 8601 format (e.g. "2026-05-01T18:00:00").
|
||||
end_at: End time in ISO 8601 format (e.g. "2026-05-01T20:00:00").
|
||||
description: Optional event description (markdown supported).
|
||||
timezone: Timezone string (default "UTC", e.g. "America/New_York").
|
||||
|
||||
Returns:
|
||||
Dictionary with status, event_id on success.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="luma_create_event",
|
||||
tool_name="create_luma_event",
|
||||
params={
|
||||
"name": name,
|
||||
"start_at": start_at,
|
||||
"end_at": end_at,
|
||||
"description": description,
|
||||
"timezone": timezone,
|
||||
},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined. Event was not created."}
|
||||
|
||||
final_name = result.params.get("name", name)
|
||||
final_start = result.params.get("start_at", start_at)
|
||||
final_end = result.params.get("end_at", end_at)
|
||||
final_desc = result.params.get("description", description)
|
||||
final_tz = result.params.get("timezone", timezone)
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"name": final_name,
|
||||
"start_at": final_start,
|
||||
"end_at": final_end,
|
||||
"timezone": final_tz,
|
||||
}
|
||||
if final_desc:
|
||||
body["description_md"] = final_desc
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(
|
||||
f"{LUMA_API}/event/create",
|
||||
headers=headers,
|
||||
json=body,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Luma Plus subscription required to create events via API."}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {"status": "error", "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}"}
|
||||
|
||||
data = resp.json()
|
||||
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": event_id,
|
||||
"message": f"Event '{final_name}' created on Luma.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error creating Luma event: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to create Luma event."}
|
||||
|
||||
return create_luma_event
|
||||
100
surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
Normal file
100
surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_list_luma_events_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def list_luma_events(
|
||||
max_results: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""List upcoming and recent Luma events.
|
||||
|
||||
Args:
|
||||
max_results: Maximum events to return (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of events including
|
||||
event_id, name, start_at, end_at, location, url.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 50)
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
all_entries: list[dict] = []
|
||||
cursor = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
while len(all_entries) < max_results:
|
||||
params: dict[str, Any] = {"limit": min(100, max_results - len(all_entries))}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
resp = await client.get(
|
||||
f"{LUMA_API}/calendar/list-events",
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
|
||||
|
||||
data = resp.json()
|
||||
entries = data.get("entries", [])
|
||||
if not entries:
|
||||
break
|
||||
all_entries.extend(entries)
|
||||
|
||||
next_cursor = data.get("next_cursor")
|
||||
if not next_cursor:
|
||||
break
|
||||
cursor = next_cursor
|
||||
|
||||
events = []
|
||||
for entry in all_entries[:max_results]:
|
||||
ev = entry.get("event", {})
|
||||
geo = ev.get("geo_info", {})
|
||||
events.append({
|
||||
"event_id": entry.get("api_id"),
|
||||
"name": ev.get("name", "Untitled"),
|
||||
"start_at": ev.get("start_at", ""),
|
||||
"end_at": ev.get("end_at", ""),
|
||||
"timezone": ev.get("timezone", ""),
|
||||
"location": geo.get("name", ""),
|
||||
"url": ev.get("url", ""),
|
||||
"visibility": ev.get("visibility", ""),
|
||||
})
|
||||
|
||||
return {"status": "success", "events": events, "total": len(events)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error listing Luma events: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to list Luma events."}
|
||||
|
||||
return list_luma_events
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_read_luma_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_luma_event(event_id: str) -> dict[str, Any]:
|
||||
"""Read detailed information about a specific Luma event.
|
||||
|
||||
Args:
|
||||
event_id: The Luma event API ID (from list_luma_events).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and full event details including
|
||||
description, attendees count, meeting URL.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
f"{LUMA_API}/events/{event_id}",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||
if resp.status_code == 404:
|
||||
return {"status": "not_found", "message": f"Event '{event_id}' not found."}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
|
||||
|
||||
data = resp.json()
|
||||
ev = data.get("event", data)
|
||||
geo = ev.get("geo_info", {})
|
||||
|
||||
event_detail = {
|
||||
"event_id": event_id,
|
||||
"name": ev.get("name", ""),
|
||||
"description": ev.get("description", ""),
|
||||
"start_at": ev.get("start_at", ""),
|
||||
"end_at": ev.get("end_at", ""),
|
||||
"timezone": ev.get("timezone", ""),
|
||||
"location_name": geo.get("name", ""),
|
||||
"address": geo.get("address", ""),
|
||||
"url": ev.get("url", ""),
|
||||
"meeting_url": ev.get("meeting_url", ""),
|
||||
"visibility": ev.get("visibility", ""),
|
||||
"cover_url": ev.get("cover_url", ""),
|
||||
}
|
||||
|
||||
return {"status": "success", "event": event_detail}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Luma event: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read Luma event."}
|
||||
|
||||
return read_luma_event
|
||||
|
|
@ -14,20 +14,28 @@ clicking "Always Allow", which adds the tool name to the connector's
|
|||
``config.trusted_tools`` allow-list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import BaseModel, create_model
|
||||
from sqlalchemy import select
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from sqlalchemy import cast, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -63,18 +71,14 @@ def _create_dynamic_input_model_from_schema(
|
|||
param_description = param_schema.get("description", "")
|
||||
is_required = param_name in required_fields
|
||||
|
||||
from typing import Any as AnyType
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
if is_required:
|
||||
field_definitions[param_name] = (
|
||||
AnyType,
|
||||
Any,
|
||||
Field(..., description=param_description),
|
||||
)
|
||||
else:
|
||||
field_definitions[param_name] = (
|
||||
AnyType | None,
|
||||
Any | None,
|
||||
Field(None, description=param_description),
|
||||
)
|
||||
|
||||
|
|
@ -100,13 +104,13 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
|
||||
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
|
||||
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via the client with retry support."""
|
||||
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||
logger.debug("MCP tool '%s' called", tool_name)
|
||||
|
||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
||||
hitl_result = request_approval(
|
||||
|
|
@ -123,20 +127,18 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = hitl_result.params
|
||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
|
||||
try:
|
||||
async with mcp_client.connect():
|
||||
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
||||
return str(result)
|
||||
except RuntimeError as e:
|
||||
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e)
|
||||
return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||
except Exception as e:
|
||||
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
logger.exception("MCP tool '%s' execution failed: %s", tool_name, e)
|
||||
return f"Error: MCP tool '{tool_name}' execution failed: {e!s}"
|
||||
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
|
|
@ -151,7 +153,7 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool (stdio): '{tool_name}'")
|
||||
logger.debug("Created MCP tool (stdio): '%s'", tool_name)
|
||||
return tool
|
||||
|
||||
|
||||
|
|
@ -163,41 +165,57 @@ async def _create_mcp_tool_from_definition_http(
|
|||
connector_name: str = "",
|
||||
connector_id: int | None = None,
|
||||
trusted_tools: list[str] | None = None,
|
||||
readonly_tools: frozenset[str] | None = None,
|
||||
tool_name_prefix: str | None = None,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||
|
||||
All MCP tools are unconditionally wrapped with HITL approval.
|
||||
``request_approval()`` is called OUTSIDE the try/except so that
|
||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
||||
Write tools are wrapped with HITL approval; read-only tools (listed in
|
||||
``readonly_tools``) execute immediately without user confirmation.
|
||||
|
||||
When ``tool_name_prefix`` is set (multi-account disambiguation), the
|
||||
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
|
||||
but the actual MCP ``call_tool`` still uses the original name.
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
original_tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
|
||||
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}")
|
||||
exposed_name = (
|
||||
f"{tool_name_prefix}_{original_tool_name}"
|
||||
if tool_name_prefix
|
||||
else original_tool_name
|
||||
)
|
||||
if tool_name_prefix:
|
||||
tool_description = f"[Account: {connector_name}] {tool_description}"
|
||||
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
|
||||
|
||||
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
|
||||
|
||||
async def mcp_http_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via HTTP transport."""
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}")
|
||||
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
||||
|
||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name=tool_name,
|
||||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": tool_description,
|
||||
"mcp_transport": "http",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = hitl_result.params
|
||||
if is_readonly:
|
||||
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
else:
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name=exposed_name,
|
||||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
"tool_description": tool_description,
|
||||
"mcp_transport": "http",
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
trusted_tools=trusted_tools,
|
||||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
|
||||
try:
|
||||
async with (
|
||||
|
|
@ -205,7 +223,9 @@ async def _create_mcp_tool_from_definition_http(
|
|||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
response = await session.call_tool(tool_name, arguments=call_kwargs)
|
||||
response = await session.call_tool(
|
||||
original_tool_name, arguments=call_kwargs,
|
||||
)
|
||||
|
||||
result = []
|
||||
for content in response.content:
|
||||
|
|
@ -217,18 +237,15 @@ async def _create_mcp_tool_from_definition_http(
|
|||
result.append(str(content))
|
||||
|
||||
result_str = "\n".join(result) if result else ""
|
||||
logger.info(
|
||||
f"MCP HTTP tool '{tool_name}' succeeded: {result_str[:200]}"
|
||||
)
|
||||
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
|
||||
return result_str
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"MCP HTTP tool '{tool_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e)
|
||||
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
|
||||
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
name=exposed_name,
|
||||
description=tool_description,
|
||||
coroutine=mcp_http_tool_call,
|
||||
args_schema=input_model,
|
||||
|
|
@ -236,12 +253,14 @@ async def _create_mcp_tool_from_definition_http(
|
|||
"mcp_input_schema": input_schema,
|
||||
"mcp_transport": "http",
|
||||
"mcp_url": url,
|
||||
"hitl": True,
|
||||
"hitl": not is_readonly,
|
||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||
"mcp_original_tool_name": original_tool_name,
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool (HTTP): '{tool_name}'")
|
||||
logger.debug("Created MCP tool (HTTP): '%s'", exposed_name)
|
||||
return tool
|
||||
|
||||
|
||||
|
|
@ -257,21 +276,24 @@ async def _load_stdio_mcp_tools(
|
|||
command = server_config.get("command")
|
||||
if not command or not isinstance(command, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid command field, skipping"
|
||||
"MCP connector %d (name: '%s') missing or invalid command field, skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
args = server_config.get("args", [])
|
||||
if not isinstance(args, list):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid args field (must be list), skipping"
|
||||
"MCP connector %d (name: '%s') has invalid args field (must be list), skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
env = server_config.get("env", {})
|
||||
if not isinstance(env, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid env field (must be dict), skipping"
|
||||
"MCP connector %d (name: '%s') has invalid env field (must be dict), skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
|
|
@ -281,8 +303,8 @@ async def _load_stdio_mcp_tools(
|
|||
tool_definitions = await mcp_client.list_tools()
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from stdio MCP server "
|
||||
f"'{command}' (connector {connector_id})"
|
||||
"Discovered %d tools from stdio MCP server '%s' (connector %d)",
|
||||
len(tool_definitions), command, connector_id,
|
||||
)
|
||||
|
||||
for tool_def in tool_definitions:
|
||||
|
|
@ -297,8 +319,8 @@ async def _load_stdio_mcp_tools(
|
|||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector_id}: {e!s}"
|
||||
"Failed to create tool '%s' from connector %d: %s",
|
||||
tool_def.get("name"), connector_id, e,
|
||||
)
|
||||
|
||||
return tools
|
||||
|
|
@ -309,24 +331,40 @@ async def _load_http_mcp_tools(
|
|||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
trusted_tools: list[str] | None = None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
readonly_tools: frozenset[str] | None = None,
|
||||
tool_name_prefix: str | None = None,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from an HTTP-based MCP server."""
|
||||
"""Load tools from an HTTP-based MCP server.
|
||||
|
||||
Args:
|
||||
allowed_tools: If non-empty, only tools whose names appear in this
|
||||
list are loaded. Empty/None means load everything (used for
|
||||
user-managed generic MCP servers).
|
||||
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
||||
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
||||
disambiguation (e.g. ``linear_25``).
|
||||
"""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
url = server_config.get("url")
|
||||
if not url or not isinstance(url, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid url field, skipping"
|
||||
"MCP connector %d (name: '%s') missing or invalid url field, skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
headers = server_config.get("headers", {})
|
||||
if not isinstance(headers, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid headers field (must be dict), skipping"
|
||||
"MCP connector %d (name: '%s') has invalid headers field (must be dict), skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
allowed_set = set(allowed_tools) if allowed_tools else None
|
||||
|
||||
try:
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||
|
|
@ -347,10 +385,21 @@ async def _load_http_mcp_tools(
|
|||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from HTTP MCP server "
|
||||
f"'{url}' (connector {connector_id})"
|
||||
)
|
||||
total_discovered = len(tool_definitions)
|
||||
|
||||
if allowed_set:
|
||||
tool_definitions = [
|
||||
td for td in tool_definitions if td["name"] in allowed_set
|
||||
]
|
||||
logger.info(
|
||||
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
|
||||
url, connector_id, len(tool_definitions), total_discovered,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
|
||||
total_discovered, url, connector_id,
|
||||
)
|
||||
|
||||
for tool_def in tool_definitions:
|
||||
try:
|
||||
|
|
@ -361,22 +410,183 @@ async def _load_http_mcp_tools(
|
|||
connector_name=connector_name,
|
||||
connector_id=connector_id,
|
||||
trusted_tools=trusted_tools,
|
||||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create HTTP tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector_id}: {e!s}"
|
||||
"Failed to create HTTP tool '%s' from connector %d: %s",
|
||||
tool_def.get("name"), connector_id, e,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to connect to HTTP MCP server at '{url}' (connector {connector_id}): {e!s}"
|
||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||
url, connector_id, e,
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
_TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry
|
||||
|
||||
_token_enc: TokenEncryption | None = None
|
||||
|
||||
|
||||
def _get_token_enc() -> TokenEncryption:
|
||||
global _token_enc
|
||||
if _token_enc is None:
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
_token_enc = TokenEncryption(app_config.SECRET_KEY)
|
||||
return _token_enc
|
||||
|
||||
|
||||
def _inject_oauth_headers(
|
||||
cfg: dict[str, Any],
|
||||
server_config: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Decrypt the MCP OAuth access token and inject it into server_config headers.
|
||||
|
||||
The DB never stores plaintext tokens in ``server_config.headers``. This
|
||||
function decrypts ``mcp_oauth.access_token`` at runtime and returns a
|
||||
*copy* of ``server_config`` with the Authorization header set.
|
||||
"""
|
||||
mcp_oauth = cfg.get("mcp_oauth", {})
|
||||
encrypted_token = mcp_oauth.get("access_token")
|
||||
if not encrypted_token:
|
||||
return server_config
|
||||
|
||||
try:
|
||||
access_token = _get_token_enc().decrypt_token(encrypted_token)
|
||||
|
||||
result = dict(server_config)
|
||||
result["headers"] = {
|
||||
**server_config.get("headers", {}),
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
return result
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Failed to decrypt MCP OAuth token — connector will be skipped",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _maybe_refresh_mcp_oauth_token(
|
||||
session: AsyncSession,
|
||||
connector: "SearchSourceConnector",
|
||||
cfg: dict[str, Any],
|
||||
server_config: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Refresh the access token for an MCP OAuth connector if it is about to expire.
|
||||
|
||||
Returns the (possibly updated) ``server_config``.
|
||||
"""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
mcp_oauth = cfg.get("mcp_oauth", {})
|
||||
expires_at_str = mcp_oauth.get("expires_at")
|
||||
if not expires_at_str:
|
||||
return server_config
|
||||
|
||||
try:
|
||||
expires_at = datetime.fromisoformat(expires_at_str)
|
||||
if expires_at.tzinfo is None:
|
||||
from datetime import timezone
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
if datetime.now(UTC) < expires_at - timedelta(seconds=_TOKEN_REFRESH_BUFFER_SECONDS):
|
||||
return server_config
|
||||
except (ValueError, TypeError):
|
||||
return server_config
|
||||
|
||||
refresh_token = mcp_oauth.get("refresh_token")
|
||||
if not refresh_token:
|
||||
logger.warning(
|
||||
"MCP connector %s token expired but no refresh_token available",
|
||||
connector.id,
|
||||
)
|
||||
return server_config
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import refresh_access_token
|
||||
|
||||
enc = _get_token_enc()
|
||||
decrypted_refresh = enc.decrypt_token(refresh_token)
|
||||
decrypted_secret = (
|
||||
enc.decrypt_token(mcp_oauth["client_secret"])
|
||||
if mcp_oauth.get("client_secret")
|
||||
else ""
|
||||
)
|
||||
|
||||
token_json = await refresh_access_token(
|
||||
token_endpoint=mcp_oauth["token_endpoint"],
|
||||
refresh_token=decrypted_refresh,
|
||||
client_id=mcp_oauth["client_id"],
|
||||
client_secret=decrypted_secret,
|
||||
)
|
||||
|
||||
new_access = token_json.get("access_token")
|
||||
if not new_access:
|
||||
logger.warning(
|
||||
"MCP connector %s token refresh returned no access_token",
|
||||
connector.id,
|
||||
)
|
||||
return server_config
|
||||
|
||||
new_expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
new_expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
updated_oauth = dict(mcp_oauth)
|
||||
updated_oauth["access_token"] = enc.encrypt_token(new_access)
|
||||
if token_json.get("refresh_token"):
|
||||
updated_oauth["refresh_token"] = enc.encrypt_token(
|
||||
token_json["refresh_token"]
|
||||
)
|
||||
updated_oauth["expires_at"] = (
|
||||
new_expires_at.isoformat() if new_expires_at else None
|
||||
)
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
connector.config = {
|
||||
**cfg,
|
||||
"server_config": server_config,
|
||||
"mcp_oauth": updated_oauth,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
|
||||
|
||||
# Invalidate cache so next call picks up the new token.
|
||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||
|
||||
# Return server_config with the fresh token injected for immediate use.
|
||||
refreshed_config = dict(server_config)
|
||||
refreshed_config["headers"] = {
|
||||
**server_config.get("headers", {}),
|
||||
"Authorization": f"Bearer {new_access}",
|
||||
}
|
||||
return refreshed_config
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh MCP OAuth token for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return server_config
|
||||
|
||||
|
||||
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||
"""Invalidate cached MCP tools.
|
||||
|
||||
|
|
@ -418,27 +628,91 @@ async def load_mcp_tools(
|
|||
return list(cached_tools)
|
||||
|
||||
try:
|
||||
# Find all connectors with MCP server config: generic MCP_CONNECTOR type
|
||||
# and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth.
|
||||
# Cast JSON -> JSONB so we can use has_key to filter by the presence of "server_config".
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
cast(SearchSourceConnector.config, JSONB).has_key("server_config"), # noqa: W601
|
||||
),
|
||||
)
|
||||
|
||||
connectors = list(result.scalars())
|
||||
|
||||
# Group connectors by type to detect multi-account scenarios.
|
||||
# When >1 connector shares the same type, tool names would collide
|
||||
# so we prefix them with "{service_key}_{connector_id}_".
|
||||
type_groups: dict[str, list[SearchSourceConnector]] = defaultdict(list)
|
||||
for connector in connectors:
|
||||
ct = (
|
||||
connector.connector_type.value
|
||||
if hasattr(connector.connector_type, "value")
|
||||
else str(connector.connector_type)
|
||||
)
|
||||
type_groups[ct].append(connector)
|
||||
|
||||
multi_account_types: set[str] = {
|
||||
ct for ct, group in type_groups.items() if len(group) > 1
|
||||
}
|
||||
if multi_account_types:
|
||||
logger.info(
|
||||
"Multi-account detected for connector types: %s",
|
||||
multi_account_types,
|
||||
)
|
||||
|
||||
tools: list[StructuredTool] = []
|
||||
for connector in result.scalars():
|
||||
for connector in connectors:
|
||||
try:
|
||||
config = connector.config or {}
|
||||
server_config = config.get("server_config", {})
|
||||
trusted_tools = config.get("trusted_tools", [])
|
||||
cfg = connector.config or {}
|
||||
server_config = cfg.get("server_config", {})
|
||||
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
|
||||
"MCP connector %d (name: '%s') has invalid or missing server_config, skipping",
|
||||
connector.id, connector.name,
|
||||
)
|
||||
continue
|
||||
|
||||
# For MCP OAuth connectors: refresh if needed, then decrypt the
|
||||
# access token and inject it into headers at runtime. The DB
|
||||
# intentionally does NOT store plaintext tokens in server_config.
|
||||
if cfg.get("mcp_oauth"):
|
||||
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||
session, connector, cfg, server_config,
|
||||
)
|
||||
# Re-read cfg after potential refresh (connector was reloaded from DB).
|
||||
cfg = connector.config or {}
|
||||
server_config = _inject_oauth_headers(cfg, server_config)
|
||||
if server_config is None:
|
||||
logger.warning(
|
||||
"Skipping MCP connector %d — OAuth token decryption failed",
|
||||
connector.id,
|
||||
)
|
||||
continue
|
||||
|
||||
trusted_tools = cfg.get("trusted_tools", [])
|
||||
|
||||
ct = (
|
||||
connector.connector_type.value
|
||||
if hasattr(connector.connector_type, "value")
|
||||
else str(connector.connector_type)
|
||||
)
|
||||
|
||||
svc_cfg = get_service_by_connector_type(ct)
|
||||
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||
|
||||
# Build a prefix only when multiple accounts share the same type.
|
||||
tool_name_prefix: str | None = None
|
||||
if ct in multi_account_types and svc_cfg:
|
||||
service_key = next(
|
||||
(k for k, v in MCP_SERVICES.items() if v is svc_cfg),
|
||||
None,
|
||||
)
|
||||
if service_key:
|
||||
tool_name_prefix = f"{service_key}_{connector.id}"
|
||||
|
||||
transport = server_config.get("transport", "stdio")
|
||||
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
|
|
@ -447,6 +721,9 @@ async def load_mcp_tools(
|
|||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=trusted_tools,
|
||||
allowed_tools=allowed_tools,
|
||||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
)
|
||||
else:
|
||||
connector_tools = await _load_stdio_mcp_tools(
|
||||
|
|
@ -460,7 +737,8 @@ async def load_mcp_tools(
|
|||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
|
||||
"Failed to load tools from MCP connector %d: %s",
|
||||
connector.id, e,
|
||||
)
|
||||
|
||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||
|
|
@ -469,9 +747,9 @@ async def load_mcp_tools(
|
|||
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
|
||||
del _mcp_tools_cache[oldest_key]
|
||||
|
||||
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||
logger.info("Loaded %d MCP tools for search space %d", len(tools), search_space_id)
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to load MCP tools: {e!s}")
|
||||
logger.exception("Failed to load MCP tools: %s", e)
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -50,6 +50,11 @@ from .confluence import (
|
|||
create_delete_confluence_page_tool,
|
||||
create_update_confluence_page_tool,
|
||||
)
|
||||
from .discord import (
|
||||
create_list_discord_channels_tool,
|
||||
create_read_discord_messages_tool,
|
||||
create_send_discord_message_tool,
|
||||
)
|
||||
from .dropbox import (
|
||||
create_create_dropbox_file_tool,
|
||||
create_delete_dropbox_file_tool,
|
||||
|
|
@ -57,6 +62,8 @@ from .dropbox import (
|
|||
from .generate_image import create_generate_image_tool
|
||||
from .gmail import (
|
||||
create_create_gmail_draft_tool,
|
||||
create_read_gmail_email_tool,
|
||||
create_search_gmail_tool,
|
||||
create_send_gmail_email_tool,
|
||||
create_trash_gmail_email_tool,
|
||||
create_update_gmail_draft_tool,
|
||||
|
|
@ -64,21 +71,18 @@ from .gmail import (
|
|||
from .google_calendar import (
|
||||
create_create_calendar_event_tool,
|
||||
create_delete_calendar_event_tool,
|
||||
create_search_calendar_events_tool,
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
from .google_drive import (
|
||||
create_create_google_drive_file_tool,
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
from .jira import (
|
||||
create_create_jira_issue_tool,
|
||||
create_delete_jira_issue_tool,
|
||||
create_update_jira_issue_tool,
|
||||
)
|
||||
from .linear import (
|
||||
create_create_linear_issue_tool,
|
||||
create_delete_linear_issue_tool,
|
||||
create_update_linear_issue_tool,
|
||||
from .connected_accounts import create_get_connected_accounts_tool
|
||||
from .luma import (
|
||||
create_create_luma_event_tool,
|
||||
create_list_luma_events_tool,
|
||||
create_read_luma_event_tool,
|
||||
)
|
||||
from .mcp_tool import load_mcp_tools
|
||||
from .notion import (
|
||||
|
|
@ -95,6 +99,11 @@ from .report import create_generate_report_tool
|
|||
from .resume import create_generate_resume_tool
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .teams import (
|
||||
create_list_teams_channels_tool,
|
||||
create_read_teams_messages_tool,
|
||||
create_send_teams_message_tool,
|
||||
)
|
||||
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
from .web_search import create_web_search_tool
|
||||
|
|
@ -114,6 +123,8 @@ class ToolDefinition:
|
|||
factory: Callable that creates the tool. Receives a dict of dependencies.
|
||||
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
|
||||
enabled_by_default: Whether the tool is enabled when no explicit config is provided
|
||||
required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``)
|
||||
that must be in ``available_connectors`` for the tool to be enabled.
|
||||
|
||||
"""
|
||||
|
||||
|
|
@ -123,6 +134,7 @@ class ToolDefinition:
|
|||
requires: list[str] = field(default_factory=list)
|
||||
enabled_by_default: bool = True
|
||||
hidden: bool = False
|
||||
required_connector: str | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -221,6 +233,21 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["db_session"],
|
||||
),
|
||||
# =========================================================================
|
||||
# SERVICE ACCOUNT DISCOVERY
|
||||
# Generic tool for the LLM to discover connected accounts and resolve
|
||||
# service-specific identifiers (e.g. Jira cloudId, Slack team, etc.)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="get_connected_accounts",
|
||||
description="Discover connected accounts for a service and their metadata",
|
||||
factory=lambda deps: create_get_connected_accounts_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# MEMORY TOOL - single update_memory, private or team by thread_visibility
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
|
|
@ -248,40 +275,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
],
|
||||
),
|
||||
# =========================================================================
|
||||
# LINEAR TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Linear connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_linear_issue",
|
||||
description="Create a new issue in the user's Linear workspace",
|
||||
factory=lambda deps: create_create_linear_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_linear_issue",
|
||||
description="Update an existing indexed Linear issue",
|
||||
factory=lambda deps: create_update_linear_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_linear_issue",
|
||||
description="Archive (delete) an existing indexed Linear issue",
|
||||
factory=lambda deps: create_delete_linear_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# NOTION TOOLS - create, update, delete pages
|
||||
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
|
|
@ -294,6 +287,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_notion_page",
|
||||
|
|
@ -304,6 +298,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_notion_page",
|
||||
|
|
@ -314,6 +309,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files
|
||||
|
|
@ -328,6 +324,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_DRIVE_FILE",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_google_drive_file",
|
||||
|
|
@ -338,6 +335,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_DRIVE_FILE",
|
||||
),
|
||||
# =========================================================================
|
||||
# DROPBOX TOOLS - create and trash files
|
||||
|
|
@ -352,6 +350,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DROPBOX_FILE",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_dropbox_file",
|
||||
|
|
@ -362,6 +361,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DROPBOX_FILE",
|
||||
),
|
||||
# =========================================================================
|
||||
# ONEDRIVE TOOLS - create and trash files
|
||||
|
|
@ -376,6 +376,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="ONEDRIVE_FILE",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_onedrive_file",
|
||||
|
|
@ -386,11 +387,23 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="ONEDRIVE_FILE",
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE CALENDAR TOOLS - create, update, delete events
|
||||
# GOOGLE CALENDAR TOOLS - search, create, update, delete events
|
||||
# Auto-disabled when no Google Calendar connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="search_calendar_events",
|
||||
description="Search Google Calendar events within a date range",
|
||||
factory=lambda deps: create_search_calendar_events_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="create_calendar_event",
|
||||
description="Create a new event on Google Calendar",
|
||||
|
|
@ -400,6 +413,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_calendar_event",
|
||||
|
|
@ -410,6 +424,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_calendar_event",
|
||||
|
|
@ -420,11 +435,34 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# GMAIL TOOLS - create drafts, update drafts, send emails, trash emails
|
||||
# GMAIL TOOLS - search, read, create drafts, update drafts, send, trash
|
||||
# Auto-disabled when no Gmail connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="search_gmail",
|
||||
description="Search emails in Gmail using Gmail search syntax",
|
||||
factory=lambda deps: create_search_gmail_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_gmail_email",
|
||||
description="Read the full content of a specific Gmail email",
|
||||
factory=lambda deps: create_read_gmail_email_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="create_gmail_draft",
|
||||
description="Create a draft email in Gmail",
|
||||
|
|
@ -434,6 +472,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_gmail_email",
|
||||
|
|
@ -444,6 +483,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="trash_gmail_email",
|
||||
|
|
@ -454,6 +494,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_gmail_draft",
|
||||
|
|
@ -464,40 +505,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# JIRA TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Jira connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_jira_issue",
|
||||
description="Create a new issue in the user's Jira project",
|
||||
factory=lambda deps: create_create_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_jira_issue",
|
||||
description="Update an existing indexed Jira issue",
|
||||
factory=lambda deps: create_update_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_jira_issue",
|
||||
description="Delete an existing indexed Jira issue",
|
||||
factory=lambda deps: create_delete_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# CONFLUENCE TOOLS - create, update, delete pages
|
||||
|
|
@ -512,6 +520,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_confluence_page",
|
||||
|
|
@ -522,6 +531,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_confluence_page",
|
||||
|
|
@ -532,6 +542,118 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# DISCORD TOOLS - list channels, read messages, send messages
|
||||
# Auto-disabled when no Discord connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="list_discord_channels",
|
||||
description="List text channels in the connected Discord server",
|
||||
factory=lambda deps: create_list_discord_channels_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DISCORD_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_discord_messages",
|
||||
description="Read recent messages from a Discord text channel",
|
||||
factory=lambda deps: create_read_discord_messages_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DISCORD_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_discord_message",
|
||||
description="Send a message to a Discord text channel",
|
||||
factory=lambda deps: create_send_discord_message_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DISCORD_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# TEAMS TOOLS - list channels, read messages, send messages
|
||||
# Auto-disabled when no Teams connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="list_teams_channels",
|
||||
description="List Microsoft Teams and their channels",
|
||||
factory=lambda deps: create_list_teams_channels_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="TEAMS_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_teams_messages",
|
||||
description="Read recent messages from a Microsoft Teams channel",
|
||||
factory=lambda deps: create_read_teams_messages_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="TEAMS_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_teams_message",
|
||||
description="Send a message to a Microsoft Teams channel",
|
||||
factory=lambda deps: create_send_teams_message_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="TEAMS_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# LUMA TOOLS - list events, read event details, create events
|
||||
# Auto-disabled when no Luma connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="list_luma_events",
|
||||
description="List upcoming and recent Luma events",
|
||||
factory=lambda deps: create_list_luma_events_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="LUMA_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_luma_event",
|
||||
description="Read detailed information about a specific Luma event",
|
||||
factory=lambda deps: create_read_luma_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="LUMA_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="create_luma_event",
|
||||
description="Create a new event on Luma",
|
||||
factory=lambda deps: create_create_luma_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="LUMA_CONNECTOR",
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -549,6 +671,22 @@ def get_tool_by_name(name: str) -> ToolDefinition | None:
|
|||
return None
|
||||
|
||||
|
||||
def get_connector_gated_tools(
|
||||
available_connectors: list[str] | None,
|
||||
) -> list[str]:
|
||||
"""Return tool names to disable"""
|
||||
if available_connectors is None:
|
||||
available = set()
|
||||
else:
|
||||
available = set(available_connectors)
|
||||
|
||||
disabled: list[str] = []
|
||||
for tool_def in BUILTIN_TOOLS:
|
||||
if tool_def.required_connector and tool_def.required_connector not in available:
|
||||
disabled.append(tool_def.name)
|
||||
return disabled
|
||||
|
||||
|
||||
def get_all_tool_names() -> list[str]:
|
||||
"""Get names of all registered tools."""
|
||||
return [tool_def.name for tool_def in BUILTIN_TOOLS]
|
||||
|
|
@ -690,15 +828,15 @@ async def build_tools_async(
|
|||
)
|
||||
tools.extend(mcp_tools)
|
||||
logging.info(
|
||||
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
||||
"Registered %d MCP tools: %s",
|
||||
len(mcp_tools), [t.name for t in mcp_tools],
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail - just continue without MCP tools
|
||||
logging.exception(f"Failed to load MCP tools: {e!s}")
|
||||
logging.exception("Failed to load MCP tools: %s", e)
|
||||
|
||||
# Log all tools being returned to agent
|
||||
logging.info(
|
||||
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
|
||||
"Total tools for agent: %d — %s",
|
||||
len(tools), [t.name for t in tools],
|
||||
)
|
||||
|
||||
return tools
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.teams.list_channels import (
|
||||
create_list_teams_channels_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.teams.read_messages import (
|
||||
create_read_teams_messages_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.teams.send_message import (
|
||||
create_send_teams_message_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_list_teams_channels_tool",
|
||||
"create_read_teams_messages_tool",
|
||||
"create_send_teams_message_tool",
|
||||
]
|
||||
37
surfsense_backend/app/agents/new_chat/tools/teams/_auth.py
Normal file
37
surfsense_backend/app/agents/new_chat/tools/teams/_auth.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""Shared auth helper for Teams agent tools (Microsoft Graph REST API)."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
GRAPH_API = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
|
||||
async def get_teams_connector(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> SearchSourceConnector | None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def get_access_token(
|
||||
db_session: AsyncSession,
|
||||
connector: SearchSourceConnector,
|
||||
) -> str:
|
||||
"""Get a valid Microsoft Graph access token, refreshing if expired."""
|
||||
from app.connectors.teams_connector import TeamsConnector
|
||||
|
||||
tc = TeamsConnector(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
return await tc._get_valid_token()
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_list_teams_channels_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def list_teams_channels() -> dict[str, Any]:
|
||||
"""List all Microsoft Teams and their channels the user has access to.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of teams, each containing
|
||||
team_id, team_name, and a list of channels (id, name).
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
teams_resp = await client.get(f"{GRAPH_API}/me/joinedTeams", headers=headers)
|
||||
|
||||
if teams_resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||
if teams_resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Graph API error: {teams_resp.status_code}"}
|
||||
|
||||
teams_data = teams_resp.json().get("value", [])
|
||||
result_teams = []
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
for team in teams_data:
|
||||
team_id = team["id"]
|
||||
ch_resp = await client.get(
|
||||
f"{GRAPH_API}/teams/{team_id}/channels",
|
||||
headers=headers,
|
||||
)
|
||||
channels = []
|
||||
if ch_resp.status_code == 200:
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch.get("displayName", "")}
|
||||
for ch in ch_resp.json().get("value", [])
|
||||
]
|
||||
result_teams.append({
|
||||
"team_id": team_id,
|
||||
"team_name": team.get("displayName", ""),
|
||||
"channels": channels,
|
||||
})
|
||||
|
||||
return {"status": "success", "teams": result_teams, "total_teams": len(result_teams)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error listing Teams channels: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to list Teams channels."}
|
||||
|
||||
return list_teams_channels
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_read_teams_messages_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_teams_messages(
|
||||
team_id: str,
|
||||
channel_id: str,
|
||||
limit: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Read recent messages from a Microsoft Teams channel.
|
||||
|
||||
Args:
|
||||
team_id: The team ID (from list_teams_channels).
|
||||
channel_id: The channel ID (from list_teams_channels).
|
||||
limit: Number of messages to fetch (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of messages including
|
||||
id, sender, content, timestamp.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
limit = min(limit, 50)
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.get(
|
||||
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
params={"$top": limit},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Insufficient permissions to read this channel."}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Graph API error: {resp.status_code}"}
|
||||
|
||||
raw_msgs = resp.json().get("value", [])
|
||||
messages = []
|
||||
for m in raw_msgs:
|
||||
sender = m.get("from", {})
|
||||
user_info = sender.get("user", {}) if sender else {}
|
||||
body = m.get("body", {})
|
||||
messages.append({
|
||||
"id": m.get("id"),
|
||||
"sender": user_info.get("displayName", "Unknown"),
|
||||
"content": body.get("content", ""),
|
||||
"content_type": body.get("contentType", "text"),
|
||||
"timestamp": m.get("createdDateTime", ""),
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"team_id": team_id,
|
||||
"channel_id": channel_id,
|
||||
"messages": messages,
|
||||
"total": len(messages),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Teams messages: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read Teams messages."}
|
||||
|
||||
return read_teams_messages
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_teams_message_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_teams_message(
|
||||
team_id: str,
|
||||
channel_id: str,
|
||||
content: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a message to a Microsoft Teams channel.
|
||||
|
||||
Requires the ChannelMessage.Send OAuth scope. If the user gets a
|
||||
permission error, they may need to re-authenticate with updated scopes.
|
||||
|
||||
Args:
|
||||
team_id: The team ID (from list_teams_channels).
|
||||
channel_id: The channel ID (from list_teams_channels).
|
||||
content: The message text (HTML supported).
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message_id on success.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="teams_send_message",
|
||||
tool_name="send_teams_message",
|
||||
params={"team_id": team_id, "channel_id": channel_id, "content": content},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined. Message was not sent."}
|
||||
|
||||
final_content = result.params.get("content", content)
|
||||
final_team = result.params.get("team_id", team_id)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(
|
||||
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"body": {"content": final_content}},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
|
||||
}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {"status": "error", "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}"}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": f"Message sent to Teams channel.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error sending Teams message: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to send Teams message."}
|
||||
|
||||
return send_teams_message
|
||||
41
surfsense_backend/app/agents/new_chat/tools/tool_response.py
Normal file
41
surfsense_backend/app/agents/new_chat/tools/tool_response.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Standardised response dict factories for LangChain agent tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ToolResponse:
|
||||
|
||||
@staticmethod
|
||||
def success(message: str, **data: Any) -> dict[str, Any]:
|
||||
return {"status": "success", "message": message, **data}
|
||||
|
||||
@staticmethod
|
||||
def error(error: str, **data: Any) -> dict[str, Any]:
|
||||
return {"status": "error", "error": error, **data}
|
||||
|
||||
@staticmethod
|
||||
def auth_error(service: str, **data: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"error": (
|
||||
f"{service} authentication has expired or been revoked. "
|
||||
"Please re-connect the integration in Settings → Connectors."
|
||||
),
|
||||
**data,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def rejected(message: str = "Action was declined by the user.") -> dict[str, Any]:
|
||||
return {"status": "rejected", "message": message}
|
||||
|
||||
@staticmethod
|
||||
def not_found(
|
||||
resource: str, identifier: str, **data: Any
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"error": f"{resource} '{identifier}' was not found.",
|
||||
**data,
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue