Merge pull request #1294 from CREDO23/feature/mcp-migration
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions

[FEAT] Live connector tools via MCP OAuth and native APIs
This commit is contained in:
Rohan Verma 2026-04-22 21:00:28 -07:00 committed by GitHub
commit 7245ab4046
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 4175 additions and 1463 deletions

View file

@ -47,7 +47,7 @@ from app.agents.new_chat.system_prompt import (
build_configurable_system_prompt,
build_surfsense_system_prompt,
)
from app.agents.new_chat.tools.registry import build_tools_async
from app.agents.new_chat.tools.registry import build_tools_async, get_connector_gated_tools
from app.db import ChatVisibility
from app.services.connector_service import ConnectorService
from app.utils.perf import get_perf_logger
@ -287,105 +287,10 @@ async def create_surfsense_deep_agent(
"llm": llm,
}
# Disable Notion action tools if no Notion connector is configured
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
has_notion_connector = (
available_connectors is not None and "NOTION_CONNECTOR" in available_connectors
modified_disabled_tools.extend(
get_connector_gated_tools(available_connectors)
)
if not has_notion_connector:
notion_tools = [
"create_notion_page",
"update_notion_page",
"delete_notion_page",
]
modified_disabled_tools.extend(notion_tools)
# Disable Linear action tools if no Linear connector is configured
has_linear_connector = (
available_connectors is not None and "LINEAR_CONNECTOR" in available_connectors
)
if not has_linear_connector:
linear_tools = [
"create_linear_issue",
"update_linear_issue",
"delete_linear_issue",
]
modified_disabled_tools.extend(linear_tools)
# Disable Google Drive action tools if no Google Drive connector is configured
has_google_drive_connector = (
available_connectors is not None and "GOOGLE_DRIVE_FILE" in available_connectors
)
if not has_google_drive_connector:
google_drive_tools = [
"create_google_drive_file",
"delete_google_drive_file",
]
modified_disabled_tools.extend(google_drive_tools)
has_dropbox_connector = (
available_connectors is not None and "DROPBOX_FILE" in available_connectors
)
if not has_dropbox_connector:
modified_disabled_tools.extend(["create_dropbox_file", "delete_dropbox_file"])
has_onedrive_connector = (
available_connectors is not None and "ONEDRIVE_FILE" in available_connectors
)
if not has_onedrive_connector:
modified_disabled_tools.extend(["create_onedrive_file", "delete_onedrive_file"])
# Disable Google Calendar action tools if no Google Calendar connector is configured
has_google_calendar_connector = (
available_connectors is not None
and "GOOGLE_CALENDAR_CONNECTOR" in available_connectors
)
if not has_google_calendar_connector:
calendar_tools = [
"create_calendar_event",
"update_calendar_event",
"delete_calendar_event",
]
modified_disabled_tools.extend(calendar_tools)
# Disable Gmail action tools if no Gmail connector is configured
has_gmail_connector = (
available_connectors is not None
and "GOOGLE_GMAIL_CONNECTOR" in available_connectors
)
if not has_gmail_connector:
gmail_tools = [
"create_gmail_draft",
"update_gmail_draft",
"send_gmail_email",
"trash_gmail_email",
]
modified_disabled_tools.extend(gmail_tools)
# Disable Jira action tools if no Jira connector is configured
has_jira_connector = (
available_connectors is not None and "JIRA_CONNECTOR" in available_connectors
)
if not has_jira_connector:
jira_tools = [
"create_jira_issue",
"update_jira_issue",
"delete_jira_issue",
]
modified_disabled_tools.extend(jira_tools)
# Disable Confluence action tools if no Confluence connector is configured
has_confluence_connector = (
available_connectors is not None
and "CONFLUENCE_CONNECTOR" in available_connectors
)
if not has_confluence_connector:
confluence_tools = [
"create_confluence_page",
"update_confluence_page",
"delete_confluence_page",
]
modified_disabled_tools.extend(confluence_tools)
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
if "search_knowledge_base" not in modified_disabled_tools:

View file

@ -38,8 +38,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
* Formatting, summarization, or analysis of content already present in the conversation
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) see <tool_routing> below
</knowledge_base_only_policy>
<tool_routing>
CRITICAL You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
Their data is NEVER in the knowledge base. You MUST call their tools immediately never
say "I don't see it in the knowledge base" or ask the user if they want you to check.
Ignore any knowledge base results for these services.
When to use which tool:
- Linear (issues) list_issues, get_issue, save_issue (create/update)
- ClickUp (tasks) clickup_search, clickup_get_task
- Jira (issues) getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
- Slack (messages, channels) slack_search_channels, slack_read_channel, slack_read_thread
- Airtable (bases, tables, records) list_bases, list_tables_for_base, list_records_for_table
- Knowledge base content (Notion, GitHub, files, notes) automatically searched
- Real-time public web data call web_search
- Reading a specific webpage call scrape_webpage
</tool_routing>
<parameter_resolution>
Some service tools require identifiers or context you do not have (account IDs,
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
IDs or technical identifiers they cannot memorise them.
Instead, follow this discovery pattern:
1. Call a listing/discovery tool to find available options.
2. ONE result use it silently, no question to the user.
3. MULTIPLE results present the options by their display names and let the
user choose. Never show raw UUIDs always use friendly names.
Discovery tools by level:
- Which account/workspace? get_connected_accounts("<service>")
- Which Jira site (cloudId)? getAccessibleAtlassianResources
- Which Jira project? getVisibleJiraProjects (after resolving cloudId)
- Which Jira issue type? getJiraProjectIssueTypesMetadata (after resolving project)
- Which channel? slack_search_channels
- Which base? list_bases
- Which table? list_tables_for_base (after resolving baseId)
- Which task? clickup_search
- Which issue? list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
chain: getAccessibleAtlassianResources getVisibleJiraProjects createJiraIssue.
If there is only one option at each step, use it silently. If multiple, present
friendly names.
Chain discovery when needed e.g. for Airtable records: list_bases pick
base list_tables_for_base pick table list_records_for_table.
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
the same service, tool names are prefixed to avoid collisions e.g.
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
Each prefixed tool's description starts with [Account: <display_name>] so you
know which account it targets. Use get_connected_accounts("<service>") to see
the full list of accounts with their connector IDs and display names.
When only one account is connected, tools have their normal unprefixed names.
</parameter_resolution>
<memory_protocol>
IMPORTANT After understanding each user message, ALWAYS check: does this message
reveal durable facts about the user (role, interests, preferences, projects,
@ -76,8 +134,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
* Formatting, summarization, or analysis of content already present in the conversation
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) see <tool_routing> below
</knowledge_base_only_policy>
<tool_routing>
CRITICAL You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
Their data is NEVER in the knowledge base. You MUST call their tools immediately never
say "I don't see it in the knowledge base" or ask if they want you to check.
Ignore any knowledge base results for these services.
When to use which tool:
- Linear (issues) list_issues, get_issue, save_issue (create/update)
- ClickUp (tasks) clickup_search, clickup_get_task
- Jira (issues) getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
- Slack (messages, channels) slack_search_channels, slack_read_channel, slack_read_thread
- Airtable (bases, tables, records) list_bases, list_tables_for_base, list_records_for_table
- Knowledge base content (Notion, GitHub, files, notes) automatically searched
- Real-time public web data call web_search
- Reading a specific webpage call scrape_webpage
</tool_routing>
<parameter_resolution>
Some service tools require identifiers or context you do not have (account IDs,
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
IDs or technical identifiers they cannot memorise them.
Instead, follow this discovery pattern:
1. Call a listing/discovery tool to find available options.
2. ONE result use it silently, no question to the user.
3. MULTIPLE results present the options by their display names and let the
user choose. Never show raw UUIDs always use friendly names.
Discovery tools by level:
- Which account/workspace? get_connected_accounts("<service>")
- Which Jira site (cloudId)? getAccessibleAtlassianResources
- Which Jira project? getVisibleJiraProjects (after resolving cloudId)
- Which Jira issue type? getJiraProjectIssueTypesMetadata (after resolving project)
- Which channel? slack_search_channels
- Which base? list_bases
- Which table? list_tables_for_base (after resolving baseId)
- Which task? clickup_search
- Which issue? list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
chain: getAccessibleAtlassianResources getVisibleJiraProjects createJiraIssue.
If there is only one option at each step, use it silently. If multiple, present
friendly names.
Chain discovery when needed e.g. for Airtable records: list_bases pick
base list_tables_for_base pick table list_records_for_table.
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
the same service, tool names are prefixed to avoid collisions e.g.
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
Each prefixed tool's description starts with [Account: <display_name>] so you
know which account it targets. Use get_connected_accounts("<service>") to see
the full list of accounts with their connector IDs and display names.
When only one account is connected, tools have their normal unprefixed names.
</parameter_resolution>
<memory_protocol>
IMPORTANT After understanding each user message, ALWAYS check: does this message
reveal durable facts about the team (decisions, conventions, architecture, processes,

View file

@ -0,0 +1,109 @@
"""Connected-accounts discovery tool.
Lets the LLM discover which accounts are connected for a given service
(e.g. "jira", "linear", "slack") and retrieve the metadata it needs to
call action tools such as Jira's ``cloudId``.
The tool returns **only** non-sensitive fields explicitly listed in the
service's ``account_metadata_keys`` (see ``registry.py``), plus the
always-present ``display_name`` and ``connector_id``.
"""
import logging
from typing import Any
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.services.mcp_oauth.registry import MCP_SERVICES
logger = logging.getLogger(__name__)
_SERVICE_KEY_BY_CONNECTOR_TYPE: dict[str, str] = {
cfg.connector_type: key for key, cfg in MCP_SERVICES.items()
}
class GetConnectedAccountsInput(BaseModel):
service: str = Field(
description=(
"Service key to look up connected accounts for. "
"Valid values: " + ", ".join(sorted(MCP_SERVICES.keys()))
),
)
def _extract_display_name(connector: SearchSourceConnector) -> str:
"""Best-effort human-readable label for a connector."""
cfg = connector.config or {}
if cfg.get("display_name"):
return cfg["display_name"]
if cfg.get("base_url"):
return f"{connector.name} ({cfg['base_url']})"
if cfg.get("organization_name"):
return f"{connector.name} ({cfg['organization_name']})"
return connector.name
def create_get_connected_accounts_tool(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> StructuredTool:
async def _run(service: str) -> list[dict[str, Any]]:
svc_cfg = MCP_SERVICES.get(service)
if not svc_cfg:
return [{"error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}"}]
try:
connector_type = SearchSourceConnectorType(svc_cfg.connector_type)
except ValueError:
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == connector_type,
)
)
connectors = result.scalars().all()
if not connectors:
return [{"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."}]
is_multi = len(connectors) > 1
accounts: list[dict[str, Any]] = []
for conn in connectors:
cfg = conn.config or {}
entry: dict[str, Any] = {
"connector_id": conn.id,
"display_name": _extract_display_name(conn),
"service": service,
}
if is_multi:
entry["tool_prefix"] = f"{service}_{conn.id}"
for key in svc_cfg.account_metadata_keys:
if key in cfg:
entry[key] = cfg[key]
accounts.append(entry)
return accounts
return StructuredTool(
name="get_connected_accounts",
description=(
"Discover which accounts are connected for a service (e.g. jira, linear, slack, clickup, airtable). "
"Returns display names and service-specific metadata the action tools need "
"(e.g. Jira's cloudId). Call this BEFORE using a service's action tools when "
"you need an account identifier or are unsure which account to use."
),
coroutine=_run,
args_schema=GetConnectedAccountsInput,
metadata={"hitl": False},
)

View file

@ -0,0 +1,15 @@
from app.agents.new_chat.tools.discord.list_channels import (
create_list_discord_channels_tool,
)
from app.agents.new_chat.tools.discord.read_messages import (
create_read_discord_messages_tool,
)
from app.agents.new_chat.tools.discord.send_message import (
create_send_discord_message_tool,
)
__all__ = [
"create_list_discord_channels_tool",
"create_read_discord_messages_tool",
"create_send_discord_message_tool",
]

View file

@ -0,0 +1,42 @@
"""Shared auth helper for Discord agent tools (REST API, not gateway bot)."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.utils.oauth_security import TokenEncryption
DISCORD_API = "https://discord.com/api/v10"
async def get_discord_connector(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> SearchSourceConnector | None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR,
)
)
return result.scalars().first()
def get_bot_token(connector: SearchSourceConnector) -> str:
"""Extract and decrypt the bot token from connector config."""
cfg = dict(connector.config)
if cfg.get("_token_encrypted") and config.SECRET_KEY:
enc = TokenEncryption(config.SECRET_KEY)
if cfg.get("bot_token"):
cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"])
token = cfg.get("bot_token")
if not token:
raise ValueError("Discord bot token not found in connector config.")
return token
def get_guild_id(connector: SearchSourceConnector) -> str | None:
return connector.config.get("guild_id")

View file

@ -0,0 +1,67 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
logger = logging.getLogger(__name__)
def create_list_discord_channels_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def list_discord_channels() -> dict[str, Any]:
"""List text channels in the connected Discord server.
Returns:
Dictionary with status and a list of channels (id, name).
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Discord tool not properly configured."}
try:
connector = await get_discord_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
guild_id = get_guild_id(connector)
if not guild_id:
return {"status": "error", "message": "No guild ID in Discord connector config."}
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/guilds/{guild_id}/channels",
headers={"Authorization": f"Bot {token}"},
timeout=15.0,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
if resp.status_code != 200:
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
# Type 0 = text channel
channels = [
{"id": ch["id"], "name": ch["name"]}
for ch in resp.json()
if ch.get("type") == 0
]
return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error listing Discord channels: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to list Discord channels."}
return list_discord_channels

View file

@ -0,0 +1,80 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
logger = logging.getLogger(__name__)
def create_read_discord_messages_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def read_discord_messages(
channel_id: str,
limit: int = 25,
) -> dict[str, Any]:
"""Read recent messages from a Discord text channel.
Args:
channel_id: The Discord channel ID (from list_discord_channels).
limit: Number of messages to fetch (default 25, max 50).
Returns:
Dictionary with status and a list of messages including
id, author, content, timestamp.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Discord tool not properly configured."}
limit = min(limit, 50)
try:
connector = await get_discord_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.get(
f"{DISCORD_API}/channels/{channel_id}/messages",
headers={"Authorization": f"Bot {token}"},
params={"limit": limit},
timeout=15.0,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
if resp.status_code == 403:
return {"status": "error", "message": "Bot lacks permission to read this channel."}
if resp.status_code != 200:
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
messages = [
{
"id": m["id"],
"author": m.get("author", {}).get("username", "Unknown"),
"content": m.get("content", ""),
"timestamp": m.get("timestamp", ""),
}
for m in resp.json()
]
return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Discord messages: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to read Discord messages."}
return read_discord_messages

View file

@ -0,0 +1,96 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
logger = logging.getLogger(__name__)
def create_send_discord_message_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def send_discord_message(
channel_id: str,
content: str,
) -> dict[str, Any]:
"""Send a message to a Discord text channel.
Args:
channel_id: The Discord channel ID (from list_discord_channels).
content: The message text (max 2000 characters).
Returns:
Dictionary with status, message_id on success.
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Discord tool not properly configured."}
if len(content) > 2000:
return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."}
try:
connector = await get_discord_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Discord connector found."}
result = request_approval(
action_type="discord_send_message",
tool_name="send_discord_message",
params={"channel_id": channel_id, "content": content},
context={"connector_id": connector.id},
)
if result.rejected:
return {"status": "rejected", "message": "User declined. Message was not sent."}
final_content = result.params.get("content", content)
final_channel = result.params.get("channel_id", channel_id)
token = get_bot_token(connector)
async with httpx.AsyncClient() as client:
resp = await client.post(
f"{DISCORD_API}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bot {token}",
"Content-Type": "application/json",
},
json={"content": final_content},
timeout=15.0,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
if resp.status_code == 403:
return {"status": "error", "message": "Bot lacks permission to send messages in this channel."}
if resp.status_code not in (200, 201):
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": f"Message sent to channel {final_channel}.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error sending Discord message: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to send Discord message."}
return send_discord_message

View file

@ -1,6 +1,12 @@
from app.agents.new_chat.tools.gmail.create_draft import (
create_create_gmail_draft_tool,
)
from app.agents.new_chat.tools.gmail.read_email import (
create_read_gmail_email_tool,
)
from app.agents.new_chat.tools.gmail.search_emails import (
create_search_gmail_tool,
)
from app.agents.new_chat.tools.gmail.send_email import (
create_send_gmail_email_tool,
)
@ -13,6 +19,8 @@ from app.agents.new_chat.tools.gmail.update_draft import (
__all__ = [
"create_create_gmail_draft_tool",
"create_read_gmail_email_tool",
"create_search_gmail_tool",
"create_send_gmail_email_tool",
"create_trash_gmail_email_tool",
"create_update_gmail_draft_tool",

View file

@ -0,0 +1,87 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
_GMAIL_TYPES = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
def create_read_gmail_email_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def read_gmail_email(message_id: str) -> dict[str, Any]:
"""Read the full content of a specific Gmail email by its message ID.
Use after search_gmail to get the complete body of an email.
Args:
message_id: The Gmail message ID (from search_gmail results).
Returns:
Dictionary with status and the full email content formatted as markdown.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
try:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
detail, error = await gmail.get_message_details(message_id)
if error:
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
return {"status": "error", "message": error}
if not detail:
return {"status": "not_found", "message": f"Email with ID '{message_id}' not found."}
content = gmail.format_message_to_markdown(detail)
return {"status": "success", "message_id": message_id, "content": content}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Gmail email: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to read email. Please try again."}
return read_gmail_email

View file

@ -0,0 +1,165 @@
import logging
from datetime import datetime
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
_GMAIL_TYPES = [
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
]
_token_encryption_cache: object | None = None
def _get_token_encryption():
global _token_encryption_cache
if _token_encryption_cache is None:
from app.config import config
from app.utils.oauth_security import TokenEncryption
if not config.SECRET_KEY:
raise RuntimeError("SECRET_KEY not configured for token decryption.")
_token_encryption_cache = TokenEncryption(config.SECRET_KEY)
return _token_encryption_cache
def _build_credentials(connector: SearchSourceConnector):
"""Build Google OAuth Credentials from a connector's stored config.
Handles both native OAuth connectors (with encrypted tokens) and
Composio-backed connectors. Shared by Gmail and Calendar tools.
"""
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
from app.utils.google_credentials import build_composio_credentials
cca_id = connector.config.get("composio_connected_account_id")
if not cca_id:
raise ValueError("Composio connected account ID not found.")
return build_composio_credentials(cca_id)
from google.oauth2.credentials import Credentials
cfg = dict(connector.config)
if cfg.get("_token_encrypted"):
enc = _get_token_encryption()
for key in ("token", "refresh_token", "client_secret"):
if cfg.get(key):
cfg[key] = enc.decrypt_token(cfg[key])
exp = (cfg.get("expiry") or "").replace("Z", "")
return Credentials(
token=cfg.get("token"),
refresh_token=cfg.get("refresh_token"),
token_uri=cfg.get("token_uri"),
client_id=cfg.get("client_id"),
client_secret=cfg.get("client_secret"),
scopes=cfg.get("scopes", []),
expiry=datetime.fromisoformat(exp) if exp else None,
)
def create_search_gmail_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def search_gmail(
query: str,
max_results: int = 10,
) -> dict[str, Any]:
"""Search emails in the user's Gmail inbox using Gmail search syntax.
Args:
query: Gmail search query, same syntax as the Gmail search bar.
Examples: "from:alice@example.com", "subject:meeting",
"is:unread", "after:2024/01/01 before:2024/02/01",
"has:attachment", "in:sent".
max_results: Number of emails to return (default 10, max 20).
Returns:
Dictionary with status and a list of email summaries including
message_id, subject, from, date, snippet.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Gmail tool not properly configured."}
max_results = min(max_results, 20)
try:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
}
creds = _build_credentials(connector)
from app.connectors.google_gmail_connector import GoogleGmailConnector
gmail = GoogleGmailConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
messages_list, error = await gmail.get_messages_list(
max_results=max_results, query=query
)
if error:
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
return {"status": "error", "message": error}
if not messages_list:
return {"status": "success", "emails": [], "total": 0, "message": "No emails found."}
emails = []
for msg in messages_list:
detail, err = await gmail.get_message_details(msg["id"])
if err:
continue
headers = {
h["name"].lower(): h["value"]
for h in detail.get("payload", {}).get("headers", [])
}
emails.append({
"message_id": detail.get("id"),
"thread_id": detail.get("threadId"),
"subject": headers.get("subject", "No Subject"),
"from": headers.get("from", "Unknown"),
"to": headers.get("to", ""),
"date": headers.get("date", ""),
"snippet": detail.get("snippet", ""),
"labels": detail.get("labelIds", []),
})
return {"status": "success", "emails": emails, "total": len(emails)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error searching Gmail: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to search Gmail. Please try again."}
return search_gmail

View file

@ -4,6 +4,9 @@ from app.agents.new_chat.tools.google_calendar.create_event import (
from app.agents.new_chat.tools.google_calendar.delete_event import (
create_delete_calendar_event_tool,
)
from app.agents.new_chat.tools.google_calendar.search_events import (
create_search_calendar_events_tool,
)
from app.agents.new_chat.tools.google_calendar.update_event import (
create_update_calendar_event_tool,
)
@ -11,5 +14,6 @@ from app.agents.new_chat.tools.google_calendar.update_event import (
__all__ = [
"create_create_calendar_event_tool",
"create_delete_calendar_event_tool",
"create_search_calendar_events_tool",
"create_update_calendar_event_tool",
]

View file

@ -0,0 +1,114 @@
import logging
from typing import Any
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
from app.db import SearchSourceConnector, SearchSourceConnectorType
logger = logging.getLogger(__name__)
_CALENDAR_TYPES = [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
]
def create_search_calendar_events_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def search_calendar_events(
start_date: str,
end_date: str,
max_results: int = 25,
) -> dict[str, Any]:
"""Search Google Calendar events within a date range.
Args:
start_date: Start date in YYYY-MM-DD format (e.g. "2026-04-01").
end_date: End date in YYYY-MM-DD format (e.g. "2026-04-30").
max_results: Maximum number of events to return (default 25, max 50).
Returns:
Dictionary with status and a list of events including
event_id, summary, start, end, location, attendees.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Calendar tool not properly configured."}
max_results = min(max_results, 50)
try:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
)
)
connector = result.scalars().first()
if not connector:
return {
"status": "error",
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
}
creds = _build_credentials(connector)
from app.connectors.google_calendar_connector import GoogleCalendarConnector
cal = GoogleCalendarConnector(
credentials=creds,
session=db_session,
user_id=user_id,
connector_id=connector.id,
)
events_raw, error = await cal.get_all_primary_calendar_events(
start_date=start_date,
end_date=end_date,
max_results=max_results,
)
if error:
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
return {"status": "auth_error", "message": error, "connector_type": "google_calendar"}
if "no events found" in error.lower():
return {"status": "success", "events": [], "total": 0, "message": error}
return {"status": "error", "message": error}
events = []
for ev in events_raw:
start = ev.get("start", {})
end = ev.get("end", {})
attendees_raw = ev.get("attendees", [])
events.append({
"event_id": ev.get("id"),
"summary": ev.get("summary", "No Title"),
"start": start.get("dateTime") or start.get("date", ""),
"end": end.get("dateTime") or end.get("date", ""),
"location": ev.get("location", ""),
"description": ev.get("description", ""),
"html_link": ev.get("htmlLink", ""),
"attendees": [
a.get("email", "") for a in attendees_raw[:10]
],
"status": ev.get("status", ""),
})
return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error searching calendar events: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to search calendar events. Please try again."}
return search_calendar_events

View file

@ -130,8 +130,8 @@ def request_approval(
try:
decision_type, edited_params = _parse_decision(approval)
except ValueError:
logger.warning("No approval decision received for %s", tool_name)
return HITLResult(rejected=False, decision_type="error", params=params)
logger.warning("No approval decision received for %s — rejecting for safety", tool_name)
return HITLResult(rejected=True, decision_type="error", params=params)
logger.info("User decision for %s: %s", tool_name, decision_type)

View file

@ -0,0 +1,15 @@
from app.agents.new_chat.tools.luma.create_event import (
create_create_luma_event_tool,
)
from app.agents.new_chat.tools.luma.list_events import (
create_list_luma_events_tool,
)
from app.agents.new_chat.tools.luma.read_event import (
create_read_luma_event_tool,
)
__all__ = [
"create_create_luma_event_tool",
"create_list_luma_events_tool",
"create_read_luma_event_tool",
]

View file

@ -0,0 +1,38 @@
"""Shared auth helper for Luma agent tools."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
LUMA_API = "https://public-api.luma.com/v1"
async def get_luma_connector(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> SearchSourceConnector | None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR,
)
)
return result.scalars().first()
def get_api_key(connector: SearchSourceConnector) -> str:
"""Extract the API key from connector config (handles both key names)."""
key = connector.config.get("api_key") or connector.config.get("LUMA_API_KEY")
if not key:
raise ValueError("Luma API key not found in connector config.")
return key
def luma_headers(api_key: str) -> dict[str, str]:
return {
"Content-Type": "application/json",
"x-luma-api-key": api_key,
}

View file

@ -0,0 +1,116 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
def create_create_luma_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def create_luma_event(
name: str,
start_at: str,
end_at: str,
description: str | None = None,
timezone: str = "UTC",
) -> dict[str, Any]:
"""Create a new event on Luma.
Args:
name: The event title.
start_at: Start time in ISO 8601 format (e.g. "2026-05-01T18:00:00").
end_at: End time in ISO 8601 format (e.g. "2026-05-01T20:00:00").
description: Optional event description (markdown supported).
timezone: Timezone string (default "UTC", e.g. "America/New_York").
Returns:
Dictionary with status, event_id on success.
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
connector = await get_luma_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
result = request_approval(
action_type="luma_create_event",
tool_name="create_luma_event",
params={
"name": name,
"start_at": start_at,
"end_at": end_at,
"description": description,
"timezone": timezone,
},
context={"connector_id": connector.id},
)
if result.rejected:
return {"status": "rejected", "message": "User declined. Event was not created."}
final_name = result.params.get("name", name)
final_start = result.params.get("start_at", start_at)
final_end = result.params.get("end_at", end_at)
final_desc = result.params.get("description", description)
final_tz = result.params.get("timezone", timezone)
api_key = get_api_key(connector)
headers = luma_headers(api_key)
body: dict[str, Any] = {
"name": final_name,
"start_at": final_start,
"end_at": final_end,
"timezone": final_tz,
}
if final_desc:
body["description_md"] = final_desc
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{LUMA_API}/event/create",
headers=headers,
json=body,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
if resp.status_code == 403:
return {"status": "error", "message": "Luma Plus subscription required to create events via API."}
if resp.status_code not in (200, 201):
return {"status": "error", "message": f"Luma API error: {resp.status_code}{resp.text[:200]}"}
data = resp.json()
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
return {
"status": "success",
"event_id": event_id,
"message": f"Event '{final_name}' created on Luma.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error creating Luma event: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to create Luma event."}
return create_luma_event

View file

@ -0,0 +1,100 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
def create_list_luma_events_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def list_luma_events(
max_results: int = 25,
) -> dict[str, Any]:
"""List upcoming and recent Luma events.
Args:
max_results: Maximum events to return (default 25, max 50).
Returns:
Dictionary with status and a list of events including
event_id, name, start_at, end_at, location, url.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
max_results = min(max_results, 50)
try:
connector = await get_luma_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
api_key = get_api_key(connector)
headers = luma_headers(api_key)
all_entries: list[dict] = []
cursor = None
async with httpx.AsyncClient(timeout=20.0) as client:
while len(all_entries) < max_results:
params: dict[str, Any] = {"limit": min(100, max_results - len(all_entries))}
if cursor:
params["cursor"] = cursor
resp = await client.get(
f"{LUMA_API}/calendar/list-events",
headers=headers,
params=params,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
if resp.status_code != 200:
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
data = resp.json()
entries = data.get("entries", [])
if not entries:
break
all_entries.extend(entries)
next_cursor = data.get("next_cursor")
if not next_cursor:
break
cursor = next_cursor
events = []
for entry in all_entries[:max_results]:
ev = entry.get("event", {})
geo = ev.get("geo_info", {})
events.append({
"event_id": entry.get("api_id"),
"name": ev.get("name", "Untitled"),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location": geo.get("name", ""),
"url": ev.get("url", ""),
"visibility": ev.get("visibility", ""),
})
return {"status": "success", "events": events, "total": len(events)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error listing Luma events: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to list Luma events."}
return list_luma_events

View file

@ -0,0 +1,82 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
logger = logging.getLogger(__name__)
def create_read_luma_event_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def read_luma_event(event_id: str) -> dict[str, Any]:
"""Read detailed information about a specific Luma event.
Args:
event_id: The Luma event API ID (from list_luma_events).
Returns:
Dictionary with status and full event details including
description, attendees count, meeting URL.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Luma tool not properly configured."}
try:
connector = await get_luma_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Luma connector found."}
api_key = get_api_key(connector)
headers = luma_headers(api_key)
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
f"{LUMA_API}/events/{event_id}",
headers=headers,
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
if resp.status_code == 404:
return {"status": "not_found", "message": f"Event '{event_id}' not found."}
if resp.status_code != 200:
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
data = resp.json()
ev = data.get("event", data)
geo = ev.get("geo_info", {})
event_detail = {
"event_id": event_id,
"name": ev.get("name", ""),
"description": ev.get("description", ""),
"start_at": ev.get("start_at", ""),
"end_at": ev.get("end_at", ""),
"timezone": ev.get("timezone", ""),
"location_name": geo.get("name", ""),
"address": geo.get("address", ""),
"url": ev.get("url", ""),
"meeting_url": ev.get("meeting_url", ""),
"visibility": ev.get("visibility", ""),
"cover_url": ev.get("cover_url", ""),
}
return {"status": "success", "event": event_detail}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Luma event: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to read Luma event."}
return read_luma_event

View file

@ -14,20 +14,28 @@ clicking "Always Allow", which adds the tool name to the connector's
``config.trusted_tools`` allow-list.
"""
from __future__ import annotations
import logging
import time
from typing import Any
from collections import defaultdict
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from app.utils.oauth_security import TokenEncryption
from langchain_core.tools import StructuredTool
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel, create_model
from sqlalchemy import select
from pydantic import BaseModel, Field, create_model
from sqlalchemy import cast, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from app.agents.new_chat.tools.mcp_client import MCPClient
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
logger = logging.getLogger(__name__)
@ -63,18 +71,14 @@ def _create_dynamic_input_model_from_schema(
param_description = param_schema.get("description", "")
is_required = param_name in required_fields
from typing import Any as AnyType
from pydantic import Field
if is_required:
field_definitions[param_name] = (
AnyType,
Any,
Field(..., description=param_description),
)
else:
field_definitions[param_name] = (
AnyType | None,
Any | None,
Field(None, description=param_description),
)
@ -100,13 +104,13 @@ async def _create_mcp_tool_from_definition_stdio(
tool_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
async def mcp_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via the client with retry support."""
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
logger.debug("MCP tool '%s' called", tool_name)
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
hitl_result = request_approval(
@ -123,20 +127,18 @@ async def _create_mcp_tool_from_definition_stdio(
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = hitl_result.params
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
try:
async with mcp_client.connect():
result = await mcp_client.call_tool(tool_name, call_kwargs)
return str(result)
except RuntimeError as e:
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
logger.error(error_msg)
return f"Error: {error_msg}"
logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e)
return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}"
except Exception as e:
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
logger.exception(error_msg)
return f"Error: {error_msg}"
logger.exception("MCP tool '%s' execution failed: %s", tool_name, e)
return f"Error: MCP tool '{tool_name}' execution failed: {e!s}"
tool = StructuredTool(
name=tool_name,
@ -151,7 +153,7 @@ async def _create_mcp_tool_from_definition_stdio(
},
)
logger.info(f"Created MCP tool (stdio): '{tool_name}'")
logger.debug("Created MCP tool (stdio): '%s'", tool_name)
return tool
@ -163,41 +165,57 @@ async def _create_mcp_tool_from_definition_http(
connector_name: str = "",
connector_id: int | None = None,
trusted_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
) -> StructuredTool:
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
All MCP tools are unconditionally wrapped with HITL approval.
``request_approval()`` is called OUTSIDE the try/except so that
``GraphInterrupt`` propagates cleanly to LangGraph.
Write tools are wrapped with HITL approval; read-only tools (listed in
``readonly_tools``) execute immediately without user confirmation.
When ``tool_name_prefix`` is set (multi-account disambiguation), the
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
but the actual MCP ``call_tool`` still uses the original name.
"""
tool_name = tool_def.get("name", "unnamed_tool")
original_tool_name = tool_def.get("name", "unnamed_tool")
tool_description = tool_def.get("description", "No description provided")
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}")
exposed_name = (
f"{tool_name_prefix}_{original_tool_name}"
if tool_name_prefix
else original_tool_name
)
if tool_name_prefix:
tool_description = f"[Account: {connector_name}] {tool_description}"
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
async def mcp_http_tool_call(**kwargs) -> str:
"""Execute the MCP tool call via HTTP transport."""
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}")
logger.debug("MCP HTTP tool '%s' called", exposed_name)
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
hitl_result = request_approval(
action_type="mcp_tool_call",
tool_name=tool_name,
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"mcp_transport": "http",
"mcp_connector_id": connector_id,
},
trusted_tools=trusted_tools,
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = hitl_result.params
if is_readonly:
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
else:
hitl_result = request_approval(
action_type="mcp_tool_call",
tool_name=exposed_name,
params=kwargs,
context={
"mcp_server": connector_name,
"tool_description": tool_description,
"mcp_transport": "http",
"mcp_connector_id": connector_id,
},
trusted_tools=trusted_tools,
)
if hitl_result.rejected:
return "Tool call rejected by user."
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
try:
async with (
@ -205,7 +223,9 @@ async def _create_mcp_tool_from_definition_http(
ClientSession(read, write) as session,
):
await session.initialize()
response = await session.call_tool(tool_name, arguments=call_kwargs)
response = await session.call_tool(
original_tool_name, arguments=call_kwargs,
)
result = []
for content in response.content:
@ -217,18 +237,15 @@ async def _create_mcp_tool_from_definition_http(
result.append(str(content))
result_str = "\n".join(result) if result else ""
logger.info(
f"MCP HTTP tool '{tool_name}' succeeded: {result_str[:200]}"
)
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
return result_str
except Exception as e:
error_msg = f"MCP HTTP tool '{tool_name}' execution failed: {e!s}"
logger.exception(error_msg)
return f"Error: {error_msg}"
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e)
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
tool = StructuredTool(
name=tool_name,
name=exposed_name,
description=tool_description,
coroutine=mcp_http_tool_call,
args_schema=input_model,
@ -236,12 +253,14 @@ async def _create_mcp_tool_from_definition_http(
"mcp_input_schema": input_schema,
"mcp_transport": "http",
"mcp_url": url,
"hitl": True,
"hitl": not is_readonly,
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
"mcp_original_tool_name": original_tool_name,
"mcp_connector_id": connector_id,
},
)
logger.info(f"Created MCP tool (HTTP): '{tool_name}'")
logger.debug("Created MCP tool (HTTP): '%s'", exposed_name)
return tool
@ -257,21 +276,24 @@ async def _load_stdio_mcp_tools(
command = server_config.get("command")
if not command or not isinstance(command, str):
logger.warning(
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid command field, skipping"
"MCP connector %d (name: '%s') missing or invalid command field, skipping",
connector_id, connector_name,
)
return tools
args = server_config.get("args", [])
if not isinstance(args, list):
logger.warning(
f"MCP connector {connector_id} (name: '{connector_name}') has invalid args field (must be list), skipping"
"MCP connector %d (name: '%s') has invalid args field (must be list), skipping",
connector_id, connector_name,
)
return tools
env = server_config.get("env", {})
if not isinstance(env, dict):
logger.warning(
f"MCP connector {connector_id} (name: '{connector_name}') has invalid env field (must be dict), skipping"
"MCP connector %d (name: '%s') has invalid env field (must be dict), skipping",
connector_id, connector_name,
)
return tools
@ -281,8 +303,8 @@ async def _load_stdio_mcp_tools(
tool_definitions = await mcp_client.list_tools()
logger.info(
f"Discovered {len(tool_definitions)} tools from stdio MCP server "
f"'{command}' (connector {connector_id})"
"Discovered %d tools from stdio MCP server '%s' (connector %d)",
len(tool_definitions), command, connector_id,
)
for tool_def in tool_definitions:
@ -297,8 +319,8 @@ async def _load_stdio_mcp_tools(
tools.append(tool)
except Exception as e:
logger.exception(
f"Failed to create tool '{tool_def.get('name')}' "
f"from connector {connector_id}: {e!s}"
"Failed to create tool '%s' from connector %d: %s",
tool_def.get("name"), connector_id, e,
)
return tools
@ -309,24 +331,40 @@ async def _load_http_mcp_tools(
connector_name: str,
server_config: dict[str, Any],
trusted_tools: list[str] | None = None,
allowed_tools: list[str] | None = None,
readonly_tools: frozenset[str] | None = None,
tool_name_prefix: str | None = None,
) -> list[StructuredTool]:
"""Load tools from an HTTP-based MCP server."""
"""Load tools from an HTTP-based MCP server.
Args:
allowed_tools: If non-empty, only tools whose names appear in this
list are loaded. Empty/None means load everything (used for
user-managed generic MCP servers).
readonly_tools: Tool names that skip HITL approval (read-only operations).
tool_name_prefix: If set, each tool name is prefixed for multi-account
disambiguation (e.g. ``linear_25``).
"""
tools: list[StructuredTool] = []
url = server_config.get("url")
if not url or not isinstance(url, str):
logger.warning(
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid url field, skipping"
"MCP connector %d (name: '%s') missing or invalid url field, skipping",
connector_id, connector_name,
)
return tools
headers = server_config.get("headers", {})
if not isinstance(headers, dict):
logger.warning(
f"MCP connector {connector_id} (name: '{connector_name}') has invalid headers field (must be dict), skipping"
"MCP connector %d (name: '%s') has invalid headers field (must be dict), skipping",
connector_id, connector_name,
)
return tools
allowed_set = set(allowed_tools) if allowed_tools else None
try:
async with (
streamablehttp_client(url, headers=headers) as (read, write, _),
@ -347,10 +385,21 @@ async def _load_http_mcp_tools(
}
)
logger.info(
f"Discovered {len(tool_definitions)} tools from HTTP MCP server "
f"'{url}' (connector {connector_id})"
)
total_discovered = len(tool_definitions)
if allowed_set:
tool_definitions = [
td for td in tool_definitions if td["name"] in allowed_set
]
logger.info(
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
url, connector_id, len(tool_definitions), total_discovered,
)
else:
logger.info(
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
total_discovered, url, connector_id,
)
for tool_def in tool_definitions:
try:
@ -361,22 +410,183 @@ async def _load_http_mcp_tools(
connector_name=connector_name,
connector_id=connector_id,
trusted_tools=trusted_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
)
tools.append(tool)
except Exception as e:
logger.exception(
f"Failed to create HTTP tool '{tool_def.get('name')}' "
f"from connector {connector_id}: {e!s}"
"Failed to create HTTP tool '%s' from connector %d: %s",
tool_def.get("name"), connector_id, e,
)
except Exception as e:
logger.exception(
f"Failed to connect to HTTP MCP server at '{url}' (connector {connector_id}): {e!s}"
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
url, connector_id, e,
)
return tools
_TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry
_token_enc: TokenEncryption | None = None
def _get_token_enc() -> TokenEncryption:
global _token_enc
if _token_enc is None:
from app.config import config as app_config
from app.utils.oauth_security import TokenEncryption
_token_enc = TokenEncryption(app_config.SECRET_KEY)
return _token_enc
def _inject_oauth_headers(
cfg: dict[str, Any],
server_config: dict[str, Any],
) -> dict[str, Any] | None:
"""Decrypt the MCP OAuth access token and inject it into server_config headers.
The DB never stores plaintext tokens in ``server_config.headers``. This
function decrypts ``mcp_oauth.access_token`` at runtime and returns a
*copy* of ``server_config`` with the Authorization header set.
"""
mcp_oauth = cfg.get("mcp_oauth", {})
encrypted_token = mcp_oauth.get("access_token")
if not encrypted_token:
return server_config
try:
access_token = _get_token_enc().decrypt_token(encrypted_token)
result = dict(server_config)
result["headers"] = {
**server_config.get("headers", {}),
"Authorization": f"Bearer {access_token}",
}
return result
except Exception:
logger.error(
"Failed to decrypt MCP OAuth token — connector will be skipped",
exc_info=True,
)
return None
async def _maybe_refresh_mcp_oauth_token(
session: AsyncSession,
connector: "SearchSourceConnector",
cfg: dict[str, Any],
server_config: dict[str, Any],
) -> dict[str, Any]:
"""Refresh the access token for an MCP OAuth connector if it is about to expire.
Returns the (possibly updated) ``server_config``.
"""
from datetime import UTC, datetime, timedelta
mcp_oauth = cfg.get("mcp_oauth", {})
expires_at_str = mcp_oauth.get("expires_at")
if not expires_at_str:
return server_config
try:
expires_at = datetime.fromisoformat(expires_at_str)
if expires_at.tzinfo is None:
from datetime import timezone
expires_at = expires_at.replace(tzinfo=timezone.utc)
if datetime.now(UTC) < expires_at - timedelta(seconds=_TOKEN_REFRESH_BUFFER_SECONDS):
return server_config
except (ValueError, TypeError):
return server_config
refresh_token = mcp_oauth.get("refresh_token")
if not refresh_token:
logger.warning(
"MCP connector %s token expired but no refresh_token available",
connector.id,
)
return server_config
try:
from app.services.mcp_oauth.discovery import refresh_access_token
enc = _get_token_enc()
decrypted_refresh = enc.decrypt_token(refresh_token)
decrypted_secret = (
enc.decrypt_token(mcp_oauth["client_secret"])
if mcp_oauth.get("client_secret")
else ""
)
token_json = await refresh_access_token(
token_endpoint=mcp_oauth["token_endpoint"],
refresh_token=decrypted_refresh,
client_id=mcp_oauth["client_id"],
client_secret=decrypted_secret,
)
new_access = token_json.get("access_token")
if not new_access:
logger.warning(
"MCP connector %s token refresh returned no access_token",
connector.id,
)
return server_config
new_expires_at = None
if token_json.get("expires_in"):
new_expires_at = datetime.now(UTC) + timedelta(
seconds=int(token_json["expires_in"])
)
updated_oauth = dict(mcp_oauth)
updated_oauth["access_token"] = enc.encrypt_token(new_access)
if token_json.get("refresh_token"):
updated_oauth["refresh_token"] = enc.encrypt_token(
token_json["refresh_token"]
)
updated_oauth["expires_at"] = (
new_expires_at.isoformat() if new_expires_at else None
)
from sqlalchemy.orm.attributes import flag_modified
connector.config = {
**cfg,
"server_config": server_config,
"mcp_oauth": updated_oauth,
}
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
# Invalidate cache so next call picks up the new token.
invalidate_mcp_tools_cache(connector.search_space_id)
# Return server_config with the fresh token injected for immediate use.
refreshed_config = dict(server_config)
refreshed_config["headers"] = {
**server_config.get("headers", {}),
"Authorization": f"Bearer {new_access}",
}
return refreshed_config
except Exception:
logger.warning(
"Failed to refresh MCP OAuth token for connector %s",
connector.id,
exc_info=True,
)
return server_config
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
"""Invalidate cached MCP tools.
@ -418,27 +628,91 @@ async def load_mcp_tools(
return list(cached_tools)
try:
# Find all connectors with MCP server config: generic MCP_CONNECTOR type
# and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth.
# Cast JSON -> JSONB so we can use has_key to filter by the presence of "server_config".
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.connector_type
== SearchSourceConnectorType.MCP_CONNECTOR,
SearchSourceConnector.search_space_id == search_space_id,
cast(SearchSourceConnector.config, JSONB).has_key("server_config"), # noqa: W601
),
)
connectors = list(result.scalars())
# Group connectors by type to detect multi-account scenarios.
# When >1 connector shares the same type, tool names would collide
# so we prefix them with "{service_key}_{connector_id}_".
type_groups: dict[str, list[SearchSourceConnector]] = defaultdict(list)
for connector in connectors:
ct = (
connector.connector_type.value
if hasattr(connector.connector_type, "value")
else str(connector.connector_type)
)
type_groups[ct].append(connector)
multi_account_types: set[str] = {
ct for ct, group in type_groups.items() if len(group) > 1
}
if multi_account_types:
logger.info(
"Multi-account detected for connector types: %s",
multi_account_types,
)
tools: list[StructuredTool] = []
for connector in result.scalars():
for connector in connectors:
try:
config = connector.config or {}
server_config = config.get("server_config", {})
trusted_tools = config.get("trusted_tools", [])
cfg = connector.config or {}
server_config = cfg.get("server_config", {})
if not server_config or not isinstance(server_config, dict):
logger.warning(
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
"MCP connector %d (name: '%s') has invalid or missing server_config, skipping",
connector.id, connector.name,
)
continue
# For MCP OAuth connectors: refresh if needed, then decrypt the
# access token and inject it into headers at runtime. The DB
# intentionally does NOT store plaintext tokens in server_config.
if cfg.get("mcp_oauth"):
server_config = await _maybe_refresh_mcp_oauth_token(
session, connector, cfg, server_config,
)
# Re-read cfg after potential refresh (connector was reloaded from DB).
cfg = connector.config or {}
server_config = _inject_oauth_headers(cfg, server_config)
if server_config is None:
logger.warning(
"Skipping MCP connector %d — OAuth token decryption failed",
connector.id,
)
continue
trusted_tools = cfg.get("trusted_tools", [])
ct = (
connector.connector_type.value
if hasattr(connector.connector_type, "value")
else str(connector.connector_type)
)
svc_cfg = get_service_by_connector_type(ct)
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
# Build a prefix only when multiple accounts share the same type.
tool_name_prefix: str | None = None
if ct in multi_account_types and svc_cfg:
service_key = next(
(k for k, v in MCP_SERVICES.items() if v is svc_cfg),
None,
)
if service_key:
tool_name_prefix = f"{service_key}_{connector.id}"
transport = server_config.get("transport", "stdio")
if transport in ("streamable-http", "http", "sse"):
@ -447,6 +721,9 @@ async def load_mcp_tools(
connector.name,
server_config,
trusted_tools=trusted_tools,
allowed_tools=allowed_tools,
readonly_tools=readonly_tools,
tool_name_prefix=tool_name_prefix,
)
else:
connector_tools = await _load_stdio_mcp_tools(
@ -460,7 +737,8 @@ async def load_mcp_tools(
except Exception as e:
logger.exception(
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
"Failed to load tools from MCP connector %d: %s",
connector.id, e,
)
_mcp_tools_cache[search_space_id] = (now, tools)
@ -469,9 +747,9 @@ async def load_mcp_tools(
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
del _mcp_tools_cache[oldest_key]
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
logger.info("Loaded %d MCP tools for search space %d", len(tools), search_space_id)
return tools
except Exception as e:
logger.exception(f"Failed to load MCP tools: {e!s}")
logger.exception("Failed to load MCP tools: %s", e)
return []

View file

@ -50,6 +50,11 @@ from .confluence import (
create_delete_confluence_page_tool,
create_update_confluence_page_tool,
)
from .discord import (
create_list_discord_channels_tool,
create_read_discord_messages_tool,
create_send_discord_message_tool,
)
from .dropbox import (
create_create_dropbox_file_tool,
create_delete_dropbox_file_tool,
@ -57,6 +62,8 @@ from .dropbox import (
from .generate_image import create_generate_image_tool
from .gmail import (
create_create_gmail_draft_tool,
create_read_gmail_email_tool,
create_search_gmail_tool,
create_send_gmail_email_tool,
create_trash_gmail_email_tool,
create_update_gmail_draft_tool,
@ -64,21 +71,18 @@ from .gmail import (
from .google_calendar import (
create_create_calendar_event_tool,
create_delete_calendar_event_tool,
create_search_calendar_events_tool,
create_update_calendar_event_tool,
)
from .google_drive import (
create_create_google_drive_file_tool,
create_delete_google_drive_file_tool,
)
from .jira import (
create_create_jira_issue_tool,
create_delete_jira_issue_tool,
create_update_jira_issue_tool,
)
from .linear import (
create_create_linear_issue_tool,
create_delete_linear_issue_tool,
create_update_linear_issue_tool,
from .connected_accounts import create_get_connected_accounts_tool
from .luma import (
create_create_luma_event_tool,
create_list_luma_events_tool,
create_read_luma_event_tool,
)
from .mcp_tool import load_mcp_tools
from .notion import (
@ -95,6 +99,11 @@ from .report import create_generate_report_tool
from .resume import create_generate_resume_tool
from .scrape_webpage import create_scrape_webpage_tool
from .search_surfsense_docs import create_search_surfsense_docs_tool
from .teams import (
create_list_teams_channels_tool,
create_read_teams_messages_tool,
create_send_teams_message_tool,
)
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
from .video_presentation import create_generate_video_presentation_tool
from .web_search import create_web_search_tool
@ -114,6 +123,8 @@ class ToolDefinition:
factory: Callable that creates the tool. Receives a dict of dependencies.
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
enabled_by_default: Whether the tool is enabled when no explicit config is provided
required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``)
that must be in ``available_connectors`` for the tool to be enabled.
"""
@ -123,6 +134,7 @@ class ToolDefinition:
requires: list[str] = field(default_factory=list)
enabled_by_default: bool = True
hidden: bool = False
required_connector: str | None = None
# =============================================================================
@ -221,6 +233,21 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
requires=["db_session"],
),
# =========================================================================
# SERVICE ACCOUNT DISCOVERY
# Generic tool for the LLM to discover connected accounts and resolve
# service-specific identifiers (e.g. Jira cloudId, Slack team, etc.)
# =========================================================================
ToolDefinition(
name="get_connected_accounts",
description="Discover connected accounts for a service and their metadata",
factory=lambda deps: create_get_connected_accounts_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
# =========================================================================
# MEMORY TOOL - single update_memory, private or team by thread_visibility
# =========================================================================
ToolDefinition(
@ -248,40 +275,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
],
),
# =========================================================================
# LINEAR TOOLS - create, update, delete issues
# Auto-disabled when no Linear connector is configured (see chat_deepagent.py)
# =========================================================================
ToolDefinition(
name="create_linear_issue",
description="Create a new issue in the user's Linear workspace",
factory=lambda deps: create_create_linear_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="update_linear_issue",
description="Update an existing indexed Linear issue",
factory=lambda deps: create_update_linear_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="delete_linear_issue",
description="Archive (delete) an existing indexed Linear issue",
factory=lambda deps: create_delete_linear_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
# =========================================================================
# NOTION TOOLS - create, update, delete pages
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
# =========================================================================
@ -294,6 +287,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="NOTION_CONNECTOR",
),
ToolDefinition(
name="update_notion_page",
@ -304,6 +298,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="NOTION_CONNECTOR",
),
ToolDefinition(
name="delete_notion_page",
@ -314,6 +309,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="NOTION_CONNECTOR",
),
# =========================================================================
# GOOGLE DRIVE TOOLS - create files, delete files
@ -328,6 +324,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_DRIVE_FILE",
),
ToolDefinition(
name="delete_google_drive_file",
@ -338,6 +335,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_DRIVE_FILE",
),
# =========================================================================
# DROPBOX TOOLS - create and trash files
@ -352,6 +350,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DROPBOX_FILE",
),
ToolDefinition(
name="delete_dropbox_file",
@ -362,6 +361,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DROPBOX_FILE",
),
# =========================================================================
# ONEDRIVE TOOLS - create and trash files
@ -376,6 +376,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="ONEDRIVE_FILE",
),
ToolDefinition(
name="delete_onedrive_file",
@ -386,11 +387,23 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="ONEDRIVE_FILE",
),
# =========================================================================
# GOOGLE CALENDAR TOOLS - create, update, delete events
# GOOGLE CALENDAR TOOLS - search, create, update, delete events
# Auto-disabled when no Google Calendar connector is configured
# =========================================================================
ToolDefinition(
name="search_calendar_events",
description="Search Google Calendar events within a date range",
factory=lambda deps: create_search_calendar_events_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
),
ToolDefinition(
name="create_calendar_event",
description="Create a new event on Google Calendar",
@ -400,6 +413,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
),
ToolDefinition(
name="update_calendar_event",
@ -410,6 +424,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
),
ToolDefinition(
name="delete_calendar_event",
@ -420,11 +435,34 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_CALENDAR_CONNECTOR",
),
# =========================================================================
# GMAIL TOOLS - create drafts, update drafts, send emails, trash emails
# GMAIL TOOLS - search, read, create drafts, update drafts, send, trash
# Auto-disabled when no Gmail connector is configured
# =========================================================================
ToolDefinition(
name="search_gmail",
description="Search emails in Gmail using Gmail search syntax",
factory=lambda deps: create_search_gmail_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="read_gmail_email",
description="Read the full content of a specific Gmail email",
factory=lambda deps: create_read_gmail_email_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="create_gmail_draft",
description="Create a draft email in Gmail",
@ -434,6 +472,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="send_gmail_email",
@ -444,6 +483,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="trash_gmail_email",
@ -454,6 +494,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
ToolDefinition(
name="update_gmail_draft",
@ -464,40 +505,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
# =========================================================================
# JIRA TOOLS - create, update, delete issues
# Auto-disabled when no Jira connector is configured (see chat_deepagent.py)
# =========================================================================
ToolDefinition(
name="create_jira_issue",
description="Create a new issue in the user's Jira project",
factory=lambda deps: create_create_jira_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="update_jira_issue",
description="Update an existing indexed Jira issue",
factory=lambda deps: create_update_jira_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
),
ToolDefinition(
name="delete_jira_issue",
description="Delete an existing indexed Jira issue",
factory=lambda deps: create_delete_jira_issue_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="GOOGLE_GMAIL_CONNECTOR",
),
# =========================================================================
# CONFLUENCE TOOLS - create, update, delete pages
@ -512,6 +520,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="CONFLUENCE_CONNECTOR",
),
ToolDefinition(
name="update_confluence_page",
@ -522,6 +531,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="CONFLUENCE_CONNECTOR",
),
ToolDefinition(
name="delete_confluence_page",
@ -532,6 +542,118 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="CONFLUENCE_CONNECTOR",
),
# =========================================================================
# DISCORD TOOLS - list channels, read messages, send messages
# Auto-disabled when no Discord connector is configured
# =========================================================================
ToolDefinition(
name="list_discord_channels",
description="List text channels in the connected Discord server",
factory=lambda deps: create_list_discord_channels_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DISCORD_CONNECTOR",
),
ToolDefinition(
name="read_discord_messages",
description="Read recent messages from a Discord text channel",
factory=lambda deps: create_read_discord_messages_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DISCORD_CONNECTOR",
),
ToolDefinition(
name="send_discord_message",
description="Send a message to a Discord text channel",
factory=lambda deps: create_send_discord_message_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="DISCORD_CONNECTOR",
),
# =========================================================================
# TEAMS TOOLS - list channels, read messages, send messages
# Auto-disabled when no Teams connector is configured
# =========================================================================
ToolDefinition(
name="list_teams_channels",
description="List Microsoft Teams and their channels",
factory=lambda deps: create_list_teams_channels_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="TEAMS_CONNECTOR",
),
ToolDefinition(
name="read_teams_messages",
description="Read recent messages from a Microsoft Teams channel",
factory=lambda deps: create_read_teams_messages_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="TEAMS_CONNECTOR",
),
ToolDefinition(
name="send_teams_message",
description="Send a message to a Microsoft Teams channel",
factory=lambda deps: create_send_teams_message_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="TEAMS_CONNECTOR",
),
# =========================================================================
# LUMA TOOLS - list events, read event details, create events
# Auto-disabled when no Luma connector is configured
# =========================================================================
ToolDefinition(
name="list_luma_events",
description="List upcoming and recent Luma events",
factory=lambda deps: create_list_luma_events_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="LUMA_CONNECTOR",
),
ToolDefinition(
name="read_luma_event",
description="Read detailed information about a specific Luma event",
factory=lambda deps: create_read_luma_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="LUMA_CONNECTOR",
),
ToolDefinition(
name="create_luma_event",
description="Create a new event on Luma",
factory=lambda deps: create_create_luma_event_tool(
db_session=deps["db_session"],
search_space_id=deps["search_space_id"],
user_id=deps["user_id"],
),
requires=["db_session", "search_space_id", "user_id"],
required_connector="LUMA_CONNECTOR",
),
]
@ -549,6 +671,22 @@ def get_tool_by_name(name: str) -> ToolDefinition | None:
return None
def get_connector_gated_tools(
available_connectors: list[str] | None,
) -> list[str]:
"""Return tool names to disable"""
if available_connectors is None:
available = set()
else:
available = set(available_connectors)
disabled: list[str] = []
for tool_def in BUILTIN_TOOLS:
if tool_def.required_connector and tool_def.required_connector not in available:
disabled.append(tool_def.name)
return disabled
def get_all_tool_names() -> list[str]:
"""Get names of all registered tools."""
return [tool_def.name for tool_def in BUILTIN_TOOLS]
@ -690,15 +828,15 @@ async def build_tools_async(
)
tools.extend(mcp_tools)
logging.info(
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
"Registered %d MCP tools: %s",
len(mcp_tools), [t.name for t in mcp_tools],
)
except Exception as e:
# Log error but don't fail - just continue without MCP tools
logging.exception(f"Failed to load MCP tools: {e!s}")
logging.exception("Failed to load MCP tools: %s", e)
# Log all tools being returned to agent
logging.info(
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
"Total tools for agent: %d%s",
len(tools), [t.name for t in tools],
)
return tools

View file

@ -0,0 +1,15 @@
from app.agents.new_chat.tools.teams.list_channels import (
create_list_teams_channels_tool,
)
from app.agents.new_chat.tools.teams.read_messages import (
create_read_teams_messages_tool,
)
from app.agents.new_chat.tools.teams.send_message import (
create_send_teams_message_tool,
)
__all__ = [
"create_list_teams_channels_tool",
"create_read_teams_messages_tool",
"create_send_teams_message_tool",
]

View file

@ -0,0 +1,37 @@
"""Shared auth helper for Teams agent tools (Microsoft Graph REST API)."""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
GRAPH_API = "https://graph.microsoft.com/v1.0"
async def get_teams_connector(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> SearchSourceConnector | None:
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR,
)
)
return result.scalars().first()
async def get_access_token(
db_session: AsyncSession,
connector: SearchSourceConnector,
) -> str:
"""Get a valid Microsoft Graph access token, refreshing if expired."""
from app.connectors.teams_connector import TeamsConnector
tc = TeamsConnector(
session=db_session,
connector_id=connector.id,
)
return await tc._get_valid_token()

View file

@ -0,0 +1,77 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
def create_list_teams_channels_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def list_teams_channels() -> dict[str, Any]:
"""List all Microsoft Teams and their channels the user has access to.
Returns:
Dictionary with status and a list of teams, each containing
team_id, team_name, and a list of channels (id, name).
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
connector = await get_teams_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
token = await get_access_token(db_session, connector)
headers = {"Authorization": f"Bearer {token}"}
async with httpx.AsyncClient(timeout=20.0) as client:
teams_resp = await client.get(f"{GRAPH_API}/me/joinedTeams", headers=headers)
if teams_resp.status_code == 401:
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
if teams_resp.status_code != 200:
return {"status": "error", "message": f"Graph API error: {teams_resp.status_code}"}
teams_data = teams_resp.json().get("value", [])
result_teams = []
async with httpx.AsyncClient(timeout=20.0) as client:
for team in teams_data:
team_id = team["id"]
ch_resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels",
headers=headers,
)
channels = []
if ch_resp.status_code == 200:
channels = [
{"id": ch["id"], "name": ch.get("displayName", "")}
for ch in ch_resp.json().get("value", [])
]
result_teams.append({
"team_id": team_id,
"team_name": team.get("displayName", ""),
"channels": channels,
})
return {"status": "success", "teams": result_teams, "total_teams": len(result_teams)}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error listing Teams channels: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to list Teams channels."}
return list_teams_channels

View file

@ -0,0 +1,91 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
def create_read_teams_messages_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def read_teams_messages(
team_id: str,
channel_id: str,
limit: int = 25,
) -> dict[str, Any]:
"""Read recent messages from a Microsoft Teams channel.
Args:
team_id: The team ID (from list_teams_channels).
channel_id: The channel ID (from list_teams_channels).
limit: Number of messages to fetch (default 25, max 50).
Returns:
Dictionary with status and a list of messages including
id, sender, content, timestamp.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
limit = min(limit, 50)
try:
connector = await get_teams_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
token = await get_access_token(db_session, connector)
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.get(
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
headers={"Authorization": f"Bearer {token}"},
params={"$top": limit},
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
if resp.status_code == 403:
return {"status": "error", "message": "Insufficient permissions to read this channel."}
if resp.status_code != 200:
return {"status": "error", "message": f"Graph API error: {resp.status_code}"}
raw_msgs = resp.json().get("value", [])
messages = []
for m in raw_msgs:
sender = m.get("from", {})
user_info = sender.get("user", {}) if sender else {}
body = m.get("body", {})
messages.append({
"id": m.get("id"),
"sender": user_info.get("displayName", "Unknown"),
"content": body.get("content", ""),
"content_type": body.get("contentType", "text"),
"timestamp": m.get("createdDateTime", ""),
})
return {
"status": "success",
"team_id": team_id,
"channel_id": channel_id,
"messages": messages,
"total": len(messages),
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error reading Teams messages: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to read Teams messages."}
return read_teams_messages

View file

@ -0,0 +1,101 @@
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.new_chat.tools.hitl import request_approval
from ._auth import GRAPH_API, get_access_token, get_teams_connector
logger = logging.getLogger(__name__)
def create_send_teams_message_tool(
db_session: AsyncSession | None = None,
search_space_id: int | None = None,
user_id: str | None = None,
):
@tool
async def send_teams_message(
team_id: str,
channel_id: str,
content: str,
) -> dict[str, Any]:
"""Send a message to a Microsoft Teams channel.
Requires the ChannelMessage.Send OAuth scope. If the user gets a
permission error, they may need to re-authenticate with updated scopes.
Args:
team_id: The team ID (from list_teams_channels).
channel_id: The channel ID (from list_teams_channels).
content: The message text (HTML supported).
Returns:
Dictionary with status, message_id on success.
IMPORTANT:
- If status is "rejected", the user explicitly declined. Do NOT retry.
"""
if db_session is None or search_space_id is None or user_id is None:
return {"status": "error", "message": "Teams tool not properly configured."}
try:
connector = await get_teams_connector(db_session, search_space_id, user_id)
if not connector:
return {"status": "error", "message": "No Teams connector found."}
result = request_approval(
action_type="teams_send_message",
tool_name="send_teams_message",
params={"team_id": team_id, "channel_id": channel_id, "content": content},
context={"connector_id": connector.id},
)
if result.rejected:
return {"status": "rejected", "message": "User declined. Message was not sent."}
final_content = result.params.get("content", content)
final_team = result.params.get("team_id", team_id)
final_channel = result.params.get("channel_id", channel_id)
token = await get_access_token(db_session, connector)
async with httpx.AsyncClient(timeout=20.0) as client:
resp = await client.post(
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
headers={
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
},
json={"body": {"content": final_content}},
)
if resp.status_code == 401:
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
if resp.status_code == 403:
return {
"status": "insufficient_permissions",
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
}
if resp.status_code not in (200, 201):
return {"status": "error", "message": f"Graph API error: {resp.status_code}{resp.text[:200]}"}
msg_data = resp.json()
return {
"status": "success",
"message_id": msg_data.get("id"),
"message": f"Message sent to Teams channel.",
}
except Exception as e:
from langgraph.errors import GraphInterrupt
if isinstance(e, GraphInterrupt):
raise
logger.error("Error sending Teams message: %s", e, exc_info=True)
return {"status": "error", "message": "Failed to send Teams message."}
return send_teams_message

View file

@ -0,0 +1,41 @@
"""Standardised response dict factories for LangChain agent tools."""
from __future__ import annotations
from typing import Any
class ToolResponse:
@staticmethod
def success(message: str, **data: Any) -> dict[str, Any]:
return {"status": "success", "message": message, **data}
@staticmethod
def error(error: str, **data: Any) -> dict[str, Any]:
return {"status": "error", "error": error, **data}
@staticmethod
def auth_error(service: str, **data: Any) -> dict[str, Any]:
return {
"status": "auth_error",
"error": (
f"{service} authentication has expired or been revoked. "
"Please re-connect the integration in Settings → Connectors."
),
**data,
}
@staticmethod
def rejected(message: str = "Action was declined by the user.") -> dict[str, Any]:
return {"status": "rejected", "message": message}
@staticmethod
def not_found(
resource: str, identifier: str, **data: Any
) -> dict[str, Any]:
return {
"status": "not_found",
"error": f"{resource} '{identifier}' was not found.",
**data,
}

View file

@ -135,20 +135,12 @@ celery_app.conf.update(
# never block fast user-facing tasks (file uploads, podcasts, etc.)
task_routes={
# Connector indexing tasks → connectors queue
"index_slack_messages": {"queue": CONNECTORS_QUEUE},
"index_notion_pages": {"queue": CONNECTORS_QUEUE},
"index_github_repos": {"queue": CONNECTORS_QUEUE},
"index_linear_issues": {"queue": CONNECTORS_QUEUE},
"index_jira_issues": {"queue": CONNECTORS_QUEUE},
"index_confluence_pages": {"queue": CONNECTORS_QUEUE},
"index_clickup_tasks": {"queue": CONNECTORS_QUEUE},
"index_google_calendar_events": {"queue": CONNECTORS_QUEUE},
"index_airtable_records": {"queue": CONNECTORS_QUEUE},
"index_google_gmail_messages": {"queue": CONNECTORS_QUEUE},
"index_google_drive_files": {"queue": CONNECTORS_QUEUE},
"index_discord_messages": {"queue": CONNECTORS_QUEUE},
"index_teams_messages": {"queue": CONNECTORS_QUEUE},
"index_luma_events": {"queue": CONNECTORS_QUEUE},
"index_elasticsearch_documents": {"queue": CONNECTORS_QUEUE},
"index_crawled_urls": {"queue": CONNECTORS_QUEUE},
"index_bookstack_pages": {"queue": CONNECTORS_QUEUE},

View file

@ -0,0 +1,98 @@
"""Standard exception hierarchy for all connectors.
ConnectorError
ConnectorAuthError (401/403 non-retryable)
ConnectorRateLimitError (429 retryable, carries ``retry_after``)
ConnectorTimeoutError (timeout/504 retryable)
ConnectorAPIError (5xx or unexpected retryable when >= 500)
"""
from __future__ import annotations
from typing import Any
class ConnectorError(Exception):
def __init__(
self,
message: str,
*,
service: str = "",
status_code: int | None = None,
response_body: Any = None,
) -> None:
super().__init__(message)
self.service = service
self.status_code = status_code
self.response_body = response_body
@property
def retryable(self) -> bool:
return False
class ConnectorAuthError(ConnectorError):
"""Token expired, revoked, insufficient scopes, or needs re-auth (401/403)."""
@property
def retryable(self) -> bool:
return False
class ConnectorRateLimitError(ConnectorError):
"""429 Too Many Requests."""
def __init__(
self,
message: str = "Rate limited",
*,
service: str = "",
retry_after: float | None = None,
status_code: int = 429,
response_body: Any = None,
) -> None:
super().__init__(
message,
service=service,
status_code=status_code,
response_body=response_body,
)
self.retry_after = retry_after
@property
def retryable(self) -> bool:
return True
class ConnectorTimeoutError(ConnectorError):
"""Request timeout or gateway timeout (504)."""
def __init__(
self,
message: str = "Request timed out",
*,
service: str = "",
status_code: int | None = None,
response_body: Any = None,
) -> None:
super().__init__(
message,
service=service,
status_code=status_code,
response_body=response_body,
)
@property
def retryable(self) -> bool:
return True
class ConnectorAPIError(ConnectorError):
"""Generic API error (5xx or unexpected status codes)."""
@property
def retryable(self) -> bool:
if self.status_code is not None:
return self.status_code >= 500
return False

View file

@ -30,6 +30,7 @@ from .jira_add_connector_route import router as jira_add_connector_router
from .linear_add_connector_route import router as linear_add_connector_router
from .logs_routes import router as logs_router
from .luma_add_connector_route import router as luma_add_connector_router
from .mcp_oauth_route import router as mcp_oauth_router
from .memory_routes import router as memory_router
from .model_list_routes import router as model_list_router
from .new_chat_routes import router as new_chat_router
@ -95,6 +96,7 @@ router.include_router(logs_router)
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
router.include_router(notifications_router) # Notifications with Zero sync
router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable
router.include_router(composio_router) # Composio OAuth and toolkit management
router.include_router(public_chat_router) # Public chat sharing and cloning
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages

View file

@ -311,7 +311,7 @@ async def airtable_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=credentials_dict,
search_space_id=space_id,
user_id=user_id,

View file

@ -301,7 +301,7 @@ async def clickup_callback(
# Update existing connector
existing_connector.config = connector_config
existing_connector.name = "ClickUp Connector"
existing_connector.is_indexable = True
existing_connector.is_indexable = False
logger.info(
f"Updated existing ClickUp connector for user {user_id} in space {space_id}"
)
@ -310,7 +310,7 @@ async def clickup_callback(
new_connector = SearchSourceConnector(
name="ClickUp Connector",
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -326,7 +326,7 @@ async def discord_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.DISCORD_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -340,7 +340,7 @@ async def calendar_callback(
config=creds_dict,
search_space_id=space_id,
user_id=user_id,
is_indexable=True,
is_indexable=False,
)
session.add(db_connector)
await session.commit()

View file

@ -371,7 +371,7 @@ async def gmail_callback(
config=creds_dict,
search_space_id=space_id,
user_id=user_id,
is_indexable=True,
is_indexable=False,
)
session.add(db_connector)
await session.commit()

View file

@ -386,7 +386,7 @@ async def jira_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.JIRA_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -399,7 +399,7 @@ async def linear_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -61,7 +61,7 @@ async def add_luma_connector(
if existing_connector:
# Update existing connector with new API key
existing_connector.config = {"api_key": request.api_key}
existing_connector.is_indexable = True
existing_connector.is_indexable = False
await session.commit()
await session.refresh(existing_connector)
@ -82,7 +82,7 @@ async def add_luma_connector(
config={"api_key": request.api_key},
search_space_id=request.space_id,
user_id=user.id,
is_indexable=True,
is_indexable=False,
)
session.add(db_connector)

View file

@ -0,0 +1,601 @@
"""Generic MCP OAuth 2.1 route for services with official MCP servers.
Handles the full flow: discovery DCR PKCE authorization token exchange
MCP_CONNECTOR creation. Currently supports Linear, Jira, ClickUp, Slack,
and Airtable.
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from urllib.parse import urlencode
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.utils.connector_naming import generate_unique_connector_name
from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair
logger = logging.getLogger(__name__)
router = APIRouter()
async def _fetch_account_metadata(
service_key: str, access_token: str, token_json: dict[str, Any],
) -> dict[str, Any]:
"""Fetch display-friendly account metadata after a successful token exchange.
DCR services (Linear, Jira, ClickUp) issue MCP-scoped tokens that cannot
call their standard REST/GraphQL APIs metadata discovery for those
happens at runtime through MCP tools instead.
Pre-configured services (Slack, Airtable) use standard OAuth tokens that
*can* call their APIs, so we extract metadata here.
Failures are logged but never block connector creation.
"""
from app.services.mcp_oauth.registry import MCP_SERVICES
svc = MCP_SERVICES.get(service_key)
if not svc or svc.supports_dcr:
return {}
import httpx
meta: dict[str, Any] = {}
try:
if service_key == "slack":
team_info = token_json.get("team", {})
meta["team_id"] = team_info.get("id", "")
# TODO: oauth.v2.user.access only returns team.id, not
# team.name. To populate team_name, add "team:read" scope
# and call GET /api/team.info here.
meta["team_name"] = team_info.get("name", "")
if meta["team_name"]:
meta["display_name"] = meta["team_name"]
elif meta["team_id"]:
meta["display_name"] = f"Slack ({meta['team_id']})"
elif service_key == "airtable":
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
"https://api.airtable.com/v0/meta/whoami",
headers={"Authorization": f"Bearer {access_token}"},
)
if resp.status_code == 200:
whoami = resp.json()
meta["user_id"] = whoami.get("id", "")
meta["user_email"] = whoami.get("email", "")
meta["display_name"] = whoami.get("email", "Airtable")
else:
logger.warning(
"Airtable whoami API returned %d (non-blocking)", resp.status_code,
)
except Exception:
logger.warning(
"Failed to fetch account metadata for %s (non-blocking)",
service_key,
exc_info=True,
)
return meta
_state_manager: OAuthStateManager | None = None
_token_encryption: TokenEncryption | None = None
def _get_state_manager() -> OAuthStateManager:
global _state_manager
if _state_manager is None:
if not config.SECRET_KEY:
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
_state_manager = OAuthStateManager(config.SECRET_KEY)
return _state_manager
def _get_token_encryption() -> TokenEncryption:
global _token_encryption
if _token_encryption is None:
if not config.SECRET_KEY:
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
_token_encryption = TokenEncryption(config.SECRET_KEY)
return _token_encryption
def _build_redirect_uri(service: str) -> str:
base = config.BACKEND_URL or "http://localhost:8000"
return f"{base.rstrip('/')}/api/v1/auth/mcp/{service}/connector/callback"
def _frontend_redirect(
space_id: int | None,
*,
success: bool = False,
connector_id: int | None = None,
error: str | None = None,
service: str = "mcp",
) -> RedirectResponse:
if success and space_id:
qs = f"success=true&connector={service}-mcp-connector"
if connector_id:
qs += f"&connectorId={connector_id}"
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
)
if error and space_id:
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
)
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
# ---------------------------------------------------------------------------
# /add — start MCP OAuth flow
# ---------------------------------------------------------------------------
@router.get("/auth/mcp/{service}/connector/add")
async def connect_mcp_service(
service: str,
space_id: int,
user: User = Depends(current_active_user),
):
from app.services.mcp_oauth.registry import get_service
svc = get_service(service)
if not svc:
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
try:
from app.services.mcp_oauth.discovery import (
discover_oauth_metadata,
register_client,
)
metadata = await discover_oauth_metadata(
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
)
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
registration_endpoint = metadata.get("registration_endpoint")
if not auth_endpoint or not token_endpoint:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
)
redirect_uri = _build_redirect_uri(service)
if svc.supports_dcr and registration_endpoint:
dcr = await register_client(registration_endpoint, redirect_uri)
client_id = dcr.get("client_id")
client_secret = dcr.get("client_secret", "")
if not client_id:
raise HTTPException(
status_code=502,
detail=f"DCR for {svc.name} did not return a client_id.",
)
elif svc.client_id_env:
client_id = getattr(config, svc.client_id_env, None)
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
if not client_id:
raise HTTPException(
status_code=500,
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
)
else:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
)
verifier, challenge = generate_pkce_pair()
enc = _get_token_encryption()
state = _get_state_manager().generate_secure_state(
space_id,
user.id,
service=service,
code_verifier=verifier,
mcp_client_id=client_id,
mcp_client_secret=enc.encrypt_token(client_secret) if client_secret else "",
mcp_token_endpoint=token_endpoint,
mcp_url=svc.mcp_url,
)
auth_params: dict[str, str] = {
"client_id": client_id,
"response_type": "code",
"redirect_uri": redirect_uri,
"code_challenge": challenge,
"code_challenge_method": "S256",
"state": state,
}
if svc.scopes:
auth_params[svc.scope_param] = " ".join(svc.scopes)
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
logger.info(
"Generated %s MCP OAuth URL for user %s, space %s",
svc.name, user.id, space_id,
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to initiate {service} MCP OAuth.",
) from e
# ---------------------------------------------------------------------------
# /callback — handle OAuth redirect
# ---------------------------------------------------------------------------
@router.get("/auth/mcp/{service}/connector/callback")
async def mcp_oauth_callback(
service: str,
code: str | None = None,
error: str | None = None,
state: str | None = None,
session: AsyncSession = Depends(get_async_session),
):
if error:
logger.warning("%s MCP OAuth error: %s", service, error)
space_id = None
if state:
try:
data = _get_state_manager().validate_state(state)
space_id = data.get("space_id")
except Exception:
pass
return _frontend_redirect(
space_id, error=f"{service}_mcp_oauth_denied", service=service,
)
if not code:
raise HTTPException(status_code=400, detail="Missing authorization code")
if not state:
raise HTTPException(status_code=400, detail="Missing state parameter")
data = _get_state_manager().validate_state(state)
user_id = UUID(data["user_id"])
space_id = data["space_id"]
svc_key = data.get("service", service)
if svc_key != service:
raise HTTPException(status_code=400, detail="State/path service mismatch")
from app.services.mcp_oauth.registry import get_service
svc = get_service(svc_key)
if not svc:
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {svc_key}")
try:
from app.services.mcp_oauth.discovery import exchange_code_for_tokens
enc = _get_token_encryption()
client_id = data["mcp_client_id"]
client_secret = (
enc.decrypt_token(data["mcp_client_secret"])
if data.get("mcp_client_secret")
else ""
)
token_endpoint = data["mcp_token_endpoint"]
code_verifier = data["code_verifier"]
mcp_url = data["mcp_url"]
redirect_uri = _build_redirect_uri(service)
token_json = await exchange_code_for_tokens(
token_endpoint=token_endpoint,
code=code,
redirect_uri=redirect_uri,
client_id=client_id,
client_secret=client_secret,
code_verifier=code_verifier,
)
access_token = token_json.get("access_token")
refresh_token = token_json.get("refresh_token")
expires_in = token_json.get("expires_in")
scope = token_json.get("scope")
if not access_token and "authed_user" in token_json:
authed = token_json["authed_user"]
access_token = authed.get("access_token")
refresh_token = refresh_token or authed.get("refresh_token")
scope = scope or authed.get("scope")
expires_in = expires_in or authed.get("expires_in")
if not access_token:
raise HTTPException(
status_code=400,
detail=f"No access token received from {svc.name}.",
)
expires_at = None
if expires_in:
expires_at = datetime.now(UTC) + timedelta(
seconds=int(expires_in)
)
connector_config = {
"server_config": {
"transport": "streamable-http",
"url": mcp_url,
},
"mcp_service": svc_key,
"mcp_oauth": {
"client_id": client_id,
"client_secret": enc.encrypt_token(client_secret) if client_secret else "",
"token_endpoint": token_endpoint,
"access_token": enc.encrypt_token(access_token),
"refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None,
"expires_at": expires_at.isoformat() if expires_at else None,
"scope": scope,
},
"_token_encrypted": True,
}
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
if account_meta:
_SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email",
"workspace_id", "workspace_name", "organization_name",
"organization_url_key", "cloud_id", "site_name", "base_url"}
for k, v in account_meta.items():
if k in _SAFE_META_KEYS:
connector_config[k] = v
logger.info(
"Stored account metadata for %s: display_name=%s",
svc_key, account_meta.get("display_name", ""),
)
# ---- Re-auth path ----
db_connector_type = SearchSourceConnectorType(svc.connector_type)
reauth_connector_id = data.get("connector_id")
if reauth_connector_id:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type == db_connector_type,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found during re-auth",
)
db_connector.config = connector_config
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
_invalidate_cache(space_id)
logger.info(
"Re-authenticated %s MCP connector %s for user %s",
svc.name, db_connector.id, user_id,
)
reauth_return_url = data.get("return_url")
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
)
return _frontend_redirect(
space_id, success=True, connector_id=db_connector.id, service=service,
)
# ---- New connector path ----
naming_identifier = account_meta.get("display_name")
connector_name = await generate_unique_connector_name(
session,
db_connector_type,
space_id,
user_id,
naming_identifier,
)
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=db_connector_type,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,
)
session.add(new_connector)
try:
await session.commit()
except IntegrityError as e:
await session.rollback()
raise HTTPException(
status_code=409, detail="A connector for this service already exists.",
) from e
_invalidate_cache(space_id)
logger.info(
"Created %s MCP connector %s for user %s in space %s",
svc.name, new_connector.id, user_id, space_id,
)
return _frontend_redirect(
space_id, success=True, connector_id=new_connector.id, service=service,
)
except HTTPException:
raise
except Exception as e:
logger.error(
"Failed to complete %s MCP OAuth: %s", service, e, exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"Failed to complete {service} MCP OAuth.",
) from e
# ---------------------------------------------------------------------------
# /reauth — re-authenticate an existing MCP connector
# ---------------------------------------------------------------------------
@router.get("/auth/mcp/{service}/connector/reauth")
async def reauth_mcp_service(
service: str,
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
from app.services.mcp_oauth.registry import get_service
svc = get_service(service)
if not svc:
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
db_connector_type = SearchSourceConnectorType(svc.connector_type)
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type == db_connector_type,
)
)
if not result.scalars().first():
raise HTTPException(
status_code=404, detail="Connector not found or access denied",
)
try:
from app.services.mcp_oauth.discovery import (
discover_oauth_metadata,
register_client,
)
metadata = await discover_oauth_metadata(
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
)
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
registration_endpoint = metadata.get("registration_endpoint")
if not auth_endpoint or not token_endpoint:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
)
redirect_uri = _build_redirect_uri(service)
if svc.supports_dcr and registration_endpoint:
dcr = await register_client(registration_endpoint, redirect_uri)
client_id = dcr.get("client_id")
client_secret = dcr.get("client_secret", "")
if not client_id:
raise HTTPException(
status_code=502,
detail=f"DCR for {svc.name} did not return a client_id.",
)
elif svc.client_id_env:
client_id = getattr(config, svc.client_id_env, None)
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
if not client_id:
raise HTTPException(
status_code=500,
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
)
else:
raise HTTPException(
status_code=502,
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
)
verifier, challenge = generate_pkce_pair()
enc = _get_token_encryption()
extra: dict = {
"service": service,
"code_verifier": verifier,
"mcp_client_id": client_id,
"mcp_client_secret": enc.encrypt_token(client_secret) if client_secret else "",
"mcp_token_endpoint": token_endpoint,
"mcp_url": svc.mcp_url,
"connector_id": connector_id,
}
if return_url and return_url.startswith("/"):
extra["return_url"] = return_url
state = _get_state_manager().generate_secure_state(
space_id, user.id, **extra,
)
auth_params: dict[str, str] = {
"client_id": client_id,
"response_type": "code",
"redirect_uri": redirect_uri,
"code_challenge": challenge,
"code_challenge_method": "S256",
"state": state,
}
if svc.scopes:
auth_params[svc.scope_param] = " ".join(svc.scopes)
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
logger.info(
"Initiating %s MCP re-auth for user %s, connector %s",
svc.name, user.id, connector_id,
)
return {"auth_url": auth_url}
except HTTPException:
raise
except Exception as e:
logger.error(
"Failed to initiate %s MCP re-auth: %s", service, e, exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"Failed to initiate {service} MCP re-auth.",
) from e
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _invalidate_cache(space_id: int) -> None:
try:
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
invalidate_mcp_tools_cache(space_id)
except Exception:
logger.debug("MCP cache invalidation skipped", exc_info=True)

View file

@ -0,0 +1,620 @@
"""Reusable base for OAuth 2.0 connector routes.
Subclasses override ``fetch_account_info``, ``build_connector_config``,
and ``get_connector_display_name`` to customise provider-specific behaviour.
Call ``build_router()`` to get a FastAPI ``APIRouter`` with ``/connector/add``,
``/connector/callback``, and ``/connector/reauth`` endpoints.
"""
from __future__ import annotations
import base64
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from urllib.parse import urlencode
from uuid import UUID
import httpx
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import RedirectResponse
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.config import config
from app.db import (
SearchSourceConnector,
SearchSourceConnectorType,
User,
get_async_session,
)
from app.users import current_active_user
from app.utils.connector_naming import (
check_duplicate_connector,
generate_unique_connector_name,
)
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
logger = logging.getLogger(__name__)
class OAuthConnectorRoute:
def __init__(
self,
*,
provider_name: str,
connector_type: SearchSourceConnectorType,
authorize_url: str,
token_url: str,
client_id_env: str,
client_secret_env: str,
redirect_uri_env: str,
scopes: list[str],
auth_prefix: str,
use_pkce: bool = False,
token_auth_method: str = "body",
is_indexable: bool = True,
extra_auth_params: dict[str, str] | None = None,
) -> None:
self.provider_name = provider_name
self.connector_type = connector_type
self.authorize_url = authorize_url
self.token_url = token_url
self.client_id_env = client_id_env
self.client_secret_env = client_secret_env
self.redirect_uri_env = redirect_uri_env
self.scopes = scopes
self.auth_prefix = auth_prefix.rstrip("/")
self.use_pkce = use_pkce
self.token_auth_method = token_auth_method
self.is_indexable = is_indexable
self.extra_auth_params = extra_auth_params or {}
self._state_manager: OAuthStateManager | None = None
self._token_encryption: TokenEncryption | None = None
def _get_client_id(self) -> str:
value = getattr(config, self.client_id_env, None)
if not value:
raise HTTPException(
status_code=500,
detail=f"{self.provider_name.title()} OAuth not configured "
f"({self.client_id_env} missing).",
)
return value
def _get_client_secret(self) -> str:
value = getattr(config, self.client_secret_env, None)
if not value:
raise HTTPException(
status_code=500,
detail=f"{self.provider_name.title()} OAuth not configured "
f"({self.client_secret_env} missing).",
)
return value
def _get_redirect_uri(self) -> str:
value = getattr(config, self.redirect_uri_env, None)
if not value:
raise HTTPException(
status_code=500,
detail=f"{self.redirect_uri_env} not configured.",
)
return value
def _get_state_manager(self) -> OAuthStateManager:
if self._state_manager is None:
if not config.SECRET_KEY:
raise HTTPException(
status_code=500,
detail="SECRET_KEY not configured for OAuth security.",
)
self._state_manager = OAuthStateManager(config.SECRET_KEY)
return self._state_manager
def _get_token_encryption(self) -> TokenEncryption:
if self._token_encryption is None:
if not config.SECRET_KEY:
raise HTTPException(
status_code=500,
detail="SECRET_KEY not configured for token encryption.",
)
self._token_encryption = TokenEncryption(config.SECRET_KEY)
return self._token_encryption
def _frontend_redirect(
self,
space_id: int | None,
*,
success: bool = False,
connector_id: int | None = None,
error: str | None = None,
) -> RedirectResponse:
if success and space_id:
connector_slug = f"{self.provider_name}-connector"
qs = f"success=true&connector={connector_slug}"
if connector_id:
qs += f"&connectorId={connector_id}"
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
)
if error and space_id:
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
)
if error:
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error={error}"
)
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
async def fetch_account_info(self, access_token: str) -> dict[str, Any]:
"""Override to fetch account/workspace info after token exchange.
Return dict is merged into connector config; key ``"name"`` is used
for the display name and dedup.
"""
return {}
def build_connector_config(
self,
token_json: dict[str, Any],
account_info: dict[str, Any],
encryption: TokenEncryption,
) -> dict[str, Any]:
"""Override for custom config shapes. Default: standard encrypted OAuth fields."""
access_token = token_json.get("access_token", "")
refresh_token = token_json.get("refresh_token")
expires_at = None
if token_json.get("expires_in"):
expires_at = datetime.now(UTC) + timedelta(
seconds=int(token_json["expires_in"])
)
cfg: dict[str, Any] = {
"access_token": encryption.encrypt_token(access_token),
"refresh_token": (
encryption.encrypt_token(refresh_token) if refresh_token else None
),
"token_type": token_json.get("token_type", "Bearer"),
"expires_in": token_json.get("expires_in"),
"expires_at": expires_at.isoformat() if expires_at else None,
"scope": token_json.get("scope"),
"_token_encrypted": True,
}
cfg.update(account_info)
return cfg
def get_connector_display_name(self, account_info: dict[str, Any]) -> str:
return str(account_info.get("name", self.provider_name.title()))
async def on_token_refresh_failure(
self,
session: AsyncSession,
connector: SearchSourceConnector,
) -> None:
try:
connector.config = {**connector.config, "auth_expired": True}
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
except Exception:
logger.warning(
"Failed to persist auth_expired flag for connector %s",
connector.id,
exc_info=True,
)
async def _exchange_code(
self, code: str, extra_state: dict[str, Any]
) -> dict[str, Any]:
client_id = self._get_client_id()
client_secret = self._get_client_secret()
redirect_uri = self._get_redirect_uri()
headers: dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded",
}
body: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
}
if self.token_auth_method == "basic":
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
headers["Authorization"] = f"Basic {creds}"
else:
body["client_id"] = client_id
body["client_secret"] = client_secret
if self.use_pkce:
verifier = extra_state.get("code_verifier")
if verifier:
body["code_verifier"] = verifier
async with httpx.AsyncClient() as client:
resp = await client.post(
self.token_url, data=body, headers=headers, timeout=30.0
)
if resp.status_code != 200:
detail = resp.text
try:
detail = resp.json().get("error_description", detail)
except Exception:
pass
raise HTTPException(
status_code=400, detail=f"Token exchange failed: {detail}"
)
return resp.json()
async def refresh_token(
self, session: AsyncSession, connector: SearchSourceConnector
) -> SearchSourceConnector:
encryption = self._get_token_encryption()
is_encrypted = connector.config.get("_token_encrypted", False)
refresh_tok = connector.config.get("refresh_token")
if is_encrypted and refresh_tok:
try:
refresh_tok = encryption.decrypt_token(refresh_tok)
except Exception as e:
logger.error("Failed to decrypt refresh token: %s", e)
raise HTTPException(
status_code=500, detail="Failed to decrypt stored refresh token"
) from e
if not refresh_tok:
await self.on_token_refresh_failure(session, connector)
raise HTTPException(
status_code=400,
detail="No refresh token available. Please re-authenticate.",
)
client_id = self._get_client_id()
client_secret = self._get_client_secret()
headers: dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded",
}
body: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": refresh_tok,
}
if self.token_auth_method == "basic":
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
headers["Authorization"] = f"Basic {creds}"
else:
body["client_id"] = client_id
body["client_secret"] = client_secret
async with httpx.AsyncClient() as client:
resp = await client.post(
self.token_url, data=body, headers=headers, timeout=30.0
)
if resp.status_code != 200:
error_detail = resp.text
try:
ej = resp.json()
error_detail = ej.get("error_description", error_detail)
error_code = ej.get("error", "")
except Exception:
error_code = ""
combined = (error_detail + error_code).lower()
if any(kw in combined for kw in ("invalid_grant", "expired", "revoked")):
await self.on_token_refresh_failure(session, connector)
raise HTTPException(
status_code=401,
detail=f"{self.provider_name.title()} authentication failed. "
"Please re-authenticate.",
)
raise HTTPException(
status_code=400, detail=f"Token refresh failed: {error_detail}"
)
token_json = resp.json()
new_access = token_json.get("access_token")
if not new_access:
raise HTTPException(
status_code=400, detail="No access token received from refresh"
)
expires_at = None
if token_json.get("expires_in"):
expires_at = datetime.now(UTC) + timedelta(
seconds=int(token_json["expires_in"])
)
updated_config = dict(connector.config)
updated_config["access_token"] = encryption.encrypt_token(new_access)
new_refresh = token_json.get("refresh_token")
if new_refresh:
updated_config["refresh_token"] = encryption.encrypt_token(new_refresh)
updated_config["expires_in"] = token_json.get("expires_in")
updated_config["expires_at"] = expires_at.isoformat() if expires_at else None
updated_config["scope"] = token_json.get("scope", updated_config.get("scope"))
updated_config["_token_encrypted"] = True
updated_config.pop("auth_expired", None)
connector.config = updated_config
flag_modified(connector, "config")
await session.commit()
await session.refresh(connector)
logger.info(
"Refreshed %s token for connector %s",
self.provider_name,
connector.id,
)
return connector
def build_router(self) -> APIRouter:
router = APIRouter()
oauth = self
@router.get(f"{oauth.auth_prefix}/connector/add")
async def connect(
space_id: int,
user: User = Depends(current_active_user),
):
if not space_id:
raise HTTPException(status_code=400, detail="space_id is required")
client_id = oauth._get_client_id()
state_mgr = oauth._get_state_manager()
extra_state: dict[str, Any] = {}
auth_params: dict[str, str] = {
"client_id": client_id,
"response_type": "code",
"redirect_uri": oauth._get_redirect_uri(),
"scope": " ".join(oauth.scopes),
}
if oauth.use_pkce:
from app.utils.oauth_security import generate_pkce_pair
verifier, challenge = generate_pkce_pair()
extra_state["code_verifier"] = verifier
auth_params["code_challenge"] = challenge
auth_params["code_challenge_method"] = "S256"
auth_params.update(oauth.extra_auth_params)
state_encoded = state_mgr.generate_secure_state(
space_id, user.id, **extra_state
)
auth_params["state"] = state_encoded
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
logger.info(
"Generated %s OAuth URL for user %s, space %s",
oauth.provider_name,
user.id,
space_id,
)
return {"auth_url": auth_url}
@router.get(f"{oauth.auth_prefix}/connector/reauth")
async def reauth(
space_id: int,
connector_id: int,
return_url: str | None = None,
user: User = Depends(current_active_user),
session: AsyncSession = Depends(get_async_session),
):
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == connector_id,
SearchSourceConnector.user_id == user.id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type == oauth.connector_type,
)
)
if not result.scalars().first():
raise HTTPException(
status_code=404,
detail=f"{oauth.provider_name.title()} connector not found "
"or access denied",
)
client_id = oauth._get_client_id()
state_mgr = oauth._get_state_manager()
extra: dict[str, Any] = {"connector_id": connector_id}
if return_url and return_url.startswith("/") and not return_url.startswith("//"):
extra["return_url"] = return_url
auth_params: dict[str, str] = {
"client_id": client_id,
"response_type": "code",
"redirect_uri": oauth._get_redirect_uri(),
"scope": " ".join(oauth.scopes),
}
if oauth.use_pkce:
from app.utils.oauth_security import generate_pkce_pair
verifier, challenge = generate_pkce_pair()
extra["code_verifier"] = verifier
auth_params["code_challenge"] = challenge
auth_params["code_challenge_method"] = "S256"
auth_params.update(oauth.extra_auth_params)
state_encoded = state_mgr.generate_secure_state(
space_id, user.id, **extra
)
auth_params["state"] = state_encoded
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
logger.info(
"Initiating %s re-auth for user %s, connector %s",
oauth.provider_name,
user.id,
connector_id,
)
return {"auth_url": auth_url}
@router.get(f"{oauth.auth_prefix}/connector/callback")
async def callback(
code: str | None = None,
error: str | None = None,
state: str | None = None,
session: AsyncSession = Depends(get_async_session),
):
error_label = f"{oauth.provider_name}_oauth_denied"
if error:
logger.warning("%s OAuth error: %s", oauth.provider_name, error)
space_id = None
if state:
try:
data = oauth._get_state_manager().validate_state(state)
space_id = data.get("space_id")
except Exception:
pass
return oauth._frontend_redirect(space_id, error=error_label)
if not code:
raise HTTPException(
status_code=400, detail="Missing authorization code"
)
if not state:
raise HTTPException(
status_code=400, detail="Missing state parameter"
)
state_mgr = oauth._get_state_manager()
try:
data = state_mgr.validate_state(state)
except Exception as e:
raise HTTPException(
status_code=400, detail="Invalid or expired state parameter."
) from e
user_id = UUID(data["user_id"])
space_id = data["space_id"]
token_json = await oauth._exchange_code(code, data)
access_token = token_json.get("access_token", "")
if not access_token:
raise HTTPException(
status_code=400,
detail=f"No access token received from {oauth.provider_name.title()}",
)
account_info = await oauth.fetch_account_info(access_token)
encryption = oauth._get_token_encryption()
connector_config = oauth.build_connector_config(
token_json, account_info, encryption
)
display_name = oauth.get_connector_display_name(account_info)
# --- Re-auth path ---
reauth_connector_id = data.get("connector_id")
reauth_return_url = data.get("return_url")
if reauth_connector_id:
result = await session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == reauth_connector_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.search_space_id == space_id,
SearchSourceConnector.connector_type == oauth.connector_type,
)
)
db_connector = result.scalars().first()
if not db_connector:
raise HTTPException(
status_code=404,
detail="Connector not found or access denied during re-auth",
)
db_connector.config = connector_config
flag_modified(db_connector, "config")
await session.commit()
await session.refresh(db_connector)
logger.info(
"Re-authenticated %s connector %s for user %s",
oauth.provider_name,
db_connector.id,
user_id,
)
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
return RedirectResponse(
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
)
return oauth._frontend_redirect(
space_id, success=True, connector_id=db_connector.id
)
# --- New connector path ---
is_dup = await check_duplicate_connector(
session,
oauth.connector_type,
space_id,
user_id,
display_name,
)
if is_dup:
logger.warning(
"Duplicate %s connector for user %s (%s)",
oauth.provider_name,
user_id,
display_name,
)
return oauth._frontend_redirect(
space_id,
error=f"duplicate_account&connector={oauth.provider_name}-connector",
)
connector_name = await generate_unique_connector_name(
session,
oauth.connector_type,
space_id,
user_id,
display_name,
)
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=oauth.connector_type,
is_indexable=oauth.is_indexable,
config=connector_config,
search_space_id=space_id,
user_id=user_id,
)
session.add(new_connector)
try:
await session.commit()
except IntegrityError as e:
await session.rollback()
raise HTTPException(
status_code=409, detail="A connector for this service already exists."
) from e
logger.info(
"Created %s connector %s for user %s in space %s",
oauth.provider_name,
new_connector.id,
user_id,
space_id,
)
return oauth._frontend_redirect(
space_id, success=True, connector_id=new_connector.id
)
return router

View file

@ -693,27 +693,10 @@ async def index_connector_content(
user: User = Depends(current_active_user),
):
"""
Index content from a connector to a search space.
Requires CONNECTORS_UPDATE permission (to trigger indexing).
Index content from a KB connector to a search space.
Currently supports:
- SLACK_CONNECTOR: Indexes messages from all accessible Slack channels
- TEAMS_CONNECTOR: Indexes messages from all accessible Microsoft Teams channels
- NOTION_CONNECTOR: Indexes pages from all accessible Notion pages
- GITHUB_CONNECTOR: Indexes code and documentation from GitHub repositories
- LINEAR_CONNECTOR: Indexes issues and comments from Linear
- JIRA_CONNECTOR: Indexes issues and comments from Jira
- DISCORD_CONNECTOR: Indexes messages from all accessible Discord channels
- LUMA_CONNECTOR: Indexes events from Luma
- ELASTICSEARCH_CONNECTOR: Indexes documents from Elasticsearch
- WEBCRAWLER_CONNECTOR: Indexes web pages from crawled websites
Args:
connector_id: ID of the connector to use
search_space_id: ID of the search space to store indexed content
Returns:
Dictionary with indexing status
Live connectors (Slack, Teams, Linear, Jira, ClickUp, Calendar, Airtable,
Gmail, Discord, Luma) use real-time agent tools instead.
"""
try:
# Get the connector first
@ -770,9 +753,7 @@ async def index_connector_content(
# For calendar connectors, default to today but allow future dates if explicitly provided
if connector.connector_type in [
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.LUMA_CONNECTOR,
]:
# Default to today if no end_date provided (users can manually select future dates)
indexing_to = today_str if end_date is None else end_date
@ -796,33 +777,22 @@ async def index_connector_content(
# For non-calendar connectors, cap at today
indexing_to = end_date if end_date else today_str
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import (
index_slack_messages_task,
)
from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES
logger.info(
f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_slack_messages_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Slack indexing started in the background."
if connector.connector_type in LIVE_CONNECTOR_TYPES:
return {
"message": (
f"{connector.connector_type.value} uses real-time agent tools; "
"background indexing is disabled."
),
"indexing_started": False,
"connector_id": connector_id,
"search_space_id": search_space_id,
"indexing_from": indexing_from,
"indexing_to": indexing_to,
}
elif connector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import (
index_teams_messages_task,
)
logger.info(
f"Triggering Teams indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_teams_messages_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Teams indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
if connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import index_notion_pages_task
logger.info(
@ -844,28 +814,6 @@ async def index_connector_content(
)
response_message = "GitHub indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import index_linear_issues_task
logger.info(
f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_linear_issues_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Linear indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.JIRA_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import index_jira_issues_task
logger.info(
f"Triggering Jira indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_jira_issues_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Jira indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.CONFLUENCE_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import (
index_confluence_pages_task,
@ -892,59 +840,6 @@ async def index_connector_content(
)
response_message = "BookStack indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.CLICKUP_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import index_clickup_tasks_task
logger.info(
f"Triggering ClickUp indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_clickup_tasks_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "ClickUp indexing started in the background."
elif (
connector.connector_type
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR
):
from app.tasks.celery_tasks.connector_tasks import (
index_google_calendar_events_task,
)
logger.info(
f"Triggering Google Calendar indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_google_calendar_events_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Google Calendar indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.AIRTABLE_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import (
index_airtable_records_task,
)
logger.info(
f"Triggering Airtable indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_airtable_records_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Airtable indexing started in the background."
elif (
connector.connector_type == SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR
):
from app.tasks.celery_tasks.connector_tasks import (
index_google_gmail_messages_task,
)
logger.info(
f"Triggering Google Gmail indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_google_gmail_messages_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Google Gmail indexing started in the background."
elif (
connector.connector_type == SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR
):
@ -1089,30 +984,6 @@ async def index_connector_content(
)
response_message = "Dropbox indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import (
index_discord_messages_task,
)
logger.info(
f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_discord_messages_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Discord indexing started in the background."
elif connector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR:
from app.tasks.celery_tasks.connector_tasks import index_luma_events_task
logger.info(
f"Triggering Luma indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
)
index_luma_events_task.delay(
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
)
response_message = "Luma indexing started in the background."
elif (
connector.connector_type
== SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR
@ -1338,57 +1209,6 @@ async def _update_connector_timestamp_by_id(session: AsyncSession, connector_id:
await session.rollback()
async def run_slack_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Create a new session and run the Slack indexing task.
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_slack_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
async def run_slack_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Slack indexing.
Args:
session: Database session
connector_id: ID of the Slack connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_slack_messages
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_slack_messages,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
_AUTH_ERROR_PATTERNS = (
"failed to refresh linear oauth",
"failed to refresh your notion connection",
@ -1927,215 +1747,6 @@ async def run_github_indexing(
)
# Add new helper functions for Linear indexing
async def run_linear_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Wrapper to run Linear indexing with its own database session."""
logger.info(
f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
)
async with async_session_maker() as session:
await run_linear_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
logger.info(f"Background task finished: Indexing Linear connector {connector_id}")
async def run_linear_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Linear indexing.
Args:
session: Database session
connector_id: ID of the Linear connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_linear_issues
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_linear_issues,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
# Add new helper functions for discord indexing
async def run_discord_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Create a new session and run the Discord indexing task.
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_discord_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
async def run_discord_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Discord indexing.
Args:
session: Database session
connector_id: ID of the Discord connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_discord_messages
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_discord_messages,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
async def run_teams_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Create a new session and run the Microsoft Teams indexing task.
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_teams_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
async def run_teams_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Microsoft Teams indexing.
Args:
session: Database session
connector_id: ID of the Teams connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers.teams_indexer import index_teams_messages
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_teams_messages,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
# Add new helper functions for Jira indexing
async def run_jira_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Wrapper to run Jira indexing with its own database session."""
logger.info(
f"Background task started: Indexing Jira connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
)
async with async_session_maker() as session:
await run_jira_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
logger.info(f"Background task finished: Indexing Jira connector {connector_id}")
async def run_jira_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Jira indexing.
Args:
session: Database session
connector_id: ID of the Jira connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_jira_issues
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_jira_issues,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
# Add new helper functions for Confluence indexing
async def run_confluence_indexing_with_new_session(
connector_id: int,
@ -2191,112 +1802,6 @@ async def run_confluence_indexing(
)
# Add new helper functions for ClickUp indexing
async def run_clickup_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Wrapper to run ClickUp indexing with its own database session."""
logger.info(
f"Background task started: Indexing ClickUp connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
)
async with async_session_maker() as session:
await run_clickup_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
logger.info(f"Background task finished: Indexing ClickUp connector {connector_id}")
async def run_clickup_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run ClickUp indexing.
Args:
session: Database session
connector_id: ID of the ClickUp connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_clickup_tasks
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_clickup_tasks,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
# Add new helper functions for Airtable indexing
async def run_airtable_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Wrapper to run Airtable indexing with its own database session."""
logger.info(
f"Background task started: Indexing Airtable connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
)
async with async_session_maker() as session:
await run_airtable_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
logger.info(f"Background task finished: Indexing Airtable connector {connector_id}")
async def run_airtable_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Airtable indexing.
Args:
session: Database session
connector_id: ID of the Airtable connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_airtable_records
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_airtable_records,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
# Add new helper functions for Google Calendar indexing
async def run_google_calendar_indexing_with_new_session(
connector_id: int,
@ -2835,58 +2340,6 @@ async def run_dropbox_indexing(
logger.error(f"Failed to update notification: {notif_error!s}")
# Add new helper functions for luma indexing
async def run_luma_indexing_with_new_session(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Create a new session and run the Luma indexing task.
This prevents session leaks by creating a dedicated session for the background task.
"""
async with async_session_maker() as session:
await run_luma_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
async def run_luma_indexing(
session: AsyncSession,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""
Background task to run Luma indexing.
Args:
session: Database session
connector_id: ID of the Luma connector
search_space_id: ID of the search space
user_id: ID of the user
start_date: Start date for indexing
end_date: End date for indexing
"""
from app.tasks.connector_indexers import index_luma_events
await _run_indexing_with_notifications(
session=session,
connector_id=connector_id,
search_space_id=search_space_id,
user_id=user_id,
start_date=start_date,
end_date=end_date,
indexing_function=index_luma_events,
update_timestamp_func=_update_connector_timestamp_by_id,
supports_heartbeat_callback=True,
)
async def run_elasticsearch_indexing_with_new_session(
connector_id: int,
search_space_id: int,

View file

@ -312,7 +312,7 @@ async def slack_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.SLACK_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -45,6 +45,7 @@ SCOPES = [
"Team.ReadBasic.All", # Read basic team information
"Channel.ReadBasic.All", # Read basic channel information
"ChannelMessage.Read.All", # Read messages in channels
"ChannelMessage.Send", # Send messages in channels
]
# Initialize security utilities
@ -320,7 +321,7 @@ async def teams_callback(
new_connector = SearchSourceConnector(
name=connector_name,
connector_type=SearchSourceConnectorType.TEAMS_CONNECTOR,
is_indexable=True,
is_indexable=False,
config=connector_config,
search_space_id=space_id,
user_id=user_id,

View file

@ -26,7 +26,7 @@ COMPOSIO_TOOLKIT_NAMES = {
}
# Toolkits that support indexing (Phase 1: Google services only)
INDEXABLE_TOOLKITS = {"googledrive", "gmail", "googlecalendar"}
INDEXABLE_TOOLKITS = {"googledrive"}
# Mapping of toolkit IDs to connector types
TOOLKIT_TO_CONNECTOR_TYPE = {

View file

@ -290,6 +290,12 @@ class LLMRouterService:
instance._router = Router(**router_kwargs)
instance._initialized = True
global _cached_context_profile, _cached_context_profile_computed
_cached_context_profile = None
_cached_context_profile_computed = False
_router_instance_cache.clear()
logger.info(
"LLM Router initialized with %d deployments, "
"strategy: %s, context_window_fallbacks: %s, fallbacks: %s",

View file

@ -0,0 +1,121 @@
"""MCP OAuth 2.1 metadata discovery, Dynamic Client Registration, and token exchange."""
from __future__ import annotations
import base64
import logging
from urllib.parse import urlparse
import httpx
logger = logging.getLogger(__name__)
async def discover_oauth_metadata(
mcp_url: str,
*,
origin_override: str | None = None,
timeout: float = 15.0,
) -> dict:
"""Fetch OAuth 2.1 metadata from the MCP server's well-known endpoint.
Per the MCP spec the discovery document lives at the *origin* of the
MCP server URL. ``origin_override`` can be used when the OAuth server
lives on a different domain (e.g. Airtable: MCP at ``mcp.airtable.com``,
OAuth at ``airtable.com``).
"""
if origin_override:
origin = origin_override.rstrip("/")
else:
parsed = urlparse(mcp_url)
origin = f"{parsed.scheme}://{parsed.netloc}"
discovery_url = f"{origin}/.well-known/oauth-authorization-server"
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.get(discovery_url, timeout=timeout)
resp.raise_for_status()
return resp.json()
async def register_client(
registration_endpoint: str,
redirect_uri: str,
*,
client_name: str = "SurfSense",
timeout: float = 15.0,
) -> dict:
"""Perform Dynamic Client Registration (RFC 7591)."""
payload = {
"client_name": client_name,
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_basic",
}
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.post(
registration_endpoint, json=payload, timeout=timeout,
)
resp.raise_for_status()
return resp.json()
async def exchange_code_for_tokens(
token_endpoint: str,
code: str,
redirect_uri: str,
client_id: str,
client_secret: str,
code_verifier: str,
*,
timeout: float = 30.0,
) -> dict:
"""Exchange an authorization code for access + refresh tokens."""
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.post(
token_endpoint,
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
},
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Authorization": f"Basic {creds}",
},
timeout=timeout,
)
resp.raise_for_status()
return resp.json()
async def refresh_access_token(
token_endpoint: str,
refresh_token: str,
client_id: str,
client_secret: str,
*,
timeout: float = 30.0,
) -> dict:
"""Refresh an expired access token."""
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.post(
token_endpoint,
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
},
headers={
"Content-Type": "application/x-www-form-urlencoded",
"Authorization": f"Basic {creds}",
},
timeout=timeout,
)
resp.raise_for_status()
return resp.json()

View file

@ -0,0 +1,161 @@
"""Registry of MCP services with OAuth support.
Each entry maps a URL-safe service key to its MCP server endpoint and
authentication configuration. Services with ``supports_dcr=True`` use
RFC 7591 Dynamic Client Registration (the MCP server issues its own
credentials); the rest use pre-configured credentials via env vars.
``allowed_tools`` whitelists which MCP tools to expose to the agent.
An empty list means "load every tool the server advertises" (used for
user-managed generic MCP servers). Service-specific entries should
curate this list to keep the agent's tool count low and selection
accuracy high.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from app.db import SearchSourceConnectorType
@dataclass(frozen=True)
class MCPServiceConfig:
name: str
mcp_url: str
connector_type: str
supports_dcr: bool = True
oauth_discovery_origin: str | None = None
client_id_env: str | None = None
client_secret_env: str | None = None
scopes: list[str] = field(default_factory=list)
scope_param: str = "scope"
auth_endpoint_override: str | None = None
token_endpoint_override: str | None = None
allowed_tools: list[str] = field(default_factory=list)
readonly_tools: frozenset[str] = field(default_factory=frozenset)
account_metadata_keys: list[str] = field(default_factory=list)
"""``connector.config`` keys exposed by ``get_connected_accounts``.
Only listed keys are returned to the LLM tokens and secrets are
never included. Every service should at least have its
``display_name`` populated during OAuth; additional service-specific
fields (e.g. Jira ``cloud_id``) are listed here so the LLM can pass
them to action tools.
"""
MCP_SERVICES: dict[str, MCPServiceConfig] = {
"linear": MCPServiceConfig(
name="Linear",
mcp_url="https://mcp.linear.app/mcp",
connector_type="LINEAR_CONNECTOR",
allowed_tools=[
"list_issues",
"get_issue",
"save_issue",
],
readonly_tools=frozenset({"list_issues", "get_issue"}),
account_metadata_keys=["organization_name", "organization_url_key"],
),
"jira": MCPServiceConfig(
name="Jira",
mcp_url="https://mcp.atlassian.com/v1/mcp",
connector_type="JIRA_CONNECTOR",
allowed_tools=[
"getAccessibleAtlassianResources",
"searchJiraIssuesUsingJql",
"getVisibleJiraProjects",
"getJiraProjectIssueTypesMetadata",
"createJiraIssue",
"editJiraIssue",
],
readonly_tools=frozenset({
"getAccessibleAtlassianResources",
"searchJiraIssuesUsingJql",
"getVisibleJiraProjects",
"getJiraProjectIssueTypesMetadata",
}),
account_metadata_keys=["cloud_id", "site_name", "base_url"],
),
"clickup": MCPServiceConfig(
name="ClickUp",
mcp_url="https://mcp.clickup.com/mcp",
connector_type="CLICKUP_CONNECTOR",
allowed_tools=[
"clickup_search",
"clickup_get_task",
],
readonly_tools=frozenset({"clickup_search", "clickup_get_task"}),
account_metadata_keys=["workspace_id", "workspace_name"],
),
"slack": MCPServiceConfig(
name="Slack",
mcp_url="https://mcp.slack.com/mcp",
connector_type="SLACK_CONNECTOR",
supports_dcr=False,
client_id_env="SLACK_CLIENT_ID",
client_secret_env="SLACK_CLIENT_SECRET",
auth_endpoint_override="https://slack.com/oauth/v2_user/authorize",
token_endpoint_override="https://slack.com/api/oauth.v2.user.access",
scopes=[
"search:read.public", "search:read.private", "search:read.mpim", "search:read.im",
"channels:history", "groups:history", "mpim:history", "im:history",
],
allowed_tools=[
"slack_search_channels",
"slack_read_channel",
"slack_read_thread",
],
readonly_tools=frozenset({"slack_search_channels", "slack_read_channel", "slack_read_thread"}),
# TODO: oauth.v2.user.access only returns team.id, not team.name.
# To populate team_name, either add "team:read" scope and call
# GET /api/team.info during OAuth callback, or switch to oauth.v2.access.
account_metadata_keys=["team_id", "team_name"],
),
"airtable": MCPServiceConfig(
name="Airtable",
mcp_url="https://mcp.airtable.com/mcp",
connector_type="AIRTABLE_CONNECTOR",
supports_dcr=False,
oauth_discovery_origin="https://airtable.com",
client_id_env="AIRTABLE_CLIENT_ID",
client_secret_env="AIRTABLE_CLIENT_SECRET",
scopes=["data.records:read", "schema.bases:read"],
allowed_tools=[
"list_bases",
"list_tables_for_base",
"list_records_for_table",
],
readonly_tools=frozenset({"list_bases", "list_tables_for_base", "list_records_for_table"}),
account_metadata_keys=["user_id", "user_email"],
),
}
_CONNECTOR_TYPE_TO_SERVICE: dict[str, MCPServiceConfig] = {
svc.connector_type: svc for svc in MCP_SERVICES.values()
}
LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset({
SearchSourceConnectorType.SLACK_CONNECTOR,
SearchSourceConnectorType.TEAMS_CONNECTOR,
SearchSourceConnectorType.LINEAR_CONNECTOR,
SearchSourceConnectorType.JIRA_CONNECTOR,
SearchSourceConnectorType.CLICKUP_CONNECTOR,
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
SearchSourceConnectorType.AIRTABLE_CONNECTOR,
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
SearchSourceConnectorType.DISCORD_CONNECTOR,
SearchSourceConnectorType.LUMA_CONNECTOR,
})
def get_service(key: str) -> MCPServiceConfig | None:
return MCP_SERVICES.get(key)
def get_service_by_connector_type(connector_type: str) -> MCPServiceConfig | None:
"""Look up an MCP service config by its ``connector_type`` enum value."""
return _CONNECTOR_TYPE_TO_SERVICE.get(connector_type)

View file

@ -227,8 +227,6 @@ class NotionToolMetadataService:
async def _check_account_health(self, connector_id: int) -> bool:
"""Check if a Notion connector's token is still valid.
Uses a lightweight ``users.me()`` call to verify the token.
Returns True if the token is expired/invalid, False if healthy.
"""
try:

View file

@ -39,52 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N
)
@celery_app.task(name="index_slack_messages", bind=True)
def index_slack_messages_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Slack messages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_slack_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
)
except Exception as e:
_handle_greenlet_error(e, "index_slack_messages", connector_id)
raise
finally:
loop.close()
async def _index_slack_messages(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Slack messages with new session."""
from app.routes.search_source_connectors_routes import (
run_slack_indexing,
)
async with get_celery_session_maker()() as session:
await run_slack_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_notion_pages", bind=True)
def index_notion_pages_task(
self,
@ -174,92 +128,6 @@ async def _index_github_repos(
)
@celery_app.task(name="index_linear_issues", bind=True)
def index_linear_issues_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Linear issues."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_linear_issues(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_linear_issues(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Linear issues with new session."""
from app.routes.search_source_connectors_routes import (
run_linear_indexing,
)
async with get_celery_session_maker()() as session:
await run_linear_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_jira_issues", bind=True)
def index_jira_issues_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Jira issues."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_jira_issues(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_jira_issues(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Jira issues with new session."""
from app.routes.search_source_connectors_routes import (
run_jira_indexing,
)
async with get_celery_session_maker()() as session:
await run_jira_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_confluence_pages", bind=True)
def index_confluence_pages_task(
self,
@ -303,49 +171,6 @@ async def _index_confluence_pages(
)
@celery_app.task(name="index_clickup_tasks", bind=True)
def index_clickup_tasks_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index ClickUp tasks."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_clickup_tasks(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_clickup_tasks(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index ClickUp tasks with new session."""
from app.routes.search_source_connectors_routes import (
run_clickup_indexing,
)
async with get_celery_session_maker()() as session:
await run_clickup_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_google_calendar_events", bind=True)
def index_google_calendar_events_task(
self,
@ -392,49 +217,6 @@ async def _index_google_calendar_events(
)
@celery_app.task(name="index_airtable_records", bind=True)
def index_airtable_records_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Airtable records."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_airtable_records(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_airtable_records(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Airtable records with new session."""
from app.routes.search_source_connectors_routes import (
run_airtable_indexing,
)
async with get_celery_session_maker()() as session:
await run_airtable_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_google_gmail_messages", bind=True)
def index_google_gmail_messages_task(
self,
@ -622,135 +404,6 @@ async def _index_dropbox_files(
)
@celery_app.task(name="index_discord_messages", bind=True)
def index_discord_messages_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Discord messages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_discord_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_discord_messages(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Discord messages with new session."""
from app.routes.search_source_connectors_routes import (
run_discord_indexing,
)
async with get_celery_session_maker()() as session:
await run_discord_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_teams_messages", bind=True)
def index_teams_messages_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Microsoft Teams messages."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_teams_messages(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_teams_messages(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Microsoft Teams messages with new session."""
from app.routes.search_source_connectors_routes import (
run_teams_indexing,
)
async with get_celery_session_maker()() as session:
await run_teams_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_luma_events", bind=True)
def index_luma_events_task(
self,
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Celery task to index Luma events."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(
_index_luma_events(
connector_id, search_space_id, user_id, start_date, end_date
)
)
finally:
loop.close()
async def _index_luma_events(
connector_id: int,
search_space_id: int,
user_id: str,
start_date: str,
end_date: str,
):
"""Index Luma events with new session."""
from app.routes.search_source_connectors_routes import (
run_luma_indexing,
)
async with get_celery_session_maker()() as session:
await run_luma_indexing(
session, connector_id, search_space_id, user_id, start_date, end_date
)
@celery_app.task(name="index_elasticsearch_documents", bind=True)
def index_elasticsearch_documents_task(
self,

View file

@ -51,50 +51,51 @@ async def _check_and_trigger_schedules():
logger.info(f"Found {len(due_connectors)} connectors due for indexing")
# Import all indexing tasks
# Import indexing tasks for KB connectors only.
# Live connectors (Linear, Slack, Jira, ClickUp, Airtable, Discord,
# Teams, Gmail, Calendar, Luma) use real-time tools instead.
from app.tasks.celery_tasks.connector_tasks import (
index_airtable_records_task,
index_clickup_tasks_task,
index_confluence_pages_task,
index_crawled_urls_task,
index_discord_messages_task,
index_elasticsearch_documents_task,
index_github_repos_task,
index_google_calendar_events_task,
index_google_drive_files_task,
index_google_gmail_messages_task,
index_jira_issues_task,
index_linear_issues_task,
index_luma_events_task,
index_notion_pages_task,
index_slack_messages_task,
)
# Map connector types to their tasks
task_map = {
SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task,
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task,
SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task,
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task,
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task,
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task,
SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task,
SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task,
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
# Composio connector types (unified with native Google tasks)
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: index_google_gmail_messages_task,
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
}
from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES
# Disable obsolete periodic indexing for live connectors in one batch.
live_disabled = []
for connector in due_connectors:
if connector.connector_type in LIVE_CONNECTOR_TYPES:
connector.periodic_indexing_enabled = False
connector.next_scheduled_at = None
live_disabled.append(connector)
if live_disabled:
await session.commit()
for c in live_disabled:
logger.info(
"Disabled obsolete periodic indexing for live connector %s (%s)",
c.id,
c.connector_type.value,
)
# Trigger indexing for each due connector
for connector in due_connectors:
if connector in live_disabled:
continue
# Primary guard: Redis lock indicates a task is currently running.
if is_connector_indexing_locked(connector.id):
logger.info(

View file

@ -1,77 +1,31 @@
"""
Connector indexers module for background tasks.
This module provides a collection of connector indexers for different platforms
and services. Each indexer is responsible for handling the indexing of content
from a specific connector type.
Available indexers:
- Slack: Index messages from Slack channels
- Notion: Index pages from Notion workspaces
- GitHub: Index repositories and files from GitHub
- Linear: Index issues from Linear workspaces
- Jira: Index issues from Jira projects
- Confluence: Index pages from Confluence spaces
- BookStack: Index pages from BookStack wiki instances
- Discord: Index messages from Discord servers
- ClickUp: Index tasks from ClickUp workspaces
- Google Gmail: Index messages from Google Gmail
- Google Calendar: Index events from Google Calendar
- Luma: Index events from Luma
- Webcrawler: Index crawled URLs
- Elasticsearch: Index documents from Elasticsearch instances
Each indexer handles content indexing from a specific connector type.
Live connectors (Slack, Linear, Jira, ClickUp, Airtable, Discord, Teams,
Luma) now use real-time agent tools instead of background indexing.
"""
# Communication platforms
# Calendar and scheduling
from .airtable_indexer import index_airtable_records
from .bookstack_indexer import index_bookstack_pages
# Note: composio_indexer is imported directly in connector_tasks.py to avoid circular imports
from .clickup_indexer import index_clickup_tasks
from .confluence_indexer import index_confluence_pages
from .discord_indexer import index_discord_messages
# Development platforms
from .elasticsearch_indexer import index_elasticsearch_documents
from .github_indexer import index_github_repos
from .google_calendar_indexer import index_google_calendar_events
from .google_drive_indexer import index_google_drive_files
from .google_gmail_indexer import index_google_gmail_messages
from .jira_indexer import index_jira_issues
# Issue tracking and project management
from .linear_indexer import index_linear_issues
# Documentation and knowledge management
from .luma_indexer import index_luma_events
from .notion_indexer import index_notion_pages
from .obsidian_indexer import index_obsidian_vault
from .slack_indexer import index_slack_messages
from .webcrawler_indexer import index_crawled_urls
__all__ = [ # noqa: RUF022
"index_airtable_records",
__all__ = [
"index_bookstack_pages",
# "index_composio_connector", # Imported directly in connector_tasks.py to avoid circular imports
"index_clickup_tasks",
"index_confluence_pages",
"index_discord_messages",
# Development platforms
"index_elasticsearch_documents",
"index_github_repos",
# Calendar and scheduling
"index_google_calendar_events",
"index_google_drive_files",
"index_luma_events",
"index_jira_issues",
# Issue tracking and project management
"index_linear_issues",
# Documentation and knowledge management
"index_google_gmail_messages",
"index_notion_pages",
"index_obsidian_vault",
"index_crawled_urls",
# Communication platforms
"index_slack_messages",
"index_google_gmail_messages",
]

View file

@ -0,0 +1,129 @@
"""Async retry decorators for connector API calls, built on tenacity."""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import TypeVar
import httpx
from tenacity import (
before_sleep_log,
retry,
retry_if_exception,
stop_after_attempt,
stop_after_delay,
wait_exponential_jitter,
)
from app.connectors.exceptions import (
ConnectorAPIError,
ConnectorAuthError,
ConnectorError,
ConnectorRateLimitError,
ConnectorTimeoutError,
)
logger = logging.getLogger(__name__)
F = TypeVar("F", bound=Callable)
def _is_retryable(exc: BaseException) -> bool:
if isinstance(exc, ConnectorError):
return exc.retryable
if isinstance(exc, (httpx.TimeoutException, httpx.ConnectError)):
return True
return False
def build_retry(
*,
max_attempts: int = 4,
max_delay: float = 60.0,
initial_delay: float = 1.0,
total_timeout: float = 180.0,
service: str = "",
) -> Callable:
"""Configurable tenacity ``@retry`` decorator with exponential backoff + jitter."""
_logger = logging.getLogger(f"connector.retry.{service}") if service else logger
return retry(
retry=retry_if_exception(_is_retryable),
stop=(stop_after_attempt(max_attempts) | stop_after_delay(total_timeout)),
wait=wait_exponential_jitter(initial=initial_delay, max=max_delay),
reraise=True,
before_sleep=before_sleep_log(_logger, logging.WARNING),
)
def retry_on_transient(
*,
service: str = "",
max_attempts: int = 4,
) -> Callable:
"""Shorthand: retry up to *max_attempts* on rate-limits, timeouts, and 5xx."""
return build_retry(max_attempts=max_attempts, service=service)
def raise_for_status(
response: httpx.Response,
*,
service: str = "",
) -> None:
"""Map non-2xx httpx responses to the appropriate ``ConnectorError``."""
if response.is_success:
return
status = response.status_code
try:
body = response.json()
except Exception:
body = response.text[:500] if response.text else None
if status == 429:
retry_after_raw = response.headers.get("Retry-After")
retry_after: float | None = None
if retry_after_raw:
try:
retry_after = float(retry_after_raw)
except (ValueError, TypeError):
pass
raise ConnectorRateLimitError(
f"{service} rate limited (429)",
service=service,
retry_after=retry_after,
response_body=body,
)
if status in (401, 403):
raise ConnectorAuthError(
f"{service} authentication failed ({status})",
service=service,
status_code=status,
response_body=body,
)
if status == 504:
raise ConnectorTimeoutError(
f"{service} gateway timeout (504)",
service=service,
status_code=status,
response_body=body,
)
if status >= 500:
raise ConnectorAPIError(
f"{service} server error ({status})",
service=service,
status_code=status,
response_body=body,
)
raise ConnectorAPIError(
f"{service} request failed ({status})",
service=service,
status_code=status,
response_body=body,
)

View file

@ -39,7 +39,7 @@ BASE_NAME_FOR_TYPE = {
def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str:
"""Get a friendly display name for a connector type."""
return BASE_NAME_FOR_TYPE.get(
connector_type, connector_type.replace("_", " ").title()
connector_type, connector_type.value.replace("_", " ").title()
)
@ -231,9 +231,11 @@ async def generate_unique_connector_name(
base = get_base_name_for_type(connector_type)
if identifier:
return f"{base} - {identifier}"
name = f"{base} - {identifier}"
return await ensure_unique_connector_name(
session, name, search_space_id, user_id,
)
# Fallback: use counter for uniqueness
count = await count_connectors_of_type(
session, connector_type, search_space_id, user_id
)

View file

@ -18,19 +18,9 @@ logger = logging.getLogger(__name__)
# Mapping of connector types to their corresponding Celery task names
CONNECTOR_TASK_MAP = {
SearchSourceConnectorType.SLACK_CONNECTOR: "index_slack_messages",
SearchSourceConnectorType.TEAMS_CONNECTOR: "index_teams_messages",
SearchSourceConnectorType.NOTION_CONNECTOR: "index_notion_pages",
SearchSourceConnectorType.GITHUB_CONNECTOR: "index_github_repos",
SearchSourceConnectorType.LINEAR_CONNECTOR: "index_linear_issues",
SearchSourceConnectorType.JIRA_CONNECTOR: "index_jira_issues",
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "index_confluence_pages",
SearchSourceConnectorType.CLICKUP_CONNECTOR: "index_clickup_tasks",
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "index_google_calendar_events",
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "index_airtable_records",
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: "index_google_gmail_messages",
SearchSourceConnectorType.DISCORD_CONNECTOR: "index_discord_messages",
SearchSourceConnectorType.LUMA_CONNECTOR: "index_luma_events",
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "index_elasticsearch_documents",
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "index_crawled_urls",
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "index_bookstack_pages",
@ -84,40 +74,20 @@ def create_periodic_schedule(
f"(frequency: {frequency_minutes} minutes). Triggering first run..."
)
# Import all indexing tasks
from app.tasks.celery_tasks.connector_tasks import (
index_airtable_records_task,
index_bookstack_pages_task,
index_clickup_tasks_task,
index_confluence_pages_task,
index_crawled_urls_task,
index_discord_messages_task,
index_elasticsearch_documents_task,
index_github_repos_task,
index_google_calendar_events_task,
index_google_gmail_messages_task,
index_jira_issues_task,
index_linear_issues_task,
index_luma_events_task,
index_notion_pages_task,
index_obsidian_vault_task,
index_slack_messages_task,
)
# Map connector type to task
task_map = {
SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task,
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task,
SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task,
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task,
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task,
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task,
SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task,
SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task,
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: index_bookstack_pages_task,

View file

@ -8,6 +8,7 @@ import { Spinner } from "@/components/ui/spinner";
import { EnumConnectorName } from "@/contracts/enums/connector";
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import { cn } from "@/lib/utils";
import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants";
import { useConnectorStatus } from "../hooks/use-connector-status";
import { ConnectorStatusBadge } from "./connector-status-badge";
@ -55,6 +56,7 @@ export const ConnectorCard: FC<ConnectorCardProps> = ({
onManage,
}) => {
const isMCP = connectorType === EnumConnectorName.MCP_CONNECTOR;
const isLive = !!connectorType && LIVE_CONNECTOR_TYPES.has(connectorType);
// Get connector status
const { getConnectorStatus, isConnectorEnabled, getConnectorStatusMessage, shouldShowWarnings } =
useConnectorStatus();
@ -123,14 +125,14 @@ export const ConnectorCard: FC<ConnectorCardProps> = ({
</span>
) : (
<>
<span>{formatDocumentCount(documentCount)}</span>
{!isLive && <span>{formatDocumentCount(documentCount)}</span>}
{!isLive && accountCount !== undefined && accountCount > 0 && (
<span className="text-muted-foreground/50"></span>
)}
{accountCount !== undefined && accountCount > 0 && (
<>
<span className="text-muted-foreground/50"></span>
<span>
{accountCount} {accountCount === 1 ? "Account" : "Accounts"}
</span>
</>
<span>
{accountCount} {accountCount === 1 ? "Account" : "Accounts"}
</span>
)}
</>
)}

View file

@ -53,8 +53,7 @@ export const DiscordConfig: FC<DiscordConfigProps> = ({ connector }) => {
return () => document.removeEventListener("visibilitychange", handleVisibilityChange);
}, [connector?.id, fetchChannels]);
// Separate channels by indexing capability
const readyToIndex = channels.filter((ch) => ch.can_index);
const accessible = channels.filter((ch) => ch.can_index);
const needsPermissions = channels.filter((ch) => !ch.can_index);
// Format last fetched time
@ -80,7 +79,7 @@ export const DiscordConfig: FC<DiscordConfigProps> = ({ connector }) => {
</div>
<div className="text-xs sm:text-sm">
<p className="text-muted-foreground mt-1 text-[10px] sm:text-sm">
The bot needs &quot;Read Message History&quot; permission to index channels. Ask a
The bot needs &quot;Read Message History&quot; permission to access channels. Ask a
server admin to grant this permission for channels shown below.
</p>
</div>
@ -127,18 +126,18 @@ export const DiscordConfig: FC<DiscordConfigProps> = ({ connector }) => {
</div>
) : (
<div className="rounded-xl bg-slate-400/5 dark:bg-white/5 overflow-hidden">
{/* Ready to index */}
{readyToIndex.length > 0 && (
{/* Accessible channels */}
{accessible.length > 0 && (
<div className={cn("p-3", needsPermissions.length > 0 && "border-b border-border")}>
<div className="flex items-center gap-2 mb-2">
<CheckCircle2 className="size-3.5 text-emerald-500" />
<span className="text-[11px] font-medium">Ready to index</span>
<span className="text-[11px] font-medium">Accessible</span>
<span className="text-[10px] text-muted-foreground">
{readyToIndex.length} {readyToIndex.length === 1 ? "channel" : "channels"}
{accessible.length} {accessible.length === 1 ? "channel" : "channels"}
</span>
</div>
<div className="flex flex-wrap gap-1.5">
{readyToIndex.map((channel) => (
{accessible.map((channel) => (
<ChannelPill key={channel.id} channel={channel} />
))}
</div>
@ -150,7 +149,7 @@ export const DiscordConfig: FC<DiscordConfigProps> = ({ connector }) => {
<div className="p-3">
<div className="flex items-center gap-2 mb-2">
<AlertCircle className="size-3.5 text-amber-500" />
<span className="text-[11px] font-medium">Grant permissions to index</span>
<span className="text-[11px] font-medium">Needs permissions</span>
<span className="text-[10px] text-muted-foreground">
{needsPermissions.length}{" "}
{needsPermissions.length === 1 ? "channel" : "channels"}

View file

@ -0,0 +1,28 @@
"use client";
import { CheckCircle2 } from "lucide-react";
import type { FC } from "react";
import type { ConnectorConfigProps } from "../index";
export const MCPServiceConfig: FC<ConnectorConfigProps> = ({ connector }) => {
const serviceName = connector.config?.mcp_service as string | undefined;
const displayName = serviceName
? serviceName.charAt(0).toUpperCase() + serviceName.slice(1)
: "this service";
return (
<div className="space-y-4">
<div className="rounded-xl border border-border bg-emerald-500/5 p-4 flex items-start gap-3">
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-emerald-500/10 shrink-0 mt-0.5">
<CheckCircle2 className="size-4 text-emerald-500" />
</div>
<div className="text-xs sm:text-sm">
<p className="font-medium text-xs sm:text-sm">Connected</p>
<p className="text-muted-foreground mt-1 text-[10px] sm:text-sm">
Your agent can search, read, and take actions in {displayName}.
</p>
</div>
</div>
</div>
);
};

View file

@ -18,9 +18,9 @@ export const TeamsConfig: FC<TeamsConfigProps> = () => {
<div className="text-xs sm:text-sm">
<p className="font-medium text-xs sm:text-sm">Microsoft Teams Access</p>
<p className="text-muted-foreground mt-1 text-[10px] sm:text-sm">
SurfSense will index messages from Teams channels that you have access to. The app can
only read messages from teams and channels where you are a member. Make sure you're a
member of the teams you want to index before connecting.
Your agent can search and read messages from Teams channels you have access to,
and send messages on your behalf. Make sure you&#39;re a member of the teams
you want to interact with.
</p>
</div>
</div>

View file

@ -16,8 +16,9 @@ import { DateRangeSelector } from "../../components/date-range-selector";
import { PeriodicSyncConfig } from "../../components/periodic-sync-config";
import { SummaryConfig } from "../../components/summary-config";
import { VisionLLMConfig } from "../../components/vision-llm-config";
import { LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants";
import { getConnectorDisplayName } from "../../tabs/all-connectors-tab";
import { getConnectorConfigComponent } from "../index";
import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index";
const REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
@ -118,11 +119,17 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
}
}, [searchSpaceId, searchSpaceIdAtom, reauthEndpoint, connector.id]);
// Get connector-specific config component
const ConnectorConfigComponent = useMemo(
() => getConnectorConfigComponent(connector.connector_type),
[connector.connector_type]
);
const isMCPBacked = Boolean(connector.config?.server_config);
const isLive = isMCPBacked || LIVE_CONNECTOR_TYPES.has(connector.connector_type);
// Get connector-specific config component (MCP-backed connectors use a generic view)
const ConnectorConfigComponent = useMemo(() => {
if (isMCPBacked) {
const { MCPServiceConfig } = require("../components/mcp-service-config");
return MCPServiceConfig as FC<ConnectorConfigProps>;
}
return getConnectorConfigComponent(connector.connector_type);
}, [connector.connector_type, isMCPBacked]);
const [isScrolled, setIsScrolled] = useState(false);
const [hasMoreContent, setHasMoreContent] = useState(false);
const [showDisconnectConfirm, setShowDisconnectConfirm] = useState(false);
@ -223,12 +230,14 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
{getConnectorDisplayName(connector.name)}
</h2>
<p className="text-xs sm:text-base text-muted-foreground mt-1">
Manage your connector settings and sync configuration
{isLive
? "Manage your connected account"
: "Manage your connector settings and sync configuration"}
</p>
</div>
</div>
{/* Quick Index Button - hidden when auth is expired */}
{connector.is_indexable && onQuickIndex && !isAuthExpired && (
{/* Quick Index Button - hidden for live connectors and when auth is expired */}
{connector.is_indexable && !isLive && onQuickIndex && !isAuthExpired && (
<Button
variant="secondary"
size="sm"
@ -271,8 +280,8 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
/>
)}
{/* Summary and sync settings - only shown for indexable connectors */}
{connector.is_indexable && (
{/* Summary and sync settings - hidden for live connectors */}
{connector.is_indexable && !isLive && (
<>
{/* AI Summary toggle */}
<SummaryConfig enabled={enableSummary} onEnabledChange={onEnableSummaryChange} />
@ -343,8 +352,8 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
</>
)}
{/* Info box - only shown for indexable connectors */}
{connector.is_indexable && (
{/* Info box - hidden for live connectors */}
{connector.is_indexable && !isLive && (
<div className="rounded-xl border border-border bg-primary/5 p-4 flex items-start gap-3">
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-primary/10 shrink-0 mt-0.5">
<Info className="size-4" />
@ -374,10 +383,12 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
{/* Fixed Footer - Action buttons */}
<div className="flex-shrink-0 flex flex-col sm:flex-row items-stretch sm:items-center justify-between gap-3 sm:gap-0 px-6 sm:px-12 py-6 sm:py-6 bg-muted border-t border-border">
{showDisconnectConfirm ? (
<div className="flex flex-col sm:flex-row items-stretch sm:items-center gap-3 flex-1 sm:flex-initial">
{showDisconnectConfirm ? (
<div className="flex flex-col sm:flex-row items-stretch sm:items-center gap-3 flex-1 sm:flex-initial">
<span className="text-xs sm:text-sm text-muted-foreground sm:whitespace-nowrap">
Are you sure?
{isLive
? "Your agent will lose access to this service."
: "This will remove all indexed data."}
</span>
<div className="flex items-center gap-2 sm:gap-3">
<Button
@ -421,7 +432,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
<RefreshCw className={cn("size-3.5", reauthing && "animate-spin")} />
Re-authenticate
</Button>
) : (
) : !isLive ? (
<Button
onClick={onSave}
disabled={isSaving || isDisconnecting}
@ -430,7 +441,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
<span className={isSaving ? "opacity-0" : ""}>Save Changes</span>
{isSaving && <Spinner size="sm" className="absolute" />}
</Button>
)}
) : null}
</div>
</div>
);

View file

@ -11,7 +11,7 @@ import { DateRangeSelector } from "../../components/date-range-selector";
import { PeriodicSyncConfig } from "../../components/periodic-sync-config";
import { SummaryConfig } from "../../components/summary-config";
import { VisionLLMConfig } from "../../components/vision-llm-config";
import type { IndexingConfigState } from "../../constants/connector-constants";
import { LIVE_CONNECTOR_TYPES, type IndexingConfigState } from "../../constants/connector-constants";
import { getConnectorDisplayName } from "../../tabs/all-connectors-tab";
import { getConnectorConfigComponent } from "../index";
@ -58,6 +58,8 @@ export const IndexingConfigurationView: FC<IndexingConfigurationViewProps> = ({
onStartIndexing,
onSkip,
}) => {
const isLive = LIVE_CONNECTOR_TYPES.has(config.connectorType);
// Get connector-specific config component
const ConnectorConfigComponent = useMemo(
() => (connector ? getConnectorConfigComponent(connector.connector_type) : null),
@ -138,7 +140,9 @@ export const IndexingConfigurationView: FC<IndexingConfigurationViewProps> = ({
)}
</div>
<p className="text-xs sm:text-base text-muted-foreground mt-1">
Configure when to start syncing your data
{isLive
? "Your account is ready to use"
: "Configure when to start syncing your data"}
</p>
</div>
</div>
@ -157,8 +161,8 @@ export const IndexingConfigurationView: FC<IndexingConfigurationViewProps> = ({
<ConnectorConfigComponent connector={connector} onConfigChange={onConfigChange} />
)}
{/* Summary and sync settings - only shown for indexable connectors */}
{connector?.is_indexable && (
{/* Summary and sync settings - hidden for live connectors */}
{connector?.is_indexable && !isLive && (
<>
{/* AI Summary toggle */}
<SummaryConfig enabled={enableSummary} onEnabledChange={onEnableSummaryChange} />
@ -209,8 +213,8 @@ export const IndexingConfigurationView: FC<IndexingConfigurationViewProps> = ({
</>
)}
{/* Info box - only shown for indexable connectors */}
{connector?.is_indexable && (
{/* Info box - hidden for live connectors */}
{connector?.is_indexable && !isLive && (
<div className="rounded-xl border border-border bg-primary/5 p-4 flex items-start gap-3">
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-primary/10 shrink-0 mt-0.5">
<Info className="size-4" />
@ -238,14 +242,20 @@ export const IndexingConfigurationView: FC<IndexingConfigurationViewProps> = ({
{/* Fixed Footer - Action buttons */}
<div className="flex-shrink-0 flex items-center justify-end px-6 sm:px-12 py-6 bg-muted">
<Button
onClick={onStartIndexing}
disabled={isStartingIndexing}
className="text-xs sm:text-sm relative"
>
<span className={isStartingIndexing ? "opacity-0" : ""}>Start Indexing</span>
{isStartingIndexing && <Spinner size="sm" className="absolute" />}
</Button>
{isLive ? (
<Button onClick={onSkip} className="text-xs sm:text-sm">
Done
</Button>
) : (
<Button
onClick={onStartIndexing}
disabled={isStartingIndexing}
className="text-xs sm:text-sm relative"
>
<span className={isStartingIndexing ? "opacity-0" : ""}>Start Indexing</span>
{isStartingIndexing && <Spinner size="sm" className="absolute" />}
</Button>
)}
</div>
</div>
);

View file

@ -1,5 +1,24 @@
import { EnumConnectorName } from "@/contracts/enums/connector";
/**
* Connectors that operate in real time (no background indexing).
* Used to adjust UI: hide sync controls, show "Connected" instead of doc counts.
*/
export const LIVE_CONNECTOR_TYPES = new Set<string>([
EnumConnectorName.LINEAR_CONNECTOR,
EnumConnectorName.SLACK_CONNECTOR,
EnumConnectorName.JIRA_CONNECTOR,
EnumConnectorName.CLICKUP_CONNECTOR,
EnumConnectorName.AIRTABLE_CONNECTOR,
EnumConnectorName.DISCORD_CONNECTOR,
EnumConnectorName.TEAMS_CONNECTOR,
EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR,
EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
EnumConnectorName.GOOGLE_GMAIL_CONNECTOR,
EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR,
EnumConnectorName.LUMA_CONNECTOR,
]);
// OAuth Connectors (Quick Connect)
export const OAUTH_CONNECTORS = [
{
@ -13,7 +32,7 @@ export const OAUTH_CONNECTORS = [
{
id: "google-gmail-connector",
title: "Gmail",
description: "Search through your emails",
description: "Search, read, draft, and send emails",
connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR,
authEndpoint: "/api/v1/auth/google/gmail/connector/add/",
selfHostedOnly: true,
@ -21,7 +40,7 @@ export const OAUTH_CONNECTORS = [
{
id: "google-calendar-connector",
title: "Google Calendar",
description: "Search through your events",
description: "Search and manage your events",
connectorType: EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR,
authEndpoint: "/api/v1/auth/google/calendar/connector/add/",
selfHostedOnly: true,
@ -29,35 +48,35 @@ export const OAUTH_CONNECTORS = [
{
id: "airtable-connector",
title: "Airtable",
description: "Search your Airtable bases",
description: "Browse bases, tables, and records",
connectorType: EnumConnectorName.AIRTABLE_CONNECTOR,
authEndpoint: "/api/v1/auth/airtable/connector/add/",
authEndpoint: "/api/v1/auth/mcp/airtable/connector/add/",
},
{
id: "notion-connector",
title: "Notion",
description: "Search your Notion pages",
connectorType: EnumConnectorName.NOTION_CONNECTOR,
authEndpoint: "/api/v1/auth/notion/connector/add/",
authEndpoint: "/api/v1/auth/notion/connector/add",
},
{
id: "linear-connector",
title: "Linear",
description: "Search issues & projects",
description: "Search, read, and manage issues & projects",
connectorType: EnumConnectorName.LINEAR_CONNECTOR,
authEndpoint: "/api/v1/auth/linear/connector/add/",
authEndpoint: "/api/v1/auth/mcp/linear/connector/add/",
},
{
id: "slack-connector",
title: "Slack",
description: "Search Slack messages",
description: "Search and read channels and threads",
connectorType: EnumConnectorName.SLACK_CONNECTOR,
authEndpoint: "/api/v1/auth/slack/connector/add/",
authEndpoint: "/api/v1/auth/mcp/slack/connector/add/",
},
{
id: "teams-connector",
title: "Microsoft Teams",
description: "Search Teams messages",
description: "Search, read, and send messages",
connectorType: EnumConnectorName.TEAMS_CONNECTOR,
authEndpoint: "/api/v1/auth/teams/connector/add/",
},
@ -78,16 +97,16 @@ export const OAUTH_CONNECTORS = [
{
id: "discord-connector",
title: "Discord",
description: "Search Discord messages",
description: "Search, read, and send messages",
connectorType: EnumConnectorName.DISCORD_CONNECTOR,
authEndpoint: "/api/v1/auth/discord/connector/add/",
},
{
id: "jira-connector",
title: "Jira",
description: "Search Jira issues",
description: "Search, read, and manage issues",
connectorType: EnumConnectorName.JIRA_CONNECTOR,
authEndpoint: "/api/v1/auth/jira/connector/add/",
authEndpoint: "/api/v1/auth/mcp/jira/connector/add/",
},
{
id: "confluence-connector",
@ -99,9 +118,9 @@ export const OAUTH_CONNECTORS = [
{
id: "clickup-connector",
title: "ClickUp",
description: "Search ClickUp tasks",
description: "Search and read tasks",
connectorType: EnumConnectorName.CLICKUP_CONNECTOR,
authEndpoint: "/api/v1/auth/clickup/connector/add/",
authEndpoint: "/api/v1/auth/mcp/clickup/connector/add/",
},
] as const;
@ -138,7 +157,7 @@ export const OTHER_CONNECTORS = [
{
id: "luma-connector",
title: "Luma",
description: "Search Luma events",
description: "Browse, read, and create events",
connectorType: EnumConnectorName.LUMA_CONNECTOR,
},
{
@ -197,14 +216,14 @@ export const COMPOSIO_CONNECTORS = [
{
id: "composio-gmail",
title: "Gmail",
description: "Search through your emails via Composio",
description: "Search, read, draft, and send emails via Composio",
connectorType: EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR,
authEndpoint: "/api/v1/auth/composio/connector/add/?toolkit_id=gmail",
},
{
id: "composio-googlecalendar",
title: "Google Calendar",
description: "Search through your events via Composio",
description: "Search and manage your events via Composio",
connectorType: EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
authEndpoint: "/api/v1/auth/composio/connector/add/?toolkit_id=googlecalendar",
},
@ -221,14 +240,14 @@ export const COMPOSIO_TOOLKITS = [
{
id: "gmail",
name: "Gmail",
description: "Search through your emails",
isIndexable: true,
description: "Search, read, draft, and send emails",
isIndexable: false,
},
{
id: "googlecalendar",
name: "Google Calendar",
description: "Search through your events",
isIndexable: true,
description: "Search and manage your events",
isIndexable: false,
},
{
id: "slack",
@ -258,66 +277,6 @@ export interface AutoIndexConfig {
}
export const AUTO_INDEX_DEFAULTS: Record<string, AutoIndexConfig> = {
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: {
daysBack: 30,
daysForward: 0,
frequencyMinutes: 1440,
syncDescription: "Syncing your last 30 days of emails.",
},
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: {
daysBack: 30,
daysForward: 0,
frequencyMinutes: 1440,
syncDescription: "Syncing your last 30 days of emails.",
},
[EnumConnectorName.SLACK_CONNECTOR]: {
daysBack: 30,
daysForward: 0,
frequencyMinutes: 1440,
syncDescription: "Syncing your last 30 days of messages.",
},
[EnumConnectorName.DISCORD_CONNECTOR]: {
daysBack: 30,
daysForward: 0,
frequencyMinutes: 1440,
syncDescription: "Syncing your last 30 days of messages.",
},
[EnumConnectorName.TEAMS_CONNECTOR]: {
daysBack: 30,
daysForward: 0,
frequencyMinutes: 1440,
syncDescription: "Syncing your last 30 days of messages.",
},
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: {
daysBack: 90,
daysForward: 90,
frequencyMinutes: 1440,
syncDescription: "Syncing 90 days of past and upcoming events.",
},
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: {
daysBack: 90,
daysForward: 90,
frequencyMinutes: 1440,
syncDescription: "Syncing 90 days of past and upcoming events.",
},
[EnumConnectorName.LINEAR_CONNECTOR]: {
daysBack: 90,
daysForward: 0,
frequencyMinutes: 1440,
syncDescription: "Syncing your last 90 days of issues.",
},
[EnumConnectorName.JIRA_CONNECTOR]: {
daysBack: 90,
daysForward: 0,
frequencyMinutes: 1440,
syncDescription: "Syncing your last 90 days of issues.",
},
[EnumConnectorName.CLICKUP_CONNECTOR]: {
daysBack: 90,
daysForward: 0,
frequencyMinutes: 1440,
syncDescription: "Syncing your last 90 days of tasks.",
},
[EnumConnectorName.NOTION_CONNECTOR]: {
daysBack: 365,
daysForward: 0,
@ -330,12 +289,6 @@ export const AUTO_INDEX_DEFAULTS: Record<string, AutoIndexConfig> = {
frequencyMinutes: 1440,
syncDescription: "Syncing your documentation.",
},
[EnumConnectorName.AIRTABLE_CONNECTOR]: {
daysBack: 365,
daysForward: 0,
frequencyMinutes: 1440,
syncDescription: "Syncing your bases.",
},
};
export const AUTO_INDEX_CONNECTOR_TYPES = new Set<string>(Object.keys(AUTO_INDEX_DEFAULTS));

View file

@ -38,6 +38,7 @@ import {
AUTO_INDEX_CONNECTOR_TYPES,
AUTO_INDEX_DEFAULTS,
COMPOSIO_CONNECTORS,
LIVE_CONNECTOR_TYPES,
OAUTH_CONNECTORS,
OTHER_CONNECTORS,
} from "../constants/connector-constants";
@ -317,7 +318,12 @@ export const useConnectorDialog = () => {
newConnector.id
);
if (
const isLiveConnector = LIVE_CONNECTOR_TYPES.has(oauthConnector.connectorType);
if (isLiveConnector) {
toast.success(`${oauthConnector.title} connected successfully!`);
await refetchAllConnectors();
} else if (
newConnector.is_indexable &&
AUTO_INDEX_CONNECTOR_TYPES.has(oauthConnector.connectorType)
) {
@ -326,6 +332,9 @@ export const useConnectorDialog = () => {
oauthConnector.title,
oauthConnector.connectorType
);
} else if (!newConnector.is_indexable) {
toast.success(`${oauthConnector.title} connected successfully!`);
await refetchAllConnectors();
} else {
toast.dismiss("auto-index");
const config = validateIndexingConfigState({

View file

@ -9,7 +9,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels";
import { cn } from "@/lib/utils";
import { COMPOSIO_CONNECTORS, OAUTH_CONNECTORS } from "../constants/connector-constants";
import { COMPOSIO_CONNECTORS, LIVE_CONNECTOR_TYPES, OAUTH_CONNECTORS } from "../constants/connector-constants";
import { getDocumentCountForConnector } from "../utils/connector-document-mapping";
import { getConnectorDisplayName } from "./all-connectors-tab";
@ -156,6 +156,7 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
{/* OAuth Connectors - Grouped by Type */}
{filteredOAuthConnectorTypes.map(([connectorType, typeConnectors]) => {
const { title } = getOAuthConnectorTypeInfo(connectorType);
const isLive = LIVE_CONNECTOR_TYPES.has(connectorType);
const isAnyIndexing = typeConnectors.some((c: SearchSourceConnector) =>
indexingConnectorIds.has(c.id)
);
@ -202,8 +203,12 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
</p>
) : (
<p className="text-[10px] text-muted-foreground mt-1 flex items-center gap-1.5">
<span>{formatDocumentCount(documentCount)}</span>
<span className="text-muted-foreground/50"></span>
{!isLive && (
<>
<span>{formatDocumentCount(documentCount)}</span>
<span className="text-muted-foreground/50"></span>
</>
)}
<span>
{accountCount} {accountCount === 1 ? "Account" : "Accounts"}
</span>
@ -230,6 +235,7 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
documentTypeCounts
);
const isMCPConnector = connector.connector_type === "MCP_CONNECTOR";
const isLive = LIVE_CONNECTOR_TYPES.has(connector.connector_type);
return (
<div
key={`connector-${connector.id}`}
@ -261,7 +267,7 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
<Spinner size="xs" />
Syncing
</p>
) : !isMCPConnector ? (
) : !isLive && !isMCPConnector ? (
<p className="text-[10px] text-muted-foreground mt-1">
{formatDocumentCount(documentCount)}
</p>

View file

@ -13,6 +13,7 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types";
import { authenticatedFetch } from "@/lib/auth-utils";
import { formatRelativeDate } from "@/lib/format-date";
import { cn } from "@/lib/utils";
import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants";
import { useConnectorStatus } from "../hooks/use-connector-status";
import { getConnectorDisplayName } from "../tabs/all-connectors-tab";
@ -43,12 +44,8 @@ interface ConnectorAccountsListViewProps {
addButtonText?: string;
}
/**
* Check if a connector type is indexable
*/
function isIndexableConnector(connectorType: string): boolean {
const nonIndexableTypes = ["MCP_CONNECTOR"];
return !nonIndexableTypes.includes(connectorType);
function isLiveConnector(connectorType: string): boolean {
return LIVE_CONNECTOR_TYPES.has(connectorType) || connectorType === "MCP_CONNECTOR";
}
export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
@ -149,7 +146,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
{connectorTitle}
</h2>
<p className="text-xs sm:text-base text-muted-foreground mt-1">
{statusMessage || "Manage your connector settings and sync configuration"}
{statusMessage || "Manage your connected accounts"}
</p>
</div>
</div>
@ -234,15 +231,13 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
<Spinner size="xs" />
Syncing
</p>
) : (
<p className="text-[10px] text-muted-foreground mt-1 whitespace-nowrap truncate">
{isIndexableConnector(connector.connector_type)
? connector.last_indexed_at
? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}`
: "Never indexed"
: "Active"}
) : !isLiveConnector(connector.connector_type) ? (
<p className="text-[10px] mt-1 whitespace-nowrap truncate text-muted-foreground">
{connector.last_indexed_at
? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}`
: "Never indexed"}
</p>
)}
) : null}
</div>
{isAuthExpired ? (
<Button