mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
Merge pull request #1294 from CREDO23/feature/mcp-migration
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
Some checks are pending
Build and Push Docker Images / tag_release (push) Waiting to run
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_backend, ./surfsense_backend/Dockerfile, backend, surfsense-backend, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-24.04-arm, linux/arm64, arm64) (push) Blocked by required conditions
Build and Push Docker Images / build (./surfsense_web, ./surfsense_web/Dockerfile, web, surfsense-web, ubuntu-latest, linux/amd64, amd64) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (backend, surfsense-backend) (push) Blocked by required conditions
Build and Push Docker Images / create_manifest (web, surfsense-web) (push) Blocked by required conditions
[FEAT] Live connector tools via MCP OAuth and native APIs
This commit is contained in:
commit
7245ab4046
65 changed files with 4175 additions and 1463 deletions
|
|
@ -47,7 +47,7 @@ from app.agents.new_chat.system_prompt import (
|
|||
build_configurable_system_prompt,
|
||||
build_surfsense_system_prompt,
|
||||
)
|
||||
from app.agents.new_chat.tools.registry import build_tools_async
|
||||
from app.agents.new_chat.tools.registry import build_tools_async, get_connector_gated_tools
|
||||
from app.db import ChatVisibility
|
||||
from app.services.connector_service import ConnectorService
|
||||
from app.utils.perf import get_perf_logger
|
||||
|
|
@ -287,105 +287,10 @@ async def create_surfsense_deep_agent(
|
|||
"llm": llm,
|
||||
}
|
||||
|
||||
# Disable Notion action tools if no Notion connector is configured
|
||||
modified_disabled_tools = list(disabled_tools) if disabled_tools else []
|
||||
has_notion_connector = (
|
||||
available_connectors is not None and "NOTION_CONNECTOR" in available_connectors
|
||||
modified_disabled_tools.extend(
|
||||
get_connector_gated_tools(available_connectors)
|
||||
)
|
||||
if not has_notion_connector:
|
||||
notion_tools = [
|
||||
"create_notion_page",
|
||||
"update_notion_page",
|
||||
"delete_notion_page",
|
||||
]
|
||||
modified_disabled_tools.extend(notion_tools)
|
||||
|
||||
# Disable Linear action tools if no Linear connector is configured
|
||||
has_linear_connector = (
|
||||
available_connectors is not None and "LINEAR_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_linear_connector:
|
||||
linear_tools = [
|
||||
"create_linear_issue",
|
||||
"update_linear_issue",
|
||||
"delete_linear_issue",
|
||||
]
|
||||
modified_disabled_tools.extend(linear_tools)
|
||||
|
||||
# Disable Google Drive action tools if no Google Drive connector is configured
|
||||
has_google_drive_connector = (
|
||||
available_connectors is not None and "GOOGLE_DRIVE_FILE" in available_connectors
|
||||
)
|
||||
if not has_google_drive_connector:
|
||||
google_drive_tools = [
|
||||
"create_google_drive_file",
|
||||
"delete_google_drive_file",
|
||||
]
|
||||
modified_disabled_tools.extend(google_drive_tools)
|
||||
|
||||
has_dropbox_connector = (
|
||||
available_connectors is not None and "DROPBOX_FILE" in available_connectors
|
||||
)
|
||||
if not has_dropbox_connector:
|
||||
modified_disabled_tools.extend(["create_dropbox_file", "delete_dropbox_file"])
|
||||
|
||||
has_onedrive_connector = (
|
||||
available_connectors is not None and "ONEDRIVE_FILE" in available_connectors
|
||||
)
|
||||
if not has_onedrive_connector:
|
||||
modified_disabled_tools.extend(["create_onedrive_file", "delete_onedrive_file"])
|
||||
|
||||
# Disable Google Calendar action tools if no Google Calendar connector is configured
|
||||
has_google_calendar_connector = (
|
||||
available_connectors is not None
|
||||
and "GOOGLE_CALENDAR_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_google_calendar_connector:
|
||||
calendar_tools = [
|
||||
"create_calendar_event",
|
||||
"update_calendar_event",
|
||||
"delete_calendar_event",
|
||||
]
|
||||
modified_disabled_tools.extend(calendar_tools)
|
||||
|
||||
# Disable Gmail action tools if no Gmail connector is configured
|
||||
has_gmail_connector = (
|
||||
available_connectors is not None
|
||||
and "GOOGLE_GMAIL_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_gmail_connector:
|
||||
gmail_tools = [
|
||||
"create_gmail_draft",
|
||||
"update_gmail_draft",
|
||||
"send_gmail_email",
|
||||
"trash_gmail_email",
|
||||
]
|
||||
modified_disabled_tools.extend(gmail_tools)
|
||||
|
||||
# Disable Jira action tools if no Jira connector is configured
|
||||
has_jira_connector = (
|
||||
available_connectors is not None and "JIRA_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_jira_connector:
|
||||
jira_tools = [
|
||||
"create_jira_issue",
|
||||
"update_jira_issue",
|
||||
"delete_jira_issue",
|
||||
]
|
||||
modified_disabled_tools.extend(jira_tools)
|
||||
|
||||
# Disable Confluence action tools if no Confluence connector is configured
|
||||
has_confluence_connector = (
|
||||
available_connectors is not None
|
||||
and "CONFLUENCE_CONNECTOR" in available_connectors
|
||||
)
|
||||
if not has_confluence_connector:
|
||||
confluence_tools = [
|
||||
"create_confluence_page",
|
||||
"update_confluence_page",
|
||||
"delete_confluence_page",
|
||||
]
|
||||
modified_disabled_tools.extend(confluence_tools)
|
||||
|
||||
# Remove direct KB search tool; we now pre-seed a scoped filesystem via middleware.
|
||||
if "search_knowledge_base" not in modified_disabled_tools:
|
||||
|
|
|
|||
|
|
@ -38,8 +38,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
|||
* Formatting, summarization, or analysis of content already present in the conversation
|
||||
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see <tool_routing> below
|
||||
</knowledge_base_only_policy>
|
||||
|
||||
<tool_routing>
|
||||
CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
|
||||
Their data is NEVER in the knowledge base. You MUST call their tools immediately — never
|
||||
say "I don't see it in the knowledge base" or ask the user if they want you to check.
|
||||
Ignore any knowledge base results for these services.
|
||||
|
||||
When to use which tool:
|
||||
- Linear (issues) → list_issues, get_issue, save_issue (create/update)
|
||||
- ClickUp (tasks) → clickup_search, clickup_get_task
|
||||
- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
|
||||
- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread
|
||||
- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table
|
||||
- Knowledge base content (Notion, GitHub, files, notes) → automatically searched
|
||||
- Real-time public web data → call web_search
|
||||
- Reading a specific webpage → call scrape_webpage
|
||||
</tool_routing>
|
||||
|
||||
<parameter_resolution>
|
||||
Some service tools require identifiers or context you do not have (account IDs,
|
||||
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
|
||||
IDs or technical identifiers — they cannot memorise them.
|
||||
|
||||
Instead, follow this discovery pattern:
|
||||
1. Call a listing/discovery tool to find available options.
|
||||
2. ONE result → use it silently, no question to the user.
|
||||
3. MULTIPLE results → present the options by their display names and let the
|
||||
user choose. Never show raw UUIDs — always use friendly names.
|
||||
|
||||
Discovery tools by level:
|
||||
- Which account/workspace? → get_connected_accounts("<service>")
|
||||
- Which Jira site (cloudId)? → getAccessibleAtlassianResources
|
||||
- Which Jira project? → getVisibleJiraProjects (after resolving cloudId)
|
||||
- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project)
|
||||
- Which channel? → slack_search_channels
|
||||
- Which base? → list_bases
|
||||
- Which table? → list_tables_for_base (after resolving baseId)
|
||||
- Which task? → clickup_search
|
||||
- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
|
||||
|
||||
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
|
||||
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
|
||||
chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue.
|
||||
If there is only one option at each step, use it silently. If multiple, present
|
||||
friendly names.
|
||||
|
||||
Chain discovery when needed — e.g. for Airtable records: list_bases → pick
|
||||
base → list_tables_for_base → pick table → list_records_for_table.
|
||||
|
||||
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
|
||||
the same service, tool names are prefixed to avoid collisions — e.g.
|
||||
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
|
||||
Each prefixed tool's description starts with [Account: <display_name>] so you
|
||||
know which account it targets. Use get_connected_accounts("<service>") to see
|
||||
the full list of accounts with their connector IDs and display names.
|
||||
When only one account is connected, tools have their normal unprefixed names.
|
||||
</parameter_resolution>
|
||||
|
||||
<memory_protocol>
|
||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||
reveal durable facts about the user (role, interests, preferences, projects,
|
||||
|
|
@ -76,8 +134,66 @@ CRITICAL RULE — KNOWLEDGE BASE FIRST, NEVER DEFAULT TO GENERAL KNOWLEDGE:
|
|||
* Formatting, summarization, or analysis of content already present in the conversation
|
||||
* Following user instructions that are clearly task-oriented (e.g., "rewrite this in bullet points")
|
||||
* Tool-usage actions like generating reports, podcasts, images, or scraping webpages
|
||||
* Queries about services that have direct tools (Linear, ClickUp, Jira, Slack, Airtable) — see <tool_routing> below
|
||||
</knowledge_base_only_policy>
|
||||
|
||||
<tool_routing>
|
||||
CRITICAL — You have direct tools for these services: Linear, ClickUp, Jira, Slack, Airtable.
|
||||
Their data is NEVER in the knowledge base. You MUST call their tools immediately — never
|
||||
say "I don't see it in the knowledge base" or ask if they want you to check.
|
||||
Ignore any knowledge base results for these services.
|
||||
|
||||
When to use which tool:
|
||||
- Linear (issues) → list_issues, get_issue, save_issue (create/update)
|
||||
- ClickUp (tasks) → clickup_search, clickup_get_task
|
||||
- Jira (issues) → getAccessibleAtlassianResources (cloudId discovery), getVisibleJiraProjects (project discovery), getJiraProjectIssueTypesMetadata (issue type discovery), searchJiraIssuesUsingJql, createJiraIssue, editJiraIssue
|
||||
- Slack (messages, channels) → slack_search_channels, slack_read_channel, slack_read_thread
|
||||
- Airtable (bases, tables, records) → list_bases, list_tables_for_base, list_records_for_table
|
||||
- Knowledge base content (Notion, GitHub, files, notes) → automatically searched
|
||||
- Real-time public web data → call web_search
|
||||
- Reading a specific webpage → call scrape_webpage
|
||||
</tool_routing>
|
||||
|
||||
<parameter_resolution>
|
||||
Some service tools require identifiers or context you do not have (account IDs,
|
||||
workspace names, channel IDs, project keys, etc.). NEVER ask the user for raw
|
||||
IDs or technical identifiers — they cannot memorise them.
|
||||
|
||||
Instead, follow this discovery pattern:
|
||||
1. Call a listing/discovery tool to find available options.
|
||||
2. ONE result → use it silently, no question to the user.
|
||||
3. MULTIPLE results → present the options by their display names and let the
|
||||
user choose. Never show raw UUIDs — always use friendly names.
|
||||
|
||||
Discovery tools by level:
|
||||
- Which account/workspace? → get_connected_accounts("<service>")
|
||||
- Which Jira site (cloudId)? → getAccessibleAtlassianResources
|
||||
- Which Jira project? → getVisibleJiraProjects (after resolving cloudId)
|
||||
- Which Jira issue type? → getJiraProjectIssueTypesMetadata (after resolving project)
|
||||
- Which channel? → slack_search_channels
|
||||
- Which base? → list_bases
|
||||
- Which table? → list_tables_for_base (after resolving baseId)
|
||||
- Which task? → clickup_search
|
||||
- Which issue? → list_issues (Linear) or searchJiraIssuesUsingJql (Jira)
|
||||
|
||||
For Jira specifically: ALWAYS call getAccessibleAtlassianResources first to
|
||||
obtain the cloudId, then pass it to other Jira tools. When creating an issue,
|
||||
chain: getAccessibleAtlassianResources → getVisibleJiraProjects → createJiraIssue.
|
||||
If there is only one option at each step, use it silently. If multiple, present
|
||||
friendly names.
|
||||
|
||||
Chain discovery when needed — e.g. for Airtable records: list_bases → pick
|
||||
base → list_tables_for_base → pick table → list_records_for_table.
|
||||
|
||||
MULTI-ACCOUNT TOOL NAMING: When the user has multiple accounts connected for
|
||||
the same service, tool names are prefixed to avoid collisions — e.g.
|
||||
linear_25_list_issues and linear_30_list_issues instead of two list_issues.
|
||||
Each prefixed tool's description starts with [Account: <display_name>] so you
|
||||
know which account it targets. Use get_connected_accounts("<service>") to see
|
||||
the full list of accounts with their connector IDs and display names.
|
||||
When only one account is connected, tools have their normal unprefixed names.
|
||||
</parameter_resolution>
|
||||
|
||||
<memory_protocol>
|
||||
IMPORTANT — After understanding each user message, ALWAYS check: does this message
|
||||
reveal durable facts about the team (decisions, conventions, architecture, processes,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,109 @@
|
|||
"""Connected-accounts discovery tool.
|
||||
|
||||
Lets the LLM discover which accounts are connected for a given service
|
||||
(e.g. "jira", "linear", "slack") and retrieve the metadata it needs to
|
||||
call action tools — such as Jira's ``cloudId``.
|
||||
|
||||
The tool returns **only** non-sensitive fields explicitly listed in the
|
||||
service's ``account_metadata_keys`` (see ``registry.py``), plus the
|
||||
always-present ``display_name`` and ``connector_id``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.services.mcp_oauth.registry import MCP_SERVICES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SERVICE_KEY_BY_CONNECTOR_TYPE: dict[str, str] = {
|
||||
cfg.connector_type: key for key, cfg in MCP_SERVICES.items()
|
||||
}
|
||||
|
||||
|
||||
class GetConnectedAccountsInput(BaseModel):
|
||||
service: str = Field(
|
||||
description=(
|
||||
"Service key to look up connected accounts for. "
|
||||
"Valid values: " + ", ".join(sorted(MCP_SERVICES.keys()))
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _extract_display_name(connector: SearchSourceConnector) -> str:
|
||||
"""Best-effort human-readable label for a connector."""
|
||||
cfg = connector.config or {}
|
||||
if cfg.get("display_name"):
|
||||
return cfg["display_name"]
|
||||
if cfg.get("base_url"):
|
||||
return f"{connector.name} ({cfg['base_url']})"
|
||||
if cfg.get("organization_name"):
|
||||
return f"{connector.name} ({cfg['organization_name']})"
|
||||
return connector.name
|
||||
|
||||
|
||||
def create_get_connected_accounts_tool(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> StructuredTool:
|
||||
|
||||
async def _run(service: str) -> list[dict[str, Any]]:
|
||||
svc_cfg = MCP_SERVICES.get(service)
|
||||
if not svc_cfg:
|
||||
return [{"error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}"}]
|
||||
|
||||
try:
|
||||
connector_type = SearchSourceConnectorType(svc_cfg.connector_type)
|
||||
except ValueError:
|
||||
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
|
||||
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == connector_type,
|
||||
)
|
||||
)
|
||||
connectors = result.scalars().all()
|
||||
|
||||
if not connectors:
|
||||
return [{"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."}]
|
||||
|
||||
is_multi = len(connectors) > 1
|
||||
|
||||
accounts: list[dict[str, Any]] = []
|
||||
for conn in connectors:
|
||||
cfg = conn.config or {}
|
||||
entry: dict[str, Any] = {
|
||||
"connector_id": conn.id,
|
||||
"display_name": _extract_display_name(conn),
|
||||
"service": service,
|
||||
}
|
||||
if is_multi:
|
||||
entry["tool_prefix"] = f"{service}_{conn.id}"
|
||||
for key in svc_cfg.account_metadata_keys:
|
||||
if key in cfg:
|
||||
entry[key] = cfg[key]
|
||||
accounts.append(entry)
|
||||
|
||||
return accounts
|
||||
|
||||
return StructuredTool(
|
||||
name="get_connected_accounts",
|
||||
description=(
|
||||
"Discover which accounts are connected for a service (e.g. jira, linear, slack, clickup, airtable). "
|
||||
"Returns display names and service-specific metadata the action tools need "
|
||||
"(e.g. Jira's cloudId). Call this BEFORE using a service's action tools when "
|
||||
"you need an account identifier or are unsure which account to use."
|
||||
),
|
||||
coroutine=_run,
|
||||
args_schema=GetConnectedAccountsInput,
|
||||
metadata={"hitl": False},
|
||||
)
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.discord.list_channels import (
|
||||
create_list_discord_channels_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.discord.read_messages import (
|
||||
create_read_discord_messages_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.discord.send_message import (
|
||||
create_send_discord_message_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_list_discord_channels_tool",
|
||||
"create_read_discord_messages_tool",
|
||||
"create_send_discord_message_tool",
|
||||
]
|
||||
42
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
42
surfsense_backend/app/agents/new_chat/tools/discord/_auth.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Shared auth helper for Discord agent tools (REST API, not gateway bot)."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
DISCORD_API = "https://discord.com/api/v10"
|
||||
|
||||
|
||||
async def get_discord_connector(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> SearchSourceConnector | None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
def get_bot_token(connector: SearchSourceConnector) -> str:
|
||||
"""Extract and decrypt the bot token from connector config."""
|
||||
cfg = dict(connector.config)
|
||||
if cfg.get("_token_encrypted") and config.SECRET_KEY:
|
||||
enc = TokenEncryption(config.SECRET_KEY)
|
||||
if cfg.get("bot_token"):
|
||||
cfg["bot_token"] = enc.decrypt_token(cfg["bot_token"])
|
||||
token = cfg.get("bot_token")
|
||||
if not token:
|
||||
raise ValueError("Discord bot token not found in connector config.")
|
||||
return token
|
||||
|
||||
|
||||
def get_guild_id(connector: SearchSourceConnector) -> str | None:
|
||||
return connector.config.get("guild_id")
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector, get_guild_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_list_discord_channels_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def list_discord_channels() -> dict[str, Any]:
|
||||
"""List text channels in the connected Discord server.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of channels (id, name).
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
guild_id = get_guild_id(connector)
|
||||
if not guild_id:
|
||||
return {"status": "error", "message": "No guild ID in Discord connector config."}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/guilds/{guild_id}/channels",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
# Type 0 = text channel
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch["name"]}
|
||||
for ch in resp.json()
|
||||
if ch.get("type") == 0
|
||||
]
|
||||
return {"status": "success", "guild_id": guild_id, "channels": channels, "total": len(channels)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error listing Discord channels: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to list Discord channels."}
|
||||
|
||||
return list_discord_channels
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_read_discord_messages_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_discord_messages(
|
||||
channel_id: str,
|
||||
limit: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Read recent messages from a Discord text channel.
|
||||
|
||||
Args:
|
||||
channel_id: The Discord channel ID (from list_discord_channels).
|
||||
limit: Number of messages to fetch (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of messages including
|
||||
id, author, content, timestamp.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
limit = min(limit, 50)
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
f"{DISCORD_API}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bot {token}"},
|
||||
params={"limit": limit},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Bot lacks permission to read this channel."}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
messages = [
|
||||
{
|
||||
"id": m["id"],
|
||||
"author": m.get("author", {}).get("username", "Unknown"),
|
||||
"content": m.get("content", ""),
|
||||
"timestamp": m.get("timestamp", ""),
|
||||
}
|
||||
for m in resp.json()
|
||||
]
|
||||
|
||||
return {"status": "success", "channel_id": channel_id, "messages": messages, "total": len(messages)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Discord messages: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read Discord messages."}
|
||||
|
||||
return read_discord_messages
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
from ._auth import DISCORD_API, get_bot_token, get_discord_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_discord_message_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_discord_message(
|
||||
channel_id: str,
|
||||
content: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a message to a Discord text channel.
|
||||
|
||||
Args:
|
||||
channel_id: The Discord channel ID (from list_discord_channels).
|
||||
content: The message text (max 2000 characters).
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message_id on success.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Discord tool not properly configured."}
|
||||
|
||||
if len(content) > 2000:
|
||||
return {"status": "error", "message": "Message exceeds Discord's 2000-character limit."}
|
||||
|
||||
try:
|
||||
connector = await get_discord_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Discord connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="discord_send_message",
|
||||
tool_name="send_discord_message",
|
||||
params={"channel_id": channel_id, "content": content},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined. Message was not sent."}
|
||||
|
||||
final_content = result.params.get("content", content)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = get_bot_token(connector)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
f"{DISCORD_API}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bot {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"content": final_content},
|
||||
timeout=15.0,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Discord bot token is invalid.", "connector_type": "discord"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Bot lacks permission to send messages in this channel."}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {"status": "error", "message": f"Discord API error: {resp.status_code}"}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": f"Message sent to channel {final_channel}.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error sending Discord message: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to send Discord message."}
|
||||
|
||||
return send_discord_message
|
||||
|
|
@ -1,6 +1,12 @@
|
|||
from app.agents.new_chat.tools.gmail.create_draft import (
|
||||
create_create_gmail_draft_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.read_email import (
|
||||
create_read_gmail_email_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.search_emails import (
|
||||
create_search_gmail_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.gmail.send_email import (
|
||||
create_send_gmail_email_tool,
|
||||
)
|
||||
|
|
@ -13,6 +19,8 @@ from app.agents.new_chat.tools.gmail.update_draft import (
|
|||
|
||||
__all__ = [
|
||||
"create_create_gmail_draft_tool",
|
||||
"create_read_gmail_email_tool",
|
||||
"create_search_gmail_tool",
|
||||
"create_send_gmail_email_tool",
|
||||
"create_trash_gmail_email_tool",
|
||||
"create_update_gmail_draft_tool",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GMAIL_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
|
||||
def create_read_gmail_email_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_gmail_email(message_id: str) -> dict[str, Any]:
|
||||
"""Read the full content of a specific Gmail email by its message ID.
|
||||
|
||||
Use after search_gmail to get the complete body of an email.
|
||||
|
||||
Args:
|
||||
message_id: The Gmail message ID (from search_gmail results).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and the full email content formatted as markdown.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
detail, error = await gmail.get_message_details(message_id)
|
||||
if error:
|
||||
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not detail:
|
||||
return {"status": "not_found", "message": f"Email with ID '{message_id}' not found."}
|
||||
|
||||
content = gmail.format_message_to_markdown(detail)
|
||||
|
||||
return {"status": "success", "message_id": message_id, "content": content}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Gmail email: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read email. Please try again."}
|
||||
|
||||
return read_gmail_email
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GMAIL_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
]
|
||||
|
||||
_token_encryption_cache: object | None = None
|
||||
|
||||
|
||||
def _get_token_encryption():
|
||||
global _token_encryption_cache
|
||||
if _token_encryption_cache is None:
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise RuntimeError("SECRET_KEY not configured for token decryption.")
|
||||
_token_encryption_cache = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption_cache
|
||||
|
||||
|
||||
def _build_credentials(connector: SearchSourceConnector):
|
||||
"""Build Google OAuth Credentials from a connector's stored config.
|
||||
|
||||
Handles both native OAuth connectors (with encrypted tokens) and
|
||||
Composio-backed connectors. Shared by Gmail and Calendar tools.
|
||||
"""
|
||||
from app.utils.google_credentials import COMPOSIO_GOOGLE_CONNECTOR_TYPES
|
||||
|
||||
if connector.connector_type in COMPOSIO_GOOGLE_CONNECTOR_TYPES:
|
||||
from app.utils.google_credentials import build_composio_credentials
|
||||
|
||||
cca_id = connector.config.get("composio_connected_account_id")
|
||||
if not cca_id:
|
||||
raise ValueError("Composio connected account ID not found.")
|
||||
return build_composio_credentials(cca_id)
|
||||
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
cfg = dict(connector.config)
|
||||
if cfg.get("_token_encrypted"):
|
||||
enc = _get_token_encryption()
|
||||
for key in ("token", "refresh_token", "client_secret"):
|
||||
if cfg.get(key):
|
||||
cfg[key] = enc.decrypt_token(cfg[key])
|
||||
|
||||
exp = (cfg.get("expiry") or "").replace("Z", "")
|
||||
return Credentials(
|
||||
token=cfg.get("token"),
|
||||
refresh_token=cfg.get("refresh_token"),
|
||||
token_uri=cfg.get("token_uri"),
|
||||
client_id=cfg.get("client_id"),
|
||||
client_secret=cfg.get("client_secret"),
|
||||
scopes=cfg.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
|
||||
def create_search_gmail_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def search_gmail(
|
||||
query: str,
|
||||
max_results: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Search emails in the user's Gmail inbox using Gmail search syntax.
|
||||
|
||||
Args:
|
||||
query: Gmail search query, same syntax as the Gmail search bar.
|
||||
Examples: "from:alice@example.com", "subject:meeting",
|
||||
"is:unread", "after:2024/01/01 before:2024/02/01",
|
||||
"has:attachment", "in:sent".
|
||||
max_results: Number of emails to return (default 10, max 20).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of email summaries including
|
||||
message_id, subject, from, date, snippet.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Gmail tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 20)
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_GMAIL_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Gmail connector found. Please connect Gmail in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_gmail_connector import GoogleGmailConnector
|
||||
|
||||
gmail = GoogleGmailConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
messages_list, error = await gmail.get_messages_list(
|
||||
max_results=max_results, query=query
|
||||
)
|
||||
if error:
|
||||
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||
return {"status": "auth_error", "message": error, "connector_type": "gmail"}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
if not messages_list:
|
||||
return {"status": "success", "emails": [], "total": 0, "message": "No emails found."}
|
||||
|
||||
emails = []
|
||||
for msg in messages_list:
|
||||
detail, err = await gmail.get_message_details(msg["id"])
|
||||
if err:
|
||||
continue
|
||||
headers = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in detail.get("payload", {}).get("headers", [])
|
||||
}
|
||||
emails.append({
|
||||
"message_id": detail.get("id"),
|
||||
"thread_id": detail.get("threadId"),
|
||||
"subject": headers.get("subject", "No Subject"),
|
||||
"from": headers.get("from", "Unknown"),
|
||||
"to": headers.get("to", ""),
|
||||
"date": headers.get("date", ""),
|
||||
"snippet": detail.get("snippet", ""),
|
||||
"labels": detail.get("labelIds", []),
|
||||
})
|
||||
|
||||
return {"status": "success", "emails": emails, "total": len(emails)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error searching Gmail: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to search Gmail. Please try again."}
|
||||
|
||||
return search_gmail
|
||||
|
|
@ -4,6 +4,9 @@ from app.agents.new_chat.tools.google_calendar.create_event import (
|
|||
from app.agents.new_chat.tools.google_calendar.delete_event import (
|
||||
create_delete_calendar_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.search_events import (
|
||||
create_search_calendar_events_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.google_calendar.update_event import (
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
|
|
@ -11,5 +14,6 @@ from app.agents.new_chat.tools.google_calendar.update_event import (
|
|||
__all__ = [
|
||||
"create_create_calendar_event_tool",
|
||||
"create_delete_calendar_event_tool",
|
||||
"create_search_calendar_events_tool",
|
||||
"create_update_calendar_event_tool",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,114 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.agents.new_chat.tools.gmail.search_emails import _build_credentials
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CALENDAR_TYPES = [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
]
|
||||
|
||||
|
||||
def create_search_calendar_events_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def search_calendar_events(
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
max_results: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Search Google Calendar events within a date range.
|
||||
|
||||
Args:
|
||||
start_date: Start date in YYYY-MM-DD format (e.g. "2026-04-01").
|
||||
end_date: End date in YYYY-MM-DD format (e.g. "2026-04-30").
|
||||
max_results: Maximum number of events to return (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of events including
|
||||
event_id, summary, start, end, location, attendees.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Calendar tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 50)
|
||||
|
||||
try:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type.in_(_CALENDAR_TYPES),
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "No Google Calendar connector found. Please connect Google Calendar in your workspace settings.",
|
||||
}
|
||||
|
||||
creds = _build_credentials(connector)
|
||||
|
||||
from app.connectors.google_calendar_connector import GoogleCalendarConnector
|
||||
|
||||
cal = GoogleCalendarConnector(
|
||||
credentials=creds,
|
||||
session=db_session,
|
||||
user_id=user_id,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
|
||||
events_raw, error = await cal.get_all_primary_calendar_events(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
if error:
|
||||
if "re-authenticate" in error.lower() or "authentication failed" in error.lower():
|
||||
return {"status": "auth_error", "message": error, "connector_type": "google_calendar"}
|
||||
if "no events found" in error.lower():
|
||||
return {"status": "success", "events": [], "total": 0, "message": error}
|
||||
return {"status": "error", "message": error}
|
||||
|
||||
events = []
|
||||
for ev in events_raw:
|
||||
start = ev.get("start", {})
|
||||
end = ev.get("end", {})
|
||||
attendees_raw = ev.get("attendees", [])
|
||||
events.append({
|
||||
"event_id": ev.get("id"),
|
||||
"summary": ev.get("summary", "No Title"),
|
||||
"start": start.get("dateTime") or start.get("date", ""),
|
||||
"end": end.get("dateTime") or end.get("date", ""),
|
||||
"location": ev.get("location", ""),
|
||||
"description": ev.get("description", ""),
|
||||
"html_link": ev.get("htmlLink", ""),
|
||||
"attendees": [
|
||||
a.get("email", "") for a in attendees_raw[:10]
|
||||
],
|
||||
"status": ev.get("status", ""),
|
||||
})
|
||||
|
||||
return {"status": "success", "events": events, "total": len(events)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error searching calendar events: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to search calendar events. Please try again."}
|
||||
|
||||
return search_calendar_events
|
||||
|
|
@ -130,8 +130,8 @@ def request_approval(
|
|||
try:
|
||||
decision_type, edited_params = _parse_decision(approval)
|
||||
except ValueError:
|
||||
logger.warning("No approval decision received for %s", tool_name)
|
||||
return HITLResult(rejected=False, decision_type="error", params=params)
|
||||
logger.warning("No approval decision received for %s — rejecting for safety", tool_name)
|
||||
return HITLResult(rejected=True, decision_type="error", params=params)
|
||||
|
||||
logger.info("User decision for %s: %s", tool_name, decision_type)
|
||||
|
||||
|
|
|
|||
15
surfsense_backend/app/agents/new_chat/tools/luma/__init__.py
Normal file
15
surfsense_backend/app/agents/new_chat/tools/luma/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.luma.create_event import (
|
||||
create_create_luma_event_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.luma.list_events import (
|
||||
create_list_luma_events_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.luma.read_event import (
|
||||
create_read_luma_event_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_create_luma_event_tool",
|
||||
"create_list_luma_events_tool",
|
||||
"create_read_luma_event_tool",
|
||||
]
|
||||
38
surfsense_backend/app/agents/new_chat/tools/luma/_auth.py
Normal file
38
surfsense_backend/app/agents/new_chat/tools/luma/_auth.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
"""Shared auth helper for Luma agent tools."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
LUMA_API = "https://public-api.luma.com/v1"
|
||||
|
||||
|
||||
async def get_luma_connector(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> SearchSourceConnector | None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
def get_api_key(connector: SearchSourceConnector) -> str:
|
||||
"""Extract the API key from connector config (handles both key names)."""
|
||||
key = connector.config.get("api_key") or connector.config.get("LUMA_API_KEY")
|
||||
if not key:
|
||||
raise ValueError("Luma API key not found in connector config.")
|
||||
return key
|
||||
|
||||
|
||||
def luma_headers(api_key: str) -> dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"x-luma-api-key": api_key,
|
||||
}
|
||||
116
surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
Normal file
116
surfsense_backend/app/agents/new_chat/tools/luma/create_event.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_create_luma_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def create_luma_event(
|
||||
name: str,
|
||||
start_at: str,
|
||||
end_at: str,
|
||||
description: str | None = None,
|
||||
timezone: str = "UTC",
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new event on Luma.
|
||||
|
||||
Args:
|
||||
name: The event title.
|
||||
start_at: Start time in ISO 8601 format (e.g. "2026-05-01T18:00:00").
|
||||
end_at: End time in ISO 8601 format (e.g. "2026-05-01T20:00:00").
|
||||
description: Optional event description (markdown supported).
|
||||
timezone: Timezone string (default "UTC", e.g. "America/New_York").
|
||||
|
||||
Returns:
|
||||
Dictionary with status, event_id on success.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="luma_create_event",
|
||||
tool_name="create_luma_event",
|
||||
params={
|
||||
"name": name,
|
||||
"start_at": start_at,
|
||||
"end_at": end_at,
|
||||
"description": description,
|
||||
"timezone": timezone,
|
||||
},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined. Event was not created."}
|
||||
|
||||
final_name = result.params.get("name", name)
|
||||
final_start = result.params.get("start_at", start_at)
|
||||
final_end = result.params.get("end_at", end_at)
|
||||
final_desc = result.params.get("description", description)
|
||||
final_tz = result.params.get("timezone", timezone)
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"name": final_name,
|
||||
"start_at": final_start,
|
||||
"end_at": final_end,
|
||||
"timezone": final_tz,
|
||||
}
|
||||
if final_desc:
|
||||
body["description_md"] = final_desc
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(
|
||||
f"{LUMA_API}/event/create",
|
||||
headers=headers,
|
||||
json=body,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Luma Plus subscription required to create events via API."}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {"status": "error", "message": f"Luma API error: {resp.status_code} — {resp.text[:200]}"}
|
||||
|
||||
data = resp.json()
|
||||
event_id = data.get("api_id") or data.get("event", {}).get("api_id")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"event_id": event_id,
|
||||
"message": f"Event '{final_name}' created on Luma.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error creating Luma event: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to create Luma event."}
|
||||
|
||||
return create_luma_event
|
||||
100
surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
Normal file
100
surfsense_backend/app/agents/new_chat/tools/luma/list_events.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_list_luma_events_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def list_luma_events(
|
||||
max_results: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""List upcoming and recent Luma events.
|
||||
|
||||
Args:
|
||||
max_results: Maximum events to return (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of events including
|
||||
event_id, name, start_at, end_at, location, url.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
max_results = min(max_results, 50)
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
all_entries: list[dict] = []
|
||||
cursor = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
while len(all_entries) < max_results:
|
||||
params: dict[str, Any] = {"limit": min(100, max_results - len(all_entries))}
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
resp = await client.get(
|
||||
f"{LUMA_API}/calendar/list-events",
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
|
||||
|
||||
data = resp.json()
|
||||
entries = data.get("entries", [])
|
||||
if not entries:
|
||||
break
|
||||
all_entries.extend(entries)
|
||||
|
||||
next_cursor = data.get("next_cursor")
|
||||
if not next_cursor:
|
||||
break
|
||||
cursor = next_cursor
|
||||
|
||||
events = []
|
||||
for entry in all_entries[:max_results]:
|
||||
ev = entry.get("event", {})
|
||||
geo = ev.get("geo_info", {})
|
||||
events.append({
|
||||
"event_id": entry.get("api_id"),
|
||||
"name": ev.get("name", "Untitled"),
|
||||
"start_at": ev.get("start_at", ""),
|
||||
"end_at": ev.get("end_at", ""),
|
||||
"timezone": ev.get("timezone", ""),
|
||||
"location": geo.get("name", ""),
|
||||
"url": ev.get("url", ""),
|
||||
"visibility": ev.get("visibility", ""),
|
||||
})
|
||||
|
||||
return {"status": "success", "events": events, "total": len(events)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error listing Luma events: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to list Luma events."}
|
||||
|
||||
return list_luma_events
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import LUMA_API, get_api_key, get_luma_connector, luma_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_read_luma_event_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_luma_event(event_id: str) -> dict[str, Any]:
|
||||
"""Read detailed information about a specific Luma event.
|
||||
|
||||
Args:
|
||||
event_id: The Luma event API ID (from list_luma_events).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and full event details including
|
||||
description, attendees count, meeting URL.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Luma tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_luma_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Luma connector found."}
|
||||
|
||||
api_key = get_api_key(connector)
|
||||
headers = luma_headers(api_key)
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
f"{LUMA_API}/events/{event_id}",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Luma API key is invalid.", "connector_type": "luma"}
|
||||
if resp.status_code == 404:
|
||||
return {"status": "not_found", "message": f"Event '{event_id}' not found."}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Luma API error: {resp.status_code}"}
|
||||
|
||||
data = resp.json()
|
||||
ev = data.get("event", data)
|
||||
geo = ev.get("geo_info", {})
|
||||
|
||||
event_detail = {
|
||||
"event_id": event_id,
|
||||
"name": ev.get("name", ""),
|
||||
"description": ev.get("description", ""),
|
||||
"start_at": ev.get("start_at", ""),
|
||||
"end_at": ev.get("end_at", ""),
|
||||
"timezone": ev.get("timezone", ""),
|
||||
"location_name": geo.get("name", ""),
|
||||
"address": geo.get("address", ""),
|
||||
"url": ev.get("url", ""),
|
||||
"meeting_url": ev.get("meeting_url", ""),
|
||||
"visibility": ev.get("visibility", ""),
|
||||
"cover_url": ev.get("cover_url", ""),
|
||||
}
|
||||
|
||||
return {"status": "success", "event": event_detail}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Luma event: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read Luma event."}
|
||||
|
||||
return read_luma_event
|
||||
|
|
@ -14,20 +14,28 @@ clicking "Always Allow", which adds the tool name to the connector's
|
|||
``config.trusted_tools`` allow-list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import BaseModel, create_model
|
||||
from sqlalchemy import select
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from sqlalchemy import cast, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
from app.agents.new_chat.tools.mcp_client import MCPClient
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
from app.services.mcp_oauth.registry import MCP_SERVICES, get_service_by_connector_type
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -63,18 +71,14 @@ def _create_dynamic_input_model_from_schema(
|
|||
param_description = param_schema.get("description", "")
|
||||
is_required = param_name in required_fields
|
||||
|
||||
from typing import Any as AnyType
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
if is_required:
|
||||
field_definitions[param_name] = (
|
||||
AnyType,
|
||||
Any,
|
||||
Field(..., description=param_description),
|
||||
)
|
||||
else:
|
||||
field_definitions[param_name] = (
|
||||
AnyType | None,
|
||||
Any | None,
|
||||
Field(None, description=param_description),
|
||||
)
|
||||
|
||||
|
|
@ -100,13 +104,13 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
|
||||
logger.info(f"MCP tool '{tool_name}' input schema: {input_schema}")
|
||||
logger.debug("MCP tool '%s' input schema: %s", tool_name, input_schema)
|
||||
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
|
||||
async def mcp_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via the client with retry support."""
|
||||
logger.info(f"MCP tool '{tool_name}' called with params: {kwargs}")
|
||||
logger.debug("MCP tool '%s' called", tool_name)
|
||||
|
||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
||||
hitl_result = request_approval(
|
||||
|
|
@ -123,20 +127,18 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = hitl_result.params
|
||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
|
||||
try:
|
||||
async with mcp_client.connect():
|
||||
result = await mcp_client.call_tool(tool_name, call_kwargs)
|
||||
return str(result)
|
||||
except RuntimeError as e:
|
||||
error_msg = f"MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||
logger.error(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
logger.error("MCP tool '%s' connection failed after retries: %s", tool_name, e)
|
||||
return f"Error: MCP tool '{tool_name}' connection failed after retries: {e!s}"
|
||||
except Exception as e:
|
||||
error_msg = f"MCP tool '{tool_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
logger.exception("MCP tool '%s' execution failed: %s", tool_name, e)
|
||||
return f"Error: MCP tool '{tool_name}' execution failed: {e!s}"
|
||||
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
|
|
@ -151,7 +153,7 @@ async def _create_mcp_tool_from_definition_stdio(
|
|||
},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool (stdio): '{tool_name}'")
|
||||
logger.debug("Created MCP tool (stdio): '%s'", tool_name)
|
||||
return tool
|
||||
|
||||
|
||||
|
|
@ -163,29 +165,45 @@ async def _create_mcp_tool_from_definition_http(
|
|||
connector_name: str = "",
|
||||
connector_id: int | None = None,
|
||||
trusted_tools: list[str] | None = None,
|
||||
readonly_tools: frozenset[str] | None = None,
|
||||
tool_name_prefix: str | None = None,
|
||||
) -> StructuredTool:
|
||||
"""Create a LangChain tool from an MCP tool definition (HTTP transport).
|
||||
|
||||
All MCP tools are unconditionally wrapped with HITL approval.
|
||||
``request_approval()`` is called OUTSIDE the try/except so that
|
||||
``GraphInterrupt`` propagates cleanly to LangGraph.
|
||||
Write tools are wrapped with HITL approval; read-only tools (listed in
|
||||
``readonly_tools``) execute immediately without user confirmation.
|
||||
|
||||
When ``tool_name_prefix`` is set (multi-account disambiguation), the
|
||||
tool exposed to the LLM gets a prefixed name (e.g. ``linear_25_list_issues``)
|
||||
but the actual MCP ``call_tool`` still uses the original name.
|
||||
"""
|
||||
tool_name = tool_def.get("name", "unnamed_tool")
|
||||
original_tool_name = tool_def.get("name", "unnamed_tool")
|
||||
tool_description = tool_def.get("description", "No description provided")
|
||||
input_schema = tool_def.get("input_schema", {"type": "object", "properties": {}})
|
||||
is_readonly = readonly_tools is not None and original_tool_name in readonly_tools
|
||||
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' input schema: {input_schema}")
|
||||
exposed_name = (
|
||||
f"{tool_name_prefix}_{original_tool_name}"
|
||||
if tool_name_prefix
|
||||
else original_tool_name
|
||||
)
|
||||
if tool_name_prefix:
|
||||
tool_description = f"[Account: {connector_name}] {tool_description}"
|
||||
|
||||
input_model = _create_dynamic_input_model_from_schema(tool_name, input_schema)
|
||||
logger.debug("MCP HTTP tool '%s' input schema: %s", exposed_name, input_schema)
|
||||
|
||||
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
|
||||
|
||||
async def mcp_http_tool_call(**kwargs) -> str:
|
||||
"""Execute the MCP tool call via HTTP transport."""
|
||||
logger.info(f"MCP HTTP tool '{tool_name}' called with params: {kwargs}")
|
||||
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
||||
|
||||
# HITL — OUTSIDE try/except so GraphInterrupt propagates to LangGraph
|
||||
if is_readonly:
|
||||
call_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
else:
|
||||
hitl_result = request_approval(
|
||||
action_type="mcp_tool_call",
|
||||
tool_name=tool_name,
|
||||
tool_name=exposed_name,
|
||||
params=kwargs,
|
||||
context={
|
||||
"mcp_server": connector_name,
|
||||
|
|
@ -197,7 +215,7 @@ async def _create_mcp_tool_from_definition_http(
|
|||
)
|
||||
if hitl_result.rejected:
|
||||
return "Tool call rejected by user."
|
||||
call_kwargs = hitl_result.params
|
||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
||||
|
||||
try:
|
||||
async with (
|
||||
|
|
@ -205,7 +223,9 @@ async def _create_mcp_tool_from_definition_http(
|
|||
ClientSession(read, write) as session,
|
||||
):
|
||||
await session.initialize()
|
||||
response = await session.call_tool(tool_name, arguments=call_kwargs)
|
||||
response = await session.call_tool(
|
||||
original_tool_name, arguments=call_kwargs,
|
||||
)
|
||||
|
||||
result = []
|
||||
for content in response.content:
|
||||
|
|
@ -217,18 +237,15 @@ async def _create_mcp_tool_from_definition_http(
|
|||
result.append(str(content))
|
||||
|
||||
result_str = "\n".join(result) if result else ""
|
||||
logger.info(
|
||||
f"MCP HTTP tool '{tool_name}' succeeded: {result_str[:200]}"
|
||||
)
|
||||
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
|
||||
return result_str
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"MCP HTTP tool '{tool_name}' execution failed: {e!s}"
|
||||
logger.exception(error_msg)
|
||||
return f"Error: {error_msg}"
|
||||
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e)
|
||||
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
|
||||
|
||||
tool = StructuredTool(
|
||||
name=tool_name,
|
||||
name=exposed_name,
|
||||
description=tool_description,
|
||||
coroutine=mcp_http_tool_call,
|
||||
args_schema=input_model,
|
||||
|
|
@ -236,12 +253,14 @@ async def _create_mcp_tool_from_definition_http(
|
|||
"mcp_input_schema": input_schema,
|
||||
"mcp_transport": "http",
|
||||
"mcp_url": url,
|
||||
"hitl": True,
|
||||
"hitl": not is_readonly,
|
||||
"hitl_dedup_key": next(iter(input_schema.get("required", [])), None),
|
||||
"mcp_original_tool_name": original_tool_name,
|
||||
"mcp_connector_id": connector_id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Created MCP tool (HTTP): '{tool_name}'")
|
||||
logger.debug("Created MCP tool (HTTP): '%s'", exposed_name)
|
||||
return tool
|
||||
|
||||
|
||||
|
|
@ -257,21 +276,24 @@ async def _load_stdio_mcp_tools(
|
|||
command = server_config.get("command")
|
||||
if not command or not isinstance(command, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid command field, skipping"
|
||||
"MCP connector %d (name: '%s') missing or invalid command field, skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
args = server_config.get("args", [])
|
||||
if not isinstance(args, list):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid args field (must be list), skipping"
|
||||
"MCP connector %d (name: '%s') has invalid args field (must be list), skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
env = server_config.get("env", {})
|
||||
if not isinstance(env, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid env field (must be dict), skipping"
|
||||
"MCP connector %d (name: '%s') has invalid env field (must be dict), skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
|
|
@ -281,8 +303,8 @@ async def _load_stdio_mcp_tools(
|
|||
tool_definitions = await mcp_client.list_tools()
|
||||
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from stdio MCP server "
|
||||
f"'{command}' (connector {connector_id})"
|
||||
"Discovered %d tools from stdio MCP server '%s' (connector %d)",
|
||||
len(tool_definitions), command, connector_id,
|
||||
)
|
||||
|
||||
for tool_def in tool_definitions:
|
||||
|
|
@ -297,8 +319,8 @@ async def _load_stdio_mcp_tools(
|
|||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector_id}: {e!s}"
|
||||
"Failed to create tool '%s' from connector %d: %s",
|
||||
tool_def.get("name"), connector_id, e,
|
||||
)
|
||||
|
||||
return tools
|
||||
|
|
@ -309,24 +331,40 @@ async def _load_http_mcp_tools(
|
|||
connector_name: str,
|
||||
server_config: dict[str, Any],
|
||||
trusted_tools: list[str] | None = None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
readonly_tools: frozenset[str] | None = None,
|
||||
tool_name_prefix: str | None = None,
|
||||
) -> list[StructuredTool]:
|
||||
"""Load tools from an HTTP-based MCP server."""
|
||||
"""Load tools from an HTTP-based MCP server.
|
||||
|
||||
Args:
|
||||
allowed_tools: If non-empty, only tools whose names appear in this
|
||||
list are loaded. Empty/None means load everything (used for
|
||||
user-managed generic MCP servers).
|
||||
readonly_tools: Tool names that skip HITL approval (read-only operations).
|
||||
tool_name_prefix: If set, each tool name is prefixed for multi-account
|
||||
disambiguation (e.g. ``linear_25``).
|
||||
"""
|
||||
tools: list[StructuredTool] = []
|
||||
|
||||
url = server_config.get("url")
|
||||
if not url or not isinstance(url, str):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') missing or invalid url field, skipping"
|
||||
"MCP connector %d (name: '%s') missing or invalid url field, skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
headers = server_config.get("headers", {})
|
||||
if not isinstance(headers, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector_id} (name: '{connector_name}') has invalid headers field (must be dict), skipping"
|
||||
"MCP connector %d (name: '%s') has invalid headers field (must be dict), skipping",
|
||||
connector_id, connector_name,
|
||||
)
|
||||
return tools
|
||||
|
||||
allowed_set = set(allowed_tools) if allowed_tools else None
|
||||
|
||||
try:
|
||||
async with (
|
||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
||||
|
|
@ -347,9 +385,20 @@ async def _load_http_mcp_tools(
|
|||
}
|
||||
)
|
||||
|
||||
total_discovered = len(tool_definitions)
|
||||
|
||||
if allowed_set:
|
||||
tool_definitions = [
|
||||
td for td in tool_definitions if td["name"] in allowed_set
|
||||
]
|
||||
logger.info(
|
||||
f"Discovered {len(tool_definitions)} tools from HTTP MCP server "
|
||||
f"'{url}' (connector {connector_id})"
|
||||
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
|
||||
url, connector_id, len(tool_definitions), total_discovered,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
|
||||
total_discovered, url, connector_id,
|
||||
)
|
||||
|
||||
for tool_def in tool_definitions:
|
||||
|
|
@ -361,22 +410,183 @@ async def _load_http_mcp_tools(
|
|||
connector_name=connector_name,
|
||||
connector_id=connector_id,
|
||||
trusted_tools=trusted_tools,
|
||||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
)
|
||||
tools.append(tool)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create HTTP tool '{tool_def.get('name')}' "
|
||||
f"from connector {connector_id}: {e!s}"
|
||||
"Failed to create HTTP tool '%s' from connector %d: %s",
|
||||
tool_def.get("name"), connector_id, e,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to connect to HTTP MCP server at '{url}' (connector {connector_id}): {e!s}"
|
||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||
url, connector_id, e,
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
_TOKEN_REFRESH_BUFFER_SECONDS = 300 # refresh 5 min before expiry
|
||||
|
||||
_token_enc: TokenEncryption | None = None
|
||||
|
||||
|
||||
def _get_token_enc() -> TokenEncryption:
|
||||
global _token_enc
|
||||
if _token_enc is None:
|
||||
from app.config import config as app_config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
_token_enc = TokenEncryption(app_config.SECRET_KEY)
|
||||
return _token_enc
|
||||
|
||||
|
||||
def _inject_oauth_headers(
|
||||
cfg: dict[str, Any],
|
||||
server_config: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Decrypt the MCP OAuth access token and inject it into server_config headers.
|
||||
|
||||
The DB never stores plaintext tokens in ``server_config.headers``. This
|
||||
function decrypts ``mcp_oauth.access_token`` at runtime and returns a
|
||||
*copy* of ``server_config`` with the Authorization header set.
|
||||
"""
|
||||
mcp_oauth = cfg.get("mcp_oauth", {})
|
||||
encrypted_token = mcp_oauth.get("access_token")
|
||||
if not encrypted_token:
|
||||
return server_config
|
||||
|
||||
try:
|
||||
access_token = _get_token_enc().decrypt_token(encrypted_token)
|
||||
|
||||
result = dict(server_config)
|
||||
result["headers"] = {
|
||||
**server_config.get("headers", {}),
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
return result
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Failed to decrypt MCP OAuth token — connector will be skipped",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _maybe_refresh_mcp_oauth_token(
|
||||
session: AsyncSession,
|
||||
connector: "SearchSourceConnector",
|
||||
cfg: dict[str, Any],
|
||||
server_config: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Refresh the access token for an MCP OAuth connector if it is about to expire.
|
||||
|
||||
Returns the (possibly updated) ``server_config``.
|
||||
"""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
mcp_oauth = cfg.get("mcp_oauth", {})
|
||||
expires_at_str = mcp_oauth.get("expires_at")
|
||||
if not expires_at_str:
|
||||
return server_config
|
||||
|
||||
try:
|
||||
expires_at = datetime.fromisoformat(expires_at_str)
|
||||
if expires_at.tzinfo is None:
|
||||
from datetime import timezone
|
||||
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
if datetime.now(UTC) < expires_at - timedelta(seconds=_TOKEN_REFRESH_BUFFER_SECONDS):
|
||||
return server_config
|
||||
except (ValueError, TypeError):
|
||||
return server_config
|
||||
|
||||
refresh_token = mcp_oauth.get("refresh_token")
|
||||
if not refresh_token:
|
||||
logger.warning(
|
||||
"MCP connector %s token expired but no refresh_token available",
|
||||
connector.id,
|
||||
)
|
||||
return server_config
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import refresh_access_token
|
||||
|
||||
enc = _get_token_enc()
|
||||
decrypted_refresh = enc.decrypt_token(refresh_token)
|
||||
decrypted_secret = (
|
||||
enc.decrypt_token(mcp_oauth["client_secret"])
|
||||
if mcp_oauth.get("client_secret")
|
||||
else ""
|
||||
)
|
||||
|
||||
token_json = await refresh_access_token(
|
||||
token_endpoint=mcp_oauth["token_endpoint"],
|
||||
refresh_token=decrypted_refresh,
|
||||
client_id=mcp_oauth["client_id"],
|
||||
client_secret=decrypted_secret,
|
||||
)
|
||||
|
||||
new_access = token_json.get("access_token")
|
||||
if not new_access:
|
||||
logger.warning(
|
||||
"MCP connector %s token refresh returned no access_token",
|
||||
connector.id,
|
||||
)
|
||||
return server_config
|
||||
|
||||
new_expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
new_expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
updated_oauth = dict(mcp_oauth)
|
||||
updated_oauth["access_token"] = enc.encrypt_token(new_access)
|
||||
if token_json.get("refresh_token"):
|
||||
updated_oauth["refresh_token"] = enc.encrypt_token(
|
||||
token_json["refresh_token"]
|
||||
)
|
||||
updated_oauth["expires_at"] = (
|
||||
new_expires_at.isoformat() if new_expires_at else None
|
||||
)
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
connector.config = {
|
||||
**cfg,
|
||||
"server_config": server_config,
|
||||
"mcp_oauth": updated_oauth,
|
||||
}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
|
||||
|
||||
# Invalidate cache so next call picks up the new token.
|
||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||
|
||||
# Return server_config with the fresh token injected for immediate use.
|
||||
refreshed_config = dict(server_config)
|
||||
refreshed_config["headers"] = {
|
||||
**server_config.get("headers", {}),
|
||||
"Authorization": f"Bearer {new_access}",
|
||||
}
|
||||
return refreshed_config
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to refresh MCP OAuth token for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
return server_config
|
||||
|
||||
|
||||
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||
"""Invalidate cached MCP tools.
|
||||
|
||||
|
|
@ -418,27 +628,91 @@ async def load_mcp_tools(
|
|||
return list(cached_tools)
|
||||
|
||||
try:
|
||||
# Find all connectors with MCP server config: generic MCP_CONNECTOR type
|
||||
# and service-specific types (LINEAR_CONNECTOR, etc.) created via MCP OAuth.
|
||||
# Cast JSON -> JSONB so we can use has_key to filter by the presence of "server_config".
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
cast(SearchSourceConnector.config, JSONB).has_key("server_config"), # noqa: W601
|
||||
),
|
||||
)
|
||||
|
||||
connectors = list(result.scalars())
|
||||
|
||||
# Group connectors by type to detect multi-account scenarios.
|
||||
# When >1 connector shares the same type, tool names would collide
|
||||
# so we prefix them with "{service_key}_{connector_id}_".
|
||||
type_groups: dict[str, list[SearchSourceConnector]] = defaultdict(list)
|
||||
for connector in connectors:
|
||||
ct = (
|
||||
connector.connector_type.value
|
||||
if hasattr(connector.connector_type, "value")
|
||||
else str(connector.connector_type)
|
||||
)
|
||||
type_groups[ct].append(connector)
|
||||
|
||||
multi_account_types: set[str] = {
|
||||
ct for ct, group in type_groups.items() if len(group) > 1
|
||||
}
|
||||
if multi_account_types:
|
||||
logger.info(
|
||||
"Multi-account detected for connector types: %s",
|
||||
multi_account_types,
|
||||
)
|
||||
|
||||
tools: list[StructuredTool] = []
|
||||
for connector in result.scalars():
|
||||
for connector in connectors:
|
||||
try:
|
||||
config = connector.config or {}
|
||||
server_config = config.get("server_config", {})
|
||||
trusted_tools = config.get("trusted_tools", [])
|
||||
cfg = connector.config or {}
|
||||
server_config = cfg.get("server_config", {})
|
||||
|
||||
if not server_config or not isinstance(server_config, dict):
|
||||
logger.warning(
|
||||
f"MCP connector {connector.id} (name: '{connector.name}') has invalid or missing server_config, skipping"
|
||||
"MCP connector %d (name: '%s') has invalid or missing server_config, skipping",
|
||||
connector.id, connector.name,
|
||||
)
|
||||
continue
|
||||
|
||||
# For MCP OAuth connectors: refresh if needed, then decrypt the
|
||||
# access token and inject it into headers at runtime. The DB
|
||||
# intentionally does NOT store plaintext tokens in server_config.
|
||||
if cfg.get("mcp_oauth"):
|
||||
server_config = await _maybe_refresh_mcp_oauth_token(
|
||||
session, connector, cfg, server_config,
|
||||
)
|
||||
# Re-read cfg after potential refresh (connector was reloaded from DB).
|
||||
cfg = connector.config or {}
|
||||
server_config = _inject_oauth_headers(cfg, server_config)
|
||||
if server_config is None:
|
||||
logger.warning(
|
||||
"Skipping MCP connector %d — OAuth token decryption failed",
|
||||
connector.id,
|
||||
)
|
||||
continue
|
||||
|
||||
trusted_tools = cfg.get("trusted_tools", [])
|
||||
|
||||
ct = (
|
||||
connector.connector_type.value
|
||||
if hasattr(connector.connector_type, "value")
|
||||
else str(connector.connector_type)
|
||||
)
|
||||
|
||||
svc_cfg = get_service_by_connector_type(ct)
|
||||
allowed_tools = svc_cfg.allowed_tools if svc_cfg else []
|
||||
readonly_tools = svc_cfg.readonly_tools if svc_cfg else frozenset()
|
||||
|
||||
# Build a prefix only when multiple accounts share the same type.
|
||||
tool_name_prefix: str | None = None
|
||||
if ct in multi_account_types and svc_cfg:
|
||||
service_key = next(
|
||||
(k for k, v in MCP_SERVICES.items() if v is svc_cfg),
|
||||
None,
|
||||
)
|
||||
if service_key:
|
||||
tool_name_prefix = f"{service_key}_{connector.id}"
|
||||
|
||||
transport = server_config.get("transport", "stdio")
|
||||
|
||||
if transport in ("streamable-http", "http", "sse"):
|
||||
|
|
@ -447,6 +721,9 @@ async def load_mcp_tools(
|
|||
connector.name,
|
||||
server_config,
|
||||
trusted_tools=trusted_tools,
|
||||
allowed_tools=allowed_tools,
|
||||
readonly_tools=readonly_tools,
|
||||
tool_name_prefix=tool_name_prefix,
|
||||
)
|
||||
else:
|
||||
connector_tools = await _load_stdio_mcp_tools(
|
||||
|
|
@ -460,7 +737,8 @@ async def load_mcp_tools(
|
|||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to load tools from MCP connector {connector.id}: {e!s}"
|
||||
"Failed to load tools from MCP connector %d: %s",
|
||||
connector.id, e,
|
||||
)
|
||||
|
||||
_mcp_tools_cache[search_space_id] = (now, tools)
|
||||
|
|
@ -469,9 +747,9 @@ async def load_mcp_tools(
|
|||
oldest_key = min(_mcp_tools_cache, key=lambda k: _mcp_tools_cache[k][0])
|
||||
del _mcp_tools_cache[oldest_key]
|
||||
|
||||
logger.info(f"Loaded {len(tools)} MCP tools for search space {search_space_id}")
|
||||
logger.info("Loaded %d MCP tools for search space %d", len(tools), search_space_id)
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to load MCP tools: {e!s}")
|
||||
logger.exception("Failed to load MCP tools: %s", e)
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -50,6 +50,11 @@ from .confluence import (
|
|||
create_delete_confluence_page_tool,
|
||||
create_update_confluence_page_tool,
|
||||
)
|
||||
from .discord import (
|
||||
create_list_discord_channels_tool,
|
||||
create_read_discord_messages_tool,
|
||||
create_send_discord_message_tool,
|
||||
)
|
||||
from .dropbox import (
|
||||
create_create_dropbox_file_tool,
|
||||
create_delete_dropbox_file_tool,
|
||||
|
|
@ -57,6 +62,8 @@ from .dropbox import (
|
|||
from .generate_image import create_generate_image_tool
|
||||
from .gmail import (
|
||||
create_create_gmail_draft_tool,
|
||||
create_read_gmail_email_tool,
|
||||
create_search_gmail_tool,
|
||||
create_send_gmail_email_tool,
|
||||
create_trash_gmail_email_tool,
|
||||
create_update_gmail_draft_tool,
|
||||
|
|
@ -64,21 +71,18 @@ from .gmail import (
|
|||
from .google_calendar import (
|
||||
create_create_calendar_event_tool,
|
||||
create_delete_calendar_event_tool,
|
||||
create_search_calendar_events_tool,
|
||||
create_update_calendar_event_tool,
|
||||
)
|
||||
from .google_drive import (
|
||||
create_create_google_drive_file_tool,
|
||||
create_delete_google_drive_file_tool,
|
||||
)
|
||||
from .jira import (
|
||||
create_create_jira_issue_tool,
|
||||
create_delete_jira_issue_tool,
|
||||
create_update_jira_issue_tool,
|
||||
)
|
||||
from .linear import (
|
||||
create_create_linear_issue_tool,
|
||||
create_delete_linear_issue_tool,
|
||||
create_update_linear_issue_tool,
|
||||
from .connected_accounts import create_get_connected_accounts_tool
|
||||
from .luma import (
|
||||
create_create_luma_event_tool,
|
||||
create_list_luma_events_tool,
|
||||
create_read_luma_event_tool,
|
||||
)
|
||||
from .mcp_tool import load_mcp_tools
|
||||
from .notion import (
|
||||
|
|
@ -95,6 +99,11 @@ from .report import create_generate_report_tool
|
|||
from .resume import create_generate_resume_tool
|
||||
from .scrape_webpage import create_scrape_webpage_tool
|
||||
from .search_surfsense_docs import create_search_surfsense_docs_tool
|
||||
from .teams import (
|
||||
create_list_teams_channels_tool,
|
||||
create_read_teams_messages_tool,
|
||||
create_send_teams_message_tool,
|
||||
)
|
||||
from .update_memory import create_update_memory_tool, create_update_team_memory_tool
|
||||
from .video_presentation import create_generate_video_presentation_tool
|
||||
from .web_search import create_web_search_tool
|
||||
|
|
@ -114,6 +123,8 @@ class ToolDefinition:
|
|||
factory: Callable that creates the tool. Receives a dict of dependencies.
|
||||
requires: List of dependency names this tool needs (e.g., "search_space_id", "db_session")
|
||||
enabled_by_default: Whether the tool is enabled when no explicit config is provided
|
||||
required_connector: Searchable type string (e.g. ``"LINEAR_CONNECTOR"``)
|
||||
that must be in ``available_connectors`` for the tool to be enabled.
|
||||
|
||||
"""
|
||||
|
||||
|
|
@ -123,6 +134,7 @@ class ToolDefinition:
|
|||
requires: list[str] = field(default_factory=list)
|
||||
enabled_by_default: bool = True
|
||||
hidden: bool = False
|
||||
required_connector: str | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -221,6 +233,21 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
requires=["db_session"],
|
||||
),
|
||||
# =========================================================================
|
||||
# SERVICE ACCOUNT DISCOVERY
|
||||
# Generic tool for the LLM to discover connected accounts and resolve
|
||||
# service-specific identifiers (e.g. Jira cloudId, Slack team, etc.)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="get_connected_accounts",
|
||||
description="Discover connected accounts for a service and their metadata",
|
||||
factory=lambda deps: create_get_connected_accounts_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# MEMORY TOOL - single update_memory, private or team by thread_visibility
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
|
|
@ -248,40 +275,6 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
],
|
||||
),
|
||||
# =========================================================================
|
||||
# LINEAR TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Linear connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_linear_issue",
|
||||
description="Create a new issue in the user's Linear workspace",
|
||||
factory=lambda deps: create_create_linear_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_linear_issue",
|
||||
description="Update an existing indexed Linear issue",
|
||||
factory=lambda deps: create_update_linear_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_linear_issue",
|
||||
description="Archive (delete) an existing indexed Linear issue",
|
||||
factory=lambda deps: create_delete_linear_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# NOTION TOOLS - create, update, delete pages
|
||||
# Auto-disabled when no Notion connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
|
|
@ -294,6 +287,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_notion_page",
|
||||
|
|
@ -304,6 +298,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_notion_page",
|
||||
|
|
@ -314,6 +309,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="NOTION_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE DRIVE TOOLS - create files, delete files
|
||||
|
|
@ -328,6 +324,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_DRIVE_FILE",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_google_drive_file",
|
||||
|
|
@ -338,6 +335,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_DRIVE_FILE",
|
||||
),
|
||||
# =========================================================================
|
||||
# DROPBOX TOOLS - create and trash files
|
||||
|
|
@ -352,6 +350,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DROPBOX_FILE",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_dropbox_file",
|
||||
|
|
@ -362,6 +361,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DROPBOX_FILE",
|
||||
),
|
||||
# =========================================================================
|
||||
# ONEDRIVE TOOLS - create and trash files
|
||||
|
|
@ -376,6 +376,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="ONEDRIVE_FILE",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_onedrive_file",
|
||||
|
|
@ -386,11 +387,23 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="ONEDRIVE_FILE",
|
||||
),
|
||||
# =========================================================================
|
||||
# GOOGLE CALENDAR TOOLS - create, update, delete events
|
||||
# GOOGLE CALENDAR TOOLS - search, create, update, delete events
|
||||
# Auto-disabled when no Google Calendar connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="search_calendar_events",
|
||||
description="Search Google Calendar events within a date range",
|
||||
factory=lambda deps: create_search_calendar_events_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="create_calendar_event",
|
||||
description="Create a new event on Google Calendar",
|
||||
|
|
@ -400,6 +413,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_calendar_event",
|
||||
|
|
@ -410,6 +424,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_calendar_event",
|
||||
|
|
@ -420,11 +435,34 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_CALENDAR_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# GMAIL TOOLS - create drafts, update drafts, send emails, trash emails
|
||||
# GMAIL TOOLS - search, read, create drafts, update drafts, send, trash
|
||||
# Auto-disabled when no Gmail connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="search_gmail",
|
||||
description="Search emails in Gmail using Gmail search syntax",
|
||||
factory=lambda deps: create_search_gmail_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_gmail_email",
|
||||
description="Read the full content of a specific Gmail email",
|
||||
factory=lambda deps: create_read_gmail_email_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="create_gmail_draft",
|
||||
description="Create a draft email in Gmail",
|
||||
|
|
@ -434,6 +472,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_gmail_email",
|
||||
|
|
@ -444,6 +483,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="trash_gmail_email",
|
||||
|
|
@ -454,6 +494,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_gmail_draft",
|
||||
|
|
@ -464,40 +505,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
# =========================================================================
|
||||
# JIRA TOOLS - create, update, delete issues
|
||||
# Auto-disabled when no Jira connector is configured (see chat_deepagent.py)
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="create_jira_issue",
|
||||
description="Create a new issue in the user's Jira project",
|
||||
factory=lambda deps: create_create_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_jira_issue",
|
||||
description="Update an existing indexed Jira issue",
|
||||
factory=lambda deps: create_update_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_jira_issue",
|
||||
description="Delete an existing indexed Jira issue",
|
||||
factory=lambda deps: create_delete_jira_issue_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="GOOGLE_GMAIL_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# CONFLUENCE TOOLS - create, update, delete pages
|
||||
|
|
@ -512,6 +520,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="update_confluence_page",
|
||||
|
|
@ -522,6 +531,7 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="delete_confluence_page",
|
||||
|
|
@ -532,6 +542,118 @@ BUILTIN_TOOLS: list[ToolDefinition] = [
|
|||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="CONFLUENCE_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# DISCORD TOOLS - list channels, read messages, send messages
|
||||
# Auto-disabled when no Discord connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="list_discord_channels",
|
||||
description="List text channels in the connected Discord server",
|
||||
factory=lambda deps: create_list_discord_channels_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DISCORD_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_discord_messages",
|
||||
description="Read recent messages from a Discord text channel",
|
||||
factory=lambda deps: create_read_discord_messages_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DISCORD_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_discord_message",
|
||||
description="Send a message to a Discord text channel",
|
||||
factory=lambda deps: create_send_discord_message_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="DISCORD_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# TEAMS TOOLS - list channels, read messages, send messages
|
||||
# Auto-disabled when no Teams connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="list_teams_channels",
|
||||
description="List Microsoft Teams and their channels",
|
||||
factory=lambda deps: create_list_teams_channels_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="TEAMS_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_teams_messages",
|
||||
description="Read recent messages from a Microsoft Teams channel",
|
||||
factory=lambda deps: create_read_teams_messages_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="TEAMS_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="send_teams_message",
|
||||
description="Send a message to a Microsoft Teams channel",
|
||||
factory=lambda deps: create_send_teams_message_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="TEAMS_CONNECTOR",
|
||||
),
|
||||
# =========================================================================
|
||||
# LUMA TOOLS - list events, read event details, create events
|
||||
# Auto-disabled when no Luma connector is configured
|
||||
# =========================================================================
|
||||
ToolDefinition(
|
||||
name="list_luma_events",
|
||||
description="List upcoming and recent Luma events",
|
||||
factory=lambda deps: create_list_luma_events_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="LUMA_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="read_luma_event",
|
||||
description="Read detailed information about a specific Luma event",
|
||||
factory=lambda deps: create_read_luma_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="LUMA_CONNECTOR",
|
||||
),
|
||||
ToolDefinition(
|
||||
name="create_luma_event",
|
||||
description="Create a new event on Luma",
|
||||
factory=lambda deps: create_create_luma_event_tool(
|
||||
db_session=deps["db_session"],
|
||||
search_space_id=deps["search_space_id"],
|
||||
user_id=deps["user_id"],
|
||||
),
|
||||
requires=["db_session", "search_space_id", "user_id"],
|
||||
required_connector="LUMA_CONNECTOR",
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -549,6 +671,22 @@ def get_tool_by_name(name: str) -> ToolDefinition | None:
|
|||
return None
|
||||
|
||||
|
||||
def get_connector_gated_tools(
|
||||
available_connectors: list[str] | None,
|
||||
) -> list[str]:
|
||||
"""Return tool names to disable"""
|
||||
if available_connectors is None:
|
||||
available = set()
|
||||
else:
|
||||
available = set(available_connectors)
|
||||
|
||||
disabled: list[str] = []
|
||||
for tool_def in BUILTIN_TOOLS:
|
||||
if tool_def.required_connector and tool_def.required_connector not in available:
|
||||
disabled.append(tool_def.name)
|
||||
return disabled
|
||||
|
||||
|
||||
def get_all_tool_names() -> list[str]:
|
||||
"""Get names of all registered tools."""
|
||||
return [tool_def.name for tool_def in BUILTIN_TOOLS]
|
||||
|
|
@ -690,15 +828,15 @@ async def build_tools_async(
|
|||
)
|
||||
tools.extend(mcp_tools)
|
||||
logging.info(
|
||||
f"Registered {len(mcp_tools)} MCP tools: {[t.name for t in mcp_tools]}",
|
||||
"Registered %d MCP tools: %s",
|
||||
len(mcp_tools), [t.name for t in mcp_tools],
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail - just continue without MCP tools
|
||||
logging.exception(f"Failed to load MCP tools: {e!s}")
|
||||
logging.exception("Failed to load MCP tools: %s", e)
|
||||
|
||||
# Log all tools being returned to agent
|
||||
logging.info(
|
||||
f"Total tools for agent: {len(tools)} - {[t.name for t in tools]}",
|
||||
"Total tools for agent: %d — %s",
|
||||
len(tools), [t.name for t in tools],
|
||||
)
|
||||
|
||||
return tools
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from app.agents.new_chat.tools.teams.list_channels import (
|
||||
create_list_teams_channels_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.teams.read_messages import (
|
||||
create_read_teams_messages_tool,
|
||||
)
|
||||
from app.agents.new_chat.tools.teams.send_message import (
|
||||
create_send_teams_message_tool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_list_teams_channels_tool",
|
||||
"create_read_teams_messages_tool",
|
||||
"create_send_teams_message_tool",
|
||||
]
|
||||
37
surfsense_backend/app/agents/new_chat/tools/teams/_auth.py
Normal file
37
surfsense_backend/app/agents/new_chat/tools/teams/_auth.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""Shared auth helper for Teams agent tools (Microsoft Graph REST API)."""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.db import SearchSourceConnector, SearchSourceConnectorType
|
||||
|
||||
GRAPH_API = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
|
||||
async def get_teams_connector(
|
||||
db_session: AsyncSession,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
) -> SearchSourceConnector | None:
|
||||
result = await db_session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.search_space_id == search_space_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def get_access_token(
|
||||
db_session: AsyncSession,
|
||||
connector: SearchSourceConnector,
|
||||
) -> str:
|
||||
"""Get a valid Microsoft Graph access token, refreshing if expired."""
|
||||
from app.connectors.teams_connector import TeamsConnector
|
||||
|
||||
tc = TeamsConnector(
|
||||
session=db_session,
|
||||
connector_id=connector.id,
|
||||
)
|
||||
return await tc._get_valid_token()
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_list_teams_channels_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def list_teams_channels() -> dict[str, Any]:
|
||||
"""List all Microsoft Teams and their channels the user has access to.
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of teams, each containing
|
||||
team_id, team_name, and a list of channels (id, name).
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
teams_resp = await client.get(f"{GRAPH_API}/me/joinedTeams", headers=headers)
|
||||
|
||||
if teams_resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||
if teams_resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Graph API error: {teams_resp.status_code}"}
|
||||
|
||||
teams_data = teams_resp.json().get("value", [])
|
||||
result_teams = []
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
for team in teams_data:
|
||||
team_id = team["id"]
|
||||
ch_resp = await client.get(
|
||||
f"{GRAPH_API}/teams/{team_id}/channels",
|
||||
headers=headers,
|
||||
)
|
||||
channels = []
|
||||
if ch_resp.status_code == 200:
|
||||
channels = [
|
||||
{"id": ch["id"], "name": ch.get("displayName", "")}
|
||||
for ch in ch_resp.json().get("value", [])
|
||||
]
|
||||
result_teams.append({
|
||||
"team_id": team_id,
|
||||
"team_name": team.get("displayName", ""),
|
||||
"channels": channels,
|
||||
})
|
||||
|
||||
return {"status": "success", "teams": result_teams, "total_teams": len(result_teams)}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error listing Teams channels: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to list Teams channels."}
|
||||
|
||||
return list_teams_channels
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_read_teams_messages_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def read_teams_messages(
|
||||
team_id: str,
|
||||
channel_id: str,
|
||||
limit: int = 25,
|
||||
) -> dict[str, Any]:
|
||||
"""Read recent messages from a Microsoft Teams channel.
|
||||
|
||||
Args:
|
||||
team_id: The team ID (from list_teams_channels).
|
||||
channel_id: The channel ID (from list_teams_channels).
|
||||
limit: Number of messages to fetch (default 25, max 50).
|
||||
|
||||
Returns:
|
||||
Dictionary with status and a list of messages including
|
||||
id, sender, content, timestamp.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
limit = min(limit, 50)
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.get(
|
||||
f"{GRAPH_API}/teams/{team_id}/channels/{channel_id}/messages",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
params={"$top": limit},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||
if resp.status_code == 403:
|
||||
return {"status": "error", "message": "Insufficient permissions to read this channel."}
|
||||
if resp.status_code != 200:
|
||||
return {"status": "error", "message": f"Graph API error: {resp.status_code}"}
|
||||
|
||||
raw_msgs = resp.json().get("value", [])
|
||||
messages = []
|
||||
for m in raw_msgs:
|
||||
sender = m.get("from", {})
|
||||
user_info = sender.get("user", {}) if sender else {}
|
||||
body = m.get("body", {})
|
||||
messages.append({
|
||||
"id": m.get("id"),
|
||||
"sender": user_info.get("displayName", "Unknown"),
|
||||
"content": body.get("content", ""),
|
||||
"content_type": body.get("contentType", "text"),
|
||||
"timestamp": m.get("createdDateTime", ""),
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"team_id": team_id,
|
||||
"channel_id": channel_id,
|
||||
"messages": messages,
|
||||
"total": len(messages),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error reading Teams messages: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to read Teams messages."}
|
||||
|
||||
return read_teams_messages
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.new_chat.tools.hitl import request_approval
|
||||
|
||||
from ._auth import GRAPH_API, get_access_token, get_teams_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_send_teams_message_tool(
|
||||
db_session: AsyncSession | None = None,
|
||||
search_space_id: int | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
@tool
|
||||
async def send_teams_message(
|
||||
team_id: str,
|
||||
channel_id: str,
|
||||
content: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a message to a Microsoft Teams channel.
|
||||
|
||||
Requires the ChannelMessage.Send OAuth scope. If the user gets a
|
||||
permission error, they may need to re-authenticate with updated scopes.
|
||||
|
||||
Args:
|
||||
team_id: The team ID (from list_teams_channels).
|
||||
channel_id: The channel ID (from list_teams_channels).
|
||||
content: The message text (HTML supported).
|
||||
|
||||
Returns:
|
||||
Dictionary with status, message_id on success.
|
||||
|
||||
IMPORTANT:
|
||||
- If status is "rejected", the user explicitly declined. Do NOT retry.
|
||||
"""
|
||||
if db_session is None or search_space_id is None or user_id is None:
|
||||
return {"status": "error", "message": "Teams tool not properly configured."}
|
||||
|
||||
try:
|
||||
connector = await get_teams_connector(db_session, search_space_id, user_id)
|
||||
if not connector:
|
||||
return {"status": "error", "message": "No Teams connector found."}
|
||||
|
||||
result = request_approval(
|
||||
action_type="teams_send_message",
|
||||
tool_name="send_teams_message",
|
||||
params={"team_id": team_id, "channel_id": channel_id, "content": content},
|
||||
context={"connector_id": connector.id},
|
||||
)
|
||||
|
||||
if result.rejected:
|
||||
return {"status": "rejected", "message": "User declined. Message was not sent."}
|
||||
|
||||
final_content = result.params.get("content", content)
|
||||
final_team = result.params.get("team_id", team_id)
|
||||
final_channel = result.params.get("channel_id", channel_id)
|
||||
|
||||
token = await get_access_token(db_session, connector)
|
||||
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
resp = await client.post(
|
||||
f"{GRAPH_API}/teams/{final_team}/channels/{final_channel}/messages",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"body": {"content": final_content}},
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
return {"status": "auth_error", "message": "Teams token expired. Please re-authenticate.", "connector_type": "teams"}
|
||||
if resp.status_code == 403:
|
||||
return {
|
||||
"status": "insufficient_permissions",
|
||||
"message": "Missing ChannelMessage.Send permission. Please re-authenticate with updated scopes.",
|
||||
}
|
||||
if resp.status_code not in (200, 201):
|
||||
return {"status": "error", "message": f"Graph API error: {resp.status_code} — {resp.text[:200]}"}
|
||||
|
||||
msg_data = resp.json()
|
||||
return {
|
||||
"status": "success",
|
||||
"message_id": msg_data.get("id"),
|
||||
"message": f"Message sent to Teams channel.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
from langgraph.errors import GraphInterrupt
|
||||
|
||||
if isinstance(e, GraphInterrupt):
|
||||
raise
|
||||
logger.error("Error sending Teams message: %s", e, exc_info=True)
|
||||
return {"status": "error", "message": "Failed to send Teams message."}
|
||||
|
||||
return send_teams_message
|
||||
41
surfsense_backend/app/agents/new_chat/tools/tool_response.py
Normal file
41
surfsense_backend/app/agents/new_chat/tools/tool_response.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Standardised response dict factories for LangChain agent tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ToolResponse:
|
||||
|
||||
@staticmethod
|
||||
def success(message: str, **data: Any) -> dict[str, Any]:
|
||||
return {"status": "success", "message": message, **data}
|
||||
|
||||
@staticmethod
|
||||
def error(error: str, **data: Any) -> dict[str, Any]:
|
||||
return {"status": "error", "error": error, **data}
|
||||
|
||||
@staticmethod
|
||||
def auth_error(service: str, **data: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"status": "auth_error",
|
||||
"error": (
|
||||
f"{service} authentication has expired or been revoked. "
|
||||
"Please re-connect the integration in Settings → Connectors."
|
||||
),
|
||||
**data,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def rejected(message: str = "Action was declined by the user.") -> dict[str, Any]:
|
||||
return {"status": "rejected", "message": message}
|
||||
|
||||
@staticmethod
|
||||
def not_found(
|
||||
resource: str, identifier: str, **data: Any
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"status": "not_found",
|
||||
"error": f"{resource} '{identifier}' was not found.",
|
||||
**data,
|
||||
}
|
||||
|
|
@ -135,20 +135,12 @@ celery_app.conf.update(
|
|||
# never block fast user-facing tasks (file uploads, podcasts, etc.)
|
||||
task_routes={
|
||||
# Connector indexing tasks → connectors queue
|
||||
"index_slack_messages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_notion_pages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_github_repos": {"queue": CONNECTORS_QUEUE},
|
||||
"index_linear_issues": {"queue": CONNECTORS_QUEUE},
|
||||
"index_jira_issues": {"queue": CONNECTORS_QUEUE},
|
||||
"index_confluence_pages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_clickup_tasks": {"queue": CONNECTORS_QUEUE},
|
||||
"index_google_calendar_events": {"queue": CONNECTORS_QUEUE},
|
||||
"index_airtable_records": {"queue": CONNECTORS_QUEUE},
|
||||
"index_google_gmail_messages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_google_drive_files": {"queue": CONNECTORS_QUEUE},
|
||||
"index_discord_messages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_teams_messages": {"queue": CONNECTORS_QUEUE},
|
||||
"index_luma_events": {"queue": CONNECTORS_QUEUE},
|
||||
"index_elasticsearch_documents": {"queue": CONNECTORS_QUEUE},
|
||||
"index_crawled_urls": {"queue": CONNECTORS_QUEUE},
|
||||
"index_bookstack_pages": {"queue": CONNECTORS_QUEUE},
|
||||
|
|
|
|||
98
surfsense_backend/app/connectors/exceptions.py
Normal file
98
surfsense_backend/app/connectors/exceptions.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""Standard exception hierarchy for all connectors.
|
||||
|
||||
ConnectorError
|
||||
├── ConnectorAuthError (401/403 — non-retryable)
|
||||
├── ConnectorRateLimitError (429 — retryable, carries ``retry_after``)
|
||||
├── ConnectorTimeoutError (timeout/504 — retryable)
|
||||
└── ConnectorAPIError (5xx or unexpected — retryable when >= 500)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ConnectorError(Exception):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
service: str = "",
|
||||
status_code: int | None = None,
|
||||
response_body: Any = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.service = service
|
||||
self.status_code = status_code
|
||||
self.response_body = response_body
|
||||
|
||||
@property
|
||||
def retryable(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class ConnectorAuthError(ConnectorError):
|
||||
"""Token expired, revoked, insufficient scopes, or needs re-auth (401/403)."""
|
||||
|
||||
@property
|
||||
def retryable(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class ConnectorRateLimitError(ConnectorError):
|
||||
"""429 Too Many Requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Rate limited",
|
||||
*,
|
||||
service: str = "",
|
||||
retry_after: float | None = None,
|
||||
status_code: int = 429,
|
||||
response_body: Any = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
message,
|
||||
service=service,
|
||||
status_code=status_code,
|
||||
response_body=response_body,
|
||||
)
|
||||
self.retry_after = retry_after
|
||||
|
||||
@property
|
||||
def retryable(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class ConnectorTimeoutError(ConnectorError):
|
||||
"""Request timeout or gateway timeout (504)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Request timed out",
|
||||
*,
|
||||
service: str = "",
|
||||
status_code: int | None = None,
|
||||
response_body: Any = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
message,
|
||||
service=service,
|
||||
status_code=status_code,
|
||||
response_body=response_body,
|
||||
)
|
||||
|
||||
@property
|
||||
def retryable(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class ConnectorAPIError(ConnectorError):
|
||||
"""Generic API error (5xx or unexpected status codes)."""
|
||||
|
||||
@property
|
||||
def retryable(self) -> bool:
|
||||
if self.status_code is not None:
|
||||
return self.status_code >= 500
|
||||
return False
|
||||
|
|
@ -30,6 +30,7 @@ from .jira_add_connector_route import router as jira_add_connector_router
|
|||
from .linear_add_connector_route import router as linear_add_connector_router
|
||||
from .logs_routes import router as logs_router
|
||||
from .luma_add_connector_route import router as luma_add_connector_router
|
||||
from .mcp_oauth_route import router as mcp_oauth_router
|
||||
from .memory_routes import router as memory_router
|
||||
from .model_list_routes import router as model_list_router
|
||||
from .new_chat_routes import router as new_chat_router
|
||||
|
|
@ -95,6 +96,7 @@ router.include_router(logs_router)
|
|||
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
||||
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
|
||||
router.include_router(notifications_router) # Notifications with Zero sync
|
||||
router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable
|
||||
router.include_router(composio_router) # Composio OAuth and toolkit management
|
||||
router.include_router(public_chat_router) # Public chat sharing and cloning
|
||||
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages
|
||||
|
|
|
|||
|
|
@ -311,7 +311,7 @@ async def airtable_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=credentials_dict,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -301,7 +301,7 @@ async def clickup_callback(
|
|||
# Update existing connector
|
||||
existing_connector.config = connector_config
|
||||
existing_connector.name = "ClickUp Connector"
|
||||
existing_connector.is_indexable = True
|
||||
existing_connector.is_indexable = False
|
||||
logger.info(
|
||||
f"Updated existing ClickUp connector for user {user_id} in space {space_id}"
|
||||
)
|
||||
|
|
@ -310,7 +310,7 @@ async def clickup_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name="ClickUp Connector",
|
||||
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -326,7 +326,7 @@ async def discord_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -340,7 +340,7 @@ async def calendar_callback(
|
|||
config=creds_dict,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
)
|
||||
session.add(db_connector)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -371,7 +371,7 @@ async def gmail_callback(
|
|||
config=creds_dict,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
)
|
||||
session.add(db_connector)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -386,7 +386,7 @@ async def jira_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -399,7 +399,7 @@ async def linear_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ async def add_luma_connector(
|
|||
if existing_connector:
|
||||
# Update existing connector with new API key
|
||||
existing_connector.config = {"api_key": request.api_key}
|
||||
existing_connector.is_indexable = True
|
||||
existing_connector.is_indexable = False
|
||||
await session.commit()
|
||||
await session.refresh(existing_connector)
|
||||
|
||||
|
|
@ -82,7 +82,7 @@ async def add_luma_connector(
|
|||
config={"api_key": request.api_key},
|
||||
search_space_id=request.space_id,
|
||||
user_id=user.id,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
)
|
||||
|
||||
session.add(db_connector)
|
||||
|
|
|
|||
601
surfsense_backend/app/routes/mcp_oauth_route.py
Normal file
601
surfsense_backend/app/routes/mcp_oauth_route.py
Normal file
|
|
@ -0,0 +1,601 @@
|
|||
"""Generic MCP OAuth 2.1 route for services with official MCP servers.
|
||||
|
||||
Handles the full flow: discovery → DCR → PKCE authorization → token exchange
|
||||
→ MCP_CONNECTOR creation. Currently supports Linear, Jira, ClickUp, Slack,
|
||||
and Airtable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import generate_unique_connector_name
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _fetch_account_metadata(
|
||||
service_key: str, access_token: str, token_json: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch display-friendly account metadata after a successful token exchange.
|
||||
|
||||
DCR services (Linear, Jira, ClickUp) issue MCP-scoped tokens that cannot
|
||||
call their standard REST/GraphQL APIs — metadata discovery for those
|
||||
happens at runtime through MCP tools instead.
|
||||
|
||||
Pre-configured services (Slack, Airtable) use standard OAuth tokens that
|
||||
*can* call their APIs, so we extract metadata here.
|
||||
|
||||
Failures are logged but never block connector creation.
|
||||
"""
|
||||
from app.services.mcp_oauth.registry import MCP_SERVICES
|
||||
|
||||
svc = MCP_SERVICES.get(service_key)
|
||||
if not svc or svc.supports_dcr:
|
||||
return {}
|
||||
|
||||
import httpx
|
||||
|
||||
meta: dict[str, Any] = {}
|
||||
|
||||
try:
|
||||
if service_key == "slack":
|
||||
team_info = token_json.get("team", {})
|
||||
meta["team_id"] = team_info.get("id", "")
|
||||
# TODO: oauth.v2.user.access only returns team.id, not
|
||||
# team.name. To populate team_name, add "team:read" scope
|
||||
# and call GET /api/team.info here.
|
||||
meta["team_name"] = team_info.get("name", "")
|
||||
if meta["team_name"]:
|
||||
meta["display_name"] = meta["team_name"]
|
||||
elif meta["team_id"]:
|
||||
meta["display_name"] = f"Slack ({meta['team_id']})"
|
||||
|
||||
elif service_key == "airtable":
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.airtable.com/v0/meta/whoami",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
whoami = resp.json()
|
||||
meta["user_id"] = whoami.get("id", "")
|
||||
meta["user_email"] = whoami.get("email", "")
|
||||
meta["display_name"] = whoami.get("email", "Airtable")
|
||||
else:
|
||||
logger.warning(
|
||||
"Airtable whoami API returned %d (non-blocking)", resp.status_code,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch account metadata for %s (non-blocking)",
|
||||
service_key,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return meta
|
||||
|
||||
_state_manager: OAuthStateManager | None = None
|
||||
_token_encryption: TokenEncryption | None = None
|
||||
|
||||
|
||||
def _get_state_manager() -> OAuthStateManager:
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def _get_token_encryption() -> TokenEncryption:
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
|
||||
def _build_redirect_uri(service: str) -> str:
|
||||
base = config.BACKEND_URL or "http://localhost:8000"
|
||||
return f"{base.rstrip('/')}/api/v1/auth/mcp/{service}/connector/callback"
|
||||
|
||||
|
||||
def _frontend_redirect(
|
||||
space_id: int | None,
|
||||
*,
|
||||
success: bool = False,
|
||||
connector_id: int | None = None,
|
||||
error: str | None = None,
|
||||
service: str = "mcp",
|
||||
) -> RedirectResponse:
|
||||
if success and space_id:
|
||||
qs = f"success=true&connector={service}-mcp-connector"
|
||||
if connector_id:
|
||||
qs += f"&connectorId={connector_id}"
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
|
||||
)
|
||||
if error and space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
|
||||
)
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /add — start MCP OAuth flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/auth/mcp/{service}/connector/add")
|
||||
async def connect_mcp_service(
|
||||
service: str,
|
||||
space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
from app.services.mcp_oauth.registry import get_service
|
||||
|
||||
svc = get_service(service)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import (
|
||||
discover_oauth_metadata,
|
||||
register_client,
|
||||
)
|
||||
|
||||
metadata = await discover_oauth_metadata(
|
||||
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
|
||||
)
|
||||
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
|
||||
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
|
||||
registration_endpoint = metadata.get("registration_endpoint")
|
||||
|
||||
if not auth_endpoint or not token_endpoint:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
|
||||
)
|
||||
|
||||
redirect_uri = _build_redirect_uri(service)
|
||||
|
||||
if svc.supports_dcr and registration_endpoint:
|
||||
dcr = await register_client(registration_endpoint, redirect_uri)
|
||||
client_id = dcr.get("client_id")
|
||||
client_secret = dcr.get("client_secret", "")
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"DCR for {svc.name} did not return a client_id.",
|
||||
)
|
||||
elif svc.client_id_env:
|
||||
client_id = getattr(config, svc.client_id_env, None)
|
||||
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
|
||||
)
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
enc = _get_token_encryption()
|
||||
|
||||
state = _get_state_manager().generate_secure_state(
|
||||
space_id,
|
||||
user.id,
|
||||
service=service,
|
||||
code_verifier=verifier,
|
||||
mcp_client_id=client_id,
|
||||
mcp_client_secret=enc.encrypt_token(client_secret) if client_secret else "",
|
||||
mcp_token_endpoint=token_endpoint,
|
||||
mcp_url=svc.mcp_url,
|
||||
)
|
||||
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": state,
|
||||
}
|
||||
if svc.scopes:
|
||||
auth_params[svc.scope_param] = " ".join(svc.scopes)
|
||||
|
||||
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Generated %s MCP OAuth URL for user %s, space %s",
|
||||
svc.name, user.id, space_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate {service} MCP OAuth.",
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /callback — handle OAuth redirect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/auth/mcp/{service}/connector/callback")
|
||||
async def mcp_oauth_callback(
|
||||
service: str,
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if error:
|
||||
logger.warning("%s MCP OAuth error: %s", service, error)
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
data = _get_state_manager().validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass
|
||||
return _frontend_redirect(
|
||||
space_id, error=f"{service}_mcp_oauth_denied", service=service,
|
||||
)
|
||||
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="Missing authorization code")
|
||||
if not state:
|
||||
raise HTTPException(status_code=400, detail="Missing state parameter")
|
||||
|
||||
data = _get_state_manager().validate_state(state)
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
svc_key = data.get("service", service)
|
||||
|
||||
if svc_key != service:
|
||||
raise HTTPException(status_code=400, detail="State/path service mismatch")
|
||||
|
||||
from app.services.mcp_oauth.registry import get_service
|
||||
|
||||
svc = get_service(svc_key)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {svc_key}")
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import exchange_code_for_tokens
|
||||
|
||||
enc = _get_token_encryption()
|
||||
client_id = data["mcp_client_id"]
|
||||
client_secret = (
|
||||
enc.decrypt_token(data["mcp_client_secret"])
|
||||
if data.get("mcp_client_secret")
|
||||
else ""
|
||||
)
|
||||
token_endpoint = data["mcp_token_endpoint"]
|
||||
code_verifier = data["code_verifier"]
|
||||
mcp_url = data["mcp_url"]
|
||||
redirect_uri = _build_redirect_uri(service)
|
||||
|
||||
token_json = await exchange_code_for_tokens(
|
||||
token_endpoint=token_endpoint,
|
||||
code=code,
|
||||
redirect_uri=redirect_uri,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
access_token = token_json.get("access_token")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
expires_in = token_json.get("expires_in")
|
||||
scope = token_json.get("scope")
|
||||
|
||||
if not access_token and "authed_user" in token_json:
|
||||
authed = token_json["authed_user"]
|
||||
access_token = authed.get("access_token")
|
||||
refresh_token = refresh_token or authed.get("refresh_token")
|
||||
scope = scope or authed.get("scope")
|
||||
expires_in = expires_in or authed.get("expires_in")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No access token received from {svc.name}.",
|
||||
)
|
||||
|
||||
expires_at = None
|
||||
if expires_in:
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(expires_in)
|
||||
)
|
||||
|
||||
connector_config = {
|
||||
"server_config": {
|
||||
"transport": "streamable-http",
|
||||
"url": mcp_url,
|
||||
},
|
||||
"mcp_service": svc_key,
|
||||
"mcp_oauth": {
|
||||
"client_id": client_id,
|
||||
"client_secret": enc.encrypt_token(client_secret) if client_secret else "",
|
||||
"token_endpoint": token_endpoint,
|
||||
"access_token": enc.encrypt_token(access_token),
|
||||
"refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None,
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"scope": scope,
|
||||
},
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
|
||||
if account_meta:
|
||||
_SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email",
|
||||
"workspace_id", "workspace_name", "organization_name",
|
||||
"organization_url_key", "cloud_id", "site_name", "base_url"}
|
||||
for k, v in account_meta.items():
|
||||
if k in _SAFE_META_KEYS:
|
||||
connector_config[k] = v
|
||||
logger.info(
|
||||
"Stored account metadata for %s: display_name=%s",
|
||||
svc_key, account_meta.get("display_name", ""),
|
||||
)
|
||||
|
||||
# ---- Re-auth path ----
|
||||
db_connector_type = SearchSourceConnectorType(svc.connector_type)
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == db_connector_type,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
_invalidate_cache(space_id)
|
||||
|
||||
logger.info(
|
||||
"Re-authenticated %s MCP connector %s for user %s",
|
||||
svc.name, db_connector.id, user_id,
|
||||
)
|
||||
reauth_return_url = data.get("return_url")
|
||||
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return _frontend_redirect(
|
||||
space_id, success=True, connector_id=db_connector.id, service=service,
|
||||
)
|
||||
|
||||
# ---- New connector path ----
|
||||
naming_identifier = account_meta.get("display_name")
|
||||
connector_name = await generate_unique_connector_name(
|
||||
session,
|
||||
db_connector_type,
|
||||
space_id,
|
||||
user_id,
|
||||
naming_identifier,
|
||||
)
|
||||
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=db_connector_type,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(new_connector)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409, detail="A connector for this service already exists.",
|
||||
) from e
|
||||
|
||||
_invalidate_cache(space_id)
|
||||
|
||||
logger.info(
|
||||
"Created %s MCP connector %s for user %s in space %s",
|
||||
svc.name, new_connector.id, user_id, space_id,
|
||||
)
|
||||
return _frontend_redirect(
|
||||
space_id, success=True, connector_id=new_connector.id, service=service,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to complete %s MCP OAuth: %s", service, e, exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to complete {service} MCP OAuth.",
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /reauth — re-authenticate an existing MCP connector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/auth/mcp/{service}/connector/reauth")
|
||||
async def reauth_mcp_service(
|
||||
service: str,
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
from app.services.mcp_oauth.registry import get_service
|
||||
|
||||
svc = get_service(service)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
|
||||
|
||||
db_connector_type = SearchSourceConnectorType(svc.connector_type)
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == db_connector_type,
|
||||
)
|
||||
)
|
||||
if not result.scalars().first():
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Connector not found or access denied",
|
||||
)
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import (
|
||||
discover_oauth_metadata,
|
||||
register_client,
|
||||
)
|
||||
|
||||
metadata = await discover_oauth_metadata(
|
||||
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
|
||||
)
|
||||
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
|
||||
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
|
||||
registration_endpoint = metadata.get("registration_endpoint")
|
||||
|
||||
if not auth_endpoint or not token_endpoint:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
|
||||
)
|
||||
|
||||
redirect_uri = _build_redirect_uri(service)
|
||||
|
||||
if svc.supports_dcr and registration_endpoint:
|
||||
dcr = await register_client(registration_endpoint, redirect_uri)
|
||||
client_id = dcr.get("client_id")
|
||||
client_secret = dcr.get("client_secret", "")
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"DCR for {svc.name} did not return a client_id.",
|
||||
)
|
||||
elif svc.client_id_env:
|
||||
client_id = getattr(config, svc.client_id_env, None)
|
||||
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
|
||||
)
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
enc = _get_token_encryption()
|
||||
|
||||
extra: dict = {
|
||||
"service": service,
|
||||
"code_verifier": verifier,
|
||||
"mcp_client_id": client_id,
|
||||
"mcp_client_secret": enc.encrypt_token(client_secret) if client_secret else "",
|
||||
"mcp_token_endpoint": token_endpoint,
|
||||
"mcp_url": svc.mcp_url,
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
|
||||
state = _get_state_manager().generate_secure_state(
|
||||
space_id, user.id, **extra,
|
||||
)
|
||||
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": state,
|
||||
}
|
||||
if svc.scopes:
|
||||
auth_params[svc.scope_param] = " ".join(svc.scopes)
|
||||
|
||||
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Initiating %s MCP re-auth for user %s, connector %s",
|
||||
svc.name, user.id, connector_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to initiate %s MCP re-auth: %s", service, e, exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to initiate {service} MCP re-auth.",
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _invalidate_cache(space_id: int) -> None:
|
||||
try:
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(space_id)
|
||||
except Exception:
|
||||
logger.debug("MCP cache invalidation skipped", exc_info=True)
|
||||
620
surfsense_backend/app/routes/oauth_connector_base.py
Normal file
620
surfsense_backend/app/routes/oauth_connector_base.py
Normal file
|
|
@ -0,0 +1,620 @@
|
|||
"""Reusable base for OAuth 2.0 connector routes.
|
||||
|
||||
Subclasses override ``fetch_account_info``, ``build_connector_config``,
|
||||
and ``get_connector_display_name`` to customise provider-specific behaviour.
|
||||
Call ``build_router()`` to get a FastAPI ``APIRouter`` with ``/connector/add``,
|
||||
``/connector/callback``, and ``/connector/reauth`` endpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthConnectorRoute:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider_name: str,
|
||||
connector_type: SearchSourceConnectorType,
|
||||
authorize_url: str,
|
||||
token_url: str,
|
||||
client_id_env: str,
|
||||
client_secret_env: str,
|
||||
redirect_uri_env: str,
|
||||
scopes: list[str],
|
||||
auth_prefix: str,
|
||||
use_pkce: bool = False,
|
||||
token_auth_method: str = "body",
|
||||
is_indexable: bool = True,
|
||||
extra_auth_params: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.provider_name = provider_name
|
||||
self.connector_type = connector_type
|
||||
self.authorize_url = authorize_url
|
||||
self.token_url = token_url
|
||||
self.client_id_env = client_id_env
|
||||
self.client_secret_env = client_secret_env
|
||||
self.redirect_uri_env = redirect_uri_env
|
||||
self.scopes = scopes
|
||||
self.auth_prefix = auth_prefix.rstrip("/")
|
||||
self.use_pkce = use_pkce
|
||||
self.token_auth_method = token_auth_method
|
||||
self.is_indexable = is_indexable
|
||||
self.extra_auth_params = extra_auth_params or {}
|
||||
|
||||
self._state_manager: OAuthStateManager | None = None
|
||||
self._token_encryption: TokenEncryption | None = None
|
||||
|
||||
def _get_client_id(self) -> str:
|
||||
value = getattr(config, self.client_id_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.provider_name.title()} OAuth not configured "
|
||||
f"({self.client_id_env} missing).",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_client_secret(self) -> str:
|
||||
value = getattr(config, self.client_secret_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.provider_name.title()} OAuth not configured "
|
||||
f"({self.client_secret_env} missing).",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_redirect_uri(self) -> str:
|
||||
value = getattr(config, self.redirect_uri_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.redirect_uri_env} not configured.",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_state_manager(self) -> OAuthStateManager:
|
||||
if self._state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="SECRET_KEY not configured for OAuth security.",
|
||||
)
|
||||
self._state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return self._state_manager
|
||||
|
||||
def _get_token_encryption(self) -> TokenEncryption:
|
||||
if self._token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="SECRET_KEY not configured for token encryption.",
|
||||
)
|
||||
self._token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return self._token_encryption
|
||||
|
||||
def _frontend_redirect(
|
||||
self,
|
||||
space_id: int | None,
|
||||
*,
|
||||
success: bool = False,
|
||||
connector_id: int | None = None,
|
||||
error: str | None = None,
|
||||
) -> RedirectResponse:
|
||||
if success and space_id:
|
||||
connector_slug = f"{self.provider_name}-connector"
|
||||
qs = f"success=true&connector={connector_slug}"
|
||||
if connector_id:
|
||||
qs += f"&connectorId={connector_id}"
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
|
||||
)
|
||||
if error and space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
|
||||
)
|
||||
if error:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error={error}"
|
||||
)
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
|
||||
|
||||
async def fetch_account_info(self, access_token: str) -> dict[str, Any]:
|
||||
"""Override to fetch account/workspace info after token exchange.
|
||||
|
||||
Return dict is merged into connector config; key ``"name"`` is used
|
||||
for the display name and dedup.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def build_connector_config(
|
||||
self,
|
||||
token_json: dict[str, Any],
|
||||
account_info: dict[str, Any],
|
||||
encryption: TokenEncryption,
|
||||
) -> dict[str, Any]:
|
||||
"""Override for custom config shapes. Default: standard encrypted OAuth fields."""
|
||||
access_token = token_json.get("access_token", "")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"access_token": encryption.encrypt_token(access_token),
|
||||
"refresh_token": (
|
||||
encryption.encrypt_token(refresh_token) if refresh_token else None
|
||||
),
|
||||
"token_type": token_json.get("token_type", "Bearer"),
|
||||
"expires_in": token_json.get("expires_in"),
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"scope": token_json.get("scope"),
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
cfg.update(account_info)
|
||||
return cfg
|
||||
|
||||
def get_connector_display_name(self, account_info: dict[str, Any]) -> str:
|
||||
return str(account_info.get("name", self.provider_name.title()))
|
||||
|
||||
async def on_token_refresh_failure(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
connector: SearchSourceConnector,
|
||||
) -> None:
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired flag for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _exchange_code(
|
||||
self, code: str, extra_state: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
client_id = self._get_client_id()
|
||||
client_secret = self._get_client_secret()
|
||||
redirect_uri = self._get_redirect_uri()
|
||||
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
body: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
if self.token_auth_method == "basic":
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
else:
|
||||
body["client_id"] = client_id
|
||||
body["client_secret"] = client_secret
|
||||
|
||||
if self.use_pkce:
|
||||
verifier = extra_state.get("code_verifier")
|
||||
if verifier:
|
||||
body["code_verifier"] = verifier
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
self.token_url, data=body, headers=headers, timeout=30.0
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
detail = resp.text
|
||||
try:
|
||||
detail = resp.json().get("error_description", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token exchange failed: {detail}"
|
||||
)
|
||||
|
||||
return resp.json()
|
||||
|
||||
async def refresh_token(
|
||||
self, session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
encryption = self._get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
refresh_tok = connector.config.get("refresh_token")
|
||||
if is_encrypted and refresh_tok:
|
||||
try:
|
||||
refresh_tok = encryption.decrypt_token(refresh_tok)
|
||||
except Exception as e:
|
||||
logger.error("Failed to decrypt refresh token: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to decrypt stored refresh token"
|
||||
) from e
|
||||
|
||||
if not refresh_tok:
|
||||
await self.on_token_refresh_failure(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
)
|
||||
|
||||
client_id = self._get_client_id()
|
||||
client_secret = self._get_client_secret()
|
||||
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
body: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_tok,
|
||||
}
|
||||
|
||||
if self.token_auth_method == "basic":
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
else:
|
||||
body["client_id"] = client_id
|
||||
body["client_secret"] = client_secret
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
self.token_url, data=body, headers=headers, timeout=30.0
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
error_detail = resp.text
|
||||
try:
|
||||
ej = resp.json()
|
||||
error_detail = ej.get("error_description", error_detail)
|
||||
error_code = ej.get("error", "")
|
||||
except Exception:
|
||||
error_code = ""
|
||||
combined = (error_detail + error_code).lower()
|
||||
if any(kw in combined for kw in ("invalid_grant", "expired", "revoked")):
|
||||
await self.on_token_refresh_failure(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"{self.provider_name.title()} authentication failed. "
|
||||
"Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
||||
token_json = resp.json()
|
||||
new_access = token_json.get("access_token")
|
||||
if not new_access:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from refresh"
|
||||
)
|
||||
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
updated_config = dict(connector.config)
|
||||
updated_config["access_token"] = encryption.encrypt_token(new_access)
|
||||
new_refresh = token_json.get("refresh_token")
|
||||
if new_refresh:
|
||||
updated_config["refresh_token"] = encryption.encrypt_token(new_refresh)
|
||||
updated_config["expires_in"] = token_json.get("expires_in")
|
||||
updated_config["expires_at"] = expires_at.isoformat() if expires_at else None
|
||||
updated_config["scope"] = token_json.get("scope", updated_config.get("scope"))
|
||||
updated_config["_token_encrypted"] = True
|
||||
updated_config.pop("auth_expired", None)
|
||||
|
||||
connector.config = updated_config
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info(
|
||||
"Refreshed %s token for connector %s",
|
||||
self.provider_name,
|
||||
connector.id,
|
||||
)
|
||||
return connector
|
||||
|
||||
def build_router(self) -> APIRouter:
|
||||
router = APIRouter()
|
||||
oauth = self
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/add")
|
||||
async def connect(
|
||||
space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
||||
client_id = oauth._get_client_id()
|
||||
state_mgr = oauth._get_state_manager()
|
||||
|
||||
extra_state: dict[str, Any] = {}
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": oauth._get_redirect_uri(),
|
||||
"scope": " ".join(oauth.scopes),
|
||||
}
|
||||
|
||||
if oauth.use_pkce:
|
||||
from app.utils.oauth_security import generate_pkce_pair
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
extra_state["code_verifier"] = verifier
|
||||
auth_params["code_challenge"] = challenge
|
||||
auth_params["code_challenge_method"] = "S256"
|
||||
|
||||
auth_params.update(oauth.extra_auth_params)
|
||||
|
||||
state_encoded = state_mgr.generate_secure_state(
|
||||
space_id, user.id, **extra_state
|
||||
)
|
||||
auth_params["state"] = state_encoded
|
||||
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Generated %s OAuth URL for user %s, space %s",
|
||||
oauth.provider_name,
|
||||
user.id,
|
||||
space_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/reauth")
|
||||
async def reauth(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == oauth.connector_type,
|
||||
)
|
||||
)
|
||||
if not result.scalars().first():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"{oauth.provider_name.title()} connector not found "
|
||||
"or access denied",
|
||||
)
|
||||
|
||||
client_id = oauth._get_client_id()
|
||||
state_mgr = oauth._get_state_manager()
|
||||
|
||||
extra: dict[str, Any] = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/") and not return_url.startswith("//"):
|
||||
extra["return_url"] = return_url
|
||||
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": oauth._get_redirect_uri(),
|
||||
"scope": " ".join(oauth.scopes),
|
||||
}
|
||||
|
||||
if oauth.use_pkce:
|
||||
from app.utils.oauth_security import generate_pkce_pair
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
extra["code_verifier"] = verifier
|
||||
auth_params["code_challenge"] = challenge
|
||||
auth_params["code_challenge_method"] = "S256"
|
||||
|
||||
auth_params.update(oauth.extra_auth_params)
|
||||
|
||||
state_encoded = state_mgr.generate_secure_state(
|
||||
space_id, user.id, **extra
|
||||
)
|
||||
auth_params["state"] = state_encoded
|
||||
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Initiating %s re-auth for user %s, connector %s",
|
||||
oauth.provider_name,
|
||||
user.id,
|
||||
connector_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/callback")
|
||||
async def callback(
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
error_label = f"{oauth.provider_name}_oauth_denied"
|
||||
|
||||
if error:
|
||||
logger.warning("%s OAuth error: %s", oauth.provider_name, error)
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
data = oauth._get_state_manager().validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass
|
||||
return oauth._frontend_redirect(space_id, error=error_label)
|
||||
|
||||
if not code:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing authorization code"
|
||||
)
|
||||
if not state:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing state parameter"
|
||||
)
|
||||
|
||||
state_mgr = oauth._get_state_manager()
|
||||
try:
|
||||
data = state_mgr.validate_state(state)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid or expired state parameter."
|
||||
) from e
|
||||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
|
||||
token_json = await oauth._exchange_code(code, data)
|
||||
|
||||
access_token = token_json.get("access_token", "")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No access token received from {oauth.provider_name.title()}",
|
||||
)
|
||||
|
||||
account_info = await oauth.fetch_account_info(access_token)
|
||||
encryption = oauth._get_token_encryption()
|
||||
connector_config = oauth.build_connector_config(
|
||||
token_json, account_info, encryption
|
||||
)
|
||||
|
||||
display_name = oauth.get_connector_display_name(account_info)
|
||||
|
||||
# --- Re-auth path ---
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == oauth.connector_type,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
"Re-authenticated %s connector %s for user %s",
|
||||
oauth.provider_name,
|
||||
db_connector.id,
|
||||
user_id,
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id, success=True, connector_id=db_connector.id
|
||||
)
|
||||
|
||||
# --- New connector path ---
|
||||
is_dup = await check_duplicate_connector(
|
||||
session,
|
||||
oauth.connector_type,
|
||||
space_id,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
if is_dup:
|
||||
logger.warning(
|
||||
"Duplicate %s connector for user %s (%s)",
|
||||
oauth.provider_name,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id,
|
||||
error=f"duplicate_account&connector={oauth.provider_name}-connector",
|
||||
)
|
||||
|
||||
connector_name = await generate_unique_connector_name(
|
||||
session,
|
||||
oauth.connector_type,
|
||||
space_id,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=oauth.connector_type,
|
||||
is_indexable=oauth.is_indexable,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(new_connector)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409, detail="A connector for this service already exists."
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
"Created %s connector %s for user %s in space %s",
|
||||
oauth.provider_name,
|
||||
new_connector.id,
|
||||
user_id,
|
||||
space_id,
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id, success=True, connector_id=new_connector.id
|
||||
)
|
||||
|
||||
return router
|
||||
|
|
@ -693,27 +693,10 @@ async def index_connector_content(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Index content from a connector to a search space.
|
||||
Requires CONNECTORS_UPDATE permission (to trigger indexing).
|
||||
Index content from a KB connector to a search space.
|
||||
|
||||
Currently supports:
|
||||
- SLACK_CONNECTOR: Indexes messages from all accessible Slack channels
|
||||
- TEAMS_CONNECTOR: Indexes messages from all accessible Microsoft Teams channels
|
||||
- NOTION_CONNECTOR: Indexes pages from all accessible Notion pages
|
||||
- GITHUB_CONNECTOR: Indexes code and documentation from GitHub repositories
|
||||
- LINEAR_CONNECTOR: Indexes issues and comments from Linear
|
||||
- JIRA_CONNECTOR: Indexes issues and comments from Jira
|
||||
- DISCORD_CONNECTOR: Indexes messages from all accessible Discord channels
|
||||
- LUMA_CONNECTOR: Indexes events from Luma
|
||||
- ELASTICSEARCH_CONNECTOR: Indexes documents from Elasticsearch
|
||||
- WEBCRAWLER_CONNECTOR: Indexes web pages from crawled websites
|
||||
|
||||
Args:
|
||||
connector_id: ID of the connector to use
|
||||
search_space_id: ID of the search space to store indexed content
|
||||
|
||||
Returns:
|
||||
Dictionary with indexing status
|
||||
Live connectors (Slack, Teams, Linear, Jira, ClickUp, Calendar, Airtable,
|
||||
Gmail, Discord, Luma) use real-time agent tools instead.
|
||||
"""
|
||||
try:
|
||||
# Get the connector first
|
||||
|
|
@ -770,9 +753,7 @@ async def index_connector_content(
|
|||
|
||||
# For calendar connectors, default to today but allow future dates if explicitly provided
|
||||
if connector.connector_type in [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR,
|
||||
]:
|
||||
# Default to today if no end_date provided (users can manually select future dates)
|
||||
indexing_to = today_str if end_date is None else end_date
|
||||
|
|
@ -796,33 +777,22 @@ async def index_connector_content(
|
|||
# For non-calendar connectors, cap at today
|
||||
indexing_to = end_date if end_date else today_str
|
||||
|
||||
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_slack_messages_task,
|
||||
)
|
||||
from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES
|
||||
|
||||
logger.info(
|
||||
f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_slack_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Slack indexing started in the background."
|
||||
if connector.connector_type in LIVE_CONNECTOR_TYPES:
|
||||
return {
|
||||
"message": (
|
||||
f"{connector.connector_type.value} uses real-time agent tools; "
|
||||
"background indexing is disabled."
|
||||
),
|
||||
"indexing_started": False,
|
||||
"connector_id": connector_id,
|
||||
"search_space_id": search_space_id,
|
||||
"indexing_from": indexing_from,
|
||||
"indexing_to": indexing_to,
|
||||
}
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_teams_messages_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Teams indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_teams_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Teams indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
if connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_notion_pages_task
|
||||
|
||||
logger.info(
|
||||
|
|
@ -844,28 +814,6 @@ async def index_connector_content(
|
|||
)
|
||||
response_message = "GitHub indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_linear_issues_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_linear_issues_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Linear indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.JIRA_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_jira_issues_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering Jira indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_jira_issues_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Jira indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.CONFLUENCE_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_confluence_pages_task,
|
||||
|
|
@ -892,59 +840,6 @@ async def index_connector_content(
|
|||
)
|
||||
response_message = "BookStack indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.CLICKUP_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_clickup_tasks_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering ClickUp indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_clickup_tasks_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "ClickUp indexing started in the background."
|
||||
|
||||
elif (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_google_calendar_events_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Google Calendar indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_google_calendar_events_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Google Calendar indexing started in the background."
|
||||
elif connector.connector_type == SearchSourceConnectorType.AIRTABLE_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_airtable_records_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Airtable indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_airtable_records_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Airtable indexing started in the background."
|
||||
elif (
|
||||
connector.connector_type == SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_google_gmail_messages_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Google Gmail indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_google_gmail_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Google Gmail indexing started in the background."
|
||||
|
||||
elif (
|
||||
connector.connector_type == SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
|
|
@ -1089,30 +984,6 @@ async def index_connector_content(
|
|||
)
|
||||
response_message = "Dropbox indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_discord_messages_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_discord_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Discord indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_luma_events_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering Luma indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_luma_events_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Luma indexing started in the background."
|
||||
|
||||
elif (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR
|
||||
|
|
@ -1338,57 +1209,6 @@ async def _update_connector_timestamp_by_id(session: AsyncSession, connector_id:
|
|||
await session.rollback()
|
||||
|
||||
|
||||
async def run_slack_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Slack indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_slack_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_slack_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Slack indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Slack connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_slack_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_slack_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
_AUTH_ERROR_PATTERNS = (
|
||||
"failed to refresh linear oauth",
|
||||
"failed to refresh your notion connection",
|
||||
|
|
@ -1927,215 +1747,6 @@ async def run_github_indexing(
|
|||
)
|
||||
|
||||
|
||||
# Add new helper functions for Linear indexing
|
||||
async def run_linear_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run Linear indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_linear_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing Linear connector {connector_id}")
|
||||
|
||||
|
||||
async def run_linear_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Linear indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Linear connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_linear_issues
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_linear_issues,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for discord indexing
|
||||
async def run_discord_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Discord indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_discord_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_discord_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Discord indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Discord connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_discord_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_discord_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
async def run_teams_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Microsoft Teams indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_teams_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_teams_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Microsoft Teams indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Teams connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers.teams_indexer import index_teams_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_teams_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Jira indexing
|
||||
async def run_jira_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run Jira indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing Jira connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_jira_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing Jira connector {connector_id}")
|
||||
|
||||
|
||||
async def run_jira_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Jira indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Jira connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_jira_issues
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_jira_issues,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Confluence indexing
|
||||
async def run_confluence_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
|
|
@ -2191,112 +1802,6 @@ async def run_confluence_indexing(
|
|||
)
|
||||
|
||||
|
||||
# Add new helper functions for ClickUp indexing
|
||||
async def run_clickup_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run ClickUp indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing ClickUp connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_clickup_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing ClickUp connector {connector_id}")
|
||||
|
||||
|
||||
async def run_clickup_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run ClickUp indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the ClickUp connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_clickup_tasks
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_clickup_tasks,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Airtable indexing
|
||||
async def run_airtable_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run Airtable indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing Airtable connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_airtable_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing Airtable connector {connector_id}")
|
||||
|
||||
|
||||
async def run_airtable_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Airtable indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Airtable connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_airtable_records
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_airtable_records,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Google Calendar indexing
|
||||
async def run_google_calendar_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
|
|
@ -2835,58 +2340,6 @@ async def run_dropbox_indexing(
|
|||
logger.error(f"Failed to update notification: {notif_error!s}")
|
||||
|
||||
|
||||
# Add new helper functions for luma indexing
|
||||
async def run_luma_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Luma indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_luma_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_luma_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Luma indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Luma connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_luma_events
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_luma_events,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
async def run_elasticsearch_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ async def slack_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.SLACK_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ SCOPES = [
|
|||
"Team.ReadBasic.All", # Read basic team information
|
||||
"Channel.ReadBasic.All", # Read basic channel information
|
||||
"ChannelMessage.Read.All", # Read messages in channels
|
||||
"ChannelMessage.Send", # Send messages in channels
|
||||
]
|
||||
|
||||
# Initialize security utilities
|
||||
|
|
@ -320,7 +321,7 @@ async def teams_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ COMPOSIO_TOOLKIT_NAMES = {
|
|||
}
|
||||
|
||||
# Toolkits that support indexing (Phase 1: Google services only)
|
||||
INDEXABLE_TOOLKITS = {"googledrive", "gmail", "googlecalendar"}
|
||||
INDEXABLE_TOOLKITS = {"googledrive"}
|
||||
|
||||
# Mapping of toolkit IDs to connector types
|
||||
TOOLKIT_TO_CONNECTOR_TYPE = {
|
||||
|
|
|
|||
|
|
@ -290,6 +290,12 @@ class LLMRouterService:
|
|||
|
||||
instance._router = Router(**router_kwargs)
|
||||
instance._initialized = True
|
||||
|
||||
global _cached_context_profile, _cached_context_profile_computed
|
||||
_cached_context_profile = None
|
||||
_cached_context_profile_computed = False
|
||||
_router_instance_cache.clear()
|
||||
|
||||
logger.info(
|
||||
"LLM Router initialized with %d deployments, "
|
||||
"strategy: %s, context_window_fallbacks: %s, fallbacks: %s",
|
||||
|
|
|
|||
0
surfsense_backend/app/services/mcp_oauth/__init__.py
Normal file
0
surfsense_backend/app/services/mcp_oauth/__init__.py
Normal file
121
surfsense_backend/app/services/mcp_oauth/discovery.py
Normal file
121
surfsense_backend/app/services/mcp_oauth/discovery.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""MCP OAuth 2.1 metadata discovery, Dynamic Client Registration, and token exchange."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def discover_oauth_metadata(
|
||||
mcp_url: str,
|
||||
*,
|
||||
origin_override: str | None = None,
|
||||
timeout: float = 15.0,
|
||||
) -> dict:
|
||||
"""Fetch OAuth 2.1 metadata from the MCP server's well-known endpoint.
|
||||
|
||||
Per the MCP spec the discovery document lives at the *origin* of the
|
||||
MCP server URL. ``origin_override`` can be used when the OAuth server
|
||||
lives on a different domain (e.g. Airtable: MCP at ``mcp.airtable.com``,
|
||||
OAuth at ``airtable.com``).
|
||||
"""
|
||||
if origin_override:
|
||||
origin = origin_override.rstrip("/")
|
||||
else:
|
||||
parsed = urlparse(mcp_url)
|
||||
origin = f"{parsed.scheme}://{parsed.netloc}"
|
||||
discovery_url = f"{origin}/.well-known/oauth-authorization-server"
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.get(discovery_url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def register_client(
|
||||
registration_endpoint: str,
|
||||
redirect_uri: str,
|
||||
*,
|
||||
client_name: str = "SurfSense",
|
||||
timeout: float = 15.0,
|
||||
) -> dict:
|
||||
"""Perform Dynamic Client Registration (RFC 7591)."""
|
||||
payload = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.post(
|
||||
registration_endpoint, json=payload, timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
token_endpoint: str,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
code_verifier: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> dict:
|
||||
"""Exchange an authorization code for access + refresh tokens."""
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.post(
|
||||
token_endpoint,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Authorization": f"Basic {creds}",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def refresh_access_token(
|
||||
token_endpoint: str,
|
||||
refresh_token: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
*,
|
||||
timeout: float = 30.0,
|
||||
) -> dict:
|
||||
"""Refresh an expired access token."""
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.post(
|
||||
token_endpoint,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Authorization": f"Basic {creds}",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
161
surfsense_backend/app/services/mcp_oauth/registry.py
Normal file
161
surfsense_backend/app/services/mcp_oauth/registry.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"""Registry of MCP services with OAuth support.
|
||||
|
||||
Each entry maps a URL-safe service key to its MCP server endpoint and
|
||||
authentication configuration. Services with ``supports_dcr=True`` use
|
||||
RFC 7591 Dynamic Client Registration (the MCP server issues its own
|
||||
credentials); the rest use pre-configured credentials via env vars.
|
||||
|
||||
``allowed_tools`` whitelists which MCP tools to expose to the agent.
|
||||
An empty list means "load every tool the server advertises" (used for
|
||||
user-managed generic MCP servers). Service-specific entries should
|
||||
curate this list to keep the agent's tool count low and selection
|
||||
accuracy high.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from app.db import SearchSourceConnectorType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MCPServiceConfig:
|
||||
name: str
|
||||
mcp_url: str
|
||||
connector_type: str
|
||||
supports_dcr: bool = True
|
||||
oauth_discovery_origin: str | None = None
|
||||
client_id_env: str | None = None
|
||||
client_secret_env: str | None = None
|
||||
scopes: list[str] = field(default_factory=list)
|
||||
scope_param: str = "scope"
|
||||
auth_endpoint_override: str | None = None
|
||||
token_endpoint_override: str | None = None
|
||||
allowed_tools: list[str] = field(default_factory=list)
|
||||
readonly_tools: frozenset[str] = field(default_factory=frozenset)
|
||||
account_metadata_keys: list[str] = field(default_factory=list)
|
||||
"""``connector.config`` keys exposed by ``get_connected_accounts``.
|
||||
|
||||
Only listed keys are returned to the LLM — tokens and secrets are
|
||||
never included. Every service should at least have its
|
||||
``display_name`` populated during OAuth; additional service-specific
|
||||
fields (e.g. Jira ``cloud_id``) are listed here so the LLM can pass
|
||||
them to action tools.
|
||||
"""
|
||||
|
||||
|
||||
MCP_SERVICES: dict[str, MCPServiceConfig] = {
|
||||
"linear": MCPServiceConfig(
|
||||
name="Linear",
|
||||
mcp_url="https://mcp.linear.app/mcp",
|
||||
connector_type="LINEAR_CONNECTOR",
|
||||
allowed_tools=[
|
||||
"list_issues",
|
||||
"get_issue",
|
||||
"save_issue",
|
||||
],
|
||||
readonly_tools=frozenset({"list_issues", "get_issue"}),
|
||||
account_metadata_keys=["organization_name", "organization_url_key"],
|
||||
),
|
||||
"jira": MCPServiceConfig(
|
||||
name="Jira",
|
||||
mcp_url="https://mcp.atlassian.com/v1/mcp",
|
||||
connector_type="JIRA_CONNECTOR",
|
||||
allowed_tools=[
|
||||
"getAccessibleAtlassianResources",
|
||||
"searchJiraIssuesUsingJql",
|
||||
"getVisibleJiraProjects",
|
||||
"getJiraProjectIssueTypesMetadata",
|
||||
"createJiraIssue",
|
||||
"editJiraIssue",
|
||||
],
|
||||
readonly_tools=frozenset({
|
||||
"getAccessibleAtlassianResources",
|
||||
"searchJiraIssuesUsingJql",
|
||||
"getVisibleJiraProjects",
|
||||
"getJiraProjectIssueTypesMetadata",
|
||||
}),
|
||||
account_metadata_keys=["cloud_id", "site_name", "base_url"],
|
||||
),
|
||||
"clickup": MCPServiceConfig(
|
||||
name="ClickUp",
|
||||
mcp_url="https://mcp.clickup.com/mcp",
|
||||
connector_type="CLICKUP_CONNECTOR",
|
||||
allowed_tools=[
|
||||
"clickup_search",
|
||||
"clickup_get_task",
|
||||
],
|
||||
readonly_tools=frozenset({"clickup_search", "clickup_get_task"}),
|
||||
account_metadata_keys=["workspace_id", "workspace_name"],
|
||||
),
|
||||
"slack": MCPServiceConfig(
|
||||
name="Slack",
|
||||
mcp_url="https://mcp.slack.com/mcp",
|
||||
connector_type="SLACK_CONNECTOR",
|
||||
supports_dcr=False,
|
||||
client_id_env="SLACK_CLIENT_ID",
|
||||
client_secret_env="SLACK_CLIENT_SECRET",
|
||||
auth_endpoint_override="https://slack.com/oauth/v2_user/authorize",
|
||||
token_endpoint_override="https://slack.com/api/oauth.v2.user.access",
|
||||
scopes=[
|
||||
"search:read.public", "search:read.private", "search:read.mpim", "search:read.im",
|
||||
"channels:history", "groups:history", "mpim:history", "im:history",
|
||||
],
|
||||
allowed_tools=[
|
||||
"slack_search_channels",
|
||||
"slack_read_channel",
|
||||
"slack_read_thread",
|
||||
],
|
||||
readonly_tools=frozenset({"slack_search_channels", "slack_read_channel", "slack_read_thread"}),
|
||||
# TODO: oauth.v2.user.access only returns team.id, not team.name.
|
||||
# To populate team_name, either add "team:read" scope and call
|
||||
# GET /api/team.info during OAuth callback, or switch to oauth.v2.access.
|
||||
account_metadata_keys=["team_id", "team_name"],
|
||||
),
|
||||
"airtable": MCPServiceConfig(
|
||||
name="Airtable",
|
||||
mcp_url="https://mcp.airtable.com/mcp",
|
||||
connector_type="AIRTABLE_CONNECTOR",
|
||||
supports_dcr=False,
|
||||
oauth_discovery_origin="https://airtable.com",
|
||||
client_id_env="AIRTABLE_CLIENT_ID",
|
||||
client_secret_env="AIRTABLE_CLIENT_SECRET",
|
||||
scopes=["data.records:read", "schema.bases:read"],
|
||||
allowed_tools=[
|
||||
"list_bases",
|
||||
"list_tables_for_base",
|
||||
"list_records_for_table",
|
||||
],
|
||||
readonly_tools=frozenset({"list_bases", "list_tables_for_base", "list_records_for_table"}),
|
||||
account_metadata_keys=["user_id", "user_email"],
|
||||
),
|
||||
}
|
||||
|
||||
_CONNECTOR_TYPE_TO_SERVICE: dict[str, MCPServiceConfig] = {
|
||||
svc.connector_type: svc for svc in MCP_SERVICES.values()
|
||||
}
|
||||
|
||||
LIVE_CONNECTOR_TYPES: frozenset[SearchSourceConnectorType] = frozenset({
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR,
|
||||
SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR,
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR,
|
||||
SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR,
|
||||
})
|
||||
|
||||
|
||||
def get_service(key: str) -> MCPServiceConfig | None:
|
||||
return MCP_SERVICES.get(key)
|
||||
|
||||
|
||||
def get_service_by_connector_type(connector_type: str) -> MCPServiceConfig | None:
|
||||
"""Look up an MCP service config by its ``connector_type`` enum value."""
|
||||
return _CONNECTOR_TYPE_TO_SERVICE.get(connector_type)
|
||||
|
|
@ -227,8 +227,6 @@ class NotionToolMetadataService:
|
|||
async def _check_account_health(self, connector_id: int) -> bool:
|
||||
"""Check if a Notion connector's token is still valid.
|
||||
|
||||
Uses a lightweight ``users.me()`` call to verify the token.
|
||||
|
||||
Returns True if the token is expired/invalid, False if healthy.
|
||||
"""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -39,52 +39,6 @@ def _handle_greenlet_error(e: Exception, task_name: str, connector_id: int) -> N
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_slack_messages", bind=True)
|
||||
def index_slack_messages_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Slack messages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_slack_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
_handle_greenlet_error(e, "index_slack_messages", connector_id)
|
||||
raise
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_slack_messages(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Slack messages with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_slack_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_slack_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_notion_pages", bind=True)
|
||||
def index_notion_pages_task(
|
||||
self,
|
||||
|
|
@ -174,92 +128,6 @@ async def _index_github_repos(
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_linear_issues", bind=True)
|
||||
def index_linear_issues_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Linear issues."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_linear_issues(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_linear_issues(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Linear issues with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_linear_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_linear_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_jira_issues", bind=True)
|
||||
def index_jira_issues_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Jira issues."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_jira_issues(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_jira_issues(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Jira issues with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_jira_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_jira_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_confluence_pages", bind=True)
|
||||
def index_confluence_pages_task(
|
||||
self,
|
||||
|
|
@ -303,49 +171,6 @@ async def _index_confluence_pages(
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_clickup_tasks", bind=True)
|
||||
def index_clickup_tasks_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index ClickUp tasks."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_clickup_tasks(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_clickup_tasks(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index ClickUp tasks with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_clickup_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_clickup_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_google_calendar_events", bind=True)
|
||||
def index_google_calendar_events_task(
|
||||
self,
|
||||
|
|
@ -392,49 +217,6 @@ async def _index_google_calendar_events(
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_airtable_records", bind=True)
|
||||
def index_airtable_records_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Airtable records."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_airtable_records(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_airtable_records(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Airtable records with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_airtable_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_airtable_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_google_gmail_messages", bind=True)
|
||||
def index_google_gmail_messages_task(
|
||||
self,
|
||||
|
|
@ -622,135 +404,6 @@ async def _index_dropbox_files(
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_discord_messages", bind=True)
|
||||
def index_discord_messages_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Discord messages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_discord_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_discord_messages(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Discord messages with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_discord_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_discord_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_teams_messages", bind=True)
|
||||
def index_teams_messages_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Microsoft Teams messages."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_teams_messages(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_teams_messages(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Microsoft Teams messages with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_teams_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_teams_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_luma_events", bind=True)
|
||||
def index_luma_events_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Celery task to index Luma events."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_luma_events(
|
||||
connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_luma_events(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Index Luma events with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_luma_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_luma_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_elasticsearch_documents", bind=True)
|
||||
def index_elasticsearch_documents_task(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -51,50 +51,51 @@ async def _check_and_trigger_schedules():
|
|||
|
||||
logger.info(f"Found {len(due_connectors)} connectors due for indexing")
|
||||
|
||||
# Import all indexing tasks
|
||||
# Import indexing tasks for KB connectors only.
|
||||
# Live connectors (Linear, Slack, Jira, ClickUp, Airtable, Discord,
|
||||
# Teams, Gmail, Calendar, Luma) use real-time tools instead.
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_airtable_records_task,
|
||||
index_clickup_tasks_task,
|
||||
index_confluence_pages_task,
|
||||
index_crawled_urls_task,
|
||||
index_discord_messages_task,
|
||||
index_elasticsearch_documents_task,
|
||||
index_github_repos_task,
|
||||
index_google_calendar_events_task,
|
||||
index_google_drive_files_task,
|
||||
index_google_gmail_messages_task,
|
||||
index_jira_issues_task,
|
||||
index_linear_issues_task,
|
||||
index_luma_events_task,
|
||||
index_notion_pages_task,
|
||||
index_slack_messages_task,
|
||||
)
|
||||
|
||||
# Map connector types to their tasks
|
||||
task_map = {
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task,
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
|
||||
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task,
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task,
|
||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
|
||||
SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task,
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
|
||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task,
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task,
|
||||
SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task,
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task,
|
||||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
|
||||
SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
|
||||
# Composio connector types (unified with native Google tasks)
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_DRIVE_CONNECTOR: index_google_drive_files_task,
|
||||
SearchSourceConnectorType.COMPOSIO_GMAIL_CONNECTOR: index_google_gmail_messages_task,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
|
||||
}
|
||||
|
||||
from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES
|
||||
|
||||
# Disable obsolete periodic indexing for live connectors in one batch.
|
||||
live_disabled = []
|
||||
for connector in due_connectors:
|
||||
if connector.connector_type in LIVE_CONNECTOR_TYPES:
|
||||
connector.periodic_indexing_enabled = False
|
||||
connector.next_scheduled_at = None
|
||||
live_disabled.append(connector)
|
||||
if live_disabled:
|
||||
await session.commit()
|
||||
for c in live_disabled:
|
||||
logger.info(
|
||||
"Disabled obsolete periodic indexing for live connector %s (%s)",
|
||||
c.id,
|
||||
c.connector_type.value,
|
||||
)
|
||||
|
||||
# Trigger indexing for each due connector
|
||||
for connector in due_connectors:
|
||||
if connector in live_disabled:
|
||||
continue
|
||||
|
||||
# Primary guard: Redis lock indicates a task is currently running.
|
||||
if is_connector_indexing_locked(connector.id):
|
||||
logger.info(
|
||||
|
|
|
|||
|
|
@ -1,77 +1,31 @@
|
|||
"""
|
||||
Connector indexers module for background tasks.
|
||||
|
||||
This module provides a collection of connector indexers for different platforms
|
||||
and services. Each indexer is responsible for handling the indexing of content
|
||||
from a specific connector type.
|
||||
|
||||
Available indexers:
|
||||
- Slack: Index messages from Slack channels
|
||||
- Notion: Index pages from Notion workspaces
|
||||
- GitHub: Index repositories and files from GitHub
|
||||
- Linear: Index issues from Linear workspaces
|
||||
- Jira: Index issues from Jira projects
|
||||
- Confluence: Index pages from Confluence spaces
|
||||
- BookStack: Index pages from BookStack wiki instances
|
||||
- Discord: Index messages from Discord servers
|
||||
- ClickUp: Index tasks from ClickUp workspaces
|
||||
- Google Gmail: Index messages from Google Gmail
|
||||
- Google Calendar: Index events from Google Calendar
|
||||
- Luma: Index events from Luma
|
||||
- Webcrawler: Index crawled URLs
|
||||
- Elasticsearch: Index documents from Elasticsearch instances
|
||||
Each indexer handles content indexing from a specific connector type.
|
||||
Live connectors (Slack, Linear, Jira, ClickUp, Airtable, Discord, Teams,
|
||||
Luma) now use real-time agent tools instead of background indexing.
|
||||
"""
|
||||
|
||||
# Communication platforms
|
||||
# Calendar and scheduling
|
||||
from .airtable_indexer import index_airtable_records
|
||||
from .bookstack_indexer import index_bookstack_pages
|
||||
|
||||
# Note: composio_indexer is imported directly in connector_tasks.py to avoid circular imports
|
||||
from .clickup_indexer import index_clickup_tasks
|
||||
from .confluence_indexer import index_confluence_pages
|
||||
from .discord_indexer import index_discord_messages
|
||||
|
||||
# Development platforms
|
||||
from .elasticsearch_indexer import index_elasticsearch_documents
|
||||
from .github_indexer import index_github_repos
|
||||
from .google_calendar_indexer import index_google_calendar_events
|
||||
from .google_drive_indexer import index_google_drive_files
|
||||
from .google_gmail_indexer import index_google_gmail_messages
|
||||
from .jira_indexer import index_jira_issues
|
||||
|
||||
# Issue tracking and project management
|
||||
from .linear_indexer import index_linear_issues
|
||||
|
||||
# Documentation and knowledge management
|
||||
from .luma_indexer import index_luma_events
|
||||
from .notion_indexer import index_notion_pages
|
||||
from .obsidian_indexer import index_obsidian_vault
|
||||
from .slack_indexer import index_slack_messages
|
||||
from .webcrawler_indexer import index_crawled_urls
|
||||
|
||||
__all__ = [ # noqa: RUF022
|
||||
"index_airtable_records",
|
||||
__all__ = [
|
||||
"index_bookstack_pages",
|
||||
# "index_composio_connector", # Imported directly in connector_tasks.py to avoid circular imports
|
||||
"index_clickup_tasks",
|
||||
"index_confluence_pages",
|
||||
"index_discord_messages",
|
||||
# Development platforms
|
||||
"index_elasticsearch_documents",
|
||||
"index_github_repos",
|
||||
# Calendar and scheduling
|
||||
"index_google_calendar_events",
|
||||
"index_google_drive_files",
|
||||
"index_luma_events",
|
||||
"index_jira_issues",
|
||||
# Issue tracking and project management
|
||||
"index_linear_issues",
|
||||
# Documentation and knowledge management
|
||||
"index_google_gmail_messages",
|
||||
"index_notion_pages",
|
||||
"index_obsidian_vault",
|
||||
"index_crawled_urls",
|
||||
# Communication platforms
|
||||
"index_slack_messages",
|
||||
"index_google_gmail_messages",
|
||||
]
|
||||
|
|
|
|||
129
surfsense_backend/app/utils/async_retry.py
Normal file
129
surfsense_backend/app/utils/async_retry.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""Async retry decorators for connector API calls, built on tenacity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception,
|
||||
stop_after_attempt,
|
||||
stop_after_delay,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from app.connectors.exceptions import (
|
||||
ConnectorAPIError,
|
||||
ConnectorAuthError,
|
||||
ConnectorError,
|
||||
ConnectorRateLimitError,
|
||||
ConnectorTimeoutError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
|
||||
def _is_retryable(exc: BaseException) -> bool:
|
||||
if isinstance(exc, ConnectorError):
|
||||
return exc.retryable
|
||||
if isinstance(exc, (httpx.TimeoutException, httpx.ConnectError)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def build_retry(
|
||||
*,
|
||||
max_attempts: int = 4,
|
||||
max_delay: float = 60.0,
|
||||
initial_delay: float = 1.0,
|
||||
total_timeout: float = 180.0,
|
||||
service: str = "",
|
||||
) -> Callable:
|
||||
"""Configurable tenacity ``@retry`` decorator with exponential backoff + jitter."""
|
||||
_logger = logging.getLogger(f"connector.retry.{service}") if service else logger
|
||||
|
||||
return retry(
|
||||
retry=retry_if_exception(_is_retryable),
|
||||
stop=(stop_after_attempt(max_attempts) | stop_after_delay(total_timeout)),
|
||||
wait=wait_exponential_jitter(initial=initial_delay, max=max_delay),
|
||||
reraise=True,
|
||||
before_sleep=before_sleep_log(_logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
def retry_on_transient(
|
||||
*,
|
||||
service: str = "",
|
||||
max_attempts: int = 4,
|
||||
) -> Callable:
|
||||
"""Shorthand: retry up to *max_attempts* on rate-limits, timeouts, and 5xx."""
|
||||
return build_retry(max_attempts=max_attempts, service=service)
|
||||
|
||||
|
||||
def raise_for_status(
|
||||
response: httpx.Response,
|
||||
*,
|
||||
service: str = "",
|
||||
) -> None:
|
||||
"""Map non-2xx httpx responses to the appropriate ``ConnectorError``."""
|
||||
if response.is_success:
|
||||
return
|
||||
|
||||
status = response.status_code
|
||||
|
||||
try:
|
||||
body = response.json()
|
||||
except Exception:
|
||||
body = response.text[:500] if response.text else None
|
||||
|
||||
if status == 429:
|
||||
retry_after_raw = response.headers.get("Retry-After")
|
||||
retry_after: float | None = None
|
||||
if retry_after_raw:
|
||||
try:
|
||||
retry_after = float(retry_after_raw)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
raise ConnectorRateLimitError(
|
||||
f"{service} rate limited (429)",
|
||||
service=service,
|
||||
retry_after=retry_after,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
if status in (401, 403):
|
||||
raise ConnectorAuthError(
|
||||
f"{service} authentication failed ({status})",
|
||||
service=service,
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
if status == 504:
|
||||
raise ConnectorTimeoutError(
|
||||
f"{service} gateway timeout (504)",
|
||||
service=service,
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
if status >= 500:
|
||||
raise ConnectorAPIError(
|
||||
f"{service} server error ({status})",
|
||||
service=service,
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
raise ConnectorAPIError(
|
||||
f"{service} request failed ({status})",
|
||||
service=service,
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
|
@ -39,7 +39,7 @@ BASE_NAME_FOR_TYPE = {
|
|||
def get_base_name_for_type(connector_type: SearchSourceConnectorType) -> str:
|
||||
"""Get a friendly display name for a connector type."""
|
||||
return BASE_NAME_FOR_TYPE.get(
|
||||
connector_type, connector_type.replace("_", " ").title()
|
||||
connector_type, connector_type.value.replace("_", " ").title()
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -231,9 +231,11 @@ async def generate_unique_connector_name(
|
|||
base = get_base_name_for_type(connector_type)
|
||||
|
||||
if identifier:
|
||||
return f"{base} - {identifier}"
|
||||
name = f"{base} - {identifier}"
|
||||
return await ensure_unique_connector_name(
|
||||
session, name, search_space_id, user_id,
|
||||
)
|
||||
|
||||
# Fallback: use counter for uniqueness
|
||||
count = await count_connectors_of_type(
|
||||
session, connector_type, search_space_id, user_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,19 +18,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# Mapping of connector types to their corresponding Celery task names
|
||||
CONNECTOR_TASK_MAP = {
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR: "index_slack_messages",
|
||||
SearchSourceConnectorType.TEAMS_CONNECTOR: "index_teams_messages",
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR: "index_notion_pages",
|
||||
SearchSourceConnectorType.GITHUB_CONNECTOR: "index_github_repos",
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR: "index_linear_issues",
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR: "index_jira_issues",
|
||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: "index_confluence_pages",
|
||||
SearchSourceConnectorType.CLICKUP_CONNECTOR: "index_clickup_tasks",
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "index_google_calendar_events",
|
||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: "index_airtable_records",
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: "index_google_gmail_messages",
|
||||
SearchSourceConnectorType.DISCORD_CONNECTOR: "index_discord_messages",
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR: "index_luma_events",
|
||||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: "index_elasticsearch_documents",
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: "index_crawled_urls",
|
||||
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: "index_bookstack_pages",
|
||||
|
|
@ -84,40 +74,20 @@ def create_periodic_schedule(
|
|||
f"(frequency: {frequency_minutes} minutes). Triggering first run..."
|
||||
)
|
||||
|
||||
# Import all indexing tasks
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_airtable_records_task,
|
||||
index_bookstack_pages_task,
|
||||
index_clickup_tasks_task,
|
||||
index_confluence_pages_task,
|
||||
index_crawled_urls_task,
|
||||
index_discord_messages_task,
|
||||
index_elasticsearch_documents_task,
|
||||
index_github_repos_task,
|
||||
index_google_calendar_events_task,
|
||||
index_google_gmail_messages_task,
|
||||
index_jira_issues_task,
|
||||
index_linear_issues_task,
|
||||
index_luma_events_task,
|
||||
index_notion_pages_task,
|
||||
index_obsidian_vault_task,
|
||||
index_slack_messages_task,
|
||||
)
|
||||
|
||||
# Map connector type to task
|
||||
task_map = {
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR: index_slack_messages_task,
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR: index_notion_pages_task,
|
||||
SearchSourceConnectorType.GITHUB_CONNECTOR: index_github_repos_task,
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR: index_linear_issues_task,
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR: index_jira_issues_task,
|
||||
SearchSourceConnectorType.CONFLUENCE_CONNECTOR: index_confluence_pages_task,
|
||||
SearchSourceConnectorType.CLICKUP_CONNECTOR: index_clickup_tasks_task,
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: index_google_calendar_events_task,
|
||||
SearchSourceConnectorType.AIRTABLE_CONNECTOR: index_airtable_records_task,
|
||||
SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR: index_google_gmail_messages_task,
|
||||
SearchSourceConnectorType.DISCORD_CONNECTOR: index_discord_messages_task,
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR: index_luma_events_task,
|
||||
SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR: index_elasticsearch_documents_task,
|
||||
SearchSourceConnectorType.WEBCRAWLER_CONNECTOR: index_crawled_urls_task,
|
||||
SearchSourceConnectorType.BOOKSTACK_CONNECTOR: index_bookstack_pages_task,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import { Spinner } from "@/components/ui/spinner";
|
|||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants";
|
||||
import { useConnectorStatus } from "../hooks/use-connector-status";
|
||||
import { ConnectorStatusBadge } from "./connector-status-badge";
|
||||
|
||||
|
|
@ -55,6 +56,7 @@ export const ConnectorCard: FC<ConnectorCardProps> = ({
|
|||
onManage,
|
||||
}) => {
|
||||
const isMCP = connectorType === EnumConnectorName.MCP_CONNECTOR;
|
||||
const isLive = !!connectorType && LIVE_CONNECTOR_TYPES.has(connectorType);
|
||||
// Get connector status
|
||||
const { getConnectorStatus, isConnectorEnabled, getConnectorStatusMessage, shouldShowWarnings } =
|
||||
useConnectorStatus();
|
||||
|
|
@ -123,14 +125,14 @@ export const ConnectorCard: FC<ConnectorCardProps> = ({
|
|||
</span>
|
||||
) : (
|
||||
<>
|
||||
<span>{formatDocumentCount(documentCount)}</span>
|
||||
{accountCount !== undefined && accountCount > 0 && (
|
||||
<>
|
||||
{!isLive && <span>{formatDocumentCount(documentCount)}</span>}
|
||||
{!isLive && accountCount !== undefined && accountCount > 0 && (
|
||||
<span className="text-muted-foreground/50">•</span>
|
||||
)}
|
||||
{accountCount !== undefined && accountCount > 0 && (
|
||||
<span>
|
||||
{accountCount} {accountCount === 1 ? "Account" : "Accounts"}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -53,8 +53,7 @@ export const DiscordConfig: FC<DiscordConfigProps> = ({ connector }) => {
|
|||
return () => document.removeEventListener("visibilitychange", handleVisibilityChange);
|
||||
}, [connector?.id, fetchChannels]);
|
||||
|
||||
// Separate channels by indexing capability
|
||||
const readyToIndex = channels.filter((ch) => ch.can_index);
|
||||
const accessible = channels.filter((ch) => ch.can_index);
|
||||
const needsPermissions = channels.filter((ch) => !ch.can_index);
|
||||
|
||||
// Format last fetched time
|
||||
|
|
@ -80,7 +79,7 @@ export const DiscordConfig: FC<DiscordConfigProps> = ({ connector }) => {
|
|||
</div>
|
||||
<div className="text-xs sm:text-sm">
|
||||
<p className="text-muted-foreground mt-1 text-[10px] sm:text-sm">
|
||||
The bot needs "Read Message History" permission to index channels. Ask a
|
||||
The bot needs "Read Message History" permission to access channels. Ask a
|
||||
server admin to grant this permission for channels shown below.
|
||||
</p>
|
||||
</div>
|
||||
|
|
@ -127,18 +126,18 @@ export const DiscordConfig: FC<DiscordConfigProps> = ({ connector }) => {
|
|||
</div>
|
||||
) : (
|
||||
<div className="rounded-xl bg-slate-400/5 dark:bg-white/5 overflow-hidden">
|
||||
{/* Ready to index */}
|
||||
{readyToIndex.length > 0 && (
|
||||
{/* Accessible channels */}
|
||||
{accessible.length > 0 && (
|
||||
<div className={cn("p-3", needsPermissions.length > 0 && "border-b border-border")}>
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<CheckCircle2 className="size-3.5 text-emerald-500" />
|
||||
<span className="text-[11px] font-medium">Ready to index</span>
|
||||
<span className="text-[11px] font-medium">Accessible</span>
|
||||
<span className="text-[10px] text-muted-foreground">
|
||||
{readyToIndex.length} {readyToIndex.length === 1 ? "channel" : "channels"}
|
||||
{accessible.length} {accessible.length === 1 ? "channel" : "channels"}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-1.5">
|
||||
{readyToIndex.map((channel) => (
|
||||
{accessible.map((channel) => (
|
||||
<ChannelPill key={channel.id} channel={channel} />
|
||||
))}
|
||||
</div>
|
||||
|
|
@ -150,7 +149,7 @@ export const DiscordConfig: FC<DiscordConfigProps> = ({ connector }) => {
|
|||
<div className="p-3">
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<AlertCircle className="size-3.5 text-amber-500" />
|
||||
<span className="text-[11px] font-medium">Grant permissions to index</span>
|
||||
<span className="text-[11px] font-medium">Needs permissions</span>
|
||||
<span className="text-[10px] text-muted-foreground">
|
||||
{needsPermissions.length}{" "}
|
||||
{needsPermissions.length === 1 ? "channel" : "channels"}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
"use client";
|
||||
|
||||
import { CheckCircle2 } from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import type { ConnectorConfigProps } from "../index";
|
||||
|
||||
export const MCPServiceConfig: FC<ConnectorConfigProps> = ({ connector }) => {
|
||||
const serviceName = connector.config?.mcp_service as string | undefined;
|
||||
const displayName = serviceName
|
||||
? serviceName.charAt(0).toUpperCase() + serviceName.slice(1)
|
||||
: "this service";
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="rounded-xl border border-border bg-emerald-500/5 p-4 flex items-start gap-3">
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-emerald-500/10 shrink-0 mt-0.5">
|
||||
<CheckCircle2 className="size-4 text-emerald-500" />
|
||||
</div>
|
||||
<div className="text-xs sm:text-sm">
|
||||
<p className="font-medium text-xs sm:text-sm">Connected</p>
|
||||
<p className="text-muted-foreground mt-1 text-[10px] sm:text-sm">
|
||||
Your agent can search, read, and take actions in {displayName}.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
|
@ -18,9 +18,9 @@ export const TeamsConfig: FC<TeamsConfigProps> = () => {
|
|||
<div className="text-xs sm:text-sm">
|
||||
<p className="font-medium text-xs sm:text-sm">Microsoft Teams Access</p>
|
||||
<p className="text-muted-foreground mt-1 text-[10px] sm:text-sm">
|
||||
SurfSense will index messages from Teams channels that you have access to. The app can
|
||||
only read messages from teams and channels where you are a member. Make sure you're a
|
||||
member of the teams you want to index before connecting.
|
||||
Your agent can search and read messages from Teams channels you have access to,
|
||||
and send messages on your behalf. Make sure you're a member of the teams
|
||||
you want to interact with.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -16,8 +16,9 @@ import { DateRangeSelector } from "../../components/date-range-selector";
|
|||
import { PeriodicSyncConfig } from "../../components/periodic-sync-config";
|
||||
import { SummaryConfig } from "../../components/summary-config";
|
||||
import { VisionLLMConfig } from "../../components/vision-llm-config";
|
||||
import { LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants";
|
||||
import { getConnectorDisplayName } from "../../tabs/all-connectors-tab";
|
||||
import { getConnectorConfigComponent } from "../index";
|
||||
import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index";
|
||||
|
||||
const REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
|
||||
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
|
||||
|
|
@ -118,11 +119,17 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
}
|
||||
}, [searchSpaceId, searchSpaceIdAtom, reauthEndpoint, connector.id]);
|
||||
|
||||
// Get connector-specific config component
|
||||
const ConnectorConfigComponent = useMemo(
|
||||
() => getConnectorConfigComponent(connector.connector_type),
|
||||
[connector.connector_type]
|
||||
);
|
||||
const isMCPBacked = Boolean(connector.config?.server_config);
|
||||
const isLive = isMCPBacked || LIVE_CONNECTOR_TYPES.has(connector.connector_type);
|
||||
|
||||
// Get connector-specific config component (MCP-backed connectors use a generic view)
|
||||
const ConnectorConfigComponent = useMemo(() => {
|
||||
if (isMCPBacked) {
|
||||
const { MCPServiceConfig } = require("../components/mcp-service-config");
|
||||
return MCPServiceConfig as FC<ConnectorConfigProps>;
|
||||
}
|
||||
return getConnectorConfigComponent(connector.connector_type);
|
||||
}, [connector.connector_type, isMCPBacked]);
|
||||
const [isScrolled, setIsScrolled] = useState(false);
|
||||
const [hasMoreContent, setHasMoreContent] = useState(false);
|
||||
const [showDisconnectConfirm, setShowDisconnectConfirm] = useState(false);
|
||||
|
|
@ -223,12 +230,14 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
{getConnectorDisplayName(connector.name)}
|
||||
</h2>
|
||||
<p className="text-xs sm:text-base text-muted-foreground mt-1">
|
||||
Manage your connector settings and sync configuration
|
||||
{isLive
|
||||
? "Manage your connected account"
|
||||
: "Manage your connector settings and sync configuration"}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
{/* Quick Index Button - hidden when auth is expired */}
|
||||
{connector.is_indexable && onQuickIndex && !isAuthExpired && (
|
||||
{/* Quick Index Button - hidden for live connectors and when auth is expired */}
|
||||
{connector.is_indexable && !isLive && onQuickIndex && !isAuthExpired && (
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
|
|
@ -271,8 +280,8 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
/>
|
||||
)}
|
||||
|
||||
{/* Summary and sync settings - only shown for indexable connectors */}
|
||||
{connector.is_indexable && (
|
||||
{/* Summary and sync settings - hidden for live connectors */}
|
||||
{connector.is_indexable && !isLive && (
|
||||
<>
|
||||
{/* AI Summary toggle */}
|
||||
<SummaryConfig enabled={enableSummary} onEnabledChange={onEnableSummaryChange} />
|
||||
|
|
@ -343,8 +352,8 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
</>
|
||||
)}
|
||||
|
||||
{/* Info box - only shown for indexable connectors */}
|
||||
{connector.is_indexable && (
|
||||
{/* Info box - hidden for live connectors */}
|
||||
{connector.is_indexable && !isLive && (
|
||||
<div className="rounded-xl border border-border bg-primary/5 p-4 flex items-start gap-3">
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-primary/10 shrink-0 mt-0.5">
|
||||
<Info className="size-4" />
|
||||
|
|
@ -377,7 +386,9 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
{showDisconnectConfirm ? (
|
||||
<div className="flex flex-col sm:flex-row items-stretch sm:items-center gap-3 flex-1 sm:flex-initial">
|
||||
<span className="text-xs sm:text-sm text-muted-foreground sm:whitespace-nowrap">
|
||||
Are you sure?
|
||||
{isLive
|
||||
? "Your agent will lose access to this service."
|
||||
: "This will remove all indexed data."}
|
||||
</span>
|
||||
<div className="flex items-center gap-2 sm:gap-3">
|
||||
<Button
|
||||
|
|
@ -421,7 +432,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
<RefreshCw className={cn("size-3.5", reauthing && "animate-spin")} />
|
||||
Re-authenticate
|
||||
</Button>
|
||||
) : (
|
||||
) : !isLive ? (
|
||||
<Button
|
||||
onClick={onSave}
|
||||
disabled={isSaving || isDisconnecting}
|
||||
|
|
@ -430,7 +441,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
|||
<span className={isSaving ? "opacity-0" : ""}>Save Changes</span>
|
||||
{isSaving && <Spinner size="sm" className="absolute" />}
|
||||
</Button>
|
||||
)}
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import { DateRangeSelector } from "../../components/date-range-selector";
|
|||
import { PeriodicSyncConfig } from "../../components/periodic-sync-config";
|
||||
import { SummaryConfig } from "../../components/summary-config";
|
||||
import { VisionLLMConfig } from "../../components/vision-llm-config";
|
||||
import type { IndexingConfigState } from "../../constants/connector-constants";
|
||||
import { LIVE_CONNECTOR_TYPES, type IndexingConfigState } from "../../constants/connector-constants";
|
||||
import { getConnectorDisplayName } from "../../tabs/all-connectors-tab";
|
||||
import { getConnectorConfigComponent } from "../index";
|
||||
|
||||
|
|
@ -58,6 +58,8 @@ export const IndexingConfigurationView: FC<IndexingConfigurationViewProps> = ({
|
|||
onStartIndexing,
|
||||
onSkip,
|
||||
}) => {
|
||||
const isLive = LIVE_CONNECTOR_TYPES.has(config.connectorType);
|
||||
|
||||
// Get connector-specific config component
|
||||
const ConnectorConfigComponent = useMemo(
|
||||
() => (connector ? getConnectorConfigComponent(connector.connector_type) : null),
|
||||
|
|
@ -138,7 +140,9 @@ export const IndexingConfigurationView: FC<IndexingConfigurationViewProps> = ({
|
|||
)}
|
||||
</div>
|
||||
<p className="text-xs sm:text-base text-muted-foreground mt-1">
|
||||
Configure when to start syncing your data
|
||||
{isLive
|
||||
? "Your account is ready to use"
|
||||
: "Configure when to start syncing your data"}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -157,8 +161,8 @@ export const IndexingConfigurationView: FC<IndexingConfigurationViewProps> = ({
|
|||
<ConnectorConfigComponent connector={connector} onConfigChange={onConfigChange} />
|
||||
)}
|
||||
|
||||
{/* Summary and sync settings - only shown for indexable connectors */}
|
||||
{connector?.is_indexable && (
|
||||
{/* Summary and sync settings - hidden for live connectors */}
|
||||
{connector?.is_indexable && !isLive && (
|
||||
<>
|
||||
{/* AI Summary toggle */}
|
||||
<SummaryConfig enabled={enableSummary} onEnabledChange={onEnableSummaryChange} />
|
||||
|
|
@ -209,8 +213,8 @@ export const IndexingConfigurationView: FC<IndexingConfigurationViewProps> = ({
|
|||
</>
|
||||
)}
|
||||
|
||||
{/* Info box - only shown for indexable connectors */}
|
||||
{connector?.is_indexable && (
|
||||
{/* Info box - hidden for live connectors */}
|
||||
{connector?.is_indexable && !isLive && (
|
||||
<div className="rounded-xl border border-border bg-primary/5 p-4 flex items-start gap-3">
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-lg bg-primary/10 shrink-0 mt-0.5">
|
||||
<Info className="size-4" />
|
||||
|
|
@ -238,6 +242,11 @@ export const IndexingConfigurationView: FC<IndexingConfigurationViewProps> = ({
|
|||
|
||||
{/* Fixed Footer - Action buttons */}
|
||||
<div className="flex-shrink-0 flex items-center justify-end px-6 sm:px-12 py-6 bg-muted">
|
||||
{isLive ? (
|
||||
<Button onClick={onSkip} className="text-xs sm:text-sm">
|
||||
Done
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
onClick={onStartIndexing}
|
||||
disabled={isStartingIndexing}
|
||||
|
|
@ -246,6 +255,7 @@ export const IndexingConfigurationView: FC<IndexingConfigurationViewProps> = ({
|
|||
<span className={isStartingIndexing ? "opacity-0" : ""}>Start Indexing</span>
|
||||
{isStartingIndexing && <Spinner size="sm" className="absolute" />}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,24 @@
|
|||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||
|
||||
/**
|
||||
* Connectors that operate in real time (no background indexing).
|
||||
* Used to adjust UI: hide sync controls, show "Connected" instead of doc counts.
|
||||
*/
|
||||
export const LIVE_CONNECTOR_TYPES = new Set<string>([
|
||||
EnumConnectorName.LINEAR_CONNECTOR,
|
||||
EnumConnectorName.SLACK_CONNECTOR,
|
||||
EnumConnectorName.JIRA_CONNECTOR,
|
||||
EnumConnectorName.CLICKUP_CONNECTOR,
|
||||
EnumConnectorName.AIRTABLE_CONNECTOR,
|
||||
EnumConnectorName.DISCORD_CONNECTOR,
|
||||
EnumConnectorName.TEAMS_CONNECTOR,
|
||||
EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR,
|
||||
EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
EnumConnectorName.GOOGLE_GMAIL_CONNECTOR,
|
||||
EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR,
|
||||
EnumConnectorName.LUMA_CONNECTOR,
|
||||
]);
|
||||
|
||||
// OAuth Connectors (Quick Connect)
|
||||
export const OAUTH_CONNECTORS = [
|
||||
{
|
||||
|
|
@ -13,7 +32,7 @@ export const OAUTH_CONNECTORS = [
|
|||
{
|
||||
id: "google-gmail-connector",
|
||||
title: "Gmail",
|
||||
description: "Search through your emails",
|
||||
description: "Search, read, draft, and send emails",
|
||||
connectorType: EnumConnectorName.GOOGLE_GMAIL_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/google/gmail/connector/add/",
|
||||
selfHostedOnly: true,
|
||||
|
|
@ -21,7 +40,7 @@ export const OAUTH_CONNECTORS = [
|
|||
{
|
||||
id: "google-calendar-connector",
|
||||
title: "Google Calendar",
|
||||
description: "Search through your events",
|
||||
description: "Search and manage your events",
|
||||
connectorType: EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/google/calendar/connector/add/",
|
||||
selfHostedOnly: true,
|
||||
|
|
@ -29,35 +48,35 @@ export const OAUTH_CONNECTORS = [
|
|||
{
|
||||
id: "airtable-connector",
|
||||
title: "Airtable",
|
||||
description: "Search your Airtable bases",
|
||||
description: "Browse bases, tables, and records",
|
||||
connectorType: EnumConnectorName.AIRTABLE_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/airtable/connector/add/",
|
||||
authEndpoint: "/api/v1/auth/mcp/airtable/connector/add/",
|
||||
},
|
||||
{
|
||||
id: "notion-connector",
|
||||
title: "Notion",
|
||||
description: "Search your Notion pages",
|
||||
connectorType: EnumConnectorName.NOTION_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/notion/connector/add/",
|
||||
authEndpoint: "/api/v1/auth/notion/connector/add",
|
||||
},
|
||||
{
|
||||
id: "linear-connector",
|
||||
title: "Linear",
|
||||
description: "Search issues & projects",
|
||||
description: "Search, read, and manage issues & projects",
|
||||
connectorType: EnumConnectorName.LINEAR_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/linear/connector/add/",
|
||||
authEndpoint: "/api/v1/auth/mcp/linear/connector/add/",
|
||||
},
|
||||
{
|
||||
id: "slack-connector",
|
||||
title: "Slack",
|
||||
description: "Search Slack messages",
|
||||
description: "Search and read channels and threads",
|
||||
connectorType: EnumConnectorName.SLACK_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/slack/connector/add/",
|
||||
authEndpoint: "/api/v1/auth/mcp/slack/connector/add/",
|
||||
},
|
||||
{
|
||||
id: "teams-connector",
|
||||
title: "Microsoft Teams",
|
||||
description: "Search Teams messages",
|
||||
description: "Search, read, and send messages",
|
||||
connectorType: EnumConnectorName.TEAMS_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/teams/connector/add/",
|
||||
},
|
||||
|
|
@ -78,16 +97,16 @@ export const OAUTH_CONNECTORS = [
|
|||
{
|
||||
id: "discord-connector",
|
||||
title: "Discord",
|
||||
description: "Search Discord messages",
|
||||
description: "Search, read, and send messages",
|
||||
connectorType: EnumConnectorName.DISCORD_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/discord/connector/add/",
|
||||
},
|
||||
{
|
||||
id: "jira-connector",
|
||||
title: "Jira",
|
||||
description: "Search Jira issues",
|
||||
description: "Search, read, and manage issues",
|
||||
connectorType: EnumConnectorName.JIRA_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/jira/connector/add/",
|
||||
authEndpoint: "/api/v1/auth/mcp/jira/connector/add/",
|
||||
},
|
||||
{
|
||||
id: "confluence-connector",
|
||||
|
|
@ -99,9 +118,9 @@ export const OAUTH_CONNECTORS = [
|
|||
{
|
||||
id: "clickup-connector",
|
||||
title: "ClickUp",
|
||||
description: "Search ClickUp tasks",
|
||||
description: "Search and read tasks",
|
||||
connectorType: EnumConnectorName.CLICKUP_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/clickup/connector/add/",
|
||||
authEndpoint: "/api/v1/auth/mcp/clickup/connector/add/",
|
||||
},
|
||||
] as const;
|
||||
|
||||
|
|
@ -138,7 +157,7 @@ export const OTHER_CONNECTORS = [
|
|||
{
|
||||
id: "luma-connector",
|
||||
title: "Luma",
|
||||
description: "Search Luma events",
|
||||
description: "Browse, read, and create events",
|
||||
connectorType: EnumConnectorName.LUMA_CONNECTOR,
|
||||
},
|
||||
{
|
||||
|
|
@ -197,14 +216,14 @@ export const COMPOSIO_CONNECTORS = [
|
|||
{
|
||||
id: "composio-gmail",
|
||||
title: "Gmail",
|
||||
description: "Search through your emails via Composio",
|
||||
description: "Search, read, draft, and send emails via Composio",
|
||||
connectorType: EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/composio/connector/add/?toolkit_id=gmail",
|
||||
},
|
||||
{
|
||||
id: "composio-googlecalendar",
|
||||
title: "Google Calendar",
|
||||
description: "Search through your events via Composio",
|
||||
description: "Search and manage your events via Composio",
|
||||
connectorType: EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
authEndpoint: "/api/v1/auth/composio/connector/add/?toolkit_id=googlecalendar",
|
||||
},
|
||||
|
|
@ -221,14 +240,14 @@ export const COMPOSIO_TOOLKITS = [
|
|||
{
|
||||
id: "gmail",
|
||||
name: "Gmail",
|
||||
description: "Search through your emails",
|
||||
isIndexable: true,
|
||||
description: "Search, read, draft, and send emails",
|
||||
isIndexable: false,
|
||||
},
|
||||
{
|
||||
id: "googlecalendar",
|
||||
name: "Google Calendar",
|
||||
description: "Search through your events",
|
||||
isIndexable: true,
|
||||
description: "Search and manage your events",
|
||||
isIndexable: false,
|
||||
},
|
||||
{
|
||||
id: "slack",
|
||||
|
|
@ -258,66 +277,6 @@ export interface AutoIndexConfig {
|
|||
}
|
||||
|
||||
export const AUTO_INDEX_DEFAULTS: Record<string, AutoIndexConfig> = {
|
||||
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: {
|
||||
daysBack: 30,
|
||||
daysForward: 0,
|
||||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing your last 30 days of emails.",
|
||||
},
|
||||
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: {
|
||||
daysBack: 30,
|
||||
daysForward: 0,
|
||||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing your last 30 days of emails.",
|
||||
},
|
||||
[EnumConnectorName.SLACK_CONNECTOR]: {
|
||||
daysBack: 30,
|
||||
daysForward: 0,
|
||||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing your last 30 days of messages.",
|
||||
},
|
||||
[EnumConnectorName.DISCORD_CONNECTOR]: {
|
||||
daysBack: 30,
|
||||
daysForward: 0,
|
||||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing your last 30 days of messages.",
|
||||
},
|
||||
[EnumConnectorName.TEAMS_CONNECTOR]: {
|
||||
daysBack: 30,
|
||||
daysForward: 0,
|
||||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing your last 30 days of messages.",
|
||||
},
|
||||
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: {
|
||||
daysBack: 90,
|
||||
daysForward: 90,
|
||||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing 90 days of past and upcoming events.",
|
||||
},
|
||||
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: {
|
||||
daysBack: 90,
|
||||
daysForward: 90,
|
||||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing 90 days of past and upcoming events.",
|
||||
},
|
||||
[EnumConnectorName.LINEAR_CONNECTOR]: {
|
||||
daysBack: 90,
|
||||
daysForward: 0,
|
||||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing your last 90 days of issues.",
|
||||
},
|
||||
[EnumConnectorName.JIRA_CONNECTOR]: {
|
||||
daysBack: 90,
|
||||
daysForward: 0,
|
||||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing your last 90 days of issues.",
|
||||
},
|
||||
[EnumConnectorName.CLICKUP_CONNECTOR]: {
|
||||
daysBack: 90,
|
||||
daysForward: 0,
|
||||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing your last 90 days of tasks.",
|
||||
},
|
||||
[EnumConnectorName.NOTION_CONNECTOR]: {
|
||||
daysBack: 365,
|
||||
daysForward: 0,
|
||||
|
|
@ -330,12 +289,6 @@ export const AUTO_INDEX_DEFAULTS: Record<string, AutoIndexConfig> = {
|
|||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing your documentation.",
|
||||
},
|
||||
[EnumConnectorName.AIRTABLE_CONNECTOR]: {
|
||||
daysBack: 365,
|
||||
daysForward: 0,
|
||||
frequencyMinutes: 1440,
|
||||
syncDescription: "Syncing your bases.",
|
||||
},
|
||||
};
|
||||
|
||||
export const AUTO_INDEX_CONNECTOR_TYPES = new Set<string>(Object.keys(AUTO_INDEX_DEFAULTS));
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ import {
|
|||
AUTO_INDEX_CONNECTOR_TYPES,
|
||||
AUTO_INDEX_DEFAULTS,
|
||||
COMPOSIO_CONNECTORS,
|
||||
LIVE_CONNECTOR_TYPES,
|
||||
OAUTH_CONNECTORS,
|
||||
OTHER_CONNECTORS,
|
||||
} from "../constants/connector-constants";
|
||||
|
|
@ -317,7 +318,12 @@ export const useConnectorDialog = () => {
|
|||
newConnector.id
|
||||
);
|
||||
|
||||
if (
|
||||
const isLiveConnector = LIVE_CONNECTOR_TYPES.has(oauthConnector.connectorType);
|
||||
|
||||
if (isLiveConnector) {
|
||||
toast.success(`${oauthConnector.title} connected successfully!`);
|
||||
await refetchAllConnectors();
|
||||
} else if (
|
||||
newConnector.is_indexable &&
|
||||
AUTO_INDEX_CONNECTOR_TYPES.has(oauthConnector.connectorType)
|
||||
) {
|
||||
|
|
@ -326,6 +332,9 @@ export const useConnectorDialog = () => {
|
|||
oauthConnector.title,
|
||||
oauthConnector.connectorType
|
||||
);
|
||||
} else if (!newConnector.is_indexable) {
|
||||
toast.success(`${oauthConnector.title} connected successfully!`);
|
||||
await refetchAllConnectors();
|
||||
} else {
|
||||
toast.dismiss("auto-index");
|
||||
const config = validateIndexingConfigState({
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
|||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||
import { getDocumentTypeLabel } from "@/lib/documents/document-type-labels";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { COMPOSIO_CONNECTORS, OAUTH_CONNECTORS } from "../constants/connector-constants";
|
||||
import { COMPOSIO_CONNECTORS, LIVE_CONNECTOR_TYPES, OAUTH_CONNECTORS } from "../constants/connector-constants";
|
||||
import { getDocumentCountForConnector } from "../utils/connector-document-mapping";
|
||||
import { getConnectorDisplayName } from "./all-connectors-tab";
|
||||
|
||||
|
|
@ -156,6 +156,7 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
{/* OAuth Connectors - Grouped by Type */}
|
||||
{filteredOAuthConnectorTypes.map(([connectorType, typeConnectors]) => {
|
||||
const { title } = getOAuthConnectorTypeInfo(connectorType);
|
||||
const isLive = LIVE_CONNECTOR_TYPES.has(connectorType);
|
||||
const isAnyIndexing = typeConnectors.some((c: SearchSourceConnector) =>
|
||||
indexingConnectorIds.has(c.id)
|
||||
);
|
||||
|
|
@ -202,8 +203,12 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
</p>
|
||||
) : (
|
||||
<p className="text-[10px] text-muted-foreground mt-1 flex items-center gap-1.5">
|
||||
{!isLive && (
|
||||
<>
|
||||
<span>{formatDocumentCount(documentCount)}</span>
|
||||
<span className="text-muted-foreground/50">•</span>
|
||||
</>
|
||||
)}
|
||||
<span>
|
||||
{accountCount} {accountCount === 1 ? "Account" : "Accounts"}
|
||||
</span>
|
||||
|
|
@ -230,6 +235,7 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
documentTypeCounts
|
||||
);
|
||||
const isMCPConnector = connector.connector_type === "MCP_CONNECTOR";
|
||||
const isLive = LIVE_CONNECTOR_TYPES.has(connector.connector_type);
|
||||
return (
|
||||
<div
|
||||
key={`connector-${connector.id}`}
|
||||
|
|
@ -261,7 +267,7 @@ export const ActiveConnectorsTab: FC<ActiveConnectorsTabProps> = ({
|
|||
<Spinner size="xs" />
|
||||
Syncing
|
||||
</p>
|
||||
) : !isMCPConnector ? (
|
||||
) : !isLive && !isMCPConnector ? (
|
||||
<p className="text-[10px] text-muted-foreground mt-1">
|
||||
{formatDocumentCount(documentCount)}
|
||||
</p>
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
|||
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||
import { formatRelativeDate } from "@/lib/format-date";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants";
|
||||
import { useConnectorStatus } from "../hooks/use-connector-status";
|
||||
import { getConnectorDisplayName } from "../tabs/all-connectors-tab";
|
||||
|
||||
|
|
@ -43,12 +44,8 @@ interface ConnectorAccountsListViewProps {
|
|||
addButtonText?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a connector type is indexable
|
||||
*/
|
||||
function isIndexableConnector(connectorType: string): boolean {
|
||||
const nonIndexableTypes = ["MCP_CONNECTOR"];
|
||||
return !nonIndexableTypes.includes(connectorType);
|
||||
function isLiveConnector(connectorType: string): boolean {
|
||||
return LIVE_CONNECTOR_TYPES.has(connectorType) || connectorType === "MCP_CONNECTOR";
|
||||
}
|
||||
|
||||
export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||
|
|
@ -149,7 +146,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
{connectorTitle}
|
||||
</h2>
|
||||
<p className="text-xs sm:text-base text-muted-foreground mt-1">
|
||||
{statusMessage || "Manage your connector settings and sync configuration"}
|
||||
{statusMessage || "Manage your connected accounts"}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
|
@ -234,15 +231,13 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
|||
<Spinner size="xs" />
|
||||
Syncing
|
||||
</p>
|
||||
) : (
|
||||
<p className="text-[10px] text-muted-foreground mt-1 whitespace-nowrap truncate">
|
||||
{isIndexableConnector(connector.connector_type)
|
||||
? connector.last_indexed_at
|
||||
) : !isLiveConnector(connector.connector_type) ? (
|
||||
<p className="text-[10px] mt-1 whitespace-nowrap truncate text-muted-foreground">
|
||||
{connector.last_indexed_at
|
||||
? `Last indexed: ${formatRelativeDate(connector.last_indexed_at)}`
|
||||
: "Never indexed"
|
||||
: "Active"}
|
||||
: "Never indexed"}
|
||||
</p>
|
||||
)}
|
||||
) : null}
|
||||
</div>
|
||||
{isAuthExpired ? (
|
||||
<Button
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue