diff --git a/surfsense_backend/app/agents/new_chat/chat_deepagent.py b/surfsense_backend/app/agents/new_chat/chat_deepagent.py index a901a7519..4b204ffa9 100644 --- a/surfsense_backend/app/agents/new_chat/chat_deepagent.py +++ b/surfsense_backend/app/agents/new_chat/chat_deepagent.py @@ -47,7 +47,7 @@ from app.agents.new_chat.system_prompt import ( build_configurable_system_prompt, build_surfsense_system_prompt, ) -from app.agents.new_chat.tools.registry import build_tools_async +from app.agents.new_chat.tools.registry import build_tools_async, get_connector_gated_tools from app.db import ChatVisibility from app.services.connector_service import ConnectorService from app.utils.perf import get_perf_logger @@ -287,105 +287,10 @@ async def create_surfsense_deep_agent( "llm": llm, } - # Disable Notion action tools if no Notion connector is configured modified_disabled_tools = list(disabled_tools) if disabled_tools else [] - has_notion_connector = ( - available_connectors is not None and "NOTION_CONNECTOR" in available_connectors + modified_disabled_tools.extend( + get_connector_gated_tools(available_connectors) ) - if not has_notion_connector: - notion_tools = [ - "create_notion_page", - "update_notion_page", - "delete_notion_page", - ] - modified_disabled_tools.extend(notion_tools) - - # Disable Linear action tools if no Linear connector is configured - has_linear_connector = ( - available_connectors is not None and "LINEAR_CONNECTOR" in available_connectors - ) - if not has_linear_connector: - linear_tools = [ - "create_linear_issue", - "update_linear_issue", - "delete_linear_issue", - ] - modified_disabled_tools.extend(linear_tools) - - # Disable Google Drive action tools if no Google Drive connector is configured - has_google_drive_connector = ( - available_connectors is not None and "GOOGLE_DRIVE_FILE" in available_connectors - ) - if not has_google_drive_connector: - google_drive_tools = [ - "create_google_drive_file", - "delete_google_drive_file", - ] - modified_disabled_tools.extend(google_drive_tools) - - has_dropbox_connector = ( - available_connectors is not None and "DROPBOX_FILE" in available_connectors - ) - if not has_dropbox_connector: - modified_disabled_tools.extend(["create_dropbox_file", "delete_dropbox_file"]) - - has_onedrive_connector = ( - available_connectors is not None and "ONEDRIVE_FILE" in available_connectors - ) - if not has_onedrive_connector: - modified_disabled_tools.extend(["create_onedrive_file", "delete_onedrive_file"]) - - # Disable Google Calendar action tools if no Google Calendar connector is configured - has_google_calendar_connector = ( - available_connectors is not None - and "GOOGLE_CALENDAR_CONNECTOR" in available_connectors - ) - if not has_google_calendar_connector: - calendar_tools = [ - "create_calendar_event", - "update_calendar_event", - "delete_calendar_event", - ] - modified_disabled_tools.extend(calendar_tools) - - # Disable Gmail action tools if no Gmail connector is configured - has_gmail_connector = ( - available_connectors is not None - and "GOOGLE_GMAIL_CONNECTOR" in available_connectors - ) - if not has_gmail_connector: - gmail_tools = [ - "create_gmail_draft", - "update_gmail_draft", - "send_gmail_email", - "trash_gmail_email", - ] - modified_disabled_tools.extend(gmail_tools) - - # Disable Jira action tools if no Jira connector is configured - has_jira_connector = ( - available_connectors is not None and "JIRA_CONNECTOR" in available_connectors - ) - if not has_jira_connector: - jira_tools = [ - "create_jira_issue", - "update_jira_issue", - "delete_jira_issue", - ] - modified_disabled_tools.extend(jira_tools) - - # Disable Confluence action tools if no Confluence connector is configured - has_confluence_connector = ( - available_connectors is not None - and "CONFLUENCE_CONNECTOR" in available_connectors - ) - if not has_confluence_connector: - confluence_tools = [ - "create_confluence_page", - "update_confluence_page", - "delete_confluence_page", - ] - modified_disabled_tools.extend(confluence_tools) # Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware. if "search_knowledge_base" not in modified_disabled_tools: diff --git a/surfsense_backend/app/agents/new_chat/system_prompt.py b/surfsense_backend/app/agents/new_chat/system_prompt.py index 9b8a7e0f9..3182735d9 100644 --- a/surfsense_backend/app/agents/new_chat/system_prompt.py +++ b/surfsense_backend/app/agents/new_chat/system_prompt.py @@ -38,8 +38,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: * Formatting, summarization, or analysis of content already present in the conversation * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") * Tool-usage actions like generating reports, podcasts, images, or scraping webpages + * Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see below + +CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable. +Their data is NEVER in the knowledge base. You MUST call their tools immediately — never +say "I don't see it in the knowledge base" or ask the user if they want you to check. +Ignore any knowledge base results for these services. + +When to use which tool: +- Linear (issues) → list_issues, get_issue, save_issue (create/update) +- ClickUp (tasks) → clickup_search, clickup_get_task +- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue +- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread +- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table +- Knowledge base content (Notion, GitHub, files, notes) → automatically searched +- Real-time public web data → call web_search +- Reading a specific webpage → call scrape_webpage + + + +Some service tools require identifiers or context you do not have (account IDs, +workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw +IDs or technical identifiers — they cannot memorise them. + +Instead, follow this discovery pattern: +1. Call a listing/discovery tool to find available options. +2. ONE result → use it silently, no question to the user. +3. MULTIPLE results → present the options by their display names and let the + user choose. Never show raw UUIDs — always use friendly names. + +Discovery tools by level: +- Which account/workspace? → get_connected_accounts("") +- Which Jira site (cloudId)? → getAccessibleAtlassianResources +- Which Jira project? → getVisibleJiraProjects (after resolving cloudId) +- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project) +- Which channel? → slack_search_channels +- Which base? → list_bases +- Which table? → list_tables_for_base (after resolving baseId) +- Which task? → clickup_search +- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira) + +For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to +obtain the cloudId, then pass it to other Jira tools. When creating an issue, +chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue. +If there is only one option at each step, use it silently. If multiple, present +friendly names. + +Chain discovery when needed — e.g. for Airtable records: list_bases → pick +base → list_tables_for_base → pick table → list_records_for_table. + +MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for +the same service, tool names are prefixed to avoid collisions — e.g. +linear_25_list_issues and linear_30_list_issues instead of two list_issues. +Each prefixed tool's description starts with [Account: ] so you +know which account it targets. Use get_connected_accounts("") to see +the full list of accounts with their connector IDs and display names. +When only one account is connected, tools have their normal unprefixed names. + + IMPORTANT — After understanding each user message, ALWAYS check: does this message reveal durable facts about the user (role, interests, preferences, projects, @@ -76,8 +134,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE: * Formatting, summarization, or analysis of content already present in the conversation * Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points") * Tool-usage actions like generating reports, podcasts, images, or scraping webpages + * Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see below + +CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable. +Their data is NEVER in the knowledge base. You MUST call their tools immediately — never +say "I don't see it in the knowledge base" or ask if they want you to check. +Ignore any knowledge base results for these services. + +When to use which tool: +- Linear (issues) → list_issues, get_issue, save_issue (create/update) +- ClickUp (tasks) → clickup_search, clickup_get_task +- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue +- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread +- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table +- Knowledge base content (Notion, GitHub, files, notes) → automatically searched +- Real-time public web data → call web_search +- Reading a specific webpage → call scrape_webpage + + + +Some service tools require identifiers or context you do not have (account IDs, +workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw +IDs or technical identifiers — they cannot memorise them. + +Instead, follow this discovery pattern: +1. Call a listing/discovery tool to find available options. +2. ONE result → use it silently, no question to the user. +3. MULTIPLE results → present the options by their display names and let the + user choose. Never show raw UUIDs — always use friendly names. + +Discovery tools by level: +- Which account/workspace? → get_connected_accounts("") +- Which Jira site (cloudId)? → getAccessibleAtlassianResources +- Which Jira project? → getVisibleJiraProjects (after resolving cloudId) +- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project) +- Which channel? → slack_search_channels +- Which base? → list_bases +- Which table? → list_tables_for_base (after resolving baseId) +- Which task? → clickup_search +- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira) + +For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to +obtain the cloudId, then pass it to other Jira tools. When creating an issue, +chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue. +If there is only one option at each step, use it silently. If multiple, present +friendly names. + +Chain discovery when needed — e.g. for Airtable records: list_bases → pick +base → list_tables_for_base → pick table → list_records_for_table. + +MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for +the same service, tool names are prefixed to avoid collisions — e.g. +linear_25_list_issues and linear_30_list_issues instead of two list_issues. +Each prefixed tool's description starts with [Account: ] so you +know which account it targets. Use get_connected_accounts("") to see +the full list of accounts with their connector IDs and display names. +When only one account is connected, tools have their normal unprefixed names. + + IMPORTANT — After understanding each user message, ALWAYS check: does this message reveal durable facts about the team (decisions, conventions, architecture, processes, diff --git a/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py new file mode 100644 index 000000000..e0b1978e1 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/connected_accounts.py @@ -0,0 +1,109 @@ +"""Connected-accounts discovery tool. + +Lets the LLM discover which accounts are connected for a given service +(e.g. "jira", "linear", "slack") and retrieve the metadata it needs to +call action tools — such as Jira's ``cloudId``. + +The tool returns **only** non-sensitive fields explicitly listed in the +service's ``account_metadata_keys`` (see ``registry.py``), plus the +always-present ``display_name`` and ``connector_id``. +""" + +import logging +from typing import Any + +from langchain_core.tools import StructuredTool +from pydantic import BaseModel, Field +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.services.mcp_oauth.registry import MCP_SERVICES + +logger = logging.getLogger(__name__) + +_SERVICE_KEY_BY_CONNECTOR_TYPE: dict[str, str] = { + cfg.connector_type: key for key, cfg in MCP_SERVICES.items() +} + + +class GetConnectedAccountsInput(BaseModel): + service: str = Field( + description=( + "Service key to look up connected accounts for. " + "Valid values: " + ", ".join(sorted(MCP_SERVICES.keys())) + ), + ) + + +def _extract_display_name(connector: SearchSourceConnector) -> str: + """Best-effort human-readable label for a connector.""" + cfg = connector.config or {} + if cfg.get("display_name"): + return cfg["display_name"] + if cfg.get("base_url"): + return f"{connector.name} ({cfg['base_url']})" + if cfg.get("organization_name"): + return f"{connector.name} ({cfg['organization_name']})" + return connector.name + + +def create_get_connected_accounts_tool( + db_session: AsyncSession, + search_space_id: int, + user_id: str, +) -> StructuredTool: + + async def _run(service: str) -> list[dict[str, Any]]: + svc_cfg = MCP_SERVICES.get(service) + if not svc_cfg: + return [{"error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}"}] + + try: + connector_type = SearchSourceConnectorType(svc_cfg.connector_type) + except ValueError: + return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}] + + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type == connector_type, + ) + ) + connectors = result.scalars().all() + + if not connectors: + return [{"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."}] + + is_multi = len(connectors) > 1 + + accounts: list[dict[str, Any]] = [] + for conn in connectors: + cfg = conn.config or {} + entry: dict[str, Any] = { + "connector_id": conn.id, + "display_name": _extract_display_name(conn), + "service": service, + } + if is_multi: + entry["tool_prefix"] = f"{service}_{conn.id}" + for key in svc_cfg.account_metadata_keys: + if key in cfg: + entry[key] = cfg[key] + accounts.append(entry) + + return accounts + + return StructuredTool( + name="get_connected_accounts", + description=( + "Discover which accounts are connected for a service (e.g. jira, linear, slack, clickup, airtable). " + "Returns display names and service-specific metadata the action tools need " + "(e.g. Jira's cloudId). Call this BEFORE using a service's action tools when " + "you need an account identifier or are unsure which account to use." + ), + coroutine=_run, + args_schema=GetConnectedAccountsInput, + metadata={"hitl": False}, + ) diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/__init__.py b/surfsense_backend/app/agents/new_chat/tools/discord/__init__.py new file mode 100644 index 000000000..b4eaec1f0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/discord/__init__.py @@ -0,0 +1,15 @@ +from app.agents.new_chat.tools.discord.list_channels import ( + create_list_discord_channels_tool, +) +from app.agents.new_chat.tools.discord.read_messages import ( + create_read_discord_messages_tool, +) +from app.agents.new_chat.tools.discord.send_message import ( + create_send_discord_message_tool, +) + +__all__ = [ + "create_list_discord_channels_tool", + "create_read_discord_messages_tool", + "create_send_discord_message_tool", +] diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py b/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py new file mode 100644 index 000000000..1f51e3660 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/discord/_auth.py @@ -0,0 +1,42 @@ +"""Shared auth helper for Discord agent tools (REST API, not gateway bot).""" + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.utils.oauth_security import TokenEncryption + +DISCORD_API = "https://discord.com/api/v10" + + +async def get_discord_connector( + db_session: AsyncSession, + search_space_id: int, + user_id: str, +) -> SearchSourceConnector | None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR, + ) + ) + return result.scalars().first() + + +def get_bot_token(connector: SearchSourceConnector) -> str: + """Extract and decrypt the bot token from connector config.""" + cfg = dict(connector.config) + if cfg.get("_token_encrypted") and config.SECRET_KEY: + enc = TokenEncryption(config.SECRET_KEY) + if cfg.get("bot_token"): + cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"]) + token = cfg.get("bot_token") + if not token: + raise ValueError("Discord bot token not found in connector config.") + return token + + +def get_guild_id(connector: SearchSourceConnector) -> str | None: + return connector.config.get("guild_id") diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py new file mode 100644 index 000000000..a33b88aa0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/discord/list_channels.py @@ -0,0 +1,67 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id + +logger = logging.getLogger(__name__) + + +def create_list_discord_channels_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def list_discord_channels() -> dict[str, Any]: + """List text channels in the connected Discord server. + + Returns: + Dictionary with status and a list of channels (id, name). + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Discord tool not properly configured."} + + try: + connector = await get_discord_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Discord connector found."} + + guild_id = get_guild_id(connector) + if not guild_id: + return {"status": "error", "message": "No guild ID in Discord connector config."} + + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{DISCORD_API}/guilds/{guild_id}/channels", + headers={"Authorization": f"Bot {token}"}, + timeout=15.0, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + if resp.status_code != 200: + return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + + # Type 0 = text channel + channels = [ + {"id": ch["id"], "name": ch["name"]} + for ch in resp.json() + if ch.get("type") == 0 + ] + return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error listing Discord channels: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to list Discord channels."} + + return list_discord_channels diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py new file mode 100644 index 000000000..852a9297b --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/discord/read_messages.py @@ -0,0 +1,80 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import DISCORD_API, get_bot_token, get_discord_connector + +logger = logging.getLogger(__name__) + + +def create_read_discord_messages_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def read_discord_messages( + channel_id: str, + limit: int = 25, + ) -> dict[str, Any]: + """Read recent messages from a Discord text channel. + + Args: + channel_id: The Discord channel ID (from list_discord_channels). + limit: Number of messages to fetch (default 25, max 50). + + Returns: + Dictionary with status and a list of messages including + id, author, content, timestamp. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Discord tool not properly configured."} + + limit = min(limit, 50) + + try: + connector = await get_discord_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Discord connector found."} + + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.get( + f"{DISCORD_API}/channels/{channel_id}/messages", + headers={"Authorization": f"Bot {token}"}, + params={"limit": limit}, + timeout=15.0, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + if resp.status_code == 403: + return {"status": "error", "message": "Bot lacks permission to read this channel."} + if resp.status_code != 200: + return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + + messages = [ + { + "id": m["id"], + "author": m.get("author", {}).get("username", "Unknown"), + "content": m.get("content", ""), + "timestamp": m.get("timestamp", ""), + } + for m in resp.json() + ] + + return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error reading Discord messages: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to read Discord messages."} + + return read_discord_messages diff --git a/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py new file mode 100644 index 000000000..be4e6fdb2 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/discord/send_message.py @@ -0,0 +1,96 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.tools.hitl import request_approval + +from ._auth import DISCORD_API, get_bot_token, get_discord_connector + +logger = logging.getLogger(__name__) + + +def create_send_discord_message_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def send_discord_message( + channel_id: str, + content: str, + ) -> dict[str, Any]: + """Send a message to a Discord text channel. + + Args: + channel_id: The Discord channel ID (from list_discord_channels). + content: The message text (max 2000 characters). + + Returns: + Dictionary with status, message_id on success. + + IMPORTANT: + - If status is "rejected", the user explicitly declined. Do NOT retry. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Discord tool not properly configured."} + + if len(content) > 2000: + return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."} + + try: + connector = await get_discord_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Discord connector found."} + + result = request_approval( + action_type="discord_send_message", + tool_name="send_discord_message", + params={"channel_id": channel_id, "content": content}, + context={"connector_id": connector.id}, + ) + + if result.rejected: + return {"status": "rejected", "message": "User declined. Message was not sent."} + + final_content = result.params.get("content", content) + final_channel = result.params.get("channel_id", channel_id) + + token = get_bot_token(connector) + + async with httpx.AsyncClient() as client: + resp = await client.post( + f"{DISCORD_API}/channels/{final_channel}/messages", + headers={ + "Authorization": f"Bot {token}", + "Content-Type": "application/json", + }, + json={"content": final_content}, + timeout=15.0, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"} + if resp.status_code == 403: + return {"status": "error", "message": "Bot lacks permission to send messages in this channel."} + if resp.status_code not in (200, 201): + return {"status": "error", "message": f"Discord API error: {resp.status_code}"} + + msg_data = resp.json() + return { + "status": "success", + "message_id": msg_data.get("id"), + "message": f"Message sent to channel {final_channel}.", + } + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error sending Discord message: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to send Discord message."} + + return send_discord_message diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/__init__.py b/surfsense_backend/app/agents/new_chat/tools/gmail/__init__.py index efb2fb0fa..294840122 100644 --- a/surfsense_backend/app/agents/new_chat/tools/gmail/__init__.py +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/__init__.py @@ -1,6 +1,12 @@ from app.agents.new_chat.tools.gmail.create_draft import ( create_create_gmail_draft_tool, ) +from app.agents.new_chat.tools.gmail.read_email import ( + create_read_gmail_email_tool, +) +from app.agents.new_chat.tools.gmail.search_emails import ( + create_search_gmail_tool, +) from app.agents.new_chat.tools.gmail.send_email import ( create_send_gmail_email_tool, ) @@ -13,6 +19,8 @@ from app.agents.new_chat.tools.gmail.update_draft import ( __all__ = [ "create_create_gmail_draft_tool", + "create_read_gmail_email_tool", + "create_search_gmail_tool", "create_send_gmail_email_tool", "create_trash_gmail_email_tool", "create_update_gmail_draft_tool", diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py new file mode 100644 index 000000000..9071f129a --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/read_email.py @@ -0,0 +1,87 @@ +import logging +from typing import Any + +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import SearchSourceConnector, SearchSourceConnectorType + +logger = logging.getLogger(__name__) + +_GMAIL_TYPES = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, +] + + +def create_read_gmail_email_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def read_gmail_email(message_id: str) -> dict[str, Any]: + """Read the full content of a specific Gmail email by its message ID. + + Use after search_gmail to get the complete body of an email. + + Args: + message_id: The Gmail message ID (from search_gmail results). + + Returns: + Dictionary with status and the full email content formatted as markdown. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Gmail tool not properly configured."} + + try: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + } + + from app.agents.new_chat.tools.gmail.search_emails import _build_credentials + + creds = _build_credentials(connector) + + from app.connectors.google_gmail_connector import GoogleGmailConnector + + gmail = GoogleGmailConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + detail, error = await gmail.get_message_details(message_id) + if error: + if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): + return {"status": "auth_error", "message": error, "connector_type": "gmail"} + return {"status": "error", "message": error} + + if not detail: + return {"status": "not_found", "message": f"Email with ID '{message_id}' not found."} + + content = gmail.format_message_to_markdown(detail) + + return {"status": "success", "message_id": message_id, "content": content} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error reading Gmail email: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to read email. Please try again."} + + return read_gmail_email diff --git a/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py new file mode 100644 index 000000000..de43f03d0 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/gmail/search_emails.py @@ -0,0 +1,165 @@ +import logging +from datetime import datetime +from typing import Any + +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import SearchSourceConnector, SearchSourceConnectorType + +logger = logging.getLogger(__name__) + +_GMAIL_TYPES = [ + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, +] + +_token_encryption_cache: object | None = None + + +def _get_token_encryption(): + global _token_encryption_cache + if _token_encryption_cache is None: + from app.config import config + from app.utils.oauth_security import TokenEncryption + + if not config.SECRET_KEY: + raise RuntimeError("SECRET_KEY not configured for token decryption.") + _token_encryption_cache = TokenEncryption(config.SECRET_KEY) + return _token_encryption_cache + + +def _build_credentials(connector: SearchSourceConnector): + """Build Google OAuth Credentials from a connector's stored config. + + Handles both native OAuth connectors (with encrypted tokens) and + Composio-backed connectors. Shared by Gmail and Calendar tools. + """ + from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES + + if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES: + from app.utils.google_credentials import build_composio_credentials + + cca_id = connector.config.get("composio_connected_account_id") + if not cca_id: + raise ValueError("Composio connected account ID not found.") + return build_composio_credentials(cca_id) + + from google.oauth2.credentials import Credentials + + cfg = dict(connector.config) + if cfg.get("_token_encrypted"): + enc = _get_token_encryption() + for key in ("token", "refresh_token", "client_secret"): + if cfg.get(key): + cfg[key] = enc.decrypt_token(cfg[key]) + + exp = (cfg.get("expiry") or "").replace("Z", "") + return Credentials( + token=cfg.get("token"), + refresh_token=cfg.get("refresh_token"), + token_uri=cfg.get("token_uri"), + client_id=cfg.get("client_id"), + client_secret=cfg.get("client_secret"), + scopes=cfg.get("scopes", []), + expiry=datetime.fromisoformat(exp) if exp else None, + ) + + +def create_search_gmail_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def search_gmail( + query: str, + max_results: int = 10, + ) -> dict[str, Any]: + """Search emails in the user's Gmail inbox using Gmail search syntax. + + Args: + query: Gmail search query, same syntax as the Gmail search bar. + Examples: "from:alice@example.com", "subject:meeting", + "is:unread", "after:2024/01/01 before:2024/02/01", + "has:attachment", "in:sent". + max_results: Number of emails to return (default 10, max 20). + + Returns: + Dictionary with status and a list of email summaries including + message_id, subject, from, date, snippet. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Gmail tool not properly configured."} + + max_results = min(max_results, 20) + + try: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_GMAIL_TYPES), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Gmail connector found. Please connect Gmail in your workspace settings.", + } + + creds = _build_credentials(connector) + + from app.connectors.google_gmail_connector import GoogleGmailConnector + + gmail = GoogleGmailConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + messages_list, error = await gmail.get_messages_list( + max_results=max_results, query=query + ) + if error: + if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): + return {"status": "auth_error", "message": error, "connector_type": "gmail"} + return {"status": "error", "message": error} + + if not messages_list: + return {"status": "success", "emails": [], "total": 0, "message": "No emails found."} + + emails = [] + for msg in messages_list: + detail, err = await gmail.get_message_details(msg["id"]) + if err: + continue + headers = { + h["name"].lower(): h["value"] + for h in detail.get("payload", {}).get("headers", []) + } + emails.append({ + "message_id": detail.get("id"), + "thread_id": detail.get("threadId"), + "subject": headers.get("subject", "No Subject"), + "from": headers.get("from", "Unknown"), + "to": headers.get("to", ""), + "date": headers.get("date", ""), + "snippet": detail.get("snippet", ""), + "labels": detail.get("labelIds", []), + }) + + return {"status": "success", "emails": emails, "total": len(emails)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error searching Gmail: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to search Gmail. Please try again."} + + return search_gmail diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/__init__.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/__init__.py index d1ce4e795..13d4c06cb 100644 --- a/surfsense_backend/app/agents/new_chat/tools/google_calendar/__init__.py +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/__init__.py @@ -4,6 +4,9 @@ from app.agents.new_chat.tools.google_calendar.create_event import ( from app.agents.new_chat.tools.google_calendar.delete_event import ( create_delete_calendar_event_tool, ) +from app.agents.new_chat.tools.google_calendar.search_events import ( + create_search_calendar_events_tool, +) from app.agents.new_chat.tools.google_calendar.update_event import ( create_update_calendar_event_tool, ) @@ -11,5 +14,6 @@ from app.agents.new_chat.tools.google_calendar.update_event import ( __all__ = [ "create_create_calendar_event_tool", "create_delete_calendar_event_tool", + "create_search_calendar_events_tool", "create_update_calendar_event_tool", ] diff --git a/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py new file mode 100644 index 000000000..a622b0efa --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/google_calendar/search_events.py @@ -0,0 +1,114 @@ +import logging +from typing import Any + +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.agents.new_chat.tools.gmail.search_emails import _build_credentials +from app.db import SearchSourceConnector, SearchSourceConnectorType + +logger = logging.getLogger(__name__) + +_CALENDAR_TYPES = [ + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, +] + + +def create_search_calendar_events_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def search_calendar_events( + start_date: str, + end_date: str, + max_results: int = 25, + ) -> dict[str, Any]: + """Search Google Calendar events within a date range. + + Args: + start_date: Start date in YYYY-MM-DD format (e.g. "2026-04-01"). + end_date: End date in YYYY-MM-DD format (e.g. "2026-04-30"). + max_results: Maximum number of events to return (default 25, max 50). + + Returns: + Dictionary with status and a list of events including + event_id, summary, start, end, location, attendees. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Calendar tool not properly configured."} + + max_results = min(max_results, 50) + + try: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES), + ) + ) + connector = result.scalars().first() + if not connector: + return { + "status": "error", + "message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.", + } + + creds = _build_credentials(connector) + + from app.connectors.google_calendar_connector import GoogleCalendarConnector + + cal = GoogleCalendarConnector( + credentials=creds, + session=db_session, + user_id=user_id, + connector_id=connector.id, + ) + + events_raw, error = await cal.get_all_primary_calendar_events( + start_date=start_date, + end_date=end_date, + max_results=max_results, + ) + + if error: + if "re-authenticate" in error.lower() or "authentication failed" in error.lower(): + return {"status": "auth_error", "message": error, "connector_type": "google_calendar"} + if "no events found" in error.lower(): + return {"status": "success", "events": [], "total": 0, "message": error} + return {"status": "error", "message": error} + + events = [] + for ev in events_raw: + start = ev.get("start", {}) + end = ev.get("end", {}) + attendees_raw = ev.get("attendees", []) + events.append({ + "event_id": ev.get("id"), + "summary": ev.get("summary", "No Title"), + "start": start.get("dateTime") or start.get("date", ""), + "end": end.get("dateTime") or end.get("date", ""), + "location": ev.get("location", ""), + "description": ev.get("description", ""), + "html_link": ev.get("htmlLink", ""), + "attendees": [ + a.get("email", "") for a in attendees_raw[:10] + ], + "status": ev.get("status", ""), + }) + + return {"status": "success", "events": events, "total": len(events)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error searching calendar events: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to search calendar events. Please try again."} + + return search_calendar_events diff --git a/surfsense_backend/app/agents/new_chat/tools/hitl.py b/surfsense_backend/app/agents/new_chat/tools/hitl.py index 64ace547c..89f02abf6 100644 --- a/surfsense_backend/app/agents/new_chat/tools/hitl.py +++ b/surfsense_backend/app/agents/new_chat/tools/hitl.py @@ -130,8 +130,8 @@ def request_approval( try: decision_type, edited_params = _parse_decision(approval) except ValueError: - logger.warning("No approval decision received for %s", tool_name) - return HITLResult(rejected=False, decision_type="error", params=params) + logger.warning("No approval decision received for %s — rejecting for safety", tool_name) + return HITLResult(rejected=True, decision_type="error", params=params) logger.info("User decision for %s: %s", tool_name, decision_type) diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/__init__.py b/surfsense_backend/app/agents/new_chat/tools/luma/__init__.py new file mode 100644 index 000000000..255119bee --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/luma/__init__.py @@ -0,0 +1,15 @@ +from app.agents.new_chat.tools.luma.create_event import ( + create_create_luma_event_tool, +) +from app.agents.new_chat.tools.luma.list_events import ( + create_list_luma_events_tool, +) +from app.agents.new_chat.tools.luma.read_event import ( + create_read_luma_event_tool, +) + +__all__ = [ + "create_create_luma_event_tool", + "create_list_luma_events_tool", + "create_read_luma_event_tool", +] diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py b/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py new file mode 100644 index 000000000..1d88161d6 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/luma/_auth.py @@ -0,0 +1,38 @@ +"""Shared auth helper for Luma agent tools.""" + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import SearchSourceConnector, SearchSourceConnectorType + +LUMA_API = "https://public-api.luma.com/v1" + + +async def get_luma_connector( + db_session: AsyncSession, + search_space_id: int, + user_id: str, +) -> SearchSourceConnector | None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR, + ) + ) + return result.scalars().first() + + +def get_api_key(connector: SearchSourceConnector) -> str: + """Extract the API key from connector config (handles both key names).""" + key = connector.config.get("api_key") or connector.config.get("LUMA_API_KEY") + if not key: + raise ValueError("Luma API key not found in connector config.") + return key + + +def luma_headers(api_key: str) -> dict[str, str]: + return { + "Content-Type": "application/json", + "x-luma-api-key": api_key, + } diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py new file mode 100644 index 000000000..2217d29e6 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/luma/create_event.py @@ -0,0 +1,116 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.tools.hitl import request_approval + +from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers + +logger = logging.getLogger(__name__) + + +def create_create_luma_event_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def create_luma_event( + name: str, + start_at: str, + end_at: str, + description: str | None = None, + timezone: str = "UTC", + ) -> dict[str, Any]: + """Create a new event on Luma. + + Args: + name: The event title. + start_at: Start time in ISO 8601 format (e.g. "2026-05-01T18:00:00"). + end_at: End time in ISO 8601 format (e.g. "2026-05-01T20:00:00"). + description: Optional event description (markdown supported). + timezone: Timezone string (default "UTC", e.g. "America/New_York"). + + Returns: + Dictionary with status, event_id on success. + + IMPORTANT: + - If status is "rejected", the user explicitly declined. Do NOT retry. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Luma tool not properly configured."} + + try: + connector = await get_luma_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Luma connector found."} + + result = request_approval( + action_type="luma_create_event", + tool_name="create_luma_event", + params={ + "name": name, + "start_at": start_at, + "end_at": end_at, + "description": description, + "timezone": timezone, + }, + context={"connector_id": connector.id}, + ) + + if result.rejected: + return {"status": "rejected", "message": "User declined. Event was not created."} + + final_name = result.params.get("name", name) + final_start = result.params.get("start_at", start_at) + final_end = result.params.get("end_at", end_at) + final_desc = result.params.get("description", description) + final_tz = result.params.get("timezone", timezone) + + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + body: dict[str, Any] = { + "name": final_name, + "start_at": final_start, + "end_at": final_end, + "timezone": final_tz, + } + if final_desc: + body["description_md"] = final_desc + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post( + f"{LUMA_API}/event/create", + headers=headers, + json=body, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + if resp.status_code == 403: + return {"status": "error", "message": "Luma Plus subscription required to create events via API."} + if resp.status_code not in (200, 201): + return {"status": "error", "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}"} + + data = resp.json() + event_id = data.get("api_id") or data.get("event", {}).get("api_id") + + return { + "status": "success", + "event_id": event_id, + "message": f"Event '{final_name}' created on Luma.", + } + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error creating Luma event: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to create Luma event."} + + return create_luma_event diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py new file mode 100644 index 000000000..cd4721758 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/luma/list_events.py @@ -0,0 +1,100 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers + +logger = logging.getLogger(__name__) + + +def create_list_luma_events_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def list_luma_events( + max_results: int = 25, + ) -> dict[str, Any]: + """List upcoming and recent Luma events. + + Args: + max_results: Maximum events to return (default 25, max 50). + + Returns: + Dictionary with status and a list of events including + event_id, name, start_at, end_at, location, url. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Luma tool not properly configured."} + + max_results = min(max_results, 50) + + try: + connector = await get_luma_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Luma connector found."} + + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + all_entries: list[dict] = [] + cursor = None + + async with httpx.AsyncClient(timeout=20.0) as client: + while len(all_entries) < max_results: + params: dict[str, Any] = {"limit": min(100, max_results - len(all_entries))} + if cursor: + params["cursor"] = cursor + + resp = await client.get( + f"{LUMA_API}/calendar/list-events", + headers=headers, + params=params, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + if resp.status_code != 200: + return {"status": "error", "message": f"Luma API error: {resp.status_code}"} + + data = resp.json() + entries = data.get("entries", []) + if not entries: + break + all_entries.extend(entries) + + next_cursor = data.get("next_cursor") + if not next_cursor: + break + cursor = next_cursor + + events = [] + for entry in all_entries[:max_results]: + ev = entry.get("event", {}) + geo = ev.get("geo_info", {}) + events.append({ + "event_id": entry.get("api_id"), + "name": ev.get("name", "Untitled"), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location": geo.get("name", ""), + "url": ev.get("url", ""), + "visibility": ev.get("visibility", ""), + }) + + return {"status": "success", "events": events, "total": len(events)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error listing Luma events: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to list Luma events."} + + return list_luma_events diff --git a/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py new file mode 100644 index 000000000..eb3ac55c6 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/luma/read_event.py @@ -0,0 +1,82 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers + +logger = logging.getLogger(__name__) + + +def create_read_luma_event_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def read_luma_event(event_id: str) -> dict[str, Any]: + """Read detailed information about a specific Luma event. + + Args: + event_id: The Luma event API ID (from list_luma_events). + + Returns: + Dictionary with status and full event details including + description, attendees count, meeting URL. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Luma tool not properly configured."} + + try: + connector = await get_luma_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Luma connector found."} + + api_key = get_api_key(connector) + headers = luma_headers(api_key) + + async with httpx.AsyncClient(timeout=15.0) as client: + resp = await client.get( + f"{LUMA_API}/events/{event_id}", + headers=headers, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"} + if resp.status_code == 404: + return {"status": "not_found", "message": f"Event '{event_id}' not found."} + if resp.status_code != 200: + return {"status": "error", "message": f"Luma API error: {resp.status_code}"} + + data = resp.json() + ev = data.get("event", data) + geo = ev.get("geo_info", {}) + + event_detail = { + "event_id": event_id, + "name": ev.get("name", ""), + "description": ev.get("description", ""), + "start_at": ev.get("start_at", ""), + "end_at": ev.get("end_at", ""), + "timezone": ev.get("timezone", ""), + "location_name": geo.get("name", ""), + "address": geo.get("address", ""), + "url": ev.get("url", ""), + "meeting_url": ev.get("meeting_url", ""), + "visibility": ev.get("visibility", ""), + "cover_url": ev.get("cover_url", ""), + } + + return {"status": "success", "event": event_detail} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error reading Luma event: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to read Luma event."} + + return read_luma_event diff --git a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py index 9743d049d..8f8e5007f 100644 --- a/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py +++ b/surfsense_backend/app/agents/new_chat/tools/mcp_tool.py @@ -14,20 +14,28 @@ clicking "Always Allow", which adds the tool name to the connector's ``config.trusted_tools`` allow-list. """ +from __future__ import annotations + import logging import time -from typing import Any +from collections import defaultdict +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from app.utils.oauth_security import TokenEncryption from langchain_core.tools import StructuredTool from mcp import ClientSession from mcp.client.streamable_http import streamablehttp_client -from pydantic import BaseModel, create_model -from sqlalchemy import select +from pydantic import BaseModel, Field, create_model +from sqlalchemy import cast, select +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession from app.agents.new_chat.tools.hitl import request_approval from app.agents.new_chat.tools.mcp_client import MCPClient from app.db import SearchSourceConnector, SearchSourceConnectorType +from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type logger = logging.getLogger(__name__) @@ -63,18 +71,14 @@ def _create_dynamic_input_model_from_schema( param_description = param_schema.get("description", "") is_required = param_name in required_fields - from typing import Any as AnyType - - from pydantic import Field - if is_required: field_definitions[param_name] = ( - AnyType, + Any, Field(..., description=param_description), ) else: field_definitions[param_name] = ( - AnyType | None, + Any | None, Field(None, description=param_description), ) @@ -100,13 +104,13 @@ async def _create_mcp_tool_from_definition_stdio( tool_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) - logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}") + logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema) input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) async def mcp_tool_call(**kwargs) -> str: """Execute the MCP tool call via the client with retry support.""" - logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}") + logger.debug("MCP tool '%s' called", tool_name) # HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph hitl_result = request_approval( @@ -123,20 +127,18 @@ async def _create_mcp_tool_from_definition_stdio( ) if hitl_result.rejected: return "Tool call rejected by user." - call_kwargs = hitl_result.params + call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None} try: async with mcp_client.connect(): result = await mcp_client.call_tool(tool_name, call_kwargs) return str(result) except RuntimeError as e: - error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}" - logger.error(error_msg) - return f"Error: {error_msg}" + logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e) + return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}" except Exception as e: - error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}" - logger.exception(error_msg) - return f"Error: {error_msg}" + logger.exception("MCP tool '%s' execution failed: %s", tool_name, e) + return f"Error: MCP tool '{tool_name}' execution failed: {e!s}" tool = StructuredTool( name=tool_name, @@ -151,7 +153,7 @@ async def _create_mcp_tool_from_definition_stdio( }, ) - logger.info(f"Created MCP tool (stdio): '{tool_name}'") + logger.debug("Created MCP tool (stdio): '%s'", tool_name) return tool @@ -163,41 +165,57 @@ async def _create_mcp_tool_from_definition_http( connector_name: str = "", connector_id: int | None = None, trusted_tools: list[str] | None = None, + readonly_tools: frozenset[str] | None = None, + tool_name_prefix: str | None = None, ) -> StructuredTool: """Create a LangChain tool from an MCP tool definition (HTTP transport). - All MCP tools are unconditionally wrapped with HITL approval. - ``request_approval()`` is called OUTSIDE the try/except so that - ``GraphInterrupt`` propagates cleanly to LangGraph. + Write tools are wrapped with HITL approval; read-only tools (listed in + ``readonly_tools``) execute immediately without user confirmation. + + When ``tool_name_prefix`` is set (multi-account disambiguation), the + tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``) + but the actual MCP ``call_tool`` still uses the original name. """ - tool_name = tool_def.get("name", "unnamed_tool") + original_tool_name = tool_def.get("name", "unnamed_tool") tool_description = tool_def.get("description", "No description provided") input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}}) + is_readonly = readonly_tools is not None and original_tool_name in readonly_tools - logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}") + exposed_name = ( + f"{tool_name_prefix}_{original_tool_name}" + if tool_name_prefix + else original_tool_name + ) + if tool_name_prefix: + tool_description = f"[Account: {connector_name}] {tool_description}" - input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema) + logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema) + + input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema) async def mcp_http_tool_call(**kwargs) -> str: """Execute the MCP tool call via HTTP transport.""" - logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}") + logger.debug("MCP HTTP tool '%s' called", exposed_name) - # HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph - hitl_result = request_approval( - action_type="mcp_tool_call", - tool_name=tool_name, - params=kwargs, - context={ - "mcp_server": connector_name, - "tool_description": tool_description, - "mcp_transport": "http", - "mcp_connector_id": connector_id, - }, - trusted_tools=trusted_tools, - ) - if hitl_result.rejected: - return "Tool call rejected by user." - call_kwargs = hitl_result.params + if is_readonly: + call_kwargs = {k: v for k, v in kwargs.items() if v is not None} + else: + hitl_result = request_approval( + action_type="mcp_tool_call", + tool_name=exposed_name, + params=kwargs, + context={ + "mcp_server": connector_name, + "tool_description": tool_description, + "mcp_transport": "http", + "mcp_connector_id": connector_id, + }, + trusted_tools=trusted_tools, + ) + if hitl_result.rejected: + return "Tool call rejected by user." + call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None} try: async with ( @@ -205,7 +223,9 @@ async def _create_mcp_tool_from_definition_http( ClientSession(read, write) as session, ): await session.initialize() - response = await session.call_tool(tool_name, arguments=call_kwargs) + response = await session.call_tool( + original_tool_name, arguments=call_kwargs, + ) result = [] for content in response.content: @@ -217,18 +237,15 @@ async def _create_mcp_tool_from_definition_http( result.append(str(content)) result_str = "\n".join(result) if result else "" - logger.info( - f"MCP HTTP tool '{tool_name}' succeeded: {result_str[:200]}" - ) + logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str)) return result_str except Exception as e: - error_msg = f"MCP HTTP tool '{tool_name}' execution failed: {e!s}" - logger.exception(error_msg) - return f"Error: {error_msg}" + logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e) + return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}" tool = StructuredTool( - name=tool_name, + name=exposed_name, description=tool_description, coroutine=mcp_http_tool_call, args_schema=input_model, @@ -236,12 +253,14 @@ async def _create_mcp_tool_from_definition_http( "mcp_input_schema": input_schema, "mcp_transport": "http", "mcp_url": url, - "hitl": True, + "hitl": not is_readonly, "hitl_dedup_key": next(iter(input_schema.get("required", [])), None), + "mcp_original_tool_name": original_tool_name, + "mcp_connector_id": connector_id, }, ) - logger.info(f"Created MCP tool (HTTP): '{tool_name}'") + logger.debug("Created MCP tool (HTTP): '%s'", exposed_name) return tool @@ -257,21 +276,24 @@ async def _load_stdio_mcp_tools( command = server_config.get("command") if not command or not isinstance(command, str): logger.warning( - f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid command field, skipping" + "MCP connector %d (name: '%s') missing or invalid command field, skipping", + connector_id, connector_name, ) return tools args = server_config.get("args", []) if not isinstance(args, list): logger.warning( - f"MCP connector {connector_id} (name: '{connector_name}') has invalid args field (must be list), skipping" + "MCP connector %d (name: '%s') has invalid args field (must be list), skipping", + connector_id, connector_name, ) return tools env = server_config.get("env", {}) if not isinstance(env, dict): logger.warning( - f"MCP connector {connector_id} (name: '{connector_name}') has invalid env field (must be dict), skipping" + "MCP connector %d (name: '%s') has invalid env field (must be dict), skipping", + connector_id, connector_name, ) return tools @@ -281,8 +303,8 @@ async def _load_stdio_mcp_tools( tool_definitions = await mcp_client.list_tools() logger.info( - f"Discovered {len(tool_definitions)} tools from stdio MCP server " - f"'{command}' (connector {connector_id})" + "Discovered %d tools from stdio MCP server '%s' (connector %d)", + len(tool_definitions), command, connector_id, ) for tool_def in tool_definitions: @@ -297,8 +319,8 @@ async def _load_stdio_mcp_tools( tools.append(tool) except Exception as e: logger.exception( - f"Failed to create tool '{tool_def.get('name')}' " - f"from connector {connector_id}: {e!s}" + "Failed to create tool '%s' from connector %d: %s", + tool_def.get("name"), connector_id, e, ) return tools @@ -309,24 +331,40 @@ async def _load_http_mcp_tools( connector_name: str, server_config: dict[str, Any], trusted_tools: list[str] | None = None, + allowed_tools: list[str] | None = None, + readonly_tools: frozenset[str] | None = None, + tool_name_prefix: str | None = None, ) -> list[StructuredTool]: - """Load tools from an HTTP-based MCP server.""" + """Load tools from an HTTP-based MCP server. + + Args: + allowed_tools: If non-empty, only tools whose names appear in this + list are loaded. Empty/None means load everything (used for + user-managed generic MCP servers). + readonly_tools: Tool names that skip HITL approval (read-only operations). + tool_name_prefix: If set, each tool name is prefixed for multi-account + disambiguation (e.g. ``linear_25``). + """ tools: list[StructuredTool] = [] url = server_config.get("url") if not url or not isinstance(url, str): logger.warning( - f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid url field, skipping" + "MCP connector %d (name: '%s') missing or invalid url field, skipping", + connector_id, connector_name, ) return tools headers = server_config.get("headers", {}) if not isinstance(headers, dict): logger.warning( - f"MCP connector {connector_id} (name: '{connector_name}') has invalid headers field (must be dict), skipping" + "MCP connector %d (name: '%s') has invalid headers field (must be dict), skipping", + connector_id, connector_name, ) return tools + allowed_set = set(allowed_tools) if allowed_tools else None + try: async with ( streamablehttp_client(url, headers=headers) as (read, write, _), @@ -347,10 +385,21 @@ async def _load_http_mcp_tools( } ) - logger.info( - f"Discovered {len(tool_definitions)} tools from HTTP MCP server " - f"'{url}' (connector {connector_id})" - ) + total_discovered = len(tool_definitions) + + if allowed_set: + tool_definitions = [ + td for td in tool_definitions if td["name"] in allowed_set + ] + logger.info( + "HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter", + url, connector_id, len(tool_definitions), total_discovered, + ) + else: + logger.info( + "Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all", + total_discovered, url, connector_id, + ) for tool_def in tool_definitions: try: @@ -361,22 +410,183 @@ async def _load_http_mcp_tools( connector_name=connector_name, connector_id=connector_id, trusted_tools=trusted_tools, + readonly_tools=readonly_tools, + tool_name_prefix=tool_name_prefix, ) tools.append(tool) except Exception as e: logger.exception( - f"Failed to create HTTP tool '{tool_def.get('name')}' " - f"from connector {connector_id}: {e!s}" + "Failed to create HTTP tool '%s' from connector %d: %s", + tool_def.get("name"), connector_id, e, ) except Exception as e: logger.exception( - f"Failed to connect to HTTP MCP server at '{url}' (connector {connector_id}): {e!s}" + "Failed to connect to HTTP MCP server at '%s' (connector %d): %s", + url, connector_id, e, ) return tools +_TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry + +_token_enc: TokenEncryption | None = None + + +def _get_token_enc() -> TokenEncryption: + global _token_enc + if _token_enc is None: + from app.config import config as app_config + from app.utils.oauth_security import TokenEncryption + + _token_enc = TokenEncryption(app_config.SECRET_KEY) + return _token_enc + + +def _inject_oauth_headers( + cfg: dict[str, Any], + server_config: dict[str, Any], +) -> dict[str, Any] | None: + """Decrypt the MCP OAuth access token and inject it into server_config headers. + + The DB never stores plaintext tokens in ``server_config.headers``. This + function decrypts ``mcp_oauth.access_token`` at runtime and returns a + *copy* of ``server_config`` with the Authorization header set. + """ + mcp_oauth = cfg.get("mcp_oauth", {}) + encrypted_token = mcp_oauth.get("access_token") + if not encrypted_token: + return server_config + + try: + access_token = _get_token_enc().decrypt_token(encrypted_token) + + result = dict(server_config) + result["headers"] = { + **server_config.get("headers", {}), + "Authorization": f"Bearer {access_token}", + } + return result + except Exception: + logger.error( + "Failed to decrypt MCP OAuth token — connector will be skipped", + exc_info=True, + ) + return None + + +async def _maybe_refresh_mcp_oauth_token( + session: AsyncSession, + connector: "SearchSourceConnector", + cfg: dict[str, Any], + server_config: dict[str, Any], +) -> dict[str, Any]: + """Refresh the access token for an MCP OAuth connector if it is about to expire. + + Returns the (possibly updated) ``server_config``. + """ + from datetime import UTC, datetime, timedelta + + mcp_oauth = cfg.get("mcp_oauth", {}) + expires_at_str = mcp_oauth.get("expires_at") + if not expires_at_str: + return server_config + + try: + expires_at = datetime.fromisoformat(expires_at_str) + if expires_at.tzinfo is None: + from datetime import timezone + expires_at = expires_at.replace(tzinfo=timezone.utc) + + if datetime.now(UTC) < expires_at - timedelta(seconds=_TOKEN_REFRESH_BUFFER_SECONDS): + return server_config + except (ValueError, TypeError): + return server_config + + refresh_token = mcp_oauth.get("refresh_token") + if not refresh_token: + logger.warning( + "MCP connector %s token expired but no refresh_token available", + connector.id, + ) + return server_config + + try: + from app.services.mcp_oauth.discovery import refresh_access_token + + enc = _get_token_enc() + decrypted_refresh = enc.decrypt_token(refresh_token) + decrypted_secret = ( + enc.decrypt_token(mcp_oauth["client_secret"]) + if mcp_oauth.get("client_secret") + else "" + ) + + token_json = await refresh_access_token( + token_endpoint=mcp_oauth["token_endpoint"], + refresh_token=decrypted_refresh, + client_id=mcp_oauth["client_id"], + client_secret=decrypted_secret, + ) + + new_access = token_json.get("access_token") + if not new_access: + logger.warning( + "MCP connector %s token refresh returned no access_token", + connector.id, + ) + return server_config + + new_expires_at = None + if token_json.get("expires_in"): + new_expires_at = datetime.now(UTC) + timedelta( + seconds=int(token_json["expires_in"]) + ) + + updated_oauth = dict(mcp_oauth) + updated_oauth["access_token"] = enc.encrypt_token(new_access) + if token_json.get("refresh_token"): + updated_oauth["refresh_token"] = enc.encrypt_token( + token_json["refresh_token"] + ) + updated_oauth["expires_at"] = ( + new_expires_at.isoformat() if new_expires_at else None + ) + + from sqlalchemy.orm.attributes import flag_modified + + connector.config = { + **cfg, + "server_config": server_config, + "mcp_oauth": updated_oauth, + } + flag_modified(connector, "config") + await session.commit() + await session.refresh(connector) + + logger.info("Refreshed MCP OAuth token for connector %s", connector.id) + + # Invalidate cache so next call picks up the new token. + invalidate_mcp_tools_cache(connector.search_space_id) + + # Return server_config with the fresh token injected for immediate use. + refreshed_config = dict(server_config) + refreshed_config["headers"] = { + **server_config.get("headers", {}), + "Authorization": f"Bearer {new_access}", + } + return refreshed_config + + except Exception: + logger.warning( + "Failed to refresh MCP OAuth token for connector %s", + connector.id, + exc_info=True, + ) + return server_config + + def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None: """Invalidate cached MCP tools. @@ -418,27 +628,91 @@ async def load_mcp_tools( return list(cached_tools) try: + # Find all connectors with MCP server config: generic MCP_CONNECTOR type + # and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth. + # Cast JSON -> JSONB so we can use has_key to filter by the presence of "server_config". result = await session.execute( select(SearchSourceConnector).filter( - SearchSourceConnector.connector_type - == SearchSourceConnectorType.MCP_CONNECTOR, SearchSourceConnector.search_space_id == search_space_id, + cast(SearchSourceConnector.config, JSONB).has_key("server_config"), # noqa: W601 ), ) + connectors = list(result.scalars()) + + # Group connectors by type to detect multi-account scenarios. + # When >1 connector shares the same type, tool names would collide + # so we prefix them with "{service_key}_{connector_id}_". + type_groups: dict[str, list[SearchSourceConnector]] = defaultdict(list) + for connector in connectors: + ct = ( + connector.connector_type.value + if hasattr(connector.connector_type, "value") + else str(connector.connector_type) + ) + type_groups[ct].append(connector) + + multi_account_types: set[str] = { + ct for ct, group in type_groups.items() if len(group) > 1 + } + if multi_account_types: + logger.info( + "Multi-account detected for connector types: %s", + multi_account_types, + ) + tools: list[StructuredTool] = [] - for connector in result.scalars(): + for connector in connectors: try: - config = connector.config or {} - server_config = config.get("server_config", {}) - trusted_tools = config.get("trusted_tools", []) + cfg = connector.config or {} + server_config = cfg.get("server_config", {}) if not server_config or not isinstance(server_config, dict): logger.warning( - f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping" + "MCP connector %d (name: '%s') has invalid or missing server_config, skipping", + connector.id, connector.name, ) continue + # For MCP OAuth connectors: refresh if needed, then decrypt the + # access token and inject it into headers at runtime. The DB + # intentionally does NOT store plaintext tokens in server_config. + if cfg.get("mcp_oauth"): + server_config = await _maybe_refresh_mcp_oauth_token( + session, connector, cfg, server_config, + ) + # Re-read cfg after potential refresh (connector was reloaded from DB). + cfg = connector.config or {} + server_config = _inject_oauth_headers(cfg, server_config) + if server_config is None: + logger.warning( + "Skipping MCP connector %d — OAuth token decryption failed", + connector.id, + ) + continue + + trusted_tools = cfg.get("trusted_tools", []) + + ct = ( + connector.connector_type.value + if hasattr(connector.connector_type, "value") + else str(connector.connector_type) + ) + + svc_cfg = get_service_by_connector_type(ct) + allowed_tools = svc_cfg.allowed_tools if svc_cfg else [] + readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset() + + # Build a prefix only when multiple accounts share the same type. + tool_name_prefix: str | None = None + if ct in multi_account_types and svc_cfg: + service_key = next( + (k for k, v in MCP_SERVICES.items() if v is svc_cfg), + None, + ) + if service_key: + tool_name_prefix = f"{service_key}_{connector.id}" + transport = server_config.get("transport", "stdio") if transport in ("streamable-http", "http", "sse"): @@ -447,6 +721,9 @@ async def load_mcp_tools( connector.name, server_config, trusted_tools=trusted_tools, + allowed_tools=allowed_tools, + readonly_tools=readonly_tools, + tool_name_prefix=tool_name_prefix, ) else: connector_tools = await _load_stdio_mcp_tools( @@ -460,7 +737,8 @@ async def load_mcp_tools( except Exception as e: logger.exception( - f"Failed to load tools from MCP connector {connector.id}: {e!s}" + "Failed to load tools from MCP connector %d: %s", + connector.id, e, ) _mcp_tools_cache[search_space_id] = (now, tools) @@ -469,9 +747,9 @@ async def load_mcp_tools( oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0]) del _mcp_tools_cache[oldest_key] - logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}") + logger.info("Loaded %d MCP tools for search space %d", len(tools), search_space_id) return tools except Exception as e: - logger.exception(f"Failed to load MCP tools: {e!s}") + logger.exception("Failed to load MCP tools: %s", e) return [] diff --git a/surfsense_backend/app/agents/new_chat/tools/registry.py b/surfsense_backend/app/agents/new_chat/tools/registry.py index 265aabbbf..85c89b114 100644 --- a/surfsense_backend/app/agents/new_chat/tools/registry.py +++ b/surfsense_backend/app/agents/new_chat/tools/registry.py @@ -50,6 +50,11 @@ from .confluence import ( create_delete_confluence_page_tool, create_update_confluence_page_tool, ) +from .discord import ( + create_list_discord_channels_tool, + create_read_discord_messages_tool, + create_send_discord_message_tool, +) from .dropbox import ( create_create_dropbox_file_tool, create_delete_dropbox_file_tool, @@ -57,6 +62,8 @@ from .dropbox import ( from .generate_image import create_generate_image_tool from .gmail import ( create_create_gmail_draft_tool, + create_read_gmail_email_tool, + create_search_gmail_tool, create_send_gmail_email_tool, create_trash_gmail_email_tool, create_update_gmail_draft_tool, @@ -64,21 +71,18 @@ from .gmail import ( from .google_calendar import ( create_create_calendar_event_tool, create_delete_calendar_event_tool, + create_search_calendar_events_tool, create_update_calendar_event_tool, ) from .google_drive import ( create_create_google_drive_file_tool, create_delete_google_drive_file_tool, ) -from .jira import ( - create_create_jira_issue_tool, - create_delete_jira_issue_tool, - create_update_jira_issue_tool, -) -from .linear import ( - create_create_linear_issue_tool, - create_delete_linear_issue_tool, - create_update_linear_issue_tool, +from .connected_accounts import create_get_connected_accounts_tool +from .luma import ( + create_create_luma_event_tool, + create_list_luma_events_tool, + create_read_luma_event_tool, ) from .mcp_tool import load_mcp_tools from .notion import ( @@ -95,6 +99,11 @@ from .report import create_generate_report_tool from .resume import create_generate_resume_tool from .scrape_webpage import create_scrape_webpage_tool from .search_surfsense_docs import create_search_surfsense_docs_tool +from .teams import ( + create_list_teams_channels_tool, + create_read_teams_messages_tool, + create_send_teams_message_tool, +) from .update_memory import create_update_memory_tool, create_update_team_memory_tool from .video_presentation import create_generate_video_presentation_tool from .web_search import create_web_search_tool @@ -114,6 +123,8 @@ class ToolDefinition: factory: Callable that creates the tool. Receives a dict of dependencies. requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session") enabled_by_default: Whether the tool is enabled when no explicit config is provided + required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``) + that must be in ``available_connectors`` for the tool to be enabled. """ @@ -123,6 +134,7 @@ class ToolDefinition: requires: list[str] = field(default_factory=list) enabled_by_default: bool = True hidden: bool = False + required_connector: str | None = None # ============================================================================= @@ -221,6 +233,21 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ requires=["db_session"], ), # ========================================================================= + # SERVICE ACCOUNT DISCOVERY + # Generic tool for the LLM to discover connected accounts and resolve + # service-specific identifiers (e.g. Jira cloudId, Slack team, etc.) + # ========================================================================= + ToolDefinition( + name="get_connected_accounts", + description="Discover connected accounts for a service and their metadata", + factory=lambda deps: create_get_connected_accounts_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + ), + # ========================================================================= # MEMORY TOOL - single update_memory, private or team by thread_visibility # ========================================================================= ToolDefinition( @@ -248,40 +275,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ ], ), # ========================================================================= - # LINEAR TOOLS - create, update, delete issues - # Auto-disabled when no Linear connector is configured (see chat_deepagent.py) - # ========================================================================= - ToolDefinition( - name="create_linear_issue", - description="Create a new issue in the user's Linear workspace", - factory=lambda deps: create_create_linear_issue_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - ), - ToolDefinition( - name="update_linear_issue", - description="Update an existing indexed Linear issue", - factory=lambda deps: create_update_linear_issue_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - ), - ToolDefinition( - name="delete_linear_issue", - description="Archive (delete) an existing indexed Linear issue", - factory=lambda deps: create_delete_linear_issue_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - ), - # ========================================================================= # NOTION TOOLS - create, update, delete pages # Auto-disabled when no Notion connector is configured (see chat_deepagent.py) # ========================================================================= @@ -294,6 +287,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="NOTION_CONNECTOR", ), ToolDefinition( name="update_notion_page", @@ -304,6 +298,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="NOTION_CONNECTOR", ), ToolDefinition( name="delete_notion_page", @@ -314,6 +309,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="NOTION_CONNECTOR", ), # ========================================================================= # GOOGLE DRIVE TOOLS - create files, delete files @@ -328,6 +324,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_DRIVE_FILE", ), ToolDefinition( name="delete_google_drive_file", @@ -338,6 +335,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_DRIVE_FILE", ), # ========================================================================= # DROPBOX TOOLS - create and trash files @@ -352,6 +350,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="DROPBOX_FILE", ), ToolDefinition( name="delete_dropbox_file", @@ -362,6 +361,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="DROPBOX_FILE", ), # ========================================================================= # ONEDRIVE TOOLS - create and trash files @@ -376,6 +376,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="ONEDRIVE_FILE", ), ToolDefinition( name="delete_onedrive_file", @@ -386,11 +387,23 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="ONEDRIVE_FILE", ), # ========================================================================= - # GOOGLE CALENDAR TOOLS - create, update, delete events + # GOOGLE CALENDAR TOOLS - search, create, update, delete events # Auto-disabled when no Google Calendar connector is configured # ========================================================================= + ToolDefinition( + name="search_calendar_events", + description="Search Google Calendar events within a date range", + factory=lambda deps: create_search_calendar_events_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_CALENDAR_CONNECTOR", + ), ToolDefinition( name="create_calendar_event", description="Create a new event on Google Calendar", @@ -400,6 +413,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_CALENDAR_CONNECTOR", ), ToolDefinition( name="update_calendar_event", @@ -410,6 +424,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_CALENDAR_CONNECTOR", ), ToolDefinition( name="delete_calendar_event", @@ -420,11 +435,34 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_CALENDAR_CONNECTOR", ), # ========================================================================= - # GMAIL TOOLS - create drafts, update drafts, send emails, trash emails + # GMAIL TOOLS - search, read, create drafts, update drafts, send, trash # Auto-disabled when no Gmail connector is configured # ========================================================================= + ToolDefinition( + name="search_gmail", + description="Search emails in Gmail using Gmail search syntax", + factory=lambda deps: create_search_gmail_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_GMAIL_CONNECTOR", + ), + ToolDefinition( + name="read_gmail_email", + description="Read the full content of a specific Gmail email", + factory=lambda deps: create_read_gmail_email_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_GMAIL_CONNECTOR", + ), ToolDefinition( name="create_gmail_draft", description="Create a draft email in Gmail", @@ -434,6 +472,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_GMAIL_CONNECTOR", ), ToolDefinition( name="send_gmail_email", @@ -444,6 +483,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_GMAIL_CONNECTOR", ), ToolDefinition( name="trash_gmail_email", @@ -454,6 +494,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_GMAIL_CONNECTOR", ), ToolDefinition( name="update_gmail_draft", @@ -464,40 +505,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], - ), - # ========================================================================= - # JIRA TOOLS - create, update, delete issues - # Auto-disabled when no Jira connector is configured (see chat_deepagent.py) - # ========================================================================= - ToolDefinition( - name="create_jira_issue", - description="Create a new issue in the user's Jira project", - factory=lambda deps: create_create_jira_issue_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - ), - ToolDefinition( - name="update_jira_issue", - description="Update an existing indexed Jira issue", - factory=lambda deps: create_update_jira_issue_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], - ), - ToolDefinition( - name="delete_jira_issue", - description="Delete an existing indexed Jira issue", - factory=lambda deps: create_delete_jira_issue_tool( - db_session=deps["db_session"], - search_space_id=deps["search_space_id"], - user_id=deps["user_id"], - ), - requires=["db_session", "search_space_id", "user_id"], + required_connector="GOOGLE_GMAIL_CONNECTOR", ), # ========================================================================= # CONFLUENCE TOOLS - create, update, delete pages @@ -512,6 +520,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="CONFLUENCE_CONNECTOR", ), ToolDefinition( name="update_confluence_page", @@ -522,6 +531,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="CONFLUENCE_CONNECTOR", ), ToolDefinition( name="delete_confluence_page", @@ -532,6 +542,118 @@ BUILTIN_TOOLS: list[ToolDefinition] = [ user_id=deps["user_id"], ), requires=["db_session", "search_space_id", "user_id"], + required_connector="CONFLUENCE_CONNECTOR", + ), + # ========================================================================= + # DISCORD TOOLS - list channels, read messages, send messages + # Auto-disabled when no Discord connector is configured + # ========================================================================= + ToolDefinition( + name="list_discord_channels", + description="List text channels in the connected Discord server", + factory=lambda deps: create_list_discord_channels_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="DISCORD_CONNECTOR", + ), + ToolDefinition( + name="read_discord_messages", + description="Read recent messages from a Discord text channel", + factory=lambda deps: create_read_discord_messages_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="DISCORD_CONNECTOR", + ), + ToolDefinition( + name="send_discord_message", + description="Send a message to a Discord text channel", + factory=lambda deps: create_send_discord_message_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="DISCORD_CONNECTOR", + ), + # ========================================================================= + # TEAMS TOOLS - list channels, read messages, send messages + # Auto-disabled when no Teams connector is configured + # ========================================================================= + ToolDefinition( + name="list_teams_channels", + description="List Microsoft Teams and their channels", + factory=lambda deps: create_list_teams_channels_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="TEAMS_CONNECTOR", + ), + ToolDefinition( + name="read_teams_messages", + description="Read recent messages from a Microsoft Teams channel", + factory=lambda deps: create_read_teams_messages_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="TEAMS_CONNECTOR", + ), + ToolDefinition( + name="send_teams_message", + description="Send a message to a Microsoft Teams channel", + factory=lambda deps: create_send_teams_message_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="TEAMS_CONNECTOR", + ), + # ========================================================================= + # LUMA TOOLS - list events, read event details, create events + # Auto-disabled when no Luma connector is configured + # ========================================================================= + ToolDefinition( + name="list_luma_events", + description="List upcoming and recent Luma events", + factory=lambda deps: create_list_luma_events_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="LUMA_CONNECTOR", + ), + ToolDefinition( + name="read_luma_event", + description="Read detailed information about a specific Luma event", + factory=lambda deps: create_read_luma_event_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="LUMA_CONNECTOR", + ), + ToolDefinition( + name="create_luma_event", + description="Create a new event on Luma", + factory=lambda deps: create_create_luma_event_tool( + db_session=deps["db_session"], + search_space_id=deps["search_space_id"], + user_id=deps["user_id"], + ), + requires=["db_session", "search_space_id", "user_id"], + required_connector="LUMA_CONNECTOR", ), ] @@ -549,6 +671,22 @@ def get_tool_by_name(name: str) -> ToolDefinition | None: return None +def get_connector_gated_tools( + available_connectors: list[str] | None, +) -> list[str]: + """Return tool names to disable""" + if available_connectors is None: + available = set() + else: + available = set(available_connectors) + + disabled: list[str] = [] + for tool_def in BUILTIN_TOOLS: + if tool_def.required_connector and tool_def.required_connector not in available: + disabled.append(tool_def.name) + return disabled + + def get_all_tool_names() -> list[str]: """Get names of all registered tools.""" return [tool_def.name for tool_def in BUILTIN_TOOLS] @@ -690,15 +828,15 @@ async def build_tools_async( ) tools.extend(mcp_tools) logging.info( - f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}", + "Registered %d MCP tools: %s", + len(mcp_tools), [t.name for t in mcp_tools], ) except Exception as e: - # Log error but don't fail - just continue without MCP tools - logging.exception(f"Failed to load MCP tools: {e!s}") + logging.exception("Failed to load MCP tools: %s", e) - # Log all tools being returned to agent logging.info( - f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}", + "Total tools for agent: %d — %s", + len(tools), [t.name for t in tools], ) return tools diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/__init__.py b/surfsense_backend/app/agents/new_chat/tools/teams/__init__.py new file mode 100644 index 000000000..60e2add49 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/teams/__init__.py @@ -0,0 +1,15 @@ +from app.agents.new_chat.tools.teams.list_channels import ( + create_list_teams_channels_tool, +) +from app.agents.new_chat.tools.teams.read_messages import ( + create_read_teams_messages_tool, +) +from app.agents.new_chat.tools.teams.send_message import ( + create_send_teams_message_tool, +) + +__all__ = [ + "create_list_teams_channels_tool", + "create_read_teams_messages_tool", + "create_send_teams_message_tool", +] diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py b/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py new file mode 100644 index 000000000..f24f5502e --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/teams/_auth.py @@ -0,0 +1,37 @@ +"""Shared auth helper for Teams agent tools (Microsoft Graph REST API).""" + +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.db import SearchSourceConnector, SearchSourceConnectorType + +GRAPH_API = "https://graph.microsoft.com/v1.0" + + +async def get_teams_connector( + db_session: AsyncSession, + search_space_id: int, + user_id: str, +) -> SearchSourceConnector | None: + result = await db_session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == search_space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR, + ) + ) + return result.scalars().first() + + +async def get_access_token( + db_session: AsyncSession, + connector: SearchSourceConnector, +) -> str: + """Get a valid Microsoft Graph access token, refreshing if expired.""" + from app.connectors.teams_connector import TeamsConnector + + tc = TeamsConnector( + session=db_session, + connector_id=connector.id, + ) + return await tc._get_valid_token() diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py new file mode 100644 index 000000000..a676595c1 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/teams/list_channels.py @@ -0,0 +1,77 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import GRAPH_API, get_access_token, get_teams_connector + +logger = logging.getLogger(__name__) + + +def create_list_teams_channels_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def list_teams_channels() -> dict[str, Any]: + """List all Microsoft Teams and their channels the user has access to. + + Returns: + Dictionary with status and a list of teams, each containing + team_id, team_name, and a list of channels (id, name). + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Teams tool not properly configured."} + + try: + connector = await get_teams_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Teams connector found."} + + token = await get_access_token(db_session, connector) + headers = {"Authorization": f"Bearer {token}"} + + async with httpx.AsyncClient(timeout=20.0) as client: + teams_resp = await client.get(f"{GRAPH_API}/me/joinedTeams", headers=headers) + + if teams_resp.status_code == 401: + return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + if teams_resp.status_code != 200: + return {"status": "error", "message": f"Graph API error: {teams_resp.status_code}"} + + teams_data = teams_resp.json().get("value", []) + result_teams = [] + + async with httpx.AsyncClient(timeout=20.0) as client: + for team in teams_data: + team_id = team["id"] + ch_resp = await client.get( + f"{GRAPH_API}/teams/{team_id}/channels", + headers=headers, + ) + channels = [] + if ch_resp.status_code == 200: + channels = [ + {"id": ch["id"], "name": ch.get("displayName", "")} + for ch in ch_resp.json().get("value", []) + ] + result_teams.append({ + "team_id": team_id, + "team_name": team.get("displayName", ""), + "channels": channels, + }) + + return {"status": "success", "teams": result_teams, "total_teams": len(result_teams)} + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error listing Teams channels: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to list Teams channels."} + + return list_teams_channels diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py new file mode 100644 index 000000000..90896cb95 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/teams/read_messages.py @@ -0,0 +1,91 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from ._auth import GRAPH_API, get_access_token, get_teams_connector + +logger = logging.getLogger(__name__) + + +def create_read_teams_messages_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def read_teams_messages( + team_id: str, + channel_id: str, + limit: int = 25, + ) -> dict[str, Any]: + """Read recent messages from a Microsoft Teams channel. + + Args: + team_id: The team ID (from list_teams_channels). + channel_id: The channel ID (from list_teams_channels). + limit: Number of messages to fetch (default 25, max 50). + + Returns: + Dictionary with status and a list of messages including + id, sender, content, timestamp. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Teams tool not properly configured."} + + limit = min(limit, 50) + + try: + connector = await get_teams_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Teams connector found."} + + token = await get_access_token(db_session, connector) + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.get( + f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages", + headers={"Authorization": f"Bearer {token}"}, + params={"$top": limit}, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + if resp.status_code == 403: + return {"status": "error", "message": "Insufficient permissions to read this channel."} + if resp.status_code != 200: + return {"status": "error", "message": f"Graph API error: {resp.status_code}"} + + raw_msgs = resp.json().get("value", []) + messages = [] + for m in raw_msgs: + sender = m.get("from", {}) + user_info = sender.get("user", {}) if sender else {} + body = m.get("body", {}) + messages.append({ + "id": m.get("id"), + "sender": user_info.get("displayName", "Unknown"), + "content": body.get("content", ""), + "content_type": body.get("contentType", "text"), + "timestamp": m.get("createdDateTime", ""), + }) + + return { + "status": "success", + "team_id": team_id, + "channel_id": channel_id, + "messages": messages, + "total": len(messages), + } + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error reading Teams messages: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to read Teams messages."} + + return read_teams_messages diff --git a/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py new file mode 100644 index 000000000..ba3a515d9 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/teams/send_message.py @@ -0,0 +1,101 @@ +import logging +from typing import Any + +import httpx +from langchain_core.tools import tool +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.new_chat.tools.hitl import request_approval + +from ._auth import GRAPH_API, get_access_token, get_teams_connector + +logger = logging.getLogger(__name__) + + +def create_send_teams_message_tool( + db_session: AsyncSession | None = None, + search_space_id: int | None = None, + user_id: str | None = None, +): + @tool + async def send_teams_message( + team_id: str, + channel_id: str, + content: str, + ) -> dict[str, Any]: + """Send a message to a Microsoft Teams channel. + + Requires the ChannelMessage.Send OAuth scope. If the user gets a + permission error, they may need to re-authenticate with updated scopes. + + Args: + team_id: The team ID (from list_teams_channels). + channel_id: The channel ID (from list_teams_channels). + content: The message text (HTML supported). + + Returns: + Dictionary with status, message_id on success. + + IMPORTANT: + - If status is "rejected", the user explicitly declined. Do NOT retry. + """ + if db_session is None or search_space_id is None or user_id is None: + return {"status": "error", "message": "Teams tool not properly configured."} + + try: + connector = await get_teams_connector(db_session, search_space_id, user_id) + if not connector: + return {"status": "error", "message": "No Teams connector found."} + + result = request_approval( + action_type="teams_send_message", + tool_name="send_teams_message", + params={"team_id": team_id, "channel_id": channel_id, "content": content}, + context={"connector_id": connector.id}, + ) + + if result.rejected: + return {"status": "rejected", "message": "User declined. Message was not sent."} + + final_content = result.params.get("content", content) + final_team = result.params.get("team_id", team_id) + final_channel = result.params.get("channel_id", channel_id) + + token = await get_access_token(db_session, connector) + + async with httpx.AsyncClient(timeout=20.0) as client: + resp = await client.post( + f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages", + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + }, + json={"body": {"content": final_content}}, + ) + + if resp.status_code == 401: + return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"} + if resp.status_code == 403: + return { + "status": "insufficient_permissions", + "message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.", + } + if resp.status_code not in (200, 201): + return {"status": "error", "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}"} + + msg_data = resp.json() + return { + "status": "success", + "message_id": msg_data.get("id"), + "message": f"Message sent to Teams channel.", + } + + except Exception as e: + from langgraph.errors import GraphInterrupt + + if isinstance(e, GraphInterrupt): + raise + logger.error("Error sending Teams message: %s", e, exc_info=True) + return {"status": "error", "message": "Failed to send Teams message."} + + return send_teams_message diff --git a/surfsense_backend/app/agents/new_chat/tools/tool_response.py b/surfsense_backend/app/agents/new_chat/tools/tool_response.py new file mode 100644 index 000000000..5fb1864b7 --- /dev/null +++ b/surfsense_backend/app/agents/new_chat/tools/tool_response.py @@ -0,0 +1,41 @@ +"""Standardised response dict factories for LangChain agent tools.""" + +from __future__ import annotations + +from typing import Any + + +class ToolResponse: + + @staticmethod + def success(message: str, **data: Any) -> dict[str, Any]: + return {"status": "success", "message": message, **data} + + @staticmethod + def error(error: str, **data: Any) -> dict[str, Any]: + return {"status": "error", "error": error, **data} + + @staticmethod + def auth_error(service: str, **data: Any) -> dict[str, Any]: + return { + "status": "auth_error", + "error": ( + f"{service} authentication has expired or been revoked. " + "Please re-connect the integration in Settings → Connectors." + ), + **data, + } + + @staticmethod + def rejected(message: str = "Action was declined by the user.") -> dict[str, Any]: + return {"status": "rejected", "message": message} + + @staticmethod + def not_found( + resource: str, identifier: str, **data: Any + ) -> dict[str, Any]: + return { + "status": "not_found", + "error": f"{resource} '{identifier}' was not found.", + **data, + } diff --git a/surfsense_backend/app/celery_app.py b/surfsense_backend/app/celery_app.py index c44391528..e3a520c48 100644 --- a/surfsense_backend/app/celery_app.py +++ b/surfsense_backend/app/celery_app.py @@ -135,20 +135,12 @@ celery_app.conf.update( # never block fast user-facing tasks (file uploads, podcasts, etc.) task_routes={ # Connector indexing tasks → connectors queue - "index_slack_messages": {"queue": CONNECTORS_QUEUE}, "index_notion_pages": {"queue": CONNECTORS_QUEUE}, "index_github_repos": {"queue": CONNECTORS_QUEUE}, - "index_linear_issues": {"queue": CONNECTORS_QUEUE}, - "index_jira_issues": {"queue": CONNECTORS_QUEUE}, "index_confluence_pages": {"queue": CONNECTORS_QUEUE}, - "index_clickup_tasks": {"queue": CONNECTORS_QUEUE}, "index_google_calendar_events": {"queue": CONNECTORS_QUEUE}, - "index_airtable_records": {"queue": CONNECTORS_QUEUE}, "index_google_gmail_messages": {"queue": CONNECTORS_QUEUE}, "index_google_drive_files": {"queue": CONNECTORS_QUEUE}, - "index_discord_messages": {"queue": CONNECTORS_QUEUE}, - "index_teams_messages": {"queue": CONNECTORS_QUEUE}, - "index_luma_events": {"queue": CONNECTORS_QUEUE}, "index_elasticsearch_documents": {"queue": CONNECTORS_QUEUE}, "index_crawled_urls": {"queue": CONNECTORS_QUEUE}, "index_bookstack_pages": {"queue": CONNECTORS_QUEUE}, diff --git a/surfsense_backend/app/connectors/exceptions.py b/surfsense_backend/app/connectors/exceptions.py new file mode 100644 index 000000000..32a1e7bdc --- /dev/null +++ b/surfsense_backend/app/connectors/exceptions.py @@ -0,0 +1,98 @@ +"""Standard exception hierarchy for all connectors. + +ConnectorError +├── ConnectorAuthError (401/403 — non-retryable) +├── ConnectorRateLimitError (429 — retryable, carries ``retry_after``) +├── ConnectorTimeoutError (timeout/504 — retryable) +└── ConnectorAPIError (5xx or unexpected — retryable when >= 500) +""" + +from __future__ import annotations + +from typing import Any + + +class ConnectorError(Exception): + + def __init__( + self, + message: str, + *, + service: str = "", + status_code: int | None = None, + response_body: Any = None, + ) -> None: + super().__init__(message) + self.service = service + self.status_code = status_code + self.response_body = response_body + + @property + def retryable(self) -> bool: + return False + + +class ConnectorAuthError(ConnectorError): + """Token expired, revoked, insufficient scopes, or needs re-auth (401/403).""" + + @property + def retryable(self) -> bool: + return False + + +class ConnectorRateLimitError(ConnectorError): + """429 Too Many Requests.""" + + def __init__( + self, + message: str = "Rate limited", + *, + service: str = "", + retry_after: float | None = None, + status_code: int = 429, + response_body: Any = None, + ) -> None: + super().__init__( + message, + service=service, + status_code=status_code, + response_body=response_body, + ) + self.retry_after = retry_after + + @property + def retryable(self) -> bool: + return True + + +class ConnectorTimeoutError(ConnectorError): + """Request timeout or gateway timeout (504).""" + + def __init__( + self, + message: str = "Request timed out", + *, + service: str = "", + status_code: int | None = None, + response_body: Any = None, + ) -> None: + super().__init__( + message, + service=service, + status_code=status_code, + response_body=response_body, + ) + + @property + def retryable(self) -> bool: + return True + + +class ConnectorAPIError(ConnectorError): + """Generic API error (5xx or unexpected status codes).""" + + @property + def retryable(self) -> bool: + if self.status_code is not None: + return self.status_code >= 500 + return False diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index ad40666cd..40ca7a7e8 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -30,6 +30,7 @@ from .jira_add_connector_route import router as jira_add_connector_router from .linear_add_connector_route import router as linear_add_connector_router from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router +from .mcp_oauth_route import router as mcp_oauth_router from .memory_routes import router as memory_router from .model_list_routes import router as model_list_router from .new_chat_routes import router as new_chat_router @@ -95,6 +96,7 @@ router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks router.include_router(surfsense_docs_router) # Surfsense documentation for citations router.include_router(notifications_router) # Notifications with Zero sync +router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable router.include_router(composio_router) # Composio OAuth and toolkit management router.include_router(public_chat_router) # Public chat sharing and cloning router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages diff --git a/surfsense_backend/app/routes/airtable_add_connector_route.py b/surfsense_backend/app/routes/airtable_add_connector_route.py index 1e0b1eb5d..f70b9166b 100644 --- a/surfsense_backend/app/routes/airtable_add_connector_route.py +++ b/surfsense_backend/app/routes/airtable_add_connector_route.py @@ -311,7 +311,7 @@ async def airtable_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR, - is_indexable=True, + is_indexable=False, config=credentials_dict, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/clickup_add_connector_route.py b/surfsense_backend/app/routes/clickup_add_connector_route.py index 2cd63eca2..f7b0876e5 100644 --- a/surfsense_backend/app/routes/clickup_add_connector_route.py +++ b/surfsense_backend/app/routes/clickup_add_connector_route.py @@ -301,7 +301,7 @@ async def clickup_callback( # Update existing connector existing_connector.config = connector_config existing_connector.name = "ClickUp Connector" - existing_connector.is_indexable = True + existing_connector.is_indexable = False logger.info( f"Updated existing ClickUp connector for user {user_id} in space {space_id}" ) @@ -310,7 +310,7 @@ async def clickup_callback( new_connector = SearchSourceConnector( name="ClickUp Connector", connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/discord_add_connector_route.py b/surfsense_backend/app/routes/discord_add_connector_route.py index 27bfffc90..4ab48f544 100644 --- a/surfsense_backend/app/routes/discord_add_connector_route.py +++ b/surfsense_backend/app/routes/discord_add_connector_route.py @@ -326,7 +326,7 @@ async def discord_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.DISCORD_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/google_calendar_add_connector_route.py b/surfsense_backend/app/routes/google_calendar_add_connector_route.py index d7ccf62ca..a143fd50d 100644 --- a/surfsense_backend/app/routes/google_calendar_add_connector_route.py +++ b/surfsense_backend/app/routes/google_calendar_add_connector_route.py @@ -340,7 +340,7 @@ async def calendar_callback( config=creds_dict, search_space_id=space_id, user_id=user_id, - is_indexable=True, + is_indexable=False, ) session.add(db_connector) await session.commit() diff --git a/surfsense_backend/app/routes/google_gmail_add_connector_route.py b/surfsense_backend/app/routes/google_gmail_add_connector_route.py index dd8feb1c7..9b807a556 100644 --- a/surfsense_backend/app/routes/google_gmail_add_connector_route.py +++ b/surfsense_backend/app/routes/google_gmail_add_connector_route.py @@ -371,7 +371,7 @@ async def gmail_callback( config=creds_dict, search_space_id=space_id, user_id=user_id, - is_indexable=True, + is_indexable=False, ) session.add(db_connector) await session.commit() diff --git a/surfsense_backend/app/routes/jira_add_connector_route.py b/surfsense_backend/app/routes/jira_add_connector_route.py index 6cd6283d7..eeb4f91d9 100644 --- a/surfsense_backend/app/routes/jira_add_connector_route.py +++ b/surfsense_backend/app/routes/jira_add_connector_route.py @@ -386,7 +386,7 @@ async def jira_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.JIRA_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/linear_add_connector_route.py b/surfsense_backend/app/routes/linear_add_connector_route.py index 9345ae495..f59c17d25 100644 --- a/surfsense_backend/app/routes/linear_add_connector_route.py +++ b/surfsense_backend/app/routes/linear_add_connector_route.py @@ -399,7 +399,7 @@ async def linear_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/luma_add_connector_route.py b/surfsense_backend/app/routes/luma_add_connector_route.py index 04d840a08..7040581bc 100644 --- a/surfsense_backend/app/routes/luma_add_connector_route.py +++ b/surfsense_backend/app/routes/luma_add_connector_route.py @@ -61,7 +61,7 @@ async def add_luma_connector( if existing_connector: # Update existing connector with new API key existing_connector.config = {"api_key": request.api_key} - existing_connector.is_indexable = True + existing_connector.is_indexable = False await session.commit() await session.refresh(existing_connector) @@ -82,7 +82,7 @@ async def add_luma_connector( config={"api_key": request.api_key}, search_space_id=request.space_id, user_id=user.id, - is_indexable=True, + is_indexable=False, ) session.add(db_connector) diff --git a/surfsense_backend/app/routes/mcp_oauth_route.py b/surfsense_backend/app/routes/mcp_oauth_route.py new file mode 100644 index 000000000..e14be83d0 --- /dev/null +++ b/surfsense_backend/app/routes/mcp_oauth_route.py @@ -0,0 +1,601 @@ +"""Generic MCP OAuth 2.1 route for services with official MCP servers. + +Handles the full flow: discovery → DCR → PKCE authorization → token exchange +→ MCP_CONNECTOR creation. Currently supports Linear, Jira, ClickUp, Slack, +and Airtable. +""" + +from __future__ import annotations + +import logging +from datetime import UTC, datetime, timedelta +from typing import Any +from urllib.parse import urlencode +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import RedirectResponse +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.attributes import flag_modified + +from app.config import config +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + User, + get_async_session, +) +from app.users import current_active_user +from app.utils.connector_naming import generate_unique_connector_name +from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +async def _fetch_account_metadata( + service_key: str, access_token: str, token_json: dict[str, Any], +) -> dict[str, Any]: + """Fetch display-friendly account metadata after a successful token exchange. + + DCR services (Linear, Jira, ClickUp) issue MCP-scoped tokens that cannot + call their standard REST/GraphQL APIs — metadata discovery for those + happens at runtime through MCP tools instead. + + Pre-configured services (Slack, Airtable) use standard OAuth tokens that + *can* call their APIs, so we extract metadata here. + + Failures are logged but never block connector creation. + """ + from app.services.mcp_oauth.registry import MCP_SERVICES + + svc = MCP_SERVICES.get(service_key) + if not svc or svc.supports_dcr: + return {} + + import httpx + + meta: dict[str, Any] = {} + + try: + if service_key == "slack": + team_info = token_json.get("team", {}) + meta["team_id"] = team_info.get("id", "") + # TODO: oauth.v2.user.access only returns team.id, not + # team.name. To populate team_name, add "team:read" scope + # and call GET /api/team.info here. + meta["team_name"] = team_info.get("name", "") + if meta["team_name"]: + meta["display_name"] = meta["team_name"] + elif meta["team_id"]: + meta["display_name"] = f"Slack ({meta['team_id']})" + + elif service_key == "airtable": + async with httpx.AsyncClient(timeout=15.0) as client: + resp = await client.get( + "https://api.airtable.com/v0/meta/whoami", + headers={"Authorization": f"Bearer {access_token}"}, + ) + if resp.status_code == 200: + whoami = resp.json() + meta["user_id"] = whoami.get("id", "") + meta["user_email"] = whoami.get("email", "") + meta["display_name"] = whoami.get("email", "Airtable") + else: + logger.warning( + "Airtable whoami API returned %d (non-blocking)", resp.status_code, + ) + + except Exception: + logger.warning( + "Failed to fetch account metadata for %s (non-blocking)", + service_key, + exc_info=True, + ) + + return meta + +_state_manager: OAuthStateManager | None = None +_token_encryption: TokenEncryption | None = None + + +def _get_state_manager() -> OAuthStateManager: + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise HTTPException(status_code=500, detail="SECRET_KEY not configured.") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def _get_token_encryption() -> TokenEncryption: + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise HTTPException(status_code=500, detail="SECRET_KEY not configured.") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + + +def _build_redirect_uri(service: str) -> str: + base = config.BACKEND_URL or "http://localhost:8000" + return f"{base.rstrip('/')}/api/v1/auth/mcp/{service}/connector/callback" + + +def _frontend_redirect( + space_id: int | None, + *, + success: bool = False, + connector_id: int | None = None, + error: str | None = None, + service: str = "mcp", +) -> RedirectResponse: + if success and space_id: + qs = f"success=true&connector={service}-mcp-connector" + if connector_id: + qs += f"&connectorId={connector_id}" + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}" + ) + if error and space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}" + ) + return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard") + + +# --------------------------------------------------------------------------- +# /add — start MCP OAuth flow +# --------------------------------------------------------------------------- + +@router.get("/auth/mcp/{service}/connector/add") +async def connect_mcp_service( + service: str, + space_id: int, + user: User = Depends(current_active_user), +): + from app.services.mcp_oauth.registry import get_service + + svc = get_service(service) + if not svc: + raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}") + + try: + from app.services.mcp_oauth.discovery import ( + discover_oauth_metadata, + register_client, + ) + + metadata = await discover_oauth_metadata( + svc.mcp_url, origin_override=svc.oauth_discovery_origin, + ) + auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint") + token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint") + registration_endpoint = metadata.get("registration_endpoint") + + if not auth_endpoint or not token_endpoint: + raise HTTPException( + status_code=502, + detail=f"{svc.name} MCP server returned incomplete OAuth metadata.", + ) + + redirect_uri = _build_redirect_uri(service) + + if svc.supports_dcr and registration_endpoint: + dcr = await register_client(registration_endpoint, redirect_uri) + client_id = dcr.get("client_id") + client_secret = dcr.get("client_secret", "") + if not client_id: + raise HTTPException( + status_code=502, + detail=f"DCR for {svc.name} did not return a client_id.", + ) + elif svc.client_id_env: + client_id = getattr(config, svc.client_id_env, None) + client_secret = getattr(config, svc.client_secret_env or "", None) or "" + if not client_id: + raise HTTPException( + status_code=500, + detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).", + ) + else: + raise HTTPException( + status_code=502, + detail=f"{svc.name} MCP server has no DCR and no fallback credentials.", + ) + + verifier, challenge = generate_pkce_pair() + enc = _get_token_encryption() + + state = _get_state_manager().generate_secure_state( + space_id, + user.id, + service=service, + code_verifier=verifier, + mcp_client_id=client_id, + mcp_client_secret=enc.encrypt_token(client_secret) if client_secret else "", + mcp_token_endpoint=token_endpoint, + mcp_url=svc.mcp_url, + ) + + auth_params: dict[str, str] = { + "client_id": client_id, + "response_type": "code", + "redirect_uri": redirect_uri, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": state, + } + if svc.scopes: + auth_params[svc.scope_param] = " ".join(svc.scopes) + + auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" + + logger.info( + "Generated %s MCP OAuth URL for user %s, space %s", + svc.name, user.id, space_id, + ) + return {"auth_url": auth_url} + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to initiate {service} MCP OAuth.", + ) from e + + +# --------------------------------------------------------------------------- +# /callback — handle OAuth redirect +# --------------------------------------------------------------------------- + +@router.get("/auth/mcp/{service}/connector/callback") +async def mcp_oauth_callback( + service: str, + code: str | None = None, + error: str | None = None, + state: str | None = None, + session: AsyncSession = Depends(get_async_session), +): + if error: + logger.warning("%s MCP OAuth error: %s", service, error) + space_id = None + if state: + try: + data = _get_state_manager().validate_state(state) + space_id = data.get("space_id") + except Exception: + pass + return _frontend_redirect( + space_id, error=f"{service}_mcp_oauth_denied", service=service, + ) + + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + data = _get_state_manager().validate_state(state) + user_id = UUID(data["user_id"]) + space_id = data["space_id"] + svc_key = data.get("service", service) + + if svc_key != service: + raise HTTPException(status_code=400, detail="State/path service mismatch") + + from app.services.mcp_oauth.registry import get_service + + svc = get_service(svc_key) + if not svc: + raise HTTPException(status_code=404, detail=f"Unknown MCP service: {svc_key}") + + try: + from app.services.mcp_oauth.discovery import exchange_code_for_tokens + + enc = _get_token_encryption() + client_id = data["mcp_client_id"] + client_secret = ( + enc.decrypt_token(data["mcp_client_secret"]) + if data.get("mcp_client_secret") + else "" + ) + token_endpoint = data["mcp_token_endpoint"] + code_verifier = data["code_verifier"] + mcp_url = data["mcp_url"] + redirect_uri = _build_redirect_uri(service) + + token_json = await exchange_code_for_tokens( + token_endpoint=token_endpoint, + code=code, + redirect_uri=redirect_uri, + client_id=client_id, + client_secret=client_secret, + code_verifier=code_verifier, + ) + + access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") + expires_in = token_json.get("expires_in") + scope = token_json.get("scope") + + if not access_token and "authed_user" in token_json: + authed = token_json["authed_user"] + access_token = authed.get("access_token") + refresh_token = refresh_token or authed.get("refresh_token") + scope = scope or authed.get("scope") + expires_in = expires_in or authed.get("expires_in") + + if not access_token: + raise HTTPException( + status_code=400, + detail=f"No access token received from {svc.name}.", + ) + + expires_at = None + if expires_in: + expires_at = datetime.now(UTC) + timedelta( + seconds=int(expires_in) + ) + + connector_config = { + "server_config": { + "transport": "streamable-http", + "url": mcp_url, + }, + "mcp_service": svc_key, + "mcp_oauth": { + "client_id": client_id, + "client_secret": enc.encrypt_token(client_secret) if client_secret else "", + "token_endpoint": token_endpoint, + "access_token": enc.encrypt_token(access_token), + "refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None, + "expires_at": expires_at.isoformat() if expires_at else None, + "scope": scope, + }, + "_token_encrypted": True, + } + + account_meta = await _fetch_account_metadata(svc_key, access_token, token_json) + if account_meta: + _SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email", + "workspace_id", "workspace_name", "organization_name", + "organization_url_key", "cloud_id", "site_name", "base_url"} + for k, v in account_meta.items(): + if k in _SAFE_META_KEYS: + connector_config[k] = v + logger.info( + "Stored account metadata for %s: display_name=%s", + svc_key, account_meta.get("display_name", ""), + ) + + # ---- Re-auth path ---- + db_connector_type = SearchSourceConnectorType(svc.connector_type) + reauth_connector_id = data.get("connector_id") + if reauth_connector_id: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == reauth_connector_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.connector_type == db_connector_type, + ) + ) + db_connector = result.scalars().first() + if not db_connector: + raise HTTPException( + status_code=404, + detail="Connector not found during re-auth", + ) + + db_connector.config = connector_config + flag_modified(db_connector, "config") + await session.commit() + await session.refresh(db_connector) + + _invalidate_cache(space_id) + + logger.info( + "Re-authenticated %s MCP connector %s for user %s", + svc.name, db_connector.id, user_id, + ) + reauth_return_url = data.get("return_url") + if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"): + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}" + ) + return _frontend_redirect( + space_id, success=True, connector_id=db_connector.id, service=service, + ) + + # ---- New connector path ---- + naming_identifier = account_meta.get("display_name") + connector_name = await generate_unique_connector_name( + session, + db_connector_type, + space_id, + user_id, + naming_identifier, + ) + + new_connector = SearchSourceConnector( + name=connector_name, + connector_type=db_connector_type, + is_indexable=False, + config=connector_config, + search_space_id=space_id, + user_id=user_id, + ) + session.add(new_connector) + + try: + await session.commit() + except IntegrityError as e: + await session.rollback() + raise HTTPException( + status_code=409, detail="A connector for this service already exists.", + ) from e + + _invalidate_cache(space_id) + + logger.info( + "Created %s MCP connector %s for user %s in space %s", + svc.name, new_connector.id, user_id, space_id, + ) + return _frontend_redirect( + space_id, success=True, connector_id=new_connector.id, service=service, + ) + + except HTTPException: + raise + except Exception as e: + logger.error( + "Failed to complete %s MCP OAuth: %s", service, e, exc_info=True, + ) + raise HTTPException( + status_code=500, + detail=f"Failed to complete {service} MCP OAuth.", + ) from e + + +# --------------------------------------------------------------------------- +# /reauth — re-authenticate an existing MCP connector +# --------------------------------------------------------------------------- + +@router.get("/auth/mcp/{service}/connector/reauth") +async def reauth_mcp_service( + service: str, + space_id: int, + connector_id: int, + return_url: str | None = None, + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), +): + from app.services.mcp_oauth.registry import get_service + + svc = get_service(service) + if not svc: + raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}") + + db_connector_type = SearchSourceConnectorType(svc.connector_type) + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.user_id == user.id, + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.connector_type == db_connector_type, + ) + ) + if not result.scalars().first(): + raise HTTPException( + status_code=404, detail="Connector not found or access denied", + ) + + try: + from app.services.mcp_oauth.discovery import ( + discover_oauth_metadata, + register_client, + ) + + metadata = await discover_oauth_metadata( + svc.mcp_url, origin_override=svc.oauth_discovery_origin, + ) + auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint") + token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint") + registration_endpoint = metadata.get("registration_endpoint") + + if not auth_endpoint or not token_endpoint: + raise HTTPException( + status_code=502, + detail=f"{svc.name} MCP server returned incomplete OAuth metadata.", + ) + + redirect_uri = _build_redirect_uri(service) + + if svc.supports_dcr and registration_endpoint: + dcr = await register_client(registration_endpoint, redirect_uri) + client_id = dcr.get("client_id") + client_secret = dcr.get("client_secret", "") + if not client_id: + raise HTTPException( + status_code=502, + detail=f"DCR for {svc.name} did not return a client_id.", + ) + elif svc.client_id_env: + client_id = getattr(config, svc.client_id_env, None) + client_secret = getattr(config, svc.client_secret_env or "", None) or "" + if not client_id: + raise HTTPException( + status_code=500, + detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).", + ) + else: + raise HTTPException( + status_code=502, + detail=f"{svc.name} MCP server has no DCR and no fallback credentials.", + ) + + verifier, challenge = generate_pkce_pair() + enc = _get_token_encryption() + + extra: dict = { + "service": service, + "code_verifier": verifier, + "mcp_client_id": client_id, + "mcp_client_secret": enc.encrypt_token(client_secret) if client_secret else "", + "mcp_token_endpoint": token_endpoint, + "mcp_url": svc.mcp_url, + "connector_id": connector_id, + } + if return_url and return_url.startswith("/"): + extra["return_url"] = return_url + + state = _get_state_manager().generate_secure_state( + space_id, user.id, **extra, + ) + + auth_params: dict[str, str] = { + "client_id": client_id, + "response_type": "code", + "redirect_uri": redirect_uri, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": state, + } + if svc.scopes: + auth_params[svc.scope_param] = " ".join(svc.scopes) + + auth_url = f"{auth_endpoint}?{urlencode(auth_params)}" + + logger.info( + "Initiating %s MCP re-auth for user %s, connector %s", + svc.name, user.id, connector_id, + ) + return {"auth_url": auth_url} + + except HTTPException: + raise + except Exception as e: + logger.error( + "Failed to initiate %s MCP re-auth: %s", service, e, exc_info=True, + ) + raise HTTPException( + status_code=500, + detail=f"Failed to initiate {service} MCP re-auth.", + ) from e + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _invalidate_cache(space_id: int) -> None: + try: + from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache + + invalidate_mcp_tools_cache(space_id) + except Exception: + logger.debug("MCP cache invalidation skipped", exc_info=True) diff --git a/surfsense_backend/app/routes/oauth_connector_base.py b/surfsense_backend/app/routes/oauth_connector_base.py new file mode 100644 index 000000000..0638e8f34 --- /dev/null +++ b/surfsense_backend/app/routes/oauth_connector_base.py @@ -0,0 +1,620 @@ +"""Reusable base for OAuth 2.0 connector routes. + +Subclasses override ``fetch_account_info``, ``build_connector_config``, +and ``get_connector_display_name`` to customise provider-specific behaviour. +Call ``build_router()`` to get a FastAPI ``APIRouter`` with ``/connector/add``, +``/connector/callback``, and ``/connector/reauth`` endpoints. +""" + +from __future__ import annotations + +import base64 +import logging +from datetime import UTC, datetime, timedelta +from typing import Any +from urllib.parse import urlencode +from uuid import UUID + +import httpx +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import RedirectResponse +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.attributes import flag_modified + +from app.config import config +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + User, + get_async_session, +) +from app.users import current_active_user +from app.utils.connector_naming import ( + check_duplicate_connector, + generate_unique_connector_name, +) +from app.utils.oauth_security import OAuthStateManager, TokenEncryption + +logger = logging.getLogger(__name__) + + +class OAuthConnectorRoute: + + def __init__( + self, + *, + provider_name: str, + connector_type: SearchSourceConnectorType, + authorize_url: str, + token_url: str, + client_id_env: str, + client_secret_env: str, + redirect_uri_env: str, + scopes: list[str], + auth_prefix: str, + use_pkce: bool = False, + token_auth_method: str = "body", + is_indexable: bool = True, + extra_auth_params: dict[str, str] | None = None, + ) -> None: + self.provider_name = provider_name + self.connector_type = connector_type + self.authorize_url = authorize_url + self.token_url = token_url + self.client_id_env = client_id_env + self.client_secret_env = client_secret_env + self.redirect_uri_env = redirect_uri_env + self.scopes = scopes + self.auth_prefix = auth_prefix.rstrip("/") + self.use_pkce = use_pkce + self.token_auth_method = token_auth_method + self.is_indexable = is_indexable + self.extra_auth_params = extra_auth_params or {} + + self._state_manager: OAuthStateManager | None = None + self._token_encryption: TokenEncryption | None = None + + def _get_client_id(self) -> str: + value = getattr(config, self.client_id_env, None) + if not value: + raise HTTPException( + status_code=500, + detail=f"{self.provider_name.title()} OAuth not configured " + f"({self.client_id_env} missing).", + ) + return value + + def _get_client_secret(self) -> str: + value = getattr(config, self.client_secret_env, None) + if not value: + raise HTTPException( + status_code=500, + detail=f"{self.provider_name.title()} OAuth not configured " + f"({self.client_secret_env} missing).", + ) + return value + + def _get_redirect_uri(self) -> str: + value = getattr(config, self.redirect_uri_env, None) + if not value: + raise HTTPException( + status_code=500, + detail=f"{self.redirect_uri_env} not configured.", + ) + return value + + def _get_state_manager(self) -> OAuthStateManager: + if self._state_manager is None: + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, + detail="SECRET_KEY not configured for OAuth security.", + ) + self._state_manager = OAuthStateManager(config.SECRET_KEY) + return self._state_manager + + def _get_token_encryption(self) -> TokenEncryption: + if self._token_encryption is None: + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, + detail="SECRET_KEY not configured for token encryption.", + ) + self._token_encryption = TokenEncryption(config.SECRET_KEY) + return self._token_encryption + + def _frontend_redirect( + self, + space_id: int | None, + *, + success: bool = False, + connector_id: int | None = None, + error: str | None = None, + ) -> RedirectResponse: + if success and space_id: + connector_slug = f"{self.provider_name}-connector" + qs = f"success=true&connector={connector_slug}" + if connector_id: + qs += f"&connectorId={connector_id}" + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}" + ) + if error and space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}" + ) + if error: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error={error}" + ) + return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard") + + async def fetch_account_info(self, access_token: str) -> dict[str, Any]: + """Override to fetch account/workspace info after token exchange. + + Return dict is merged into connector config; key ``"name"`` is used + for the display name and dedup. + """ + return {} + + def build_connector_config( + self, + token_json: dict[str, Any], + account_info: dict[str, Any], + encryption: TokenEncryption, + ) -> dict[str, Any]: + """Override for custom config shapes. Default: standard encrypted OAuth fields.""" + access_token = token_json.get("access_token", "") + refresh_token = token_json.get("refresh_token") + + expires_at = None + if token_json.get("expires_in"): + expires_at = datetime.now(UTC) + timedelta( + seconds=int(token_json["expires_in"]) + ) + + cfg: dict[str, Any] = { + "access_token": encryption.encrypt_token(access_token), + "refresh_token": ( + encryption.encrypt_token(refresh_token) if refresh_token else None + ), + "token_type": token_json.get("token_type", "Bearer"), + "expires_in": token_json.get("expires_in"), + "expires_at": expires_at.isoformat() if expires_at else None, + "scope": token_json.get("scope"), + "_token_encrypted": True, + } + cfg.update(account_info) + return cfg + + def get_connector_display_name(self, account_info: dict[str, Any]) -> str: + return str(account_info.get("name", self.provider_name.title())) + + async def on_token_refresh_failure( + self, + session: AsyncSession, + connector: SearchSourceConnector, + ) -> None: + try: + connector.config = {**connector.config, "auth_expired": True} + flag_modified(connector, "config") + await session.commit() + await session.refresh(connector) + except Exception: + logger.warning( + "Failed to persist auth_expired flag for connector %s", + connector.id, + exc_info=True, + ) + + async def _exchange_code( + self, code: str, extra_state: dict[str, Any] + ) -> dict[str, Any]: + client_id = self._get_client_id() + client_secret = self._get_client_secret() + redirect_uri = self._get_redirect_uri() + + headers: dict[str, str] = { + "Content-Type": "application/x-www-form-urlencoded", + } + body: dict[str, str] = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + } + + if self.token_auth_method == "basic": + creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() + headers["Authorization"] = f"Basic {creds}" + else: + body["client_id"] = client_id + body["client_secret"] = client_secret + + if self.use_pkce: + verifier = extra_state.get("code_verifier") + if verifier: + body["code_verifier"] = verifier + + async with httpx.AsyncClient() as client: + resp = await client.post( + self.token_url, data=body, headers=headers, timeout=30.0 + ) + + if resp.status_code != 200: + detail = resp.text + try: + detail = resp.json().get("error_description", detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token exchange failed: {detail}" + ) + + return resp.json() + + async def refresh_token( + self, session: AsyncSession, connector: SearchSourceConnector + ) -> SearchSourceConnector: + encryption = self._get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_tok = connector.config.get("refresh_token") + if is_encrypted and refresh_tok: + try: + refresh_tok = encryption.decrypt_token(refresh_tok) + except Exception as e: + logger.error("Failed to decrypt refresh token: %s", e) + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + + if not refresh_tok: + await self.on_token_refresh_failure(session, connector) + raise HTTPException( + status_code=400, + detail="No refresh token available. Please re-authenticate.", + ) + + client_id = self._get_client_id() + client_secret = self._get_client_secret() + + headers: dict[str, str] = { + "Content-Type": "application/x-www-form-urlencoded", + } + body: dict[str, str] = { + "grant_type": "refresh_token", + "refresh_token": refresh_tok, + } + + if self.token_auth_method == "basic": + creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() + headers["Authorization"] = f"Basic {creds}" + else: + body["client_id"] = client_id + body["client_secret"] = client_secret + + async with httpx.AsyncClient() as client: + resp = await client.post( + self.token_url, data=body, headers=headers, timeout=30.0 + ) + + if resp.status_code != 200: + error_detail = resp.text + try: + ej = resp.json() + error_detail = ej.get("error_description", error_detail) + error_code = ej.get("error", "") + except Exception: + error_code = "" + combined = (error_detail + error_code).lower() + if any(kw in combined for kw in ("invalid_grant", "expired", "revoked")): + await self.on_token_refresh_failure(session, connector) + raise HTTPException( + status_code=401, + detail=f"{self.provider_name.title()} authentication failed. " + "Please re-authenticate.", + ) + raise HTTPException( + status_code=400, detail=f"Token refresh failed: {error_detail}" + ) + + token_json = resp.json() + new_access = token_json.get("access_token") + if not new_access: + raise HTTPException( + status_code=400, detail="No access token received from refresh" + ) + + expires_at = None + if token_json.get("expires_in"): + expires_at = datetime.now(UTC) + timedelta( + seconds=int(token_json["expires_in"]) + ) + + updated_config = dict(connector.config) + updated_config["access_token"] = encryption.encrypt_token(new_access) + new_refresh = token_json.get("refresh_token") + if new_refresh: + updated_config["refresh_token"] = encryption.encrypt_token(new_refresh) + updated_config["expires_in"] = token_json.get("expires_in") + updated_config["expires_at"] = expires_at.isoformat() if expires_at else None + updated_config["scope"] = token_json.get("scope", updated_config.get("scope")) + updated_config["_token_encrypted"] = True + updated_config.pop("auth_expired", None) + + connector.config = updated_config + flag_modified(connector, "config") + await session.commit() + await session.refresh(connector) + + logger.info( + "Refreshed %s token for connector %s", + self.provider_name, + connector.id, + ) + return connector + + def build_router(self) -> APIRouter: + router = APIRouter() + oauth = self + + @router.get(f"{oauth.auth_prefix}/connector/add") + async def connect( + space_id: int, + user: User = Depends(current_active_user), + ): + if not space_id: + raise HTTPException(status_code=400, detail="space_id is required") + + client_id = oauth._get_client_id() + state_mgr = oauth._get_state_manager() + + extra_state: dict[str, Any] = {} + auth_params: dict[str, str] = { + "client_id": client_id, + "response_type": "code", + "redirect_uri": oauth._get_redirect_uri(), + "scope": " ".join(oauth.scopes), + } + + if oauth.use_pkce: + from app.utils.oauth_security import generate_pkce_pair + + verifier, challenge = generate_pkce_pair() + extra_state["code_verifier"] = verifier + auth_params["code_challenge"] = challenge + auth_params["code_challenge_method"] = "S256" + + auth_params.update(oauth.extra_auth_params) + + state_encoded = state_mgr.generate_secure_state( + space_id, user.id, **extra_state + ) + auth_params["state"] = state_encoded + auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}" + + logger.info( + "Generated %s OAuth URL for user %s, space %s", + oauth.provider_name, + user.id, + space_id, + ) + return {"auth_url": auth_url} + + @router.get(f"{oauth.auth_prefix}/connector/reauth") + async def reauth( + space_id: int, + connector_id: int, + return_url: str | None = None, + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), + ): + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.user_id == user.id, + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.connector_type == oauth.connector_type, + ) + ) + if not result.scalars().first(): + raise HTTPException( + status_code=404, + detail=f"{oauth.provider_name.title()} connector not found " + "or access denied", + ) + + client_id = oauth._get_client_id() + state_mgr = oauth._get_state_manager() + + extra: dict[str, Any] = {"connector_id": connector_id} + if return_url and return_url.startswith("/") and not return_url.startswith("//"): + extra["return_url"] = return_url + + auth_params: dict[str, str] = { + "client_id": client_id, + "response_type": "code", + "redirect_uri": oauth._get_redirect_uri(), + "scope": " ".join(oauth.scopes), + } + + if oauth.use_pkce: + from app.utils.oauth_security import generate_pkce_pair + + verifier, challenge = generate_pkce_pair() + extra["code_verifier"] = verifier + auth_params["code_challenge"] = challenge + auth_params["code_challenge_method"] = "S256" + + auth_params.update(oauth.extra_auth_params) + + state_encoded = state_mgr.generate_secure_state( + space_id, user.id, **extra + ) + auth_params["state"] = state_encoded + auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}" + + logger.info( + "Initiating %s re-auth for user %s, connector %s", + oauth.provider_name, + user.id, + connector_id, + ) + return {"auth_url": auth_url} + + @router.get(f"{oauth.auth_prefix}/connector/callback") + async def callback( + code: str | None = None, + error: str | None = None, + state: str | None = None, + session: AsyncSession = Depends(get_async_session), + ): + error_label = f"{oauth.provider_name}_oauth_denied" + + if error: + logger.warning("%s OAuth error: %s", oauth.provider_name, error) + space_id = None + if state: + try: + data = oauth._get_state_manager().validate_state(state) + space_id = data.get("space_id") + except Exception: + pass + return oauth._frontend_redirect(space_id, error=error_label) + + if not code: + raise HTTPException( + status_code=400, detail="Missing authorization code" + ) + if not state: + raise HTTPException( + status_code=400, detail="Missing state parameter" + ) + + state_mgr = oauth._get_state_manager() + try: + data = state_mgr.validate_state(state) + except Exception as e: + raise HTTPException( + status_code=400, detail="Invalid or expired state parameter." + ) from e + + user_id = UUID(data["user_id"]) + space_id = data["space_id"] + + token_json = await oauth._exchange_code(code, data) + + access_token = token_json.get("access_token", "") + if not access_token: + raise HTTPException( + status_code=400, + detail=f"No access token received from {oauth.provider_name.title()}", + ) + + account_info = await oauth.fetch_account_info(access_token) + encryption = oauth._get_token_encryption() + connector_config = oauth.build_connector_config( + token_json, account_info, encryption + ) + + display_name = oauth.get_connector_display_name(account_info) + + # --- Re-auth path --- + reauth_connector_id = data.get("connector_id") + reauth_return_url = data.get("return_url") + + if reauth_connector_id: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == reauth_connector_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.connector_type == oauth.connector_type, + ) + ) + db_connector = result.scalars().first() + if not db_connector: + raise HTTPException( + status_code=404, + detail="Connector not found or access denied during re-auth", + ) + + db_connector.config = connector_config + flag_modified(db_connector, "config") + await session.commit() + await session.refresh(db_connector) + + logger.info( + "Re-authenticated %s connector %s for user %s", + oauth.provider_name, + db_connector.id, + user_id, + ) + if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"): + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}" + ) + return oauth._frontend_redirect( + space_id, success=True, connector_id=db_connector.id + ) + + # --- New connector path --- + is_dup = await check_duplicate_connector( + session, + oauth.connector_type, + space_id, + user_id, + display_name, + ) + if is_dup: + logger.warning( + "Duplicate %s connector for user %s (%s)", + oauth.provider_name, + user_id, + display_name, + ) + return oauth._frontend_redirect( + space_id, + error=f"duplicate_account&connector={oauth.provider_name}-connector", + ) + + connector_name = await generate_unique_connector_name( + session, + oauth.connector_type, + space_id, + user_id, + display_name, + ) + + new_connector = SearchSourceConnector( + name=connector_name, + connector_type=oauth.connector_type, + is_indexable=oauth.is_indexable, + config=connector_config, + search_space_id=space_id, + user_id=user_id, + ) + session.add(new_connector) + + try: + await session.commit() + except IntegrityError as e: + await session.rollback() + raise HTTPException( + status_code=409, detail="A connector for this service already exists." + ) from e + + logger.info( + "Created %s connector %s for user %s in space %s", + oauth.provider_name, + new_connector.id, + user_id, + space_id, + ) + return oauth._frontend_redirect( + space_id, success=True, connector_id=new_connector.id + ) + + return router diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index b87ce28c9..989894003 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -693,27 +693,10 @@ async def index_connector_content( user: User = Depends(current_active_user), ): """ - Index content from a connector to a search space. - Requires CONNECTORS_UPDATE permission (to trigger indexing). + Index content from a KB connector to a search space. - Currently supports: - - SLACK_CONNECTOR: Indexes messages from all accessible Slack channels - - TEAMS_CONNECTOR: Indexes messages from all accessible Microsoft Teams channels - - NOTION_CONNECTOR: Indexes pages from all accessible Notion pages - - GITHUB_CONNECTOR: Indexes code and documentation from GitHub repositories - - LINEAR_CONNECTOR: Indexes issues and comments from Linear - - JIRA_CONNECTOR: Indexes issues and comments from Jira - - DISCORD_CONNECTOR: Indexes messages from all accessible Discord channels - - LUMA_CONNECTOR: Indexes events from Luma - - ELASTICSEARCH_CONNECTOR: Indexes documents from Elasticsearch - - WEBCRAWLER_CONNECTOR: Indexes web pages from crawled websites - - Args: - connector_id: ID of the connector to use - search_space_id: ID of the search space to store indexed content - - Returns: - Dictionary with indexing status + Live connectors (Slack, Teams, Linear, Jira, ClickUp, Calendar, Airtable, + Gmail, Discord, Luma) use real-time agent tools instead. """ try: # Get the connector first @@ -770,9 +753,7 @@ async def index_connector_content( # For calendar connectors, default to today but allow future dates if explicitly provided if connector.connector_type in [ - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, - SearchSourceConnectorType.LUMA_CONNECTOR, ]: # Default to today if no end_date provided (users can manually select future dates) indexing_to = today_str if end_date is None else end_date @@ -796,33 +777,22 @@ async def index_connector_content( # For non-calendar connectors, cap at today indexing_to = end_date if end_date else today_str - if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import ( - index_slack_messages_task, - ) + from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES - logger.info( - f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_slack_messages_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Slack indexing started in the background." + if connector.connector_type in LIVE_CONNECTOR_TYPES: + return { + "message": ( + f"{connector.connector_type.value} uses real-time agent tools; " + "background indexing is disabled." + ), + "indexing_started": False, + "connector_id": connector_id, + "search_space_id": search_space_id, + "indexing_from": indexing_from, + "indexing_to": indexing_to, + } - elif connector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import ( - index_teams_messages_task, - ) - - logger.info( - f"Triggering Teams indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_teams_messages_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Teams indexing started in the background." - - elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR: + if connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR: from app.tasks.celery_tasks.connector_tasks import index_notion_pages_task logger.info( @@ -844,28 +814,6 @@ async def index_connector_content( ) response_message = "GitHub indexing started in the background." - elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import index_linear_issues_task - - logger.info( - f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_linear_issues_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Linear indexing started in the background." - - elif connector.connector_type == SearchSourceConnectorType.JIRA_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import index_jira_issues_task - - logger.info( - f"Triggering Jira indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_jira_issues_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Jira indexing started in the background." - elif connector.connector_type == SearchSourceConnectorType.CONFLUENCE_CONNECTOR: from app.tasks.celery_tasks.connector_tasks import ( index_confluence_pages_task, @@ -892,59 +840,6 @@ async def index_connector_content( ) response_message = "BookStack indexing started in the background." - elif connector.connector_type == SearchSourceConnectorType.CLICKUP_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import index_clickup_tasks_task - - logger.info( - f"Triggering ClickUp indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_clickup_tasks_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "ClickUp indexing started in the background." - - elif ( - connector.connector_type - == SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR - ): - from app.tasks.celery_tasks.connector_tasks import ( - index_google_calendar_events_task, - ) - - logger.info( - f"Triggering Google Calendar indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_google_calendar_events_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Google Calendar indexing started in the background." - elif connector.connector_type == SearchSourceConnectorType.AIRTABLE_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import ( - index_airtable_records_task, - ) - - logger.info( - f"Triggering Airtable indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_airtable_records_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Airtable indexing started in the background." - elif ( - connector.connector_type == SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR - ): - from app.tasks.celery_tasks.connector_tasks import ( - index_google_gmail_messages_task, - ) - - logger.info( - f"Triggering Google Gmail indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_google_gmail_messages_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Google Gmail indexing started in the background." - elif ( connector.connector_type == SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR ): @@ -1089,30 +984,6 @@ async def index_connector_content( ) response_message = "Dropbox indexing started in the background." - elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import ( - index_discord_messages_task, - ) - - logger.info( - f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_discord_messages_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Discord indexing started in the background." - - elif connector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR: - from app.tasks.celery_tasks.connector_tasks import index_luma_events_task - - logger.info( - f"Triggering Luma indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}" - ) - index_luma_events_task.delay( - connector_id, search_space_id, str(user.id), indexing_from, indexing_to - ) - response_message = "Luma indexing started in the background." - elif ( connector.connector_type == SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR @@ -1338,57 +1209,6 @@ async def _update_connector_timestamp_by_id(session: AsyncSession, connector_id: await session.rollback() -async def run_slack_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Create a new session and run the Slack indexing task. - This prevents session leaks by creating a dedicated session for the background task. - """ - async with async_session_maker() as session: - await run_slack_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -async def run_slack_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Slack indexing. - - Args: - session: Database session - connector_id: ID of the Slack connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_slack_messages - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_slack_messages, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - _AUTH_ERROR_PATTERNS = ( "failed to refresh linear oauth", "failed to refresh your notion connection", @@ -1927,215 +1747,6 @@ async def run_github_indexing( ) -# Add new helper functions for Linear indexing -async def run_linear_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Wrapper to run Linear indexing with its own database session.""" - logger.info( - f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}" - ) - async with async_session_maker() as session: - await run_linear_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - logger.info(f"Background task finished: Indexing Linear connector {connector_id}") - - -async def run_linear_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Linear indexing. - - Args: - session: Database session - connector_id: ID of the Linear connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_linear_issues - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_linear_issues, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - -# Add new helper functions for discord indexing -async def run_discord_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Create a new session and run the Discord indexing task. - This prevents session leaks by creating a dedicated session for the background task. - """ - async with async_session_maker() as session: - await run_discord_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -async def run_discord_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Discord indexing. - - Args: - session: Database session - connector_id: ID of the Discord connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_discord_messages - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_discord_messages, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - -async def run_teams_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Create a new session and run the Microsoft Teams indexing task. - This prevents session leaks by creating a dedicated session for the background task. - """ - async with async_session_maker() as session: - await run_teams_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -async def run_teams_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Microsoft Teams indexing. - - Args: - session: Database session - connector_id: ID of the Teams connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers.teams_indexer import index_teams_messages - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_teams_messages, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - -# Add new helper functions for Jira indexing -async def run_jira_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Wrapper to run Jira indexing with its own database session.""" - logger.info( - f"Background task started: Indexing Jira connector {connector_id} into space {search_space_id} from {start_date} to {end_date}" - ) - async with async_session_maker() as session: - await run_jira_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - logger.info(f"Background task finished: Indexing Jira connector {connector_id}") - - -async def run_jira_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Jira indexing. - - Args: - session: Database session - connector_id: ID of the Jira connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_jira_issues - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_jira_issues, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - # Add new helper functions for Confluence indexing async def run_confluence_indexing_with_new_session( connector_id: int, @@ -2191,112 +1802,6 @@ async def run_confluence_indexing( ) -# Add new helper functions for ClickUp indexing -async def run_clickup_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Wrapper to run ClickUp indexing with its own database session.""" - logger.info( - f"Background task started: Indexing ClickUp connector {connector_id} into space {search_space_id} from {start_date} to {end_date}" - ) - async with async_session_maker() as session: - await run_clickup_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - logger.info(f"Background task finished: Indexing ClickUp connector {connector_id}") - - -async def run_clickup_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run ClickUp indexing. - - Args: - session: Database session - connector_id: ID of the ClickUp connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_clickup_tasks - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_clickup_tasks, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - -# Add new helper functions for Airtable indexing -async def run_airtable_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Wrapper to run Airtable indexing with its own database session.""" - logger.info( - f"Background task started: Indexing Airtable connector {connector_id} into space {search_space_id} from {start_date} to {end_date}" - ) - async with async_session_maker() as session: - await run_airtable_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - logger.info(f"Background task finished: Indexing Airtable connector {connector_id}") - - -async def run_airtable_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Airtable indexing. - - Args: - session: Database session - connector_id: ID of the Airtable connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_airtable_records - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_airtable_records, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - # Add new helper functions for Google Calendar indexing async def run_google_calendar_indexing_with_new_session( connector_id: int, @@ -2835,58 +2340,6 @@ async def run_dropbox_indexing( logger.error(f"Failed to update notification: {notif_error!s}") -# Add new helper functions for luma indexing -async def run_luma_indexing_with_new_session( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Create a new session and run the Luma indexing task. - This prevents session leaks by creating a dedicated session for the background task. - """ - async with async_session_maker() as session: - await run_luma_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -async def run_luma_indexing( - session: AsyncSession, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """ - Background task to run Luma indexing. - - Args: - session: Database session - connector_id: ID of the Luma connector - search_space_id: ID of the search space - user_id: ID of the user - start_date: Start date for indexing - end_date: End date for indexing - """ - from app.tasks.connector_indexers import index_luma_events - - await _run_indexing_with_notifications( - session=session, - connector_id=connector_id, - search_space_id=search_space_id, - user_id=user_id, - start_date=start_date, - end_date=end_date, - indexing_function=index_luma_events, - update_timestamp_func=_update_connector_timestamp_by_id, - supports_heartbeat_callback=True, - ) - - async def run_elasticsearch_indexing_with_new_session( connector_id: int, search_space_id: int, diff --git a/surfsense_backend/app/routes/slack_add_connector_route.py b/surfsense_backend/app/routes/slack_add_connector_route.py index 405ab2c4f..f6a1458a0 100644 --- a/surfsense_backend/app/routes/slack_add_connector_route.py +++ b/surfsense_backend/app/routes/slack_add_connector_route.py @@ -312,7 +312,7 @@ async def slack_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.SLACK_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/routes/teams_add_connector_route.py b/surfsense_backend/app/routes/teams_add_connector_route.py index 4442307ba..9d0f5144f 100644 --- a/surfsense_backend/app/routes/teams_add_connector_route.py +++ b/surfsense_backend/app/routes/teams_add_connector_route.py @@ -45,6 +45,7 @@ SCOPES = [ "Team.ReadBasic.All", # Read basic team information "Channel.ReadBasic.All", # Read basic channel information "ChannelMessage.Read.All", # Read messages in channels + "ChannelMessage.Send", # Send messages in channels ] # Initialize security utilities @@ -320,7 +321,7 @@ async def teams_callback( new_connector = SearchSourceConnector( name=connector_name, connector_type=SearchSourceConnectorType.TEAMS_CONNECTOR, - is_indexable=True, + is_indexable=False, config=connector_config, search_space_id=space_id, user_id=user_id, diff --git a/surfsense_backend/app/services/composio_service.py b/surfsense_backend/app/services/composio_service.py index 13fe37832..a8abe4aa8 100644 --- a/surfsense_backend/app/services/composio_service.py +++ b/surfsense_backend/app/services/composio_service.py @@ -26,7 +26,7 @@ COMPOSIO_TOOLKIT_NAMES = { } # Toolkits that support indexing (Phase 1: Google services only) -INDEXABLE_TOOLKITS = {"googledrive", "gmail", "googlecalendar"} +INDEXABLE_TOOLKITS = {"googledrive"} # Mapping of toolkit IDs to connector types TOOLKIT_TO_CONNECTOR_TYPE = { diff --git a/surfsense_backend/app/services/llm_router_service.py b/surfsense_backend/app/services/llm_router_service.py index c9eeff01b..4bce79a43 100644 --- a/surfsense_backend/app/services/llm_router_service.py +++ b/surfsense_backend/app/services/llm_router_service.py @@ -290,6 +290,12 @@ class LLMRouterService: instance._router = Router(**router_kwargs) instance._initialized = True + + global _cached_context_profile, _cached_context_profile_computed + _cached_context_profile = None + _cached_context_profile_computed = False + _router_instance_cache.clear() + logger.info( "LLM Router initialized with %d deployments, " "strategy: %s, context_window_fallbacks: %s, fallbacks: %s", diff --git a/surfsense_backend/app/services/mcp_oauth/__init__.py b/surfsense_backend/app/services/mcp_oauth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/surfsense_backend/app/services/mcp_oauth/discovery.py b/surfsense_backend/app/services/mcp_oauth/discovery.py new file mode 100644 index 000000000..b0f3fef2a --- /dev/null +++ b/surfsense_backend/app/services/mcp_oauth/discovery.py @@ -0,0 +1,121 @@ +"""MCP OAuth 2.1 metadata discovery, Dynamic Client Registration, and token exchange.""" + +from __future__ import annotations + +import base64 +import logging +from urllib.parse import urlparse + +import httpx + +logger = logging.getLogger(__name__) + + +async def discover_oauth_metadata( + mcp_url: str, + *, + origin_override: str | None = None, + timeout: float = 15.0, +) -> dict: + """Fetch OAuth 2.1 metadata from the MCP server's well-known endpoint. + + Per the MCP spec the discovery document lives at the *origin* of the + MCP server URL. ``origin_override`` can be used when the OAuth server + lives on a different domain (e.g. Airtable: MCP at ``mcp.airtable.com``, + OAuth at ``airtable.com``). + """ + if origin_override: + origin = origin_override.rstrip("/") + else: + parsed = urlparse(mcp_url) + origin = f"{parsed.scheme}://{parsed.netloc}" + discovery_url = f"{origin}/.well-known/oauth-authorization-server" + + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.get(discovery_url, timeout=timeout) + resp.raise_for_status() + return resp.json() + + +async def register_client( + registration_endpoint: str, + redirect_uri: str, + *, + client_name: str = "SurfSense", + timeout: float = 15.0, +) -> dict: + """Perform Dynamic Client Registration (RFC 7591).""" + payload = { + "client_name": client_name, + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_basic", + } + + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.post( + registration_endpoint, json=payload, timeout=timeout, + ) + resp.raise_for_status() + return resp.json() + + +async def exchange_code_for_tokens( + token_endpoint: str, + code: str, + redirect_uri: str, + client_id: str, + client_secret: str, + code_verifier: str, + *, + timeout: float = 30.0, +) -> dict: + """Exchange an authorization code for access + refresh tokens.""" + creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() + + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.post( + token_endpoint, + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "code_verifier": code_verifier, + }, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": f"Basic {creds}", + }, + timeout=timeout, + ) + resp.raise_for_status() + return resp.json() + + +async def refresh_access_token( + token_endpoint: str, + refresh_token: str, + client_id: str, + client_secret: str, + *, + timeout: float = 30.0, +) -> dict: + """Refresh an expired access token.""" + creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode() + + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.post( + token_endpoint, + data={ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + }, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": f"Basic {creds}", + }, + timeout=timeout, + ) + resp.raise_for_status() + return resp.json() diff --git a/surfsense_backend/app/services/mcp_oauth/registry.py b/surfsense_backend/app/services/mcp_oauth/registry.py new file mode 100644 index 000000000..49bc74d3d --- /dev/null +++ b/surfsense_backend/app/services/mcp_oauth/registry.py @@ -0,0 +1,161 @@ +"""Registry of MCP services with OAuth support. + +Each entry maps a URL-safe service key to its MCP server endpoint and +authentication configuration. Services with ``supports_dcr=True`` use +RFC 7591 Dynamic Client Registration (the MCP server issues its own +credentials); the rest use pre-configured credentials via env vars. + +``allowed_tools`` whitelists which MCP tools to expose to the agent. +An empty list means "load every tool the server advertises" (used for +user-managed generic MCP servers). Service-specific entries should +curate this list to keep the agent's tool count low and selection +accuracy high. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from app.db import SearchSourceConnectorType + + +@dataclass(frozen=True) +class MCPServiceConfig: + name: str + mcp_url: str + connector_type: str + supports_dcr: bool = True + oauth_discovery_origin: str | None = None + client_id_env: str | None = None + client_secret_env: str | None = None + scopes: list[str] = field(default_factory=list) + scope_param: str = "scope" + auth_endpoint_override: str | None = None + token_endpoint_override: str | None = None + allowed_tools: list[str] = field(default_factory=list) + readonly_tools: frozenset[str] = field(default_factory=frozenset) + account_metadata_keys: list[str] = field(default_factory=list) + """``connector.config`` keys exposed by ``get_connected_accounts``. + + Only listed keys are returned to the LLM — tokens and secrets are + never included. Every service should at least have its + ``display_name`` populated during OAuth; additional service-specific + fields (e.g. Jira ``cloud_id``) are listed here so the LLM can pass + them to action tools. + """ + + +MCP_SERVICES: dict[str, MCPServiceConfig] = { + "linear": MCPServiceConfig( + name="Linear", + mcp_url="https://mcp.linear.app/mcp", + connector_type="LINEAR_CONNECTOR", + allowed_tools=[ + "list_issues", + "get_issue", + "save_issue", + ], + readonly_tools=frozenset({"list_issues", "get_issue"}), + account_metadata_keys=["organization_name", "organization_url_key"], + ), + "jira": MCPServiceConfig( + name="Jira", + mcp_url="https://mcp.atlassian.com/v1/mcp", + connector_type="JIRA_CONNECTOR", + allowed_tools=[ + "getAccessibleAtlassianResources", + "searchJiraIssuesUsingJql", + "getVisibleJiraProjects", + "getJiraProjectIssueTypesMetadata", + "createJiraIssue", + "editJiraIssue", + ], + readonly_tools=frozenset({ + "getAccessibleAtlassianResources", + "searchJiraIssuesUsingJql", + "getVisibleJiraProjects", + "getJiraProjectIssueTypesMetadata", + }), + account_metadata_keys=["cloud_id", "site_name", "base_url"], + ), + "clickup": MCPServiceConfig( + name="ClickUp", + mcp_url="https://mcp.clickup.com/mcp", + connector_type="CLICKUP_CONNECTOR", + allowed_tools=[ + "clickup_search", + "clickup_get_task", + ], + readonly_tools=frozenset({"clickup_search", "clickup_get_task"}), + account_metadata_keys=["workspace_id", "workspace_name"], + ), + "slack": MCPServiceConfig( + name="Slack", + mcp_url="https://mcp.slack.com/mcp", + connector_type="SLACK_CONNECTOR", + supports_dcr=False, + client_id_env="SLACK_CLIENT_ID", + client_secret_env="SLACK_CLIENT_SECRET", + auth_endpoint_override="https://slack.com/oauth/v2_user/authorize", + token_endpoint_override="https://slack.com/api/oauth.v2.user.access", + scopes=[ + "search:read.public", "search:read.private", "search:read.mpim", "search:read.im", + "channels:history", "groups:history", "mpim:history", "im:history", + ], + allowed_tools=[ + "slack_search_channels", + "slack_read_channel", + "slack_read_thread", + ], + readonly_tools=frozenset({"slack_search_channels", "slack_read_channel", "slack_read_thread"}), + # TODO: oauth.v2.user.access only returns team.id, not team.name. + # To populate team_name, either add "team:read" scope and call + # GET /api/team.info during OAuth callback, or switch to oauth.v2.access. + account_metadata_keys=["team_id", "team_name"], + ), + "airtable": MCPServiceConfig( + name="Airtable", + mcp_url="https://mcp.airtable.com/mcp", + connector_type="AIRTABLE_CONNECTOR", + supports_dcr=False, + oauth_discovery_origin="https://airtable.com", + client_id_env="AIRTABLE_CLIENT_ID", + client_secret_env="AIRTABLE_CLIENT_SECRET", + scopes=["data.records:read", "schema.bases:read"], + allowed_tools=[ + "list_bases", + "list_tables_for_base", + "list_records_for_table", + ], + readonly_tools=frozenset({"list_bases", "list_tables_for_base", "list_records_for_table"}), + account_metadata_keys=["user_id", "user_email"], + ), +} + +_CONNECTOR_TYPE_TO_SERVICE: dict[str, MCPServiceConfig] = { + svc.connector_type: svc for svc in MCP_SERVICES.values() +} + +LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset({ + SearchSourceConnectorType.SLACK_CONNECTOR, + SearchSourceConnectorType.TEAMS_CONNECTOR, + SearchSourceConnectorType.LINEAR_CONNECTOR, + SearchSourceConnectorType.JIRA_CONNECTOR, + SearchSourceConnectorType.CLICKUP_CONNECTOR, + SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, + SearchSourceConnectorType.AIRTABLE_CONNECTOR, + SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR, + SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR, + SearchSourceConnectorType.DISCORD_CONNECTOR, + SearchSourceConnectorType.LUMA_CONNECTOR, +}) + + +def get_service(key: str) -> MCPServiceConfig | None: + return MCP_SERVICES.get(key) + + +def get_service_by_connector_type(connector_type: str) -> MCPServiceConfig | None: + """Look up an MCP service config by its ``connector_type`` enum value.""" + return _CONNECTOR_TYPE_TO_SERVICE.get(connector_type) diff --git a/surfsense_backend/app/services/notion/tool_metadata_service.py b/surfsense_backend/app/services/notion/tool_metadata_service.py index 097ef3461..19dc1fd89 100644 --- a/surfsense_backend/app/services/notion/tool_metadata_service.py +++ b/surfsense_backend/app/services/notion/tool_metadata_service.py @@ -227,8 +227,6 @@ class NotionToolMetadataService: async def _check_account_health(self, connector_id: int) -> bool: """Check if a Notion connector's token is still valid. - Uses a lightweight ``users.me()`` call to verify the token. - Returns True if the token is expired/invalid, False if healthy. """ try: diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py index 57475c9fd..141d5ffca 100644 --- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py @@ -39,52 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N ) -@celery_app.task(name="index_slack_messages", bind=True) -def index_slack_messages_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Slack messages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_slack_messages( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - except Exception as e: - _handle_greenlet_error(e, "index_slack_messages", connector_id) - raise - finally: - loop.close() - - -async def _index_slack_messages( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Slack messages with new session.""" - from app.routes.search_source_connectors_routes import ( - run_slack_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_slack_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - @celery_app.task(name="index_notion_pages", bind=True) def index_notion_pages_task( self, @@ -174,92 +128,6 @@ async def _index_github_repos( ) -@celery_app.task(name="index_linear_issues", bind=True) -def index_linear_issues_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Linear issues.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_linear_issues( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_linear_issues( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Linear issues with new session.""" - from app.routes.search_source_connectors_routes import ( - run_linear_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_linear_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -@celery_app.task(name="index_jira_issues", bind=True) -def index_jira_issues_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Jira issues.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_jira_issues( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_jira_issues( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Jira issues with new session.""" - from app.routes.search_source_connectors_routes import ( - run_jira_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_jira_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - @celery_app.task(name="index_confluence_pages", bind=True) def index_confluence_pages_task( self, @@ -303,49 +171,6 @@ async def _index_confluence_pages( ) -@celery_app.task(name="index_clickup_tasks", bind=True) -def index_clickup_tasks_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index ClickUp tasks.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_clickup_tasks( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_clickup_tasks( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index ClickUp tasks with new session.""" - from app.routes.search_source_connectors_routes import ( - run_clickup_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_clickup_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - @celery_app.task(name="index_google_calendar_events", bind=True) def index_google_calendar_events_task( self, @@ -392,49 +217,6 @@ async def _index_google_calendar_events( ) -@celery_app.task(name="index_airtable_records", bind=True) -def index_airtable_records_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Airtable records.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_airtable_records( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_airtable_records( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Airtable records with new session.""" - from app.routes.search_source_connectors_routes import ( - run_airtable_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_airtable_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - @celery_app.task(name="index_google_gmail_messages", bind=True) def index_google_gmail_messages_task( self, @@ -622,135 +404,6 @@ async def _index_dropbox_files( ) -@celery_app.task(name="index_discord_messages", bind=True) -def index_discord_messages_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Discord messages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_discord_messages( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_discord_messages( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Discord messages with new session.""" - from app.routes.search_source_connectors_routes import ( - run_discord_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_discord_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -@celery_app.task(name="index_teams_messages", bind=True) -def index_teams_messages_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Microsoft Teams messages.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_teams_messages( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_teams_messages( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Microsoft Teams messages with new session.""" - from app.routes.search_source_connectors_routes import ( - run_teams_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_teams_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - -@celery_app.task(name="index_luma_events", bind=True) -def index_luma_events_task( - self, - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Celery task to index Luma events.""" - import asyncio - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete( - _index_luma_events( - connector_id, search_space_id, user_id, start_date, end_date - ) - ) - finally: - loop.close() - - -async def _index_luma_events( - connector_id: int, - search_space_id: int, - user_id: str, - start_date: str, - end_date: str, -): - """Index Luma events with new session.""" - from app.routes.search_source_connectors_routes import ( - run_luma_indexing, - ) - - async with get_celery_session_maker()() as session: - await run_luma_indexing( - session, connector_id, search_space_id, user_id, start_date, end_date - ) - - @celery_app.task(name="index_elasticsearch_documents", bind=True) def index_elasticsearch_documents_task( self, diff --git a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py index e6890b0a8..373f04b48 100644 --- a/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py +++ b/surfsense_backend/app/tasks/celery_tasks/schedule_checker_task.py @@ -51,50 +51,51 @@ async def _check_and_trigger_schedules(): logger.info(f"Found {len(due_connectors)} connectors due for indexing") - # Import all indexing tasks + # Import indexing tasks for KB connectors only. + # Live connectors (Linear, Slack, Jira, ClickUp, Airtable, Discord, + # Teams, Gmail, Calendar, Luma) use real-time tools instead. from app.tasks.celery_tasks.connector_tasks import ( - index_airtable_records_task, - index_clickup_tasks_task, index_confluence_pages_task, index_crawled_urls_task, - index_discord_messages_task, index_elasticsearch_documents_task, index_github_repos_task, - index_google_calendar_events_task, index_google_drive_files_task, - index_google_gmail_messages_task, - index_jira_issues_task, - index_linear_issues_task, - index_luma_events_task, index_notion_pages_task, - index_slack_messages_task, ) - # Map connector types to their tasks task_map = { - SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task, SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task, SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task, - SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task, - SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task, SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task, - SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task, - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task, - SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task, - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task, - SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task, - SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task, SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task, SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task, SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task, - # Composio connector types (unified with native Google tasks) SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task, - SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: index_google_gmail_messages_task, - SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task, } + from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES + + # Disable obsolete periodic indexing for live connectors in one batch. + live_disabled = [] + for connector in due_connectors: + if connector.connector_type in LIVE_CONNECTOR_TYPES: + connector.periodic_indexing_enabled = False + connector.next_scheduled_at = None + live_disabled.append(connector) + if live_disabled: + await session.commit() + for c in live_disabled: + logger.info( + "Disabled obsolete periodic indexing for live connector %s (%s)", + c.id, + c.connector_type.value, + ) + # Trigger indexing for each due connector for connector in due_connectors: + if connector in live_disabled: + continue + # Primary guard: Redis lock indicates a task is currently running. if is_connector_indexing_locked(connector.id): logger.info( diff --git a/surfsense_backend/app/tasks/connector_indexers/__init__.py b/surfsense_backend/app/tasks/connector_indexers/__init__.py index 1b032d54a..2b0ad7fa0 100644 --- a/surfsense_backend/app/tasks/connector_indexers/__init__.py +++ b/surfsense_backend/app/tasks/connector_indexers/__init__.py @@ -1,77 +1,31 @@ """ Connector indexers module for background tasks. -This module provides a collection of connector indexers for different platforms -and services. Each indexer is responsible for handling the indexing of content -from a specific connector type. - -Available indexers: -- Slack: Index messages from Slack channels -- Notion: Index pages from Notion workspaces -- GitHub: Index repositories and files from GitHub -- Linear: Index issues from Linear workspaces -- Jira: Index issues from Jira projects -- Confluence: Index pages from Confluence spaces -- BookStack: Index pages from BookStack wiki instances -- Discord: Index messages from Discord servers -- ClickUp: Index tasks from ClickUp workspaces -- Google Gmail: Index messages from Google Gmail -- Google Calendar: Index events from Google Calendar -- Luma: Index events from Luma -- Webcrawler: Index crawled URLs -- Elasticsearch: Index documents from Elasticsearch instances +Each indexer handles content indexing from a specific connector type. +Live connectors (Slack, Linear, Jira, ClickUp, Airtable, Discord, Teams, +Luma) now use real-time agent tools instead of background indexing. """ -# Communication platforms -# Calendar and scheduling -from .airtable_indexer import index_airtable_records from .bookstack_indexer import index_bookstack_pages - -# Note: composio_indexer is imported directly in connector_tasks.py to avoid circular imports -from .clickup_indexer import index_clickup_tasks from .confluence_indexer import index_confluence_pages -from .discord_indexer import index_discord_messages - -# Development platforms from .elasticsearch_indexer import index_elasticsearch_documents from .github_indexer import index_github_repos from .google_calendar_indexer import index_google_calendar_events from .google_drive_indexer import index_google_drive_files from .google_gmail_indexer import index_google_gmail_messages -from .jira_indexer import index_jira_issues - -# Issue tracking and project management -from .linear_indexer import index_linear_issues - -# Documentation and knowledge management -from .luma_indexer import index_luma_events from .notion_indexer import index_notion_pages from .obsidian_indexer import index_obsidian_vault -from .slack_indexer import index_slack_messages from .webcrawler_indexer import index_crawled_urls -__all__ = [ # noqa: RUF022 - "index_airtable_records", +__all__ = [ "index_bookstack_pages", - # "index_composio_connector", # Imported directly in connector_tasks.py to avoid circular imports - "index_clickup_tasks", "index_confluence_pages", - "index_discord_messages", - # Development platforms "index_elasticsearch_documents", "index_github_repos", - # Calendar and scheduling "index_google_calendar_events", "index_google_drive_files", - "index_luma_events", - "index_jira_issues", - # Issue tracking and project management - "index_linear_issues", - # Documentation and knowledge management + "index_google_gmail_messages", "index_notion_pages", "index_obsidian_vault", "index_crawled_urls", - # Communication platforms - "index_slack_messages", - "index_google_gmail_messages", ] diff --git a/surfsense_backend/app/utils/async_retry.py b/surfsense_backend/app/utils/async_retry.py new file mode 100644 index 000000000..c3bdd5386 --- /dev/null +++ b/surfsense_backend/app/utils/async_retry.py @@ -0,0 +1,129 @@ +"""Async retry decorators for connector API calls, built on tenacity.""" + +from __future__ import annotations + +import logging +from collections.abc import Callable +from typing import TypeVar + +import httpx +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception, + stop_after_attempt, + stop_after_delay, + wait_exponential_jitter, +) + +from app.connectors.exceptions import ( + ConnectorAPIError, + ConnectorAuthError, + ConnectorError, + ConnectorRateLimitError, + ConnectorTimeoutError, +) + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable) + + +def _is_retryable(exc: BaseException) -> bool: + if isinstance(exc, ConnectorError): + return exc.retryable + if isinstance(exc, (httpx.TimeoutException, httpx.ConnectError)): + return True + return False + + +def build_retry( + *, + max_attempts: int = 4, + max_delay: float = 60.0, + initial_delay: float = 1.0, + total_timeout: float = 180.0, + service: str = "", +) -> Callable: + """Configurable tenacity ``@retry`` decorator with exponential backoff + jitter.""" + _logger = logging.getLogger(f"connector.retry.{service}") if service else logger + + return retry( + retry=retry_if_exception(_is_retryable), + stop=(stop_after_attempt(max_attempts) | stop_after_delay(total_timeout)), + wait=wait_exponential_jitter(initial=initial_delay, max=max_delay), + reraise=True, + before_sleep=before_sleep_log(_logger, logging.WARNING), + ) + + +def retry_on_transient( + *, + service: str = "", + max_attempts: int = 4, +) -> Callable: + """Shorthand: retry up to *max_attempts* on rate-limits, timeouts, and 5xx.""" + return build_retry(max_attempts=max_attempts, service=service) + + +def raise_for_status( + response: httpx.Response, + *, + service: str = "", +) -> None: + """Map non-2xx httpx responses to the appropriate ``ConnectorError``.""" + if response.is_success: + return + + status = response.status_code + + try: + body = response.json() + except Exception: + body = response.text[:500] if response.text else None + + if status == 429: + retry_after_raw = response.headers.get("Retry-After") + retry_after: float | None = None + if retry_after_raw: + try: + retry_after = float(retry_after_raw) + except (ValueError, TypeError): + pass + raise ConnectorRateLimitError( + f"{service} rate limited (429)", + service=service, + retry_after=retry_after, + response_body=body, + ) + + if status in (401, 403): + raise ConnectorAuthError( + f"{service} authentication failed ({status})", + service=service, + status_code=status, + response_body=body, + ) + + if status == 504: + raise ConnectorTimeoutError( + f"{service} gateway timeout (504)", + service=service, + status_code=status, + response_body=body, + ) + + if status >= 500: + raise ConnectorAPIError( + f"{service} server error ({status})", + service=service, + status_code=status, + response_body=body, + ) + + raise ConnectorAPIError( + f"{service} request failed ({status})", + service=service, + status_code=status, + response_body=body, + ) diff --git a/surfsense_backend/app/utils/connector_naming.py b/surfsense_backend/app/utils/connector_naming.py index 610be4a22..889bf1464 100644 --- a/surfsense_backend/app/utils/connector_naming.py +++ b/surfsense_backend/app/utils/connector_naming.py @@ -39,7 +39,7 @@ BASE_NAME_FOR_TYPE = { def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str: """Get a friendly display name for a connector type.""" return BASE_NAME_FOR_TYPE.get( - connector_type, connector_type.replace("_", " ").title() + connector_type, connector_type.value.replace("_", " ").title() ) @@ -231,9 +231,11 @@ async def generate_unique_connector_name( base = get_base_name_for_type(connector_type) if identifier: - return f"{base} - {identifier}" + name = f"{base} - {identifier}" + return await ensure_unique_connector_name( + session, name, search_space_id, user_id, + ) - # Fallback: use counter for uniqueness count = await count_connectors_of_type( session, connector_type, search_space_id, user_id ) diff --git a/surfsense_backend/app/utils/periodic_scheduler.py b/surfsense_backend/app/utils/periodic_scheduler.py index 9ea45df63..923f969d5 100644 --- a/surfsense_backend/app/utils/periodic_scheduler.py +++ b/surfsense_backend/app/utils/periodic_scheduler.py @@ -18,19 +18,9 @@ logger = logging.getLogger(__name__) # Mapping of connector types to their corresponding Celery task names CONNECTOR_TASK_MAP = { - SearchSourceConnectorType.SLACK_CONNECTOR: "index_slack_messages", - SearchSourceConnectorType.TEAMS_CONNECTOR: "index_teams_messages", SearchSourceConnectorType.NOTION_CONNECTOR: "index_notion_pages", SearchSourceConnectorType.GITHUB_CONNECTOR: "index_github_repos", - SearchSourceConnectorType.LINEAR_CONNECTOR: "index_linear_issues", - SearchSourceConnectorType.JIRA_CONNECTOR: "index_jira_issues", SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "index_confluence_pages", - SearchSourceConnectorType.CLICKUP_CONNECTOR: "index_clickup_tasks", - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "index_google_calendar_events", - SearchSourceConnectorType.AIRTABLE_CONNECTOR: "index_airtable_records", - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: "index_google_gmail_messages", - SearchSourceConnectorType.DISCORD_CONNECTOR: "index_discord_messages", - SearchSourceConnectorType.LUMA_CONNECTOR: "index_luma_events", SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "index_elasticsearch_documents", SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "index_crawled_urls", SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "index_bookstack_pages", @@ -84,40 +74,20 @@ def create_periodic_schedule( f"(frequency: {frequency_minutes} minutes). Triggering first run..." ) - # Import all indexing tasks from app.tasks.celery_tasks.connector_tasks import ( - index_airtable_records_task, index_bookstack_pages_task, - index_clickup_tasks_task, index_confluence_pages_task, index_crawled_urls_task, - index_discord_messages_task, index_elasticsearch_documents_task, index_github_repos_task, - index_google_calendar_events_task, - index_google_gmail_messages_task, - index_jira_issues_task, - index_linear_issues_task, - index_luma_events_task, index_notion_pages_task, index_obsidian_vault_task, - index_slack_messages_task, ) - # Map connector type to task task_map = { - SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task, SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task, SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task, - SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task, - SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task, SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task, - SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task, - SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task, - SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task, - SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task, - SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task, - SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task, SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task, SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task, SearchSourceConnectorType.BOOKSTACK_CONNECTOR: index_bookstack_pages_task, diff --git a/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx b/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx index d24057b1c..e0df73e66 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/components/connector-card.tsx @@ -8,6 +8,7 @@ import { Spinner } from "@/components/ui/spinner"; import { EnumConnectorName } from "@/contracts/enums/connector"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import { cn } from "@/lib/utils"; +import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants"; import { useConnectorStatus } from "../hooks/use-connector-status"; import { ConnectorStatusBadge } from "./connector-status-badge"; @@ -55,6 +56,7 @@ export const ConnectorCard: FC = ({ onManage, }) => { const isMCP = connectorType === EnumConnectorName.MCP_CONNECTOR; + const isLive = !!connectorType && LIVE_CONNECTOR_TYPES.has(connectorType); // Get connector status const { getConnectorStatus, isConnectorEnabled, getConnectorStatusMessage, shouldShowWarnings } = useConnectorStatus(); @@ -123,14 +125,14 @@ export const ConnectorCard: FC = ({ ) : ( <> - {formatDocumentCount(documentCount)} + {!isLive && {formatDocumentCount(documentCount)}} + {!isLive && accountCount !== undefined && accountCount > 0 && ( + + )} {accountCount !== undefined && accountCount > 0 && ( - <> - - - {accountCount} {accountCount === 1 ? "Account" : "Accounts"} - - + + {accountCount} {accountCount === 1 ? "Account" : "Accounts"} + )} )} diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/discord-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/discord-config.tsx index f782a6f4d..c8714ba40 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/discord-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/discord-config.tsx @@ -53,8 +53,7 @@ export const DiscordConfig: FC = ({ connector }) => { return () => document.removeEventListener("visibilitychange", handleVisibilityChange); }, [connector?.id, fetchChannels]); - // Separate channels by indexing capability - const readyToIndex = channels.filter((ch) => ch.can_index); + const accessible = channels.filter((ch) => ch.can_index); const needsPermissions = channels.filter((ch) => !ch.can_index); // Format last fetched time @@ -80,7 +79,7 @@ export const DiscordConfig: FC = ({ connector }) => {

- The bot needs "Read Message History" permission to index channels. Ask a + The bot needs "Read Message History" permission to access channels. Ask a server admin to grant this permission for channels shown below.

@@ -127,18 +126,18 @@ export const DiscordConfig: FC = ({ connector }) => { ) : (
- {/* Ready to index */} - {readyToIndex.length > 0 && ( + {/* Accessible channels */} + {accessible.length > 0 && (
0 && "border-b border-border")}>
- Ready to index + Accessible - {readyToIndex.length} {readyToIndex.length === 1 ? "channel" : "channels"} + {accessible.length} {accessible.length === 1 ? "channel" : "channels"}
- {readyToIndex.map((channel) => ( + {accessible.map((channel) => ( ))}
@@ -150,7 +149,7 @@ export const DiscordConfig: FC = ({ connector }) => {
- Grant permissions to index + Needs permissions {needsPermissions.length}{" "} {needsPermissions.length === 1 ? "channel" : "channels"} diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx new file mode 100644 index 000000000..71d0e31a8 --- /dev/null +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/mcp-service-config.tsx @@ -0,0 +1,28 @@ +"use client"; + +import { CheckCircle2 } from "lucide-react"; +import type { FC } from "react"; +import type { ConnectorConfigProps } from "../index"; + +export const MCPServiceConfig: FC = ({ connector }) => { + const serviceName = connector.config?.mcp_service as string | undefined; + const displayName = serviceName + ? serviceName.charAt(0).toUpperCase() + serviceName.slice(1) + : "this service"; + + return ( +
+
+
+ +
+
+

Connected

+

+ Your agent can search, read, and take actions in {displayName}. +

+
+
+
+ ); +}; diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx index ac08a6c03..e96ddfd29 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/components/teams-config.tsx @@ -18,9 +18,9 @@ export const TeamsConfig: FC = () => {

Microsoft Teams Access

- SurfSense will index messages from Teams channels that you have access to. The app can - only read messages from teams and channels where you are a member. Make sure you're a - member of the teams you want to index before connecting. + Your agent can search and read messages from Teams channels you have access to, + and send messages on your behalf. Make sure you're a member of the teams + you want to interact with.

diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx index e19600ab2..a69cf968f 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/connector-edit-view.tsx @@ -16,8 +16,9 @@ import { DateRangeSelector } from "../../components/date-range-selector"; import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; import { SummaryConfig } from "../../components/summary-config"; import { VisionLLMConfig } from "../../components/vision-llm-config"; +import { LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants"; import { getConnectorDisplayName } from "../../tabs/all-connectors-tab"; -import { getConnectorConfigComponent } from "../index"; +import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index"; const REAUTH_ENDPOINTS: Partial> = { [EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth", @@ -118,11 +119,17 @@ export const ConnectorEditView: FC = ({ } }, [searchSpaceId, searchSpaceIdAtom, reauthEndpoint, connector.id]); - // Get connector-specific config component - const ConnectorConfigComponent = useMemo( - () => getConnectorConfigComponent(connector.connector_type), - [connector.connector_type] - ); + const isMCPBacked = Boolean(connector.config?.server_config); + const isLive = isMCPBacked || LIVE_CONNECTOR_TYPES.has(connector.connector_type); + + // Get connector-specific config component (MCP-backed connectors use a generic view) + const ConnectorConfigComponent = useMemo(() => { + if (isMCPBacked) { + const { MCPServiceConfig } = require("../components/mcp-service-config"); + return MCPServiceConfig as FC; + } + return getConnectorConfigComponent(connector.connector_type); + }, [connector.connector_type, isMCPBacked]); const [isScrolled, setIsScrolled] = useState(false); const [hasMoreContent, setHasMoreContent] = useState(false); const [showDisconnectConfirm, setShowDisconnectConfirm] = useState(false); @@ -223,12 +230,14 @@ export const ConnectorEditView: FC = ({ {getConnectorDisplayName(connector.name)}

- Manage your connector settings and sync configuration + {isLive + ? "Manage your connected account" + : "Manage your connector settings and sync configuration"}

- {/* Quick Index Button - hidden when auth is expired */} - {connector.is_indexable && onQuickIndex && !isAuthExpired && ( + {/* Quick Index Button - hidden for live connectors and when auth is expired */} + {connector.is_indexable && !isLive && onQuickIndex && !isAuthExpired && ( - ) : ( + ) : !isLive ? ( - )} + ) : null}
); diff --git a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx index 13c257004..c65367e65 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/connector-configs/views/indexing-configuration-view.tsx @@ -11,7 +11,7 @@ import { DateRangeSelector } from "../../components/date-range-selector"; import { PeriodicSyncConfig } from "../../components/periodic-sync-config"; import { SummaryConfig } from "../../components/summary-config"; import { VisionLLMConfig } from "../../components/vision-llm-config"; -import type { IndexingConfigState } from "../../constants/connector-constants"; +import { LIVE_CONNECTOR_TYPES, type IndexingConfigState } from "../../constants/connector-constants"; import { getConnectorDisplayName } from "../../tabs/all-connectors-tab"; import { getConnectorConfigComponent } from "../index"; @@ -58,6 +58,8 @@ export const IndexingConfigurationView: FC = ({ onStartIndexing, onSkip, }) => { + const isLive = LIVE_CONNECTOR_TYPES.has(config.connectorType); + // Get connector-specific config component const ConnectorConfigComponent = useMemo( () => (connector ? getConnectorConfigComponent(connector.connector_type) : null), @@ -138,7 +140,9 @@ export const IndexingConfigurationView: FC = ({ )}

- Configure when to start syncing your data + {isLive + ? "Your account is ready to use" + : "Configure when to start syncing your data"}

@@ -157,8 +161,8 @@ export const IndexingConfigurationView: FC = ({ )} - {/* Summary and sync settings - only shown for indexable connectors */} - {connector?.is_indexable && ( + {/* Summary and sync settings - hidden for live connectors */} + {connector?.is_indexable && !isLive && ( <> {/* AI Summary toggle */} @@ -209,8 +213,8 @@ export const IndexingConfigurationView: FC = ({ )} - {/* Info box - only shown for indexable connectors */} - {connector?.is_indexable && ( + {/* Info box - hidden for live connectors */} + {connector?.is_indexable && !isLive && (
@@ -238,14 +242,20 @@ export const IndexingConfigurationView: FC = ({ {/* Fixed Footer - Action buttons */}
- + {isLive ? ( + + ) : ( + + )}
); diff --git a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts index 6f60c63d6..05f866d0f 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/constants/connector-constants.ts @@ -1,5 +1,24 @@ import { EnumConnectorName } from "@/contracts/enums/connector"; +/** + * Connectors that operate in real time (no background indexing). + * Used to adjust UI: hide sync controls, show "Connected" instead of doc counts. + */ +export const LIVE_CONNECTOR_TYPES = new Set([ + EnumConnectorName.LINEAR_CONNECTOR, + EnumConnectorName.SLACK_CONNECTOR, + EnumConnectorName.JIRA_CONNECTOR, + EnumConnectorName.CLICKUP_CONNECTOR, + EnumConnectorName.AIRTABLE_CONNECTOR, + EnumConnectorName.DISCORD_CONNECTOR, + EnumConnectorName.TEAMS_CONNECTOR, + EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR, + EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, + EnumConnectorName.GOOGLE_GMAIL_CONNECTOR, + EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR, + EnumConnectorName.LUMA_CONNECTOR, +]); + // OAuth Connectors (Quick Connect) export const OAUTH_CONNECTORS = [ { @@ -13,7 +32,7 @@ export const OAUTH_CONNECTORS = [ { id: "google-gmail-connector", title: "Gmail", - description: "Search through your emails", + description: "Search, read, draft, and send emails", connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR, authEndpoint: "/api/v1/auth/google/gmail/connector/add/", selfHostedOnly: true, @@ -21,7 +40,7 @@ export const OAUTH_CONNECTORS = [ { id: "google-calendar-connector", title: "Google Calendar", - description: "Search through your events", + description: "Search and manage your events", connectorType: EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR, authEndpoint: "/api/v1/auth/google/calendar/connector/add/", selfHostedOnly: true, @@ -29,35 +48,35 @@ export const OAUTH_CONNECTORS = [ { id: "airtable-connector", title: "Airtable", - description: "Search your Airtable bases", + description: "Browse bases, tables, and records", connectorType: EnumConnectorName.AIRTABLE_CONNECTOR, - authEndpoint: "/api/v1/auth/airtable/connector/add/", + authEndpoint: "/api/v1/auth/mcp/airtable/connector/add/", }, { id: "notion-connector", title: "Notion", description: "Search your Notion pages", connectorType: EnumConnectorName.NOTION_CONNECTOR, - authEndpoint: "/api/v1/auth/notion/connector/add/", + authEndpoint: "/api/v1/auth/notion/connector/add", }, { id: "linear-connector", title: "Linear", - description: "Search issues & projects", + description: "Search, read, and manage issues & projects", connectorType: EnumConnectorName.LINEAR_CONNECTOR, - authEndpoint: "/api/v1/auth/linear/connector/add/", + authEndpoint: "/api/v1/auth/mcp/linear/connector/add/", }, { id: "slack-connector", title: "Slack", - description: "Search Slack messages", + description: "Search and read channels and threads", connectorType: EnumConnectorName.SLACK_CONNECTOR, - authEndpoint: "/api/v1/auth/slack/connector/add/", + authEndpoint: "/api/v1/auth/mcp/slack/connector/add/", }, { id: "teams-connector", title: "Microsoft Teams", - description: "Search Teams messages", + description: "Search, read, and send messages", connectorType: EnumConnectorName.TEAMS_CONNECTOR, authEndpoint: "/api/v1/auth/teams/connector/add/", }, @@ -78,16 +97,16 @@ export const OAUTH_CONNECTORS = [ { id: "discord-connector", title: "Discord", - description: "Search Discord messages", + description: "Search, read, and send messages", connectorType: EnumConnectorName.DISCORD_CONNECTOR, authEndpoint: "/api/v1/auth/discord/connector/add/", }, { id: "jira-connector", title: "Jira", - description: "Search Jira issues", + description: "Search, read, and manage issues", connectorType: EnumConnectorName.JIRA_CONNECTOR, - authEndpoint: "/api/v1/auth/jira/connector/add/", + authEndpoint: "/api/v1/auth/mcp/jira/connector/add/", }, { id: "confluence-connector", @@ -99,9 +118,9 @@ export const OAUTH_CONNECTORS = [ { id: "clickup-connector", title: "ClickUp", - description: "Search ClickUp tasks", + description: "Search and read tasks", connectorType: EnumConnectorName.CLICKUP_CONNECTOR, - authEndpoint: "/api/v1/auth/clickup/connector/add/", + authEndpoint: "/api/v1/auth/mcp/clickup/connector/add/", }, ] as const; @@ -138,7 +157,7 @@ export const OTHER_CONNECTORS = [ { id: "luma-connector", title: "Luma", - description: "Search Luma events", + description: "Browse, read, and create events", connectorType: EnumConnectorName.LUMA_CONNECTOR, }, { @@ -197,14 +216,14 @@ export const COMPOSIO_CONNECTORS = [ { id: "composio-gmail", title: "Gmail", - description: "Search through your emails via Composio", + description: "Search, read, draft, and send emails via Composio", connectorType: EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR, authEndpoint: "/api/v1/auth/composio/connector/add/?toolkit_id=gmail", }, { id: "composio-googlecalendar", title: "Google Calendar", - description: "Search through your events via Composio", + description: "Search and manage your events via Composio", connectorType: EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR, authEndpoint: "/api/v1/auth/composio/connector/add/?toolkit_id=googlecalendar", }, @@ -221,14 +240,14 @@ export const COMPOSIO_TOOLKITS = [ { id: "gmail", name: "Gmail", - description: "Search through your emails", - isIndexable: true, + description: "Search, read, draft, and send emails", + isIndexable: false, }, { id: "googlecalendar", name: "Google Calendar", - description: "Search through your events", - isIndexable: true, + description: "Search and manage your events", + isIndexable: false, }, { id: "slack", @@ -258,66 +277,6 @@ export interface AutoIndexConfig { } export const AUTO_INDEX_DEFAULTS: Record = { - [EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: { - daysBack: 30, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 30 days of emails.", - }, - [EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: { - daysBack: 30, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 30 days of emails.", - }, - [EnumConnectorName.SLACK_CONNECTOR]: { - daysBack: 30, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 30 days of messages.", - }, - [EnumConnectorName.DISCORD_CONNECTOR]: { - daysBack: 30, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 30 days of messages.", - }, - [EnumConnectorName.TEAMS_CONNECTOR]: { - daysBack: 30, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 30 days of messages.", - }, - [EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: { - daysBack: 90, - daysForward: 90, - frequencyMinutes: 1440, - syncDescription: "Syncing 90 days of past and upcoming events.", - }, - [EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: { - daysBack: 90, - daysForward: 90, - frequencyMinutes: 1440, - syncDescription: "Syncing 90 days of past and upcoming events.", - }, - [EnumConnectorName.LINEAR_CONNECTOR]: { - daysBack: 90, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 90 days of issues.", - }, - [EnumConnectorName.JIRA_CONNECTOR]: { - daysBack: 90, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 90 days of issues.", - }, - [EnumConnectorName.CLICKUP_CONNECTOR]: { - daysBack: 90, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your last 90 days of tasks.", - }, [EnumConnectorName.NOTION_CONNECTOR]: { daysBack: 365, daysForward: 0, @@ -330,12 +289,6 @@ export const AUTO_INDEX_DEFAULTS: Record = { frequencyMinutes: 1440, syncDescription: "Syncing your documentation.", }, - [EnumConnectorName.AIRTABLE_CONNECTOR]: { - daysBack: 365, - daysForward: 0, - frequencyMinutes: 1440, - syncDescription: "Syncing your bases.", - }, }; export const AUTO_INDEX_CONNECTOR_TYPES = new Set(Object.keys(AUTO_INDEX_DEFAULTS)); diff --git a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts index 404ee16f0..a8d395e5c 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts +++ b/surfsense_web/components/assistant-ui/connector-popup/hooks/use-connector-dialog.ts @@ -38,6 +38,7 @@ import { AUTO_INDEX_CONNECTOR_TYPES, AUTO_INDEX_DEFAULTS, COMPOSIO_CONNECTORS, + LIVE_CONNECTOR_TYPES, OAUTH_CONNECTORS, OTHER_CONNECTORS, } from "../constants/connector-constants"; @@ -317,7 +318,12 @@ export const useConnectorDialog = () => { newConnector.id ); - if ( + const isLiveConnector = LIVE_CONNECTOR_TYPES.has(oauthConnector.connectorType); + + if (isLiveConnector) { + toast.success(`${oauthConnector.title} connected successfully!`); + await refetchAllConnectors(); + } else if ( newConnector.is_indexable && AUTO_INDEX_CONNECTOR_TYPES.has(oauthConnector.connectorType) ) { @@ -326,6 +332,9 @@ export const useConnectorDialog = () => { oauthConnector.title, oauthConnector.connectorType ); + } else if (!newConnector.is_indexable) { + toast.success(`${oauthConnector.title} connected successfully!`); + await refetchAllConnectors(); } else { toast.dismiss("auto-index"); const config = validateIndexingConfigState({ diff --git a/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx b/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx index 7a29dd5ca..fe9aab14f 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/tabs/active-connectors-tab.tsx @@ -9,7 +9,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels"; import { cn } from "@/lib/utils"; -import { COMPOSIO_CONNECTORS, OAUTH_CONNECTORS } from "../constants/connector-constants"; +import { COMPOSIO_CONNECTORS, LIVE_CONNECTOR_TYPES, OAUTH_CONNECTORS } from "../constants/connector-constants"; import { getDocumentCountForConnector } from "../utils/connector-document-mapping"; import { getConnectorDisplayName } from "./all-connectors-tab"; @@ -156,6 +156,7 @@ export const ActiveConnectorsTab: FC = ({ {/* OAuth Connectors - Grouped by Type */} {filteredOAuthConnectorTypes.map(([connectorType, typeConnectors]) => { const { title } = getOAuthConnectorTypeInfo(connectorType); + const isLive = LIVE_CONNECTOR_TYPES.has(connectorType); const isAnyIndexing = typeConnectors.some((c: SearchSourceConnector) => indexingConnectorIds.has(c.id) ); @@ -202,8 +203,12 @@ export const ActiveConnectorsTab: FC = ({

) : (

- {formatDocumentCount(documentCount)} - + {!isLive && ( + <> + {formatDocumentCount(documentCount)} + + + )} {accountCount} {accountCount === 1 ? "Account" : "Accounts"} @@ -230,6 +235,7 @@ export const ActiveConnectorsTab: FC = ({ documentTypeCounts ); const isMCPConnector = connector.connector_type === "MCP_CONNECTOR"; + const isLive = LIVE_CONNECTOR_TYPES.has(connector.connector_type); return (

= ({ Syncing

- ) : !isMCPConnector ? ( + ) : !isLive && !isMCPConnector ? (

{formatDocumentCount(documentCount)}

diff --git a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx index b4c049c5c..b48b14ed2 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/views/connector-accounts-list-view.tsx @@ -13,6 +13,7 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types"; import { authenticatedFetch } from "@/lib/auth-utils"; import { formatRelativeDate } from "@/lib/format-date"; import { cn } from "@/lib/utils"; +import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants"; import { useConnectorStatus } from "../hooks/use-connector-status"; import { getConnectorDisplayName } from "../tabs/all-connectors-tab"; @@ -43,12 +44,8 @@ interface ConnectorAccountsListViewProps { addButtonText?: string; } -/** - * Check if a connector type is indexable - */ -function isIndexableConnector(connectorType: string): boolean { - const nonIndexableTypes = ["MCP_CONNECTOR"]; - return !nonIndexableTypes.includes(connectorType); +function isLiveConnector(connectorType: string): boolean { + return LIVE_CONNECTOR_TYPES.has(connectorType) || connectorType === "MCP_CONNECTOR"; } export const ConnectorAccountsListView: FC = ({ @@ -149,7 +146,7 @@ export const ConnectorAccountsListView: FC = ({ {connectorTitle}

- {statusMessage || "Manage your connector settings and sync configuration"} + {statusMessage || "Manage your connected accounts"}

@@ -234,15 +231,13 @@ export const ConnectorAccountsListView: FC = ({ Syncing

- ) : ( -

- {isIndexableConnector(connector.connector_type) - ? connector.last_indexed_at - ? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}` - : "Never indexed" - : "Active"} + ) : !isLiveConnector(connector.connector_type) ? ( +

+ {connector.last_indexed_at + ? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}` + : "Never indexed"}

- )} + ) : null} {isAuthExpired ? (