add account metadata discovery and connected_accounts tool

This commit is contained in:
CREDO23 2026-04-22 18:57:26 +02:00
parent a4bc621c2a
commit 9eb54bc4af
3 changed files with 261 additions and 11 deletions

View file

@ -0,0 +1,109 @@
"""Connected-accounts discovery tool.
Lets the LLM discover which accounts are connected for a given service
(e.g. "jira", "linear", "slack") and retrieve the metadata it needs to
call action tools such as Jira's ``cloudId``.
The tool returns **only** non-sensitive fields explicitly listed in the
service's ``account_metadata_keys`` (see ``registry.py``), plus the
always-present ``display_name`` and ``connector_id``.
"""
import logging
from typing import Any
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db import SearchSourceConnector, SearchSourceConnectorType
from app.services.mcp_oauth.registry import MCP_SERVICES
logger = logging.getLogger(__name__)
_SERVICE_KEY_BY_CONNECTOR_TYPE: dict[str, str] = {
cfg.connector_type: key for key, cfg in MCP_SERVICES.items()
}
class GetConnectedAccountsInput(BaseModel):
service: str = Field(
description=(
"Service key to look up connected accounts for. "
"Valid values: " + ", ".join(sorted(MCP_SERVICES.keys()))
),
)
def _extract_display_name(connector: SearchSourceConnector) -> str:
"""Best-effort human-readable label for a connector."""
cfg = connector.config or {}
if cfg.get("display_name"):
return cfg["display_name"]
if cfg.get("base_url"):
return f"{connector.name} ({cfg['base_url']})"
if cfg.get("organization_name"):
return f"{connector.name} ({cfg['organization_name']})"
return connector.name
def create_get_connected_accounts_tool(
db_session: AsyncSession,
search_space_id: int,
user_id: str,
) -> StructuredTool:
async def _run(service: str) -> list[dict[str, Any]]:
svc_cfg = MCP_SERVICES.get(service)
if not svc_cfg:
return [{"error": f"Unknown service '{service}'. Valid: {', '.join(sorted(MCP_SERVICES.keys()))}"}]
try:
connector_type = SearchSourceConnectorType(svc_cfg.connector_type)
except ValueError:
return [{"error": f"Connector type '{svc_cfg.connector_type}' not found."}]
result = await db_session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.search_space_id == search_space_id,
SearchSourceConnector.user_id == user_id,
SearchSourceConnector.connector_type == connector_type,
)
)
connectors = result.scalars().all()
if not connectors:
return [{"error": f"No {svc_cfg.name} accounts connected. Ask the user to connect one in settings."}]
is_multi = len(connectors) > 1
accounts: list[dict[str, Any]] = []
for conn in connectors:
cfg = conn.config or {}
entry: dict[str, Any] = {
"connector_id": conn.id,
"display_name": _extract_display_name(conn),
"service": service,
}
if is_multi:
entry["tool_prefix"] = f"{service}_{conn.id}"
for key in svc_cfg.account_metadata_keys:
if key in cfg:
entry[key] = cfg[key]
accounts.append(entry)
return accounts
return StructuredTool(
name="get_connected_accounts",
description=(
"Discover which accounts are connected for a service (e.g. jira, linear, slack, clickup, airtable). "
"Returns display names and service-specific metadata the action tools need "
"(e.g. Jira's cloudId). Call this BEFORE using a service's action tools when "
"you need an account identifier or are unsure which account to use."
),
coroutine=_run,
args_schema=GetConnectedAccountsInput,
metadata={"hitl": False},
)

View file

@ -1,13 +1,15 @@
"""Generic MCP OAuth 2.1 route for services with official MCP servers.
Handles the full flow: discovery DCR PKCE authorization token exchange
MCP_CONNECTOR creation. Currently supports Linear, Jira, and ClickUp.
MCP_CONNECTOR creation. Currently supports Linear, Jira, ClickUp, Slack,
and Airtable.
"""
from __future__ import annotations
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
from urllib.parse import urlencode
from uuid import UUID
@ -33,6 +35,70 @@ logger = logging.getLogger(__name__)
router = APIRouter()
async def _fetch_account_metadata(
service_key: str, access_token: str, token_json: dict[str, Any],
) -> dict[str, Any]:
"""Fetch display-friendly account metadata after a successful token exchange.
DCR services (Linear, Jira, ClickUp) issue MCP-scoped tokens that cannot
call their standard REST/GraphQL APIs metadata discovery for those
happens at runtime through MCP tools instead.
Pre-configured services (Slack, Airtable) use standard OAuth tokens that
*can* call their APIs, so we extract metadata here.
Failures are logged but never block connector creation.
"""
from app.services.mcp_oauth.registry import MCP_SERVICES
svc = MCP_SERVICES.get(service_key)
if not svc or svc.supports_dcr:
return {}
import httpx
meta: dict[str, Any] = {}
try:
if service_key == "slack":
team_info = token_json.get("team", {})
meta["team_id"] = team_info.get("id", "")
# TODO: oauth.v2.user.access only returns team.id, not
# team.name. To populate team_name, add "team:read" scope
# and call GET /api/team.info here.
meta["team_name"] = team_info.get("name", "")
if meta["team_name"]:
meta["display_name"] = meta["team_name"]
elif meta["team_id"]:
meta["display_name"] = f"Slack ({meta['team_id']})"
elif service_key == "airtable":
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.get(
"https://api.airtable.com/v0/meta/whoami",
headers={"Authorization": f"Bearer {access_token}"},
)
if resp.status_code != 200:
logger.warning(
"Airtable whoami API response: status=%s body=%s",
resp.status_code, resp.text[:300],
)
if resp.status_code == 200:
whoami = resp.json()
meta["user_id"] = whoami.get("id", "")
meta["user_email"] = whoami.get("email", "")
meta["display_name"] = whoami.get("email", "Airtable")
except Exception:
logger.warning(
"Failed to fetch account metadata for %s (non-blocking)",
service_key,
exc_info=True,
)
return meta
_state_manager: OAuthStateManager | None = None
_token_encryption: TokenEncryption | None = None
@ -295,6 +361,14 @@ async def mcp_oauth_callback(
"_token_encrypted": True,
}
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
if account_meta:
connector_config.update(account_meta)
logger.info(
"Stored account metadata for %s: display_name=%s",
svc_key, account_meta.get("display_name", ""),
)
# ---- Re-auth path ----
db_connector_type = SearchSourceConnectorType(svc.connector_type)
reauth_connector_id = data.get("connector_id")
@ -335,12 +409,13 @@ async def mcp_oauth_callback(
)
# ---- New connector path ----
naming_identifier = account_meta.get("display_name")
connector_name = await generate_unique_connector_name(
session,
db_connector_type,
space_id,
user_id,
svc.name,
naming_identifier,
)
new_connector = SearchSourceConnector(

View file

@ -4,6 +4,12 @@ Each entry maps a URL-safe service key to its MCP server endpoint and
authentication configuration. Services with ``supports_dcr=True`` use
RFC 7591 Dynamic Client Registration (the MCP server issues its own
credentials); the rest use pre-configured credentials via env vars.
``allowed_tools`` whitelists which MCP tools to expose to the agent.
An empty list means "load every tool the server advertises" (used for
user-managed generic MCP servers). Service-specific entries should
curate this list to keep the agent's tool count low and selection
accuracy high.
"""
from __future__ import annotations
@ -24,6 +30,17 @@ class MCPServiceConfig:
scope_param: str = "scope"
auth_endpoint_override: str | None = None
token_endpoint_override: str | None = None
allowed_tools: list[str] = field(default_factory=list)
readonly_tools: frozenset[str] = field(default_factory=frozenset)
account_metadata_keys: list[str] = field(default_factory=list)
"""``connector.config`` keys exposed by ``get_connected_accounts``.
Only listed keys are returned to the LLM tokens and secrets are
never included. Every service should at least have its
``display_name`` populated during OAuth; additional service-specific
fields (e.g. Jira ``cloud_id``) are listed here so the LLM can pass
them to action tools.
"""
MCP_SERVICES: dict[str, MCPServiceConfig] = {
@ -31,16 +48,44 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = {
name="Linear",
mcp_url="https://mcp.linear.app/mcp",
connector_type="LINEAR_CONNECTOR",
allowed_tools=[
"list_issues",
"get_issue",
"save_issue",
],
readonly_tools=frozenset({"list_issues", "get_issue"}),
account_metadata_keys=["organization_name", "organization_url_key"],
),
"jira": MCPServiceConfig(
name="Jira",
mcp_url="https://mcp.atlassian.com/v1/mcp",
connector_type="JIRA_CONNECTOR",
allowed_tools=[
"getAccessibleAtlassianResources",
"searchJiraIssuesUsingJql",
"getVisibleJiraProjects",
"getJiraProjectIssueTypesMetadata",
"createJiraIssue",
"editJiraIssue",
],
readonly_tools=frozenset({
"getAccessibleAtlassianResources",
"searchJiraIssuesUsingJql",
"getVisibleJiraProjects",
"getJiraProjectIssueTypesMetadata",
}),
account_metadata_keys=["cloud_id", "site_name", "base_url"],
),
"clickup": MCPServiceConfig(
name="ClickUp",
mcp_url="https://mcp.clickup.com/mcp",
connector_type="CLICKUP_CONNECTOR",
allowed_tools=[
"clickup_search",
"clickup_get_task",
],
readonly_tools=frozenset({"clickup_search", "clickup_get_task"}),
account_metadata_keys=["workspace_id", "workspace_name"],
),
"slack": MCPServiceConfig(
name="Slack",
@ -49,17 +94,22 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = {
supports_dcr=False,
client_id_env="SLACK_CLIENT_ID",
client_secret_env="SLACK_CLIENT_SECRET",
scope_param="user_scope",
auth_endpoint_override="https://slack.com/oauth/v2/authorize",
token_endpoint_override="https://slack.com/api/oauth.v2.access",
auth_endpoint_override="https://slack.com/oauth/v2_user/authorize",
token_endpoint_override="https://slack.com/api/oauth.v2.user.access",
scopes=[
"search:read.public", "search:read.private", "search:read.mpim",
"search:read.im", "search:read.files", "search:read.users",
"chat:write",
"search:read.public", "search:read.private", "search:read.mpim", "search:read.im",
"channels:history", "groups:history", "mpim:history", "im:history",
"canvases:read", "canvases:write",
"users:read", "users:read.email",
],
allowed_tools=[
"slack_search_channels",
"slack_read_channel",
"slack_read_thread",
],
readonly_tools=frozenset({"slack_search_channels", "slack_read_channel", "slack_read_thread"}),
# TODO: oauth.v2.user.access only returns team.id, not team.name.
# To populate team_name, either add "team:read" scope and call
# GET /api/team.info during OAuth callback, or switch to oauth.v2.access.
account_metadata_keys=["team_id", "team_name"],
),
"airtable": MCPServiceConfig(
name="Airtable",
@ -69,10 +119,26 @@ MCP_SERVICES: dict[str, MCPServiceConfig] = {
oauth_discovery_origin="https://airtable.com",
client_id_env="AIRTABLE_CLIENT_ID",
client_secret_env="AIRTABLE_CLIENT_SECRET",
scopes=["data.records:read", "data.records:write", "schema.bases:read", "schema.bases:write"],
scopes=["data.records:read", "schema.bases:read"],
allowed_tools=[
"list_bases",
"list_tables_for_base",
"list_records_for_table",
],
readonly_tools=frozenset({"list_bases", "list_tables_for_base", "list_records_for_table"}),
account_metadata_keys=["user_id", "user_email"],
),
}
_CONNECTOR_TYPE_TO_SERVICE: dict[str, MCPServiceConfig] = {
svc.connector_type: svc for svc in MCP_SERVICES.values()
}
def get_service(key: str) -> MCPServiceConfig | None:
return MCP_SERVICES.get(key)
def get_service_by_connector_type(connector_type: str) -> MCPServiceConfig | None:
"""Look up an MCP service config by its ``connector_type`` enum value."""
return _CONNECTOR_TYPE_TO_SERVICE.get(connector_type)