mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-11 08:42:39 +02:00
fix: reactive 401 recovery for live MCP connectors and unified reauth endpoints
This commit is contained in:
parent
16f47578d7
commit
e3172dc282
4 changed files with 396 additions and 178 deletions
|
|
@ -194,6 +194,31 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
|
|
||||||
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
|
input_model = _create_dynamic_input_model_from_schema(exposed_name, input_schema)
|
||||||
|
|
||||||
|
async def _do_mcp_call(
|
||||||
|
call_headers: dict[str, str],
|
||||||
|
call_kwargs: dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""Execute a single MCP HTTP call with the given headers."""
|
||||||
|
async with (
|
||||||
|
streamablehttp_client(url, headers=call_headers) as (read, write, _),
|
||||||
|
ClientSession(read, write) as session,
|
||||||
|
):
|
||||||
|
await session.initialize()
|
||||||
|
response = await session.call_tool(
|
||||||
|
original_tool_name, arguments=call_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for content in response.content:
|
||||||
|
if hasattr(content, "text"):
|
||||||
|
result.append(content.text)
|
||||||
|
elif hasattr(content, "data"):
|
||||||
|
result.append(str(content.data))
|
||||||
|
else:
|
||||||
|
result.append(str(content))
|
||||||
|
|
||||||
|
return "\n".join(result) if result else ""
|
||||||
|
|
||||||
async def mcp_http_tool_call(**kwargs) -> str:
|
async def mcp_http_tool_call(**kwargs) -> str:
|
||||||
"""Execute the MCP tool call via HTTP transport."""
|
"""Execute the MCP tool call via HTTP transport."""
|
||||||
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
logger.debug("MCP HTTP tool '%s' called", exposed_name)
|
||||||
|
|
@ -218,31 +243,46 @@ async def _create_mcp_tool_from_definition_http(
|
||||||
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
call_kwargs = {k: v for k, v in hitl_result.params.items() if v is not None}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with (
|
result_str = await _do_mcp_call(headers, call_kwargs)
|
||||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
|
||||||
ClientSession(read, write) as session,
|
return result_str
|
||||||
):
|
|
||||||
await session.initialize()
|
except Exception as first_err:
|
||||||
response = await session.call_tool(
|
if not _is_auth_error(first_err) or connector_id is None:
|
||||||
original_tool_name, arguments=call_kwargs,
|
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, first_err)
|
||||||
|
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {first_err!s}"
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"MCP HTTP tool '%s' got 401 — attempting token refresh for connector %s",
|
||||||
|
exposed_name, connector_id,
|
||||||
|
)
|
||||||
|
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
||||||
|
if fresh_headers is None:
|
||||||
|
await _mark_connector_auth_expired(connector_id)
|
||||||
|
return (
|
||||||
|
f"Error: MCP tool '{exposed_name}' authentication expired. "
|
||||||
|
"Please re-authenticate the connector in your settings."
|
||||||
)
|
)
|
||||||
|
|
||||||
result = []
|
try:
|
||||||
for content in response.content:
|
result_str = await _do_mcp_call(fresh_headers, call_kwargs)
|
||||||
if hasattr(content, "text"):
|
logger.info(
|
||||||
result.append(content.text)
|
"MCP HTTP tool '%s' succeeded after 401 recovery",
|
||||||
elif hasattr(content, "data"):
|
exposed_name,
|
||||||
result.append(str(content.data))
|
)
|
||||||
else:
|
|
||||||
result.append(str(content))
|
|
||||||
|
|
||||||
result_str = "\n".join(result) if result else ""
|
|
||||||
logger.debug("MCP HTTP tool '%s' succeeded (len=%d)", exposed_name, len(result_str))
|
|
||||||
return result_str
|
return result_str
|
||||||
|
except Exception as retry_err:
|
||||||
except Exception as e:
|
logger.exception(
|
||||||
logger.exception("MCP HTTP tool '%s' execution failed: %s", exposed_name, e)
|
"MCP HTTP tool '%s' still failing after token refresh: %s",
|
||||||
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {e!s}"
|
exposed_name, retry_err,
|
||||||
|
)
|
||||||
|
if _is_auth_error(retry_err):
|
||||||
|
await _mark_connector_auth_expired(connector_id)
|
||||||
|
return (
|
||||||
|
f"Error: MCP tool '{exposed_name}' authentication expired. "
|
||||||
|
"Please re-authenticate the connector in your settings."
|
||||||
|
)
|
||||||
|
return f"Error: MCP HTTP tool '{exposed_name}' execution failed: {retry_err!s}"
|
||||||
|
|
||||||
tool = StructuredTool(
|
tool = StructuredTool(
|
||||||
name=exposed_name,
|
name=exposed_name,
|
||||||
|
|
@ -365,66 +405,98 @@ async def _load_http_mcp_tools(
|
||||||
|
|
||||||
allowed_set = set(allowed_tools) if allowed_tools else None
|
allowed_set = set(allowed_tools) if allowed_tools else None
|
||||||
|
|
||||||
try:
|
async def _discover(disc_headers: dict[str, str]) -> list[dict[str, Any]]:
|
||||||
|
"""Connect, initialize, and list tools from the MCP server."""
|
||||||
async with (
|
async with (
|
||||||
streamablehttp_client(url, headers=headers) as (read, write, _),
|
streamablehttp_client(url, headers=disc_headers) as (read, write, _),
|
||||||
ClientSession(read, write) as session,
|
ClientSession(read, write) as session,
|
||||||
):
|
):
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
response = await session.list_tools()
|
response = await session.list_tools()
|
||||||
tool_definitions = []
|
return [
|
||||||
for tool in response.tools:
|
{
|
||||||
tool_definitions.append(
|
"name": tool.name,
|
||||||
{
|
"description": tool.description or "",
|
||||||
"name": tool.name,
|
"input_schema": tool.inputSchema
|
||||||
"description": tool.description or "",
|
if hasattr(tool, "inputSchema")
|
||||||
"input_schema": tool.inputSchema
|
else {},
|
||||||
if hasattr(tool, "inputSchema")
|
}
|
||||||
else {},
|
for tool in response.tools
|
||||||
}
|
]
|
||||||
)
|
|
||||||
|
|
||||||
total_discovered = len(tool_definitions)
|
try:
|
||||||
|
tool_definitions = await _discover(headers)
|
||||||
|
except Exception as first_err:
|
||||||
|
if not _is_auth_error(first_err) or connector_id is None:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
||||||
|
url, connector_id, first_err,
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
if allowed_set:
|
logger.warning(
|
||||||
tool_definitions = [
|
"HTTP MCP discovery for connector %d got 401 — attempting token refresh",
|
||||||
td for td in tool_definitions if td["name"] in allowed_set
|
connector_id,
|
||||||
]
|
|
||||||
logger.info(
|
|
||||||
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
|
|
||||||
url, connector_id, len(tool_definitions), total_discovered,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
|
|
||||||
total_discovered, url, connector_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
for tool_def in tool_definitions:
|
|
||||||
try:
|
|
||||||
tool = await _create_mcp_tool_from_definition_http(
|
|
||||||
tool_def,
|
|
||||||
url,
|
|
||||||
headers,
|
|
||||||
connector_name=connector_name,
|
|
||||||
connector_id=connector_id,
|
|
||||||
trusted_tools=trusted_tools,
|
|
||||||
readonly_tools=readonly_tools,
|
|
||||||
tool_name_prefix=tool_name_prefix,
|
|
||||||
)
|
|
||||||
tools.append(tool)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to create HTTP tool '%s' from connector %d: %s",
|
|
||||||
tool_def.get("name"), connector_id, e,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to connect to HTTP MCP server at '%s' (connector %d): %s",
|
|
||||||
url, connector_id, e,
|
|
||||||
)
|
)
|
||||||
|
fresh_headers = await _force_refresh_and_get_headers(connector_id)
|
||||||
|
if fresh_headers is None:
|
||||||
|
await _mark_connector_auth_expired(connector_id)
|
||||||
|
logger.error(
|
||||||
|
"HTTP MCP discovery for connector %d: token refresh failed, marking auth_expired",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_definitions = await _discover(fresh_headers)
|
||||||
|
headers = fresh_headers
|
||||||
|
logger.info(
|
||||||
|
"HTTP MCP discovery for connector %d succeeded after 401 recovery",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
except Exception as retry_err:
|
||||||
|
logger.exception(
|
||||||
|
"HTTP MCP discovery for connector %d still failing after refresh: %s",
|
||||||
|
connector_id, retry_err,
|
||||||
|
)
|
||||||
|
if _is_auth_error(retry_err):
|
||||||
|
await _mark_connector_auth_expired(connector_id)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
total_discovered = len(tool_definitions)
|
||||||
|
|
||||||
|
if allowed_set:
|
||||||
|
tool_definitions = [
|
||||||
|
td for td in tool_definitions if td["name"] in allowed_set
|
||||||
|
]
|
||||||
|
logger.info(
|
||||||
|
"HTTP MCP server '%s' (connector %d): %d/%d tools after allowlist filter",
|
||||||
|
url, connector_id, len(tool_definitions), total_discovered,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Discovered %d tools from HTTP MCP server '%s' (connector %d) — no allowlist, loading all",
|
||||||
|
total_discovered, url, connector_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
for tool_def in tool_definitions:
|
||||||
|
try:
|
||||||
|
tool = await _create_mcp_tool_from_definition_http(
|
||||||
|
tool_def,
|
||||||
|
url,
|
||||||
|
headers,
|
||||||
|
connector_name=connector_name,
|
||||||
|
connector_id=connector_id,
|
||||||
|
trusted_tools=trusted_tools,
|
||||||
|
readonly_tools=readonly_tools,
|
||||||
|
tool_name_prefix=tool_name_prefix,
|
||||||
|
)
|
||||||
|
tools.append(tool)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to create HTTP tool '%s' from connector %d: %s",
|
||||||
|
tool_def.get("name"), connector_id, e,
|
||||||
|
)
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
|
@ -476,6 +548,91 @@ def _inject_oauth_headers(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _refresh_connector_token(
|
||||||
|
session: AsyncSession,
|
||||||
|
connector: "SearchSourceConnector",
|
||||||
|
) -> str | None:
|
||||||
|
"""Refresh the OAuth token for an MCP connector and persist the result.
|
||||||
|
|
||||||
|
This is the shared core used by both proactive (pre-expiry) and reactive
|
||||||
|
(401 recovery) refresh paths. It handles:
|
||||||
|
- Decrypting the current refresh token / client secret
|
||||||
|
- Calling the token endpoint
|
||||||
|
- Encrypting and persisting the new tokens
|
||||||
|
- Clearing ``auth_expired`` if it was set
|
||||||
|
- Invalidating the MCP tools cache
|
||||||
|
|
||||||
|
Returns the **plaintext** new access token on success, or ``None`` on
|
||||||
|
failure (no refresh token, IdP error, etc.).
|
||||||
|
"""
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
from app.services.mcp_oauth.discovery import refresh_access_token
|
||||||
|
|
||||||
|
cfg = connector.config or {}
|
||||||
|
mcp_oauth = cfg.get("mcp_oauth", {})
|
||||||
|
|
||||||
|
refresh_token = mcp_oauth.get("refresh_token")
|
||||||
|
if not refresh_token:
|
||||||
|
logger.warning(
|
||||||
|
"MCP connector %s: no refresh_token available",
|
||||||
|
connector.id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
enc = _get_token_enc()
|
||||||
|
decrypted_refresh = enc.decrypt_token(refresh_token)
|
||||||
|
decrypted_secret = (
|
||||||
|
enc.decrypt_token(mcp_oauth["client_secret"])
|
||||||
|
if mcp_oauth.get("client_secret")
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
token_json = await refresh_access_token(
|
||||||
|
token_endpoint=mcp_oauth["token_endpoint"],
|
||||||
|
refresh_token=decrypted_refresh,
|
||||||
|
client_id=mcp_oauth["client_id"],
|
||||||
|
client_secret=decrypted_secret,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_access = token_json.get("access_token")
|
||||||
|
if not new_access:
|
||||||
|
logger.warning(
|
||||||
|
"MCP connector %s: token refresh returned no access_token",
|
||||||
|
connector.id,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
new_expires_at = None
|
||||||
|
if token_json.get("expires_in"):
|
||||||
|
new_expires_at = datetime.now(UTC) + timedelta(
|
||||||
|
seconds=int(token_json["expires_in"])
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_oauth = dict(mcp_oauth)
|
||||||
|
updated_oauth["access_token"] = enc.encrypt_token(new_access)
|
||||||
|
if token_json.get("refresh_token"):
|
||||||
|
updated_oauth["refresh_token"] = enc.encrypt_token(
|
||||||
|
token_json["refresh_token"]
|
||||||
|
)
|
||||||
|
updated_oauth["expires_at"] = (
|
||||||
|
new_expires_at.isoformat() if new_expires_at else None
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_cfg = {**cfg, "mcp_oauth": updated_oauth}
|
||||||
|
updated_cfg.pop("auth_expired", None)
|
||||||
|
connector.config = updated_cfg
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(connector)
|
||||||
|
|
||||||
|
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||||
|
|
||||||
|
return new_access
|
||||||
|
|
||||||
|
|
||||||
async def _maybe_refresh_mcp_oauth_token(
|
async def _maybe_refresh_mcp_oauth_token(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
connector: "SearchSourceConnector",
|
connector: "SearchSourceConnector",
|
||||||
|
|
@ -504,73 +661,13 @@ async def _maybe_refresh_mcp_oauth_token(
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return server_config
|
return server_config
|
||||||
|
|
||||||
refresh_token = mcp_oauth.get("refresh_token")
|
|
||||||
if not refresh_token:
|
|
||||||
logger.warning(
|
|
||||||
"MCP connector %s token expired but no refresh_token available",
|
|
||||||
connector.id,
|
|
||||||
)
|
|
||||||
return server_config
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.services.mcp_oauth.discovery import refresh_access_token
|
new_access = await _refresh_connector_token(session, connector)
|
||||||
|
|
||||||
enc = _get_token_enc()
|
|
||||||
decrypted_refresh = enc.decrypt_token(refresh_token)
|
|
||||||
decrypted_secret = (
|
|
||||||
enc.decrypt_token(mcp_oauth["client_secret"])
|
|
||||||
if mcp_oauth.get("client_secret")
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
token_json = await refresh_access_token(
|
|
||||||
token_endpoint=mcp_oauth["token_endpoint"],
|
|
||||||
refresh_token=decrypted_refresh,
|
|
||||||
client_id=mcp_oauth["client_id"],
|
|
||||||
client_secret=decrypted_secret,
|
|
||||||
)
|
|
||||||
|
|
||||||
new_access = token_json.get("access_token")
|
|
||||||
if not new_access:
|
if not new_access:
|
||||||
logger.warning(
|
|
||||||
"MCP connector %s token refresh returned no access_token",
|
|
||||||
connector.id,
|
|
||||||
)
|
|
||||||
return server_config
|
return server_config
|
||||||
|
|
||||||
new_expires_at = None
|
logger.info("Proactively refreshed MCP OAuth token for connector %s", connector.id)
|
||||||
if token_json.get("expires_in"):
|
|
||||||
new_expires_at = datetime.now(UTC) + timedelta(
|
|
||||||
seconds=int(token_json["expires_in"])
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_oauth = dict(mcp_oauth)
|
|
||||||
updated_oauth["access_token"] = enc.encrypt_token(new_access)
|
|
||||||
if token_json.get("refresh_token"):
|
|
||||||
updated_oauth["refresh_token"] = enc.encrypt_token(
|
|
||||||
token_json["refresh_token"]
|
|
||||||
)
|
|
||||||
updated_oauth["expires_at"] = (
|
|
||||||
new_expires_at.isoformat() if new_expires_at else None
|
|
||||||
)
|
|
||||||
|
|
||||||
from sqlalchemy.orm.attributes import flag_modified
|
|
||||||
|
|
||||||
connector.config = {
|
|
||||||
**cfg,
|
|
||||||
"server_config": server_config,
|
|
||||||
"mcp_oauth": updated_oauth,
|
|
||||||
}
|
|
||||||
flag_modified(connector, "config")
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(connector)
|
|
||||||
|
|
||||||
logger.info("Refreshed MCP OAuth token for connector %s", connector.id)
|
|
||||||
|
|
||||||
# Invalidate cache so next call picks up the new token.
|
|
||||||
invalidate_mcp_tools_cache(connector.search_space_id)
|
|
||||||
|
|
||||||
# Return server_config with the fresh token injected for immediate use.
|
|
||||||
refreshed_config = dict(server_config)
|
refreshed_config = dict(server_config)
|
||||||
refreshed_config["headers"] = {
|
refreshed_config["headers"] = {
|
||||||
**server_config.get("headers", {}),
|
**server_config.get("headers", {}),
|
||||||
|
|
@ -587,6 +684,117 @@ async def _maybe_refresh_mcp_oauth_token(
|
||||||
return server_config
|
return server_config
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reactive 401 handling helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _is_auth_error(exc: Exception) -> bool:
|
||||||
|
"""Check if an exception indicates an HTTP 401 authentication failure."""
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
if isinstance(exc, httpx.HTTPStatusError):
|
||||||
|
return exc.response.status_code == 401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
err_str = str(exc).lower()
|
||||||
|
return "401" in err_str or "unauthorized" in err_str
|
||||||
|
|
||||||
|
|
||||||
|
async def _force_refresh_and_get_headers(
|
||||||
|
connector_id: int,
|
||||||
|
) -> dict[str, str] | None:
|
||||||
|
"""Force-refresh OAuth token for a connector and return fresh HTTP headers.
|
||||||
|
|
||||||
|
Opens a **new** DB session so this can be called from inside tool closures
|
||||||
|
that don't have access to the original session.
|
||||||
|
|
||||||
|
Returns ``None`` when the connector is not OAuth-backed, has no
|
||||||
|
refresh token, or the refresh itself fails.
|
||||||
|
"""
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cfg = connector.config or {}
|
||||||
|
if not cfg.get("mcp_oauth"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
server_config = cfg.get("server_config", {})
|
||||||
|
|
||||||
|
new_access = await _refresh_connector_token(session, connector)
|
||||||
|
if not new_access:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Force-refreshed MCP OAuth token for connector %s (401 recovery)",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
**server_config.get("headers", {}),
|
||||||
|
"Authorization": f"Bearer {new_access}",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to force-refresh MCP OAuth token for connector %s",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _mark_connector_auth_expired(connector_id: int) -> None:
|
||||||
|
"""Set ``config.auth_expired = True`` so the frontend shows re-auth UI."""
|
||||||
|
from app.db import async_session_maker
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == connector_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
if not connector:
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg = dict(connector.config or {})
|
||||||
|
if cfg.get("auth_expired"):
|
||||||
|
return
|
||||||
|
|
||||||
|
cfg["auth_expired"] = True
|
||||||
|
connector.config = cfg
|
||||||
|
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
|
flag_modified(connector, "config")
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Marked MCP connector %s as auth_expired after unrecoverable 401",
|
||||||
|
connector_id,
|
||||||
|
)
|
||||||
|
invalidate_mcp_tools_cache(connector.search_space_id)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to mark connector %s as auth_expired",
|
||||||
|
connector_id,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
def invalidate_mcp_tools_cache(search_space_id: int | None = None) -> None:
|
||||||
"""Invalidate cached MCP tools.
|
"""Invalidate cached MCP tools.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ import { toast } from "sonner";
|
||||||
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
import { activeSearchSpaceIdAtom } from "@/atoms/search-spaces/search-space-query.atoms";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Spinner } from "@/components/ui/spinner";
|
import { Spinner } from "@/components/ui/spinner";
|
||||||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
|
||||||
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
import { getConnectorIcon } from "@/contracts/enums/connectorIcons";
|
||||||
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||||
|
|
@ -16,23 +15,10 @@ import { DateRangeSelector } from "../../components/date-range-selector";
|
||||||
import { PeriodicSyncConfig } from "../../components/periodic-sync-config";
|
import { PeriodicSyncConfig } from "../../components/periodic-sync-config";
|
||||||
import { SummaryConfig } from "../../components/summary-config";
|
import { SummaryConfig } from "../../components/summary-config";
|
||||||
import { VisionLLMConfig } from "../../components/vision-llm-config";
|
import { VisionLLMConfig } from "../../components/vision-llm-config";
|
||||||
import { LIVE_CONNECTOR_TYPES } from "../../constants/connector-constants";
|
import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../../constants/connector-constants";
|
||||||
import { getConnectorDisplayName } from "../../tabs/all-connectors-tab";
|
import { getConnectorDisplayName } from "../../tabs/all-connectors-tab";
|
||||||
import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index";
|
import { type ConnectorConfigProps, getConnectorConfigComponent } from "../index";
|
||||||
|
|
||||||
const REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
|
|
||||||
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
|
|
||||||
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
|
|
||||||
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
|
|
||||||
};
|
|
||||||
|
|
||||||
interface ConnectorEditViewProps {
|
interface ConnectorEditViewProps {
|
||||||
connector: SearchSourceConnector;
|
connector: SearchSourceConnector;
|
||||||
startDate: Date | undefined;
|
startDate: Date | undefined;
|
||||||
|
|
@ -86,7 +72,7 @@ export const ConnectorEditView: FC<ConnectorEditViewProps> = ({
|
||||||
}) => {
|
}) => {
|
||||||
const searchSpaceIdAtom = useAtomValue(activeSearchSpaceIdAtom);
|
const searchSpaceIdAtom = useAtomValue(activeSearchSpaceIdAtom);
|
||||||
const isAuthExpired = connector.config?.auth_expired === true;
|
const isAuthExpired = connector.config?.auth_expired === true;
|
||||||
const reauthEndpoint = REAUTH_ENDPOINTS[connector.connector_type];
|
const reauthEndpoint = getReauthEndpoint(connector);
|
||||||
const [reauthing, setReauthing] = useState(false);
|
const [reauthing, setReauthing] = useState(false);
|
||||||
|
|
||||||
const handleReauth = useCallback(async () => {
|
const handleReauth = useCallback(async () => {
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import { EnumConnectorName } from "@/contracts/enums/connector";
|
import { EnumConnectorName } from "@/contracts/enums/connector";
|
||||||
|
import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Connectors that operate in real time (no background indexing).
|
* Connectors that operate in real time (no background indexing).
|
||||||
|
|
@ -367,5 +368,43 @@ export function getConnectorTelemetryMeta(connectorType: string): ConnectorTelem
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// REAUTH ENDPOINTS
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Legacy (non-MCP) OAuth reauth endpoints, keyed by connector type.
|
||||||
|
* These are used for connectors that were NOT created via MCP OAuth.
|
||||||
|
*/
|
||||||
|
export const LEGACY_REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
|
||||||
|
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
|
||||||
|
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
|
||||||
|
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
|
||||||
|
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
|
||||||
|
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||||
|
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||||
|
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
||||||
|
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
|
||||||
|
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
|
||||||
|
[EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth",
|
||||||
|
[EnumConnectorName.TEAMS_CONNECTOR]: "/api/v1/auth/teams/connector/reauth",
|
||||||
|
[EnumConnectorName.DISCORD_CONNECTOR]: "/api/v1/auth/discord/connector/reauth",
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resolve the reauth endpoint for a connector.
|
||||||
|
*
|
||||||
|
* MCP OAuth connectors (those with ``config.mcp_service``) dynamically build
|
||||||
|
* the URL from the service key. Legacy OAuth connectors fall back to the
|
||||||
|
* static ``LEGACY_REAUTH_ENDPOINTS`` map.
|
||||||
|
*/
|
||||||
|
export function getReauthEndpoint(connector: SearchSourceConnector): string | undefined {
|
||||||
|
const mcpService = connector.config?.mcp_service as string | undefined;
|
||||||
|
if (mcpService) {
|
||||||
|
return `/api/v1/auth/mcp/${mcpService}/connector/reauth`;
|
||||||
|
}
|
||||||
|
return LEGACY_REAUTH_ENDPOINTS[connector.connector_type];
|
||||||
|
}
|
||||||
|
|
||||||
// Re-export IndexingConfigState from schemas for backward compatibility
|
// Re-export IndexingConfigState from schemas for backward compatibility
|
||||||
export type { IndexingConfigState } from "./connector-popup.schemas";
|
export type { IndexingConfigState } from "./connector-popup.schemas";
|
||||||
|
|
|
||||||
|
|
@ -13,25 +13,10 @@ import type { SearchSourceConnector } from "@/contracts/types/connector.types";
|
||||||
import { authenticatedFetch } from "@/lib/auth-utils";
|
import { authenticatedFetch } from "@/lib/auth-utils";
|
||||||
import { formatRelativeDate } from "@/lib/format-date";
|
import { formatRelativeDate } from "@/lib/format-date";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { LIVE_CONNECTOR_TYPES } from "../constants/connector-constants";
|
import { LIVE_CONNECTOR_TYPES, getReauthEndpoint } from "../constants/connector-constants";
|
||||||
import { useConnectorStatus } from "../hooks/use-connector-status";
|
import { useConnectorStatus } from "../hooks/use-connector-status";
|
||||||
import { getConnectorDisplayName } from "../tabs/all-connectors-tab";
|
import { getConnectorDisplayName } from "../tabs/all-connectors-tab";
|
||||||
|
|
||||||
const REAUTH_ENDPOINTS: Partial<Record<string, string>> = {
|
|
||||||
[EnumConnectorName.LINEAR_CONNECTOR]: "/api/v1/auth/linear/connector/reauth",
|
|
||||||
[EnumConnectorName.NOTION_CONNECTOR]: "/api/v1/auth/notion/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/google/drive/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_GMAIL_CONNECTOR]: "/api/v1/auth/google/gmail/connector/reauth",
|
|
||||||
[EnumConnectorName.GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/google/calendar/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GOOGLE_DRIVE_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GMAIL_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR]: "/api/v1/auth/composio/connector/reauth",
|
|
||||||
[EnumConnectorName.ONEDRIVE_CONNECTOR]: "/api/v1/auth/onedrive/connector/reauth",
|
|
||||||
[EnumConnectorName.JIRA_CONNECTOR]: "/api/v1/auth/jira/connector/reauth",
|
|
||||||
[EnumConnectorName.DROPBOX_CONNECTOR]: "/api/v1/auth/dropbox/connector/reauth",
|
|
||||||
[EnumConnectorName.CONFLUENCE_CONNECTOR]: "/api/v1/auth/confluence/connector/reauth",
|
|
||||||
};
|
|
||||||
|
|
||||||
interface ConnectorAccountsListViewProps {
|
interface ConnectorAccountsListViewProps {
|
||||||
connectorType: string;
|
connectorType: string;
|
||||||
connectorTitle: string;
|
connectorTitle: string;
|
||||||
|
|
@ -68,16 +53,15 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||||
const isEnabled = isConnectorEnabled(connectorType);
|
const isEnabled = isConnectorEnabled(connectorType);
|
||||||
const statusMessage = getConnectorStatusMessage(connectorType);
|
const statusMessage = getConnectorStatusMessage(connectorType);
|
||||||
|
|
||||||
const reauthEndpoint = REAUTH_ENDPOINTS[connectorType];
|
|
||||||
|
|
||||||
const handleReauth = useCallback(
|
const handleReauth = useCallback(
|
||||||
async (connectorId: number) => {
|
async (connector: SearchSourceConnector) => {
|
||||||
if (!searchSpaceId || !reauthEndpoint) return;
|
const endpoint = getReauthEndpoint(connector);
|
||||||
setReauthingId(connectorId);
|
if (!searchSpaceId || !endpoint) return;
|
||||||
|
setReauthingId(connector.id);
|
||||||
try {
|
try {
|
||||||
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
const backendUrl = process.env.NEXT_PUBLIC_FASTAPI_BACKEND_URL || "http://localhost:8000";
|
||||||
const url = new URL(`${backendUrl}${reauthEndpoint}`);
|
const url = new URL(`${backendUrl}${endpoint}`);
|
||||||
url.searchParams.set("connector_id", String(connectorId));
|
url.searchParams.set("connector_id", String(connector.id));
|
||||||
url.searchParams.set("space_id", String(searchSpaceId));
|
url.searchParams.set("space_id", String(searchSpaceId));
|
||||||
url.searchParams.set("return_url", window.location.pathname);
|
url.searchParams.set("return_url", window.location.pathname);
|
||||||
const response = await authenticatedFetch(url.toString());
|
const response = await authenticatedFetch(url.toString());
|
||||||
|
|
@ -99,7 +83,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||||
setReauthingId(null);
|
setReauthingId(null);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[searchSpaceId, reauthEndpoint]
|
[searchSpaceId]
|
||||||
);
|
);
|
||||||
|
|
||||||
// Filter connectors to only show those of this type
|
// Filter connectors to only show those of this type
|
||||||
|
|
@ -200,7 +184,8 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||||
<div className="grid grid-cols-1 sm:grid-cols-2 gap-3">
|
<div className="grid grid-cols-1 sm:grid-cols-2 gap-3">
|
||||||
{typeConnectors.map((connector) => {
|
{typeConnectors.map((connector) => {
|
||||||
const isIndexing = indexingConnectorIds.has(connector.id);
|
const isIndexing = indexingConnectorIds.has(connector.id);
|
||||||
const isAuthExpired = !!reauthEndpoint && connector.config?.auth_expired === true;
|
const connectorReauthEndpoint = getReauthEndpoint(connector);
|
||||||
|
const isAuthExpired = !!connectorReauthEndpoint && connector.config?.auth_expired === true;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
|
@ -243,7 +228,7 @@ export const ConnectorAccountsListView: FC<ConnectorAccountsListViewProps> = ({
|
||||||
<Button
|
<Button
|
||||||
size="sm"
|
size="sm"
|
||||||
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-amber-600 hover:bg-amber-700 text-white border-0 shadow-xs shrink-0"
|
className="h-8 text-[11px] px-3 rounded-lg font-medium bg-amber-600 hover:bg-amber-700 text-white border-0 shadow-xs shrink-0"
|
||||||
onClick={() => handleReauth(connector.id)}
|
onClick={() => handleReauth(connector)}
|
||||||
disabled={reauthingId === connector.id}
|
disabled={reauthingId === connector.id}
|
||||||
>
|
>
|
||||||
<RefreshCw
|
<RefreshCw
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue