mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-08 15:22:39 +02:00
Merge remote-tracking branch 'upstream/dev' into feat/obsidian-plugin
This commit is contained in:
commit
9b1b9a90c0
175 changed files with 10592 additions and 2302 deletions
|
|
@ -30,6 +30,7 @@ from .jira_add_connector_route import router as jira_add_connector_router
|
|||
from .linear_add_connector_route import router as linear_add_connector_router
|
||||
from .logs_routes import router as logs_router
|
||||
from .luma_add_connector_route import router as luma_add_connector_router
|
||||
from .mcp_oauth_route import router as mcp_oauth_router
|
||||
from .memory_routes import router as memory_router
|
||||
from .model_list_routes import router as model_list_router
|
||||
from .new_chat_routes import router as new_chat_router
|
||||
|
|
@ -97,6 +98,7 @@ router.include_router(logs_router)
|
|||
router.include_router(circleback_webhook_router) # Circleback meeting webhooks
|
||||
router.include_router(surfsense_docs_router) # Surfsense documentation for citations
|
||||
router.include_router(notifications_router) # Notifications with Zero sync
|
||||
router.include_router(mcp_oauth_router) # MCP OAuth 2.1 for Linear, Jira, ClickUp, Slack, Airtable
|
||||
router.include_router(composio_router) # Composio OAuth and toolkit management
|
||||
router.include_router(public_chat_router) # Public chat sharing and cloning
|
||||
router.include_router(incentive_tasks_router) # Incentive tasks for earning free pages
|
||||
|
|
|
|||
|
|
@ -311,7 +311,7 @@ async def airtable_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=credentials_dict,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -301,7 +301,7 @@ async def clickup_callback(
|
|||
# Update existing connector
|
||||
existing_connector.config = connector_config
|
||||
existing_connector.name = "ClickUp Connector"
|
||||
existing_connector.is_indexable = True
|
||||
existing_connector.is_indexable = False
|
||||
logger.info(
|
||||
f"Updated existing ClickUp connector for user {user_id} in space {space_id}"
|
||||
)
|
||||
|
|
@ -310,7 +310,7 @@ async def clickup_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name="ClickUp Connector",
|
||||
connector_type=SearchSourceConnectorType.CLICKUP_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -326,7 +326,7 @@ async def discord_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.DISCORD_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -340,7 +340,7 @@ async def calendar_callback(
|
|||
config=creds_dict,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
)
|
||||
session.add(db_connector)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -371,7 +371,7 @@ async def gmail_callback(
|
|||
config=creds_dict,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
)
|
||||
session.add(db_connector)
|
||||
await session.commit()
|
||||
|
|
|
|||
|
|
@ -386,7 +386,7 @@ async def jira_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.JIRA_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -399,7 +399,7 @@ async def linear_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ async def add_luma_connector(
|
|||
if existing_connector:
|
||||
# Update existing connector with new API key
|
||||
existing_connector.config = {"api_key": request.api_key}
|
||||
existing_connector.is_indexable = True
|
||||
existing_connector.is_indexable = False
|
||||
await session.commit()
|
||||
await session.refresh(existing_connector)
|
||||
|
||||
|
|
@ -82,7 +82,7 @@ async def add_luma_connector(
|
|||
config={"api_key": request.api_key},
|
||||
search_space_id=request.space_id,
|
||||
user_id=user.id,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
)
|
||||
|
||||
session.add(db_connector)
|
||||
|
|
|
|||
601
surfsense_backend/app/routes/mcp_oauth_route.py
Normal file
601
surfsense_backend/app/routes/mcp_oauth_route.py
Normal file
|
|
@ -0,0 +1,601 @@
|
|||
"""Generic MCP OAuth 2.1 route for services with official MCP servers.
|
||||
|
||||
Handles the full flow: discovery → DCR → PKCE authorization → token exchange
|
||||
→ MCP_CONNECTOR creation. Currently supports Linear, Jira, ClickUp, Slack,
|
||||
and Airtable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import generate_unique_connector_name
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption, generate_pkce_pair
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _fetch_account_metadata(
|
||||
service_key: str, access_token: str, token_json: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch display-friendly account metadata after a successful token exchange.
|
||||
|
||||
DCR services (Linear, Jira, ClickUp) issue MCP-scoped tokens that cannot
|
||||
call their standard REST/GraphQL APIs — metadata discovery for those
|
||||
happens at runtime through MCP tools instead.
|
||||
|
||||
Pre-configured services (Slack, Airtable) use standard OAuth tokens that
|
||||
*can* call their APIs, so we extract metadata here.
|
||||
|
||||
Failures are logged but never block connector creation.
|
||||
"""
|
||||
from app.services.mcp_oauth.registry import MCP_SERVICES
|
||||
|
||||
svc = MCP_SERVICES.get(service_key)
|
||||
if not svc or svc.supports_dcr:
|
||||
return {}
|
||||
|
||||
import httpx
|
||||
|
||||
meta: dict[str, Any] = {}
|
||||
|
||||
try:
|
||||
if service_key == "slack":
|
||||
team_info = token_json.get("team", {})
|
||||
meta["team_id"] = team_info.get("id", "")
|
||||
# TODO: oauth.v2.user.access only returns team.id, not
|
||||
# team.name. To populate team_name, add "team:read" scope
|
||||
# and call GET /api/team.info here.
|
||||
meta["team_name"] = team_info.get("name", "")
|
||||
if meta["team_name"]:
|
||||
meta["display_name"] = meta["team_name"]
|
||||
elif meta["team_id"]:
|
||||
meta["display_name"] = f"Slack ({meta['team_id']})"
|
||||
|
||||
elif service_key == "airtable":
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.get(
|
||||
"https://api.airtable.com/v0/meta/whoami",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
whoami = resp.json()
|
||||
meta["user_id"] = whoami.get("id", "")
|
||||
meta["user_email"] = whoami.get("email", "")
|
||||
meta["display_name"] = whoami.get("email", "Airtable")
|
||||
else:
|
||||
logger.warning(
|
||||
"Airtable whoami API returned %d (non-blocking)", resp.status_code,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch account metadata for %s (non-blocking)",
|
||||
service_key,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return meta
|
||||
|
||||
_state_manager: OAuthStateManager | None = None
|
||||
_token_encryption: TokenEncryption | None = None
|
||||
|
||||
|
||||
def _get_state_manager() -> OAuthStateManager:
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def _get_token_encryption() -> TokenEncryption:
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(status_code=500, detail="SECRET_KEY not configured.")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
|
||||
def _build_redirect_uri(service: str) -> str:
|
||||
base = config.BACKEND_URL or "http://localhost:8000"
|
||||
return f"{base.rstrip('/')}/api/v1/auth/mcp/{service}/connector/callback"
|
||||
|
||||
|
||||
def _frontend_redirect(
|
||||
space_id: int | None,
|
||||
*,
|
||||
success: bool = False,
|
||||
connector_id: int | None = None,
|
||||
error: str | None = None,
|
||||
service: str = "mcp",
|
||||
) -> RedirectResponse:
|
||||
if success and space_id:
|
||||
qs = f"success=true&connector={service}-mcp-connector"
|
||||
if connector_id:
|
||||
qs += f"&connectorId={connector_id}"
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
|
||||
)
|
||||
if error and space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
|
||||
)
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /add — start MCP OAuth flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/auth/mcp/{service}/connector/add")
|
||||
async def connect_mcp_service(
|
||||
service: str,
|
||||
space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
from app.services.mcp_oauth.registry import get_service
|
||||
|
||||
svc = get_service(service)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import (
|
||||
discover_oauth_metadata,
|
||||
register_client,
|
||||
)
|
||||
|
||||
metadata = await discover_oauth_metadata(
|
||||
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
|
||||
)
|
||||
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
|
||||
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
|
||||
registration_endpoint = metadata.get("registration_endpoint")
|
||||
|
||||
if not auth_endpoint or not token_endpoint:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
|
||||
)
|
||||
|
||||
redirect_uri = _build_redirect_uri(service)
|
||||
|
||||
if svc.supports_dcr and registration_endpoint:
|
||||
dcr = await register_client(registration_endpoint, redirect_uri)
|
||||
client_id = dcr.get("client_id")
|
||||
client_secret = dcr.get("client_secret", "")
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"DCR for {svc.name} did not return a client_id.",
|
||||
)
|
||||
elif svc.client_id_env:
|
||||
client_id = getattr(config, svc.client_id_env, None)
|
||||
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
|
||||
)
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
enc = _get_token_encryption()
|
||||
|
||||
state = _get_state_manager().generate_secure_state(
|
||||
space_id,
|
||||
user.id,
|
||||
service=service,
|
||||
code_verifier=verifier,
|
||||
mcp_client_id=client_id,
|
||||
mcp_client_secret=enc.encrypt_token(client_secret) if client_secret else "",
|
||||
mcp_token_endpoint=token_endpoint,
|
||||
mcp_url=svc.mcp_url,
|
||||
)
|
||||
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": state,
|
||||
}
|
||||
if svc.scopes:
|
||||
auth_params[svc.scope_param] = " ".join(svc.scopes)
|
||||
|
||||
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Generated %s MCP OAuth URL for user %s, space %s",
|
||||
svc.name, user.id, space_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to initiate %s MCP OAuth: %s", service, e, exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate {service} MCP OAuth.",
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /callback — handle OAuth redirect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/auth/mcp/{service}/connector/callback")
|
||||
async def mcp_oauth_callback(
|
||||
service: str,
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if error:
|
||||
logger.warning("%s MCP OAuth error: %s", service, error)
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
data = _get_state_manager().validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass
|
||||
return _frontend_redirect(
|
||||
space_id, error=f"{service}_mcp_oauth_denied", service=service,
|
||||
)
|
||||
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="Missing authorization code")
|
||||
if not state:
|
||||
raise HTTPException(status_code=400, detail="Missing state parameter")
|
||||
|
||||
data = _get_state_manager().validate_state(state)
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
svc_key = data.get("service", service)
|
||||
|
||||
if svc_key != service:
|
||||
raise HTTPException(status_code=400, detail="State/path service mismatch")
|
||||
|
||||
from app.services.mcp_oauth.registry import get_service
|
||||
|
||||
svc = get_service(svc_key)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {svc_key}")
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import exchange_code_for_tokens
|
||||
|
||||
enc = _get_token_encryption()
|
||||
client_id = data["mcp_client_id"]
|
||||
client_secret = (
|
||||
enc.decrypt_token(data["mcp_client_secret"])
|
||||
if data.get("mcp_client_secret")
|
||||
else ""
|
||||
)
|
||||
token_endpoint = data["mcp_token_endpoint"]
|
||||
code_verifier = data["code_verifier"]
|
||||
mcp_url = data["mcp_url"]
|
||||
redirect_uri = _build_redirect_uri(service)
|
||||
|
||||
token_json = await exchange_code_for_tokens(
|
||||
token_endpoint=token_endpoint,
|
||||
code=code,
|
||||
redirect_uri=redirect_uri,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
access_token = token_json.get("access_token")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
expires_in = token_json.get("expires_in")
|
||||
scope = token_json.get("scope")
|
||||
|
||||
if not access_token and "authed_user" in token_json:
|
||||
authed = token_json["authed_user"]
|
||||
access_token = authed.get("access_token")
|
||||
refresh_token = refresh_token or authed.get("refresh_token")
|
||||
scope = scope or authed.get("scope")
|
||||
expires_in = expires_in or authed.get("expires_in")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No access token received from {svc.name}.",
|
||||
)
|
||||
|
||||
expires_at = None
|
||||
if expires_in:
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(expires_in)
|
||||
)
|
||||
|
||||
connector_config = {
|
||||
"server_config": {
|
||||
"transport": "streamable-http",
|
||||
"url": mcp_url,
|
||||
},
|
||||
"mcp_service": svc_key,
|
||||
"mcp_oauth": {
|
||||
"client_id": client_id,
|
||||
"client_secret": enc.encrypt_token(client_secret) if client_secret else "",
|
||||
"token_endpoint": token_endpoint,
|
||||
"access_token": enc.encrypt_token(access_token),
|
||||
"refresh_token": enc.encrypt_token(refresh_token) if refresh_token else None,
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"scope": scope,
|
||||
},
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
account_meta = await _fetch_account_metadata(svc_key, access_token, token_json)
|
||||
if account_meta:
|
||||
_SAFE_META_KEYS = {"display_name", "team_id", "team_name", "user_id", "user_email",
|
||||
"workspace_id", "workspace_name", "organization_name",
|
||||
"organization_url_key", "cloud_id", "site_name", "base_url"}
|
||||
for k, v in account_meta.items():
|
||||
if k in _SAFE_META_KEYS:
|
||||
connector_config[k] = v
|
||||
logger.info(
|
||||
"Stored account metadata for %s: display_name=%s",
|
||||
svc_key, account_meta.get("display_name", ""),
|
||||
)
|
||||
|
||||
# ---- Re-auth path ----
|
||||
db_connector_type = SearchSourceConnectorType(svc.connector_type)
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == db_connector_type,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
_invalidate_cache(space_id)
|
||||
|
||||
logger.info(
|
||||
"Re-authenticated %s MCP connector %s for user %s",
|
||||
svc.name, db_connector.id, user_id,
|
||||
)
|
||||
reauth_return_url = data.get("return_url")
|
||||
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return _frontend_redirect(
|
||||
space_id, success=True, connector_id=db_connector.id, service=service,
|
||||
)
|
||||
|
||||
# ---- New connector path ----
|
||||
naming_identifier = account_meta.get("display_name")
|
||||
connector_name = await generate_unique_connector_name(
|
||||
session,
|
||||
db_connector_type,
|
||||
space_id,
|
||||
user_id,
|
||||
naming_identifier,
|
||||
)
|
||||
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=db_connector_type,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(new_connector)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409, detail="A connector for this service already exists.",
|
||||
) from e
|
||||
|
||||
_invalidate_cache(space_id)
|
||||
|
||||
logger.info(
|
||||
"Created %s MCP connector %s for user %s in space %s",
|
||||
svc.name, new_connector.id, user_id, space_id,
|
||||
)
|
||||
return _frontend_redirect(
|
||||
space_id, success=True, connector_id=new_connector.id, service=service,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to complete %s MCP OAuth: %s", service, e, exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to complete {service} MCP OAuth.",
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /reauth — re-authenticate an existing MCP connector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/auth/mcp/{service}/connector/reauth")
|
||||
async def reauth_mcp_service(
|
||||
service: str,
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
from app.services.mcp_oauth.registry import get_service
|
||||
|
||||
svc = get_service(service)
|
||||
if not svc:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown MCP service: {service}")
|
||||
|
||||
db_connector_type = SearchSourceConnectorType(svc.connector_type)
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == db_connector_type,
|
||||
)
|
||||
)
|
||||
if not result.scalars().first():
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Connector not found or access denied",
|
||||
)
|
||||
|
||||
try:
|
||||
from app.services.mcp_oauth.discovery import (
|
||||
discover_oauth_metadata,
|
||||
register_client,
|
||||
)
|
||||
|
||||
metadata = await discover_oauth_metadata(
|
||||
svc.mcp_url, origin_override=svc.oauth_discovery_origin,
|
||||
)
|
||||
auth_endpoint = svc.auth_endpoint_override or metadata.get("authorization_endpoint")
|
||||
token_endpoint = svc.token_endpoint_override or metadata.get("token_endpoint")
|
||||
registration_endpoint = metadata.get("registration_endpoint")
|
||||
|
||||
if not auth_endpoint or not token_endpoint:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server returned incomplete OAuth metadata.",
|
||||
)
|
||||
|
||||
redirect_uri = _build_redirect_uri(service)
|
||||
|
||||
if svc.supports_dcr and registration_endpoint:
|
||||
dcr = await register_client(registration_endpoint, redirect_uri)
|
||||
client_id = dcr.get("client_id")
|
||||
client_secret = dcr.get("client_secret", "")
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"DCR for {svc.name} did not return a client_id.",
|
||||
)
|
||||
elif svc.client_id_env:
|
||||
client_id = getattr(config, svc.client_id_env, None)
|
||||
client_secret = getattr(config, svc.client_secret_env or "", None) or ""
|
||||
if not client_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{svc.name} MCP OAuth not configured ({svc.client_id_env}).",
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"{svc.name} MCP server has no DCR and no fallback credentials.",
|
||||
)
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
enc = _get_token_encryption()
|
||||
|
||||
extra: dict = {
|
||||
"service": service,
|
||||
"code_verifier": verifier,
|
||||
"mcp_client_id": client_id,
|
||||
"mcp_client_secret": enc.encrypt_token(client_secret) if client_secret else "",
|
||||
"mcp_token_endpoint": token_endpoint,
|
||||
"mcp_url": svc.mcp_url,
|
||||
"connector_id": connector_id,
|
||||
}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
|
||||
state = _get_state_manager().generate_secure_state(
|
||||
space_id, user.id, **extra,
|
||||
)
|
||||
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": state,
|
||||
}
|
||||
if svc.scopes:
|
||||
auth_params[svc.scope_param] = " ".join(svc.scopes)
|
||||
|
||||
auth_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Initiating %s MCP re-auth for user %s, connector %s",
|
||||
svc.name, user.id, connector_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to initiate %s MCP re-auth: %s", service, e, exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to initiate {service} MCP re-auth.",
|
||||
) from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _invalidate_cache(space_id: int) -> None:
|
||||
try:
|
||||
from app.agents.new_chat.tools.mcp_tool import invalidate_mcp_tools_cache
|
||||
|
||||
invalidate_mcp_tools_cache(space_id)
|
||||
except Exception:
|
||||
logger.debug("MCP cache invalidation skipped", exc_info=True)
|
||||
|
|
@ -22,6 +22,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.agents.new_chat.filesystem_selection import (
|
||||
ClientPlatform,
|
||||
LocalFilesystemMount,
|
||||
FilesystemMode,
|
||||
FilesystemSelection,
|
||||
)
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
ChatComment,
|
||||
ChatVisibility,
|
||||
|
|
@ -36,6 +43,7 @@ from app.db import (
|
|||
)
|
||||
from app.schemas.new_chat import (
|
||||
AgentToolInfo,
|
||||
LocalFilesystemMountPayload,
|
||||
NewChatMessageRead,
|
||||
NewChatRequest,
|
||||
NewChatThreadCreate,
|
||||
|
|
@ -63,6 +71,67 @@ _background_tasks: set[asyncio.Task] = set()
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _resolve_filesystem_selection(
|
||||
*,
|
||||
mode: str,
|
||||
client_platform: str,
|
||||
local_mounts: list[LocalFilesystemMountPayload] | None,
|
||||
) -> FilesystemSelection:
|
||||
"""Validate and normalize filesystem mode settings from request payload."""
|
||||
try:
|
||||
resolved_mode = FilesystemMode(mode)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid filesystem_mode") from exc
|
||||
try:
|
||||
resolved_platform = ClientPlatform(client_platform)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid client_platform") from exc
|
||||
|
||||
if resolved_mode == FilesystemMode.DESKTOP_LOCAL_FOLDER:
|
||||
if not config.ENABLE_DESKTOP_LOCAL_FILESYSTEM:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Desktop local filesystem mode is disabled on this deployment.",
|
||||
)
|
||||
if resolved_platform != ClientPlatform.DESKTOP:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="desktop_local_folder mode is only available on desktop runtime.",
|
||||
)
|
||||
normalized_mounts: list[tuple[str, str]] = []
|
||||
seen_mounts: set[str] = set()
|
||||
for mount in local_mounts or []:
|
||||
mount_id = mount.mount_id.strip()
|
||||
root_path = mount.root_path.strip()
|
||||
if not mount_id or not root_path:
|
||||
continue
|
||||
if mount_id in seen_mounts:
|
||||
continue
|
||||
seen_mounts.add(mount_id)
|
||||
normalized_mounts.append((mount_id, root_path))
|
||||
if not normalized_mounts:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"local_filesystem_mounts must include at least one mount for "
|
||||
"desktop_local_folder mode."
|
||||
),
|
||||
)
|
||||
return FilesystemSelection(
|
||||
mode=resolved_mode,
|
||||
client_platform=resolved_platform,
|
||||
local_mounts=tuple(
|
||||
LocalFilesystemMount(mount_id=mount_id, root_path=root_path)
|
||||
for mount_id, root_path in normalized_mounts
|
||||
),
|
||||
)
|
||||
|
||||
return FilesystemSelection(
|
||||
mode=FilesystemMode.CLOUD,
|
||||
client_platform=resolved_platform,
|
||||
)
|
||||
|
||||
|
||||
def _try_delete_sandbox(thread_id: int) -> None:
|
||||
"""Fire-and-forget sandbox + local file deletion so the HTTP response isn't blocked."""
|
||||
from app.agents.new_chat.sandbox import (
|
||||
|
|
@ -1098,6 +1167,7 @@ async def list_agent_tools(
|
|||
@router.post("/new_chat")
|
||||
async def handle_new_chat(
|
||||
request: NewChatRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -1133,6 +1203,11 @@ async def handle_new_chat(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
local_mounts=request.local_filesystem_mounts,
|
||||
)
|
||||
|
||||
# Get search space to check LLM config preferences
|
||||
search_space_result = await session.execute(
|
||||
|
|
@ -1175,6 +1250,8 @@ async def handle_new_chat(
|
|||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
disabled_tools=request.disabled_tools,
|
||||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
@ -1202,6 +1279,7 @@ async def handle_new_chat(
|
|||
async def regenerate_response(
|
||||
thread_id: int,
|
||||
request: RegenerateRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -1247,6 +1325,11 @@ async def regenerate_response(
|
|||
|
||||
# Check thread-level access based on visibility
|
||||
await check_thread_access(session, thread, user)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
local_mounts=request.local_filesystem_mounts,
|
||||
)
|
||||
|
||||
# Get the checkpointer and state history
|
||||
checkpointer = await get_checkpointer()
|
||||
|
|
@ -1412,6 +1495,8 @@ async def regenerate_response(
|
|||
thread_visibility=thread.visibility,
|
||||
current_user_display_name=user.display_name or "A team member",
|
||||
disabled_tools=request.disabled_tools,
|
||||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
):
|
||||
yield chunk
|
||||
streaming_completed = True
|
||||
|
|
@ -1477,6 +1562,7 @@ async def regenerate_response(
|
|||
async def resume_chat(
|
||||
thread_id: int,
|
||||
request: ResumeRequest,
|
||||
http_request: Request,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
|
|
@ -1498,6 +1584,11 @@ async def resume_chat(
|
|||
)
|
||||
|
||||
await check_thread_access(session, thread, user)
|
||||
filesystem_selection = _resolve_filesystem_selection(
|
||||
mode=request.filesystem_mode,
|
||||
client_platform=request.client_platform,
|
||||
local_mounts=request.local_filesystem_mounts,
|
||||
)
|
||||
|
||||
search_space_result = await session.execute(
|
||||
select(SearchSpace).filter(SearchSpace.id == request.search_space_id)
|
||||
|
|
@ -1526,6 +1617,8 @@ async def resume_chat(
|
|||
user_id=str(user.id),
|
||||
llm_config_id=llm_config_id,
|
||||
thread_visibility=thread.visibility,
|
||||
filesystem_selection=filesystem_selection,
|
||||
request_id=getattr(http_request.state, "request_id", "unknown"),
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
|
|
|
|||
620
surfsense_backend/app/routes/oauth_connector_base.py
Normal file
620
surfsense_backend/app/routes/oauth_connector_base.py
Normal file
|
|
@ -0,0 +1,620 @@
|
|||
"""Reusable base for OAuth 2.0 connector routes.
|
||||
|
||||
Subclasses override ``fetch_account_info``, ``build_connector_config``,
|
||||
and ``get_connector_display_name`` to customise provider-specific behaviour.
|
||||
Call ``build_router()`` to get a FastAPI ``APIRouter`` with ``/connector/add``,
|
||||
``/connector/callback``, and ``/connector/reauth`` endpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthConnectorRoute:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider_name: str,
|
||||
connector_type: SearchSourceConnectorType,
|
||||
authorize_url: str,
|
||||
token_url: str,
|
||||
client_id_env: str,
|
||||
client_secret_env: str,
|
||||
redirect_uri_env: str,
|
||||
scopes: list[str],
|
||||
auth_prefix: str,
|
||||
use_pkce: bool = False,
|
||||
token_auth_method: str = "body",
|
||||
is_indexable: bool = True,
|
||||
extra_auth_params: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.provider_name = provider_name
|
||||
self.connector_type = connector_type
|
||||
self.authorize_url = authorize_url
|
||||
self.token_url = token_url
|
||||
self.client_id_env = client_id_env
|
||||
self.client_secret_env = client_secret_env
|
||||
self.redirect_uri_env = redirect_uri_env
|
||||
self.scopes = scopes
|
||||
self.auth_prefix = auth_prefix.rstrip("/")
|
||||
self.use_pkce = use_pkce
|
||||
self.token_auth_method = token_auth_method
|
||||
self.is_indexable = is_indexable
|
||||
self.extra_auth_params = extra_auth_params or {}
|
||||
|
||||
self._state_manager: OAuthStateManager | None = None
|
||||
self._token_encryption: TokenEncryption | None = None
|
||||
|
||||
def _get_client_id(self) -> str:
|
||||
value = getattr(config, self.client_id_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.provider_name.title()} OAuth not configured "
|
||||
f"({self.client_id_env} missing).",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_client_secret(self) -> str:
|
||||
value = getattr(config, self.client_secret_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.provider_name.title()} OAuth not configured "
|
||||
f"({self.client_secret_env} missing).",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_redirect_uri(self) -> str:
|
||||
value = getattr(config, self.redirect_uri_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.redirect_uri_env} not configured.",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_state_manager(self) -> OAuthStateManager:
|
||||
if self._state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="SECRET_KEY not configured for OAuth security.",
|
||||
)
|
||||
self._state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return self._state_manager
|
||||
|
||||
def _get_token_encryption(self) -> TokenEncryption:
|
||||
if self._token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="SECRET_KEY not configured for token encryption.",
|
||||
)
|
||||
self._token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return self._token_encryption
|
||||
|
||||
def _frontend_redirect(
|
||||
self,
|
||||
space_id: int | None,
|
||||
*,
|
||||
success: bool = False,
|
||||
connector_id: int | None = None,
|
||||
error: str | None = None,
|
||||
) -> RedirectResponse:
|
||||
if success and space_id:
|
||||
connector_slug = f"{self.provider_name}-connector"
|
||||
qs = f"success=true&connector={connector_slug}"
|
||||
if connector_id:
|
||||
qs += f"&connectorId={connector_id}"
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
|
||||
)
|
||||
if error and space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
|
||||
)
|
||||
if error:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error={error}"
|
||||
)
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
|
||||
|
||||
async def fetch_account_info(self, access_token: str) -> dict[str, Any]:
|
||||
"""Override to fetch account/workspace info after token exchange.
|
||||
|
||||
Return dict is merged into connector config; key ``"name"`` is used
|
||||
for the display name and dedup.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def build_connector_config(
|
||||
self,
|
||||
token_json: dict[str, Any],
|
||||
account_info: dict[str, Any],
|
||||
encryption: TokenEncryption,
|
||||
) -> dict[str, Any]:
|
||||
"""Override for custom config shapes. Default: standard encrypted OAuth fields."""
|
||||
access_token = token_json.get("access_token", "")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"access_token": encryption.encrypt_token(access_token),
|
||||
"refresh_token": (
|
||||
encryption.encrypt_token(refresh_token) if refresh_token else None
|
||||
),
|
||||
"token_type": token_json.get("token_type", "Bearer"),
|
||||
"expires_in": token_json.get("expires_in"),
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"scope": token_json.get("scope"),
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
cfg.update(account_info)
|
||||
return cfg
|
||||
|
||||
def get_connector_display_name(self, account_info: dict[str, Any]) -> str:
|
||||
return str(account_info.get("name", self.provider_name.title()))
|
||||
|
||||
async def on_token_refresh_failure(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
connector: SearchSourceConnector,
|
||||
) -> None:
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired flag for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _exchange_code(
|
||||
self, code: str, extra_state: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
client_id = self._get_client_id()
|
||||
client_secret = self._get_client_secret()
|
||||
redirect_uri = self._get_redirect_uri()
|
||||
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
body: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
if self.token_auth_method == "basic":
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
else:
|
||||
body["client_id"] = client_id
|
||||
body["client_secret"] = client_secret
|
||||
|
||||
if self.use_pkce:
|
||||
verifier = extra_state.get("code_verifier")
|
||||
if verifier:
|
||||
body["code_verifier"] = verifier
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
self.token_url, data=body, headers=headers, timeout=30.0
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
detail = resp.text
|
||||
try:
|
||||
detail = resp.json().get("error_description", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token exchange failed: {detail}"
|
||||
)
|
||||
|
||||
return resp.json()
|
||||
|
||||
async def refresh_token(
|
||||
self, session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
encryption = self._get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
refresh_tok = connector.config.get("refresh_token")
|
||||
if is_encrypted and refresh_tok:
|
||||
try:
|
||||
refresh_tok = encryption.decrypt_token(refresh_tok)
|
||||
except Exception as e:
|
||||
logger.error("Failed to decrypt refresh token: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to decrypt stored refresh token"
|
||||
) from e
|
||||
|
||||
if not refresh_tok:
|
||||
await self.on_token_refresh_failure(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
)
|
||||
|
||||
client_id = self._get_client_id()
|
||||
client_secret = self._get_client_secret()
|
||||
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
body: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_tok,
|
||||
}
|
||||
|
||||
if self.token_auth_method == "basic":
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
else:
|
||||
body["client_id"] = client_id
|
||||
body["client_secret"] = client_secret
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
self.token_url, data=body, headers=headers, timeout=30.0
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
error_detail = resp.text
|
||||
try:
|
||||
ej = resp.json()
|
||||
error_detail = ej.get("error_description", error_detail)
|
||||
error_code = ej.get("error", "")
|
||||
except Exception:
|
||||
error_code = ""
|
||||
combined = (error_detail + error_code).lower()
|
||||
if any(kw in combined for kw in ("invalid_grant", "expired", "revoked")):
|
||||
await self.on_token_refresh_failure(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"{self.provider_name.title()} authentication failed. "
|
||||
"Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
||||
token_json = resp.json()
|
||||
new_access = token_json.get("access_token")
|
||||
if not new_access:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from refresh"
|
||||
)
|
||||
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
updated_config = dict(connector.config)
|
||||
updated_config["access_token"] = encryption.encrypt_token(new_access)
|
||||
new_refresh = token_json.get("refresh_token")
|
||||
if new_refresh:
|
||||
updated_config["refresh_token"] = encryption.encrypt_token(new_refresh)
|
||||
updated_config["expires_in"] = token_json.get("expires_in")
|
||||
updated_config["expires_at"] = expires_at.isoformat() if expires_at else None
|
||||
updated_config["scope"] = token_json.get("scope", updated_config.get("scope"))
|
||||
updated_config["_token_encrypted"] = True
|
||||
updated_config.pop("auth_expired", None)
|
||||
|
||||
connector.config = updated_config
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info(
|
||||
"Refreshed %s token for connector %s",
|
||||
self.provider_name,
|
||||
connector.id,
|
||||
)
|
||||
return connector
|
||||
|
||||
def build_router(self) -> APIRouter:
|
||||
router = APIRouter()
|
||||
oauth = self
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/add")
|
||||
async def connect(
|
||||
space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
||||
client_id = oauth._get_client_id()
|
||||
state_mgr = oauth._get_state_manager()
|
||||
|
||||
extra_state: dict[str, Any] = {}
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": oauth._get_redirect_uri(),
|
||||
"scope": " ".join(oauth.scopes),
|
||||
}
|
||||
|
||||
if oauth.use_pkce:
|
||||
from app.utils.oauth_security import generate_pkce_pair
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
extra_state["code_verifier"] = verifier
|
||||
auth_params["code_challenge"] = challenge
|
||||
auth_params["code_challenge_method"] = "S256"
|
||||
|
||||
auth_params.update(oauth.extra_auth_params)
|
||||
|
||||
state_encoded = state_mgr.generate_secure_state(
|
||||
space_id, user.id, **extra_state
|
||||
)
|
||||
auth_params["state"] = state_encoded
|
||||
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Generated %s OAuth URL for user %s, space %s",
|
||||
oauth.provider_name,
|
||||
user.id,
|
||||
space_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/reauth")
|
||||
async def reauth(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == oauth.connector_type,
|
||||
)
|
||||
)
|
||||
if not result.scalars().first():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"{oauth.provider_name.title()} connector not found "
|
||||
"or access denied",
|
||||
)
|
||||
|
||||
client_id = oauth._get_client_id()
|
||||
state_mgr = oauth._get_state_manager()
|
||||
|
||||
extra: dict[str, Any] = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/") and not return_url.startswith("//"):
|
||||
extra["return_url"] = return_url
|
||||
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": oauth._get_redirect_uri(),
|
||||
"scope": " ".join(oauth.scopes),
|
||||
}
|
||||
|
||||
if oauth.use_pkce:
|
||||
from app.utils.oauth_security import generate_pkce_pair
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
extra["code_verifier"] = verifier
|
||||
auth_params["code_challenge"] = challenge
|
||||
auth_params["code_challenge_method"] = "S256"
|
||||
|
||||
auth_params.update(oauth.extra_auth_params)
|
||||
|
||||
state_encoded = state_mgr.generate_secure_state(
|
||||
space_id, user.id, **extra
|
||||
)
|
||||
auth_params["state"] = state_encoded
|
||||
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Initiating %s re-auth for user %s, connector %s",
|
||||
oauth.provider_name,
|
||||
user.id,
|
||||
connector_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/callback")
|
||||
async def callback(
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
error_label = f"{oauth.provider_name}_oauth_denied"
|
||||
|
||||
if error:
|
||||
logger.warning("%s OAuth error: %s", oauth.provider_name, error)
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
data = oauth._get_state_manager().validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass
|
||||
return oauth._frontend_redirect(space_id, error=error_label)
|
||||
|
||||
if not code:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing authorization code"
|
||||
)
|
||||
if not state:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing state parameter"
|
||||
)
|
||||
|
||||
state_mgr = oauth._get_state_manager()
|
||||
try:
|
||||
data = state_mgr.validate_state(state)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid or expired state parameter."
|
||||
) from e
|
||||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
|
||||
token_json = await oauth._exchange_code(code, data)
|
||||
|
||||
access_token = token_json.get("access_token", "")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No access token received from {oauth.provider_name.title()}",
|
||||
)
|
||||
|
||||
account_info = await oauth.fetch_account_info(access_token)
|
||||
encryption = oauth._get_token_encryption()
|
||||
connector_config = oauth.build_connector_config(
|
||||
token_json, account_info, encryption
|
||||
)
|
||||
|
||||
display_name = oauth.get_connector_display_name(account_info)
|
||||
|
||||
# --- Re-auth path ---
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == oauth.connector_type,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
"Re-authenticated %s connector %s for user %s",
|
||||
oauth.provider_name,
|
||||
db_connector.id,
|
||||
user_id,
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/") and not reauth_return_url.startswith("//"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id, success=True, connector_id=db_connector.id
|
||||
)
|
||||
|
||||
# --- New connector path ---
|
||||
is_dup = await check_duplicate_connector(
|
||||
session,
|
||||
oauth.connector_type,
|
||||
space_id,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
if is_dup:
|
||||
logger.warning(
|
||||
"Duplicate %s connector for user %s (%s)",
|
||||
oauth.provider_name,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id,
|
||||
error=f"duplicate_account&connector={oauth.provider_name}-connector",
|
||||
)
|
||||
|
||||
connector_name = await generate_unique_connector_name(
|
||||
session,
|
||||
oauth.connector_type,
|
||||
space_id,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=oauth.connector_type,
|
||||
is_indexable=oauth.is_indexable,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(new_connector)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409, detail="A connector for this service already exists."
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
"Created %s connector %s for user %s in space %s",
|
||||
oauth.provider_name,
|
||||
new_connector.id,
|
||||
user_id,
|
||||
space_id,
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id, success=True, connector_id=new_connector.id
|
||||
)
|
||||
|
||||
return router
|
||||
|
|
@ -693,27 +693,10 @@ async def index_connector_content(
|
|||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Index content from a connector to a search space.
|
||||
Requires CONNECTORS_UPDATE permission (to trigger indexing).
|
||||
Index content from a KB connector to a search space.
|
||||
|
||||
Currently supports:
|
||||
- SLACK_CONNECTOR: Indexes messages from all accessible Slack channels
|
||||
- TEAMS_CONNECTOR: Indexes messages from all accessible Microsoft Teams channels
|
||||
- NOTION_CONNECTOR: Indexes pages from all accessible Notion pages
|
||||
- GITHUB_CONNECTOR: Indexes code and documentation from GitHub repositories
|
||||
- LINEAR_CONNECTOR: Indexes issues and comments from Linear
|
||||
- JIRA_CONNECTOR: Indexes issues and comments from Jira
|
||||
- DISCORD_CONNECTOR: Indexes messages from all accessible Discord channels
|
||||
- LUMA_CONNECTOR: Indexes events from Luma
|
||||
- ELASTICSEARCH_CONNECTOR: Indexes documents from Elasticsearch
|
||||
- WEBCRAWLER_CONNECTOR: Indexes web pages from crawled websites
|
||||
|
||||
Args:
|
||||
connector_id: ID of the connector to use
|
||||
search_space_id: ID of the search space to store indexed content
|
||||
|
||||
Returns:
|
||||
Dictionary with indexing status
|
||||
Live connectors (Slack, Teams, Linear, Jira, ClickUp, Calendar, Airtable,
|
||||
Gmail, Discord, Luma) use real-time agent tools instead.
|
||||
"""
|
||||
try:
|
||||
# Get the connector first
|
||||
|
|
@ -770,9 +753,7 @@ async def index_connector_content(
|
|||
|
||||
# For calendar connectors, default to today but allow future dates if explicitly provided
|
||||
if connector.connector_type in [
|
||||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.COMPOSIO_GOOGLE_CALENDAR_CONNECTOR,
|
||||
SearchSourceConnectorType.LUMA_CONNECTOR,
|
||||
]:
|
||||
# Default to today if no end_date provided (users can manually select future dates)
|
||||
indexing_to = today_str if end_date is None else end_date
|
||||
|
|
@ -796,33 +777,22 @@ async def index_connector_content(
|
|||
# For non-calendar connectors, cap at today
|
||||
indexing_to = end_date if end_date else today_str
|
||||
|
||||
if connector.connector_type == SearchSourceConnectorType.SLACK_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_slack_messages_task,
|
||||
)
|
||||
from app.services.mcp_oauth.registry import LIVE_CONNECTOR_TYPES
|
||||
|
||||
logger.info(
|
||||
f"Triggering Slack indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_slack_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Slack indexing started in the background."
|
||||
if connector.connector_type in LIVE_CONNECTOR_TYPES:
|
||||
return {
|
||||
"message": (
|
||||
f"{connector.connector_type.value} uses real-time agent tools; "
|
||||
"background indexing is disabled."
|
||||
),
|
||||
"indexing_started": False,
|
||||
"connector_id": connector_id,
|
||||
"search_space_id": search_space_id,
|
||||
"indexing_from": indexing_from,
|
||||
"indexing_to": indexing_to,
|
||||
}
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_teams_messages_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Teams indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_teams_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Teams indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
if connector.connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_notion_pages_task
|
||||
|
||||
logger.info(
|
||||
|
|
@ -844,28 +814,6 @@ async def index_connector_content(
|
|||
)
|
||||
response_message = "GitHub indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.LINEAR_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_linear_issues_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering Linear indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_linear_issues_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Linear indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.JIRA_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_jira_issues_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering Jira indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_jira_issues_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Jira indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.CONFLUENCE_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_confluence_pages_task,
|
||||
|
|
@ -892,59 +840,6 @@ async def index_connector_content(
|
|||
)
|
||||
response_message = "BookStack indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.CLICKUP_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_clickup_tasks_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering ClickUp indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_clickup_tasks_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "ClickUp indexing started in the background."
|
||||
|
||||
elif (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_google_calendar_events_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Google Calendar indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_google_calendar_events_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Google Calendar indexing started in the background."
|
||||
elif connector.connector_type == SearchSourceConnectorType.AIRTABLE_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_airtable_records_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Airtable indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_airtable_records_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Airtable indexing started in the background."
|
||||
elif (
|
||||
connector.connector_type == SearchSourceConnectorType.GOOGLE_GMAIL_CONNECTOR
|
||||
):
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_google_gmail_messages_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Google Gmail indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_google_gmail_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Google Gmail indexing started in the background."
|
||||
|
||||
elif (
|
||||
connector.connector_type == SearchSourceConnectorType.GOOGLE_DRIVE_CONNECTOR
|
||||
):
|
||||
|
|
@ -1089,30 +984,6 @@ async def index_connector_content(
|
|||
)
|
||||
response_message = "Dropbox indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_discord_messages_task,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Triggering Discord indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_discord_messages_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Discord indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.LUMA_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import index_luma_events_task
|
||||
|
||||
logger.info(
|
||||
f"Triggering Luma indexing for connector {connector_id} into search space {search_space_id} from {indexing_from} to {indexing_to}"
|
||||
)
|
||||
index_luma_events_task.delay(
|
||||
connector_id, search_space_id, str(user.id), indexing_from, indexing_to
|
||||
)
|
||||
response_message = "Luma indexing started in the background."
|
||||
|
||||
elif (
|
||||
connector.connector_type
|
||||
== SearchSourceConnectorType.ELASTICSEARCH_CONNECTOR
|
||||
|
|
@ -1319,57 +1190,6 @@ async def _update_connector_timestamp_by_id(session: AsyncSession, connector_id:
|
|||
await session.rollback()
|
||||
|
||||
|
||||
async def run_slack_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Slack indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_slack_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_slack_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Slack indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Slack connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_slack_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_slack_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
_AUTH_ERROR_PATTERNS = (
|
||||
"failed to refresh linear oauth",
|
||||
"failed to refresh your notion connection",
|
||||
|
|
@ -1908,215 +1728,6 @@ async def run_github_indexing(
|
|||
)
|
||||
|
||||
|
||||
# Add new helper functions for Linear indexing
|
||||
async def run_linear_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run Linear indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing Linear connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_linear_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing Linear connector {connector_id}")
|
||||
|
||||
|
||||
async def run_linear_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Linear indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Linear connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_linear_issues
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_linear_issues,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for discord indexing
|
||||
async def run_discord_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Discord indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_discord_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_discord_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Discord indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Discord connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_discord_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_discord_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
async def run_teams_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Microsoft Teams indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_teams_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_teams_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Microsoft Teams indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Teams connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers.teams_indexer import index_teams_messages
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_teams_messages,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Jira indexing
|
||||
async def run_jira_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run Jira indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing Jira connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_jira_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing Jira connector {connector_id}")
|
||||
|
||||
|
||||
async def run_jira_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Jira indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Jira connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_jira_issues
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_jira_issues,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Confluence indexing
|
||||
async def run_confluence_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
|
|
@ -2172,112 +1783,6 @@ async def run_confluence_indexing(
|
|||
)
|
||||
|
||||
|
||||
# Add new helper functions for ClickUp indexing
|
||||
async def run_clickup_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run ClickUp indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing ClickUp connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_clickup_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing ClickUp connector {connector_id}")
|
||||
|
||||
|
||||
async def run_clickup_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run ClickUp indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the ClickUp connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_clickup_tasks
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_clickup_tasks,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Airtable indexing
|
||||
async def run_airtable_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""Wrapper to run Airtable indexing with its own database session."""
|
||||
logger.info(
|
||||
f"Background task started: Indexing Airtable connector {connector_id} into space {search_space_id} from {start_date} to {end_date}"
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
await run_airtable_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
logger.info(f"Background task finished: Indexing Airtable connector {connector_id}")
|
||||
|
||||
|
||||
async def run_airtable_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Airtable indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Airtable connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_airtable_records
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_airtable_records,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
# Add new helper functions for Google Calendar indexing
|
||||
async def run_google_calendar_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
|
|
@ -2816,58 +2321,6 @@ async def run_dropbox_indexing(
|
|||
logger.error(f"Failed to update notification: {notif_error!s}")
|
||||
|
||||
|
||||
# Add new helper functions for luma indexing
|
||||
async def run_luma_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Create a new session and run the Luma indexing task.
|
||||
This prevents session leaks by creating a dedicated session for the background task.
|
||||
"""
|
||||
async with async_session_maker() as session:
|
||||
await run_luma_indexing(
|
||||
session, connector_id, search_space_id, user_id, start_date, end_date
|
||||
)
|
||||
|
||||
|
||||
async def run_luma_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
):
|
||||
"""
|
||||
Background task to run Luma indexing.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector_id: ID of the Luma connector
|
||||
search_space_id: ID of the search space
|
||||
user_id: ID of the user
|
||||
start_date: Start date for indexing
|
||||
end_date: End date for indexing
|
||||
"""
|
||||
from app.tasks.connector_indexers import index_luma_events
|
||||
|
||||
await _run_indexing_with_notifications(
|
||||
session=session,
|
||||
connector_id=connector_id,
|
||||
search_space_id=search_space_id,
|
||||
user_id=user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
indexing_function=index_luma_events,
|
||||
update_timestamp_func=_update_connector_timestamp_by_id,
|
||||
supports_heartbeat_callback=True,
|
||||
)
|
||||
|
||||
|
||||
async def run_elasticsearch_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
|
|
@ -3580,13 +3033,18 @@ async def trust_mcp_tool(
|
|||
"""Add a tool to the MCP connector's trusted (always-allow) list.
|
||||
|
||||
Once trusted, the tool executes without HITL approval on subsequent calls.
|
||||
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors
|
||||
(LINEAR_CONNECTOR, JIRA_CONNECTOR, etc.) by checking for ``server_config``.
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
@ -3631,13 +3089,17 @@ async def untrust_mcp_tool(
|
|||
"""Remove a tool from the MCP connector's trusted list.
|
||||
|
||||
The tool will require HITL approval again on subsequent calls.
|
||||
Works for both generic MCP_CONNECTOR and OAuth-backed MCP connectors.
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy.dialects.postgresql import JSONB as PG_JSONB
|
||||
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.connector_type
|
||||
== SearchSourceConnectorType.MCP_CONNECTOR,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
cast(SearchSourceConnector.config, PG_JSONB).has_key("server_config"), # noqa: W601
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
|
|
|||
|
|
@ -312,7 +312,7 @@ async def slack_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.SLACK_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ SCOPES = [
|
|||
"Team.ReadBasic.All", # Read basic team information
|
||||
"Channel.ReadBasic.All", # Read basic channel information
|
||||
"ChannelMessage.Read.All", # Read messages in channels
|
||||
"ChannelMessage.Send", # Send messages in channels
|
||||
]
|
||||
|
||||
# Initialize security utilities
|
||||
|
|
@ -320,7 +321,7 @@ async def teams_callback(
|
|||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.TEAMS_CONNECTOR,
|
||||
is_indexable=True,
|
||||
is_indexable=False,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue