mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
add reusable OAuth connector route base class
This commit is contained in:
parent
a1804265b8
commit
2dfe03b9b2
1 changed files with 620 additions and 0 deletions
620
surfsense_backend/app/routes/oauth_connector_base.py
Normal file
620
surfsense_backend/app/routes/oauth_connector_base.py
Normal file
|
|
@ -0,0 +1,620 @@
|
|||
"""Reusable base for OAuth 2.0 connector routes.
|
||||
|
||||
Subclasses override ``fetch_account_info``, ``build_connector_config``,
|
||||
and ``get_connector_display_name`` to customise provider-specific behaviour.
|
||||
Call ``build_router()`` to get a FastAPI ``APIRouter`` with ``/connector/add``,
|
||||
``/connector/callback``, and ``/connector/reauth`` endpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthConnectorRoute:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider_name: str,
|
||||
connector_type: SearchSourceConnectorType,
|
||||
authorize_url: str,
|
||||
token_url: str,
|
||||
client_id_env: str,
|
||||
client_secret_env: str,
|
||||
redirect_uri_env: str,
|
||||
scopes: list[str],
|
||||
auth_prefix: str,
|
||||
use_pkce: bool = False,
|
||||
token_auth_method: str = "body",
|
||||
is_indexable: bool = True,
|
||||
extra_auth_params: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.provider_name = provider_name
|
||||
self.connector_type = connector_type
|
||||
self.authorize_url = authorize_url
|
||||
self.token_url = token_url
|
||||
self.client_id_env = client_id_env
|
||||
self.client_secret_env = client_secret_env
|
||||
self.redirect_uri_env = redirect_uri_env
|
||||
self.scopes = scopes
|
||||
self.auth_prefix = auth_prefix.rstrip("/")
|
||||
self.use_pkce = use_pkce
|
||||
self.token_auth_method = token_auth_method
|
||||
self.is_indexable = is_indexable
|
||||
self.extra_auth_params = extra_auth_params or {}
|
||||
|
||||
self._state_manager: OAuthStateManager | None = None
|
||||
self._token_encryption: TokenEncryption | None = None
|
||||
|
||||
def _get_client_id(self) -> str:
|
||||
value = getattr(config, self.client_id_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.provider_name.title()} OAuth not configured "
|
||||
f"({self.client_id_env} missing).",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_client_secret(self) -> str:
|
||||
value = getattr(config, self.client_secret_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.provider_name.title()} OAuth not configured "
|
||||
f"({self.client_secret_env} missing).",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_redirect_uri(self) -> str:
|
||||
value = getattr(config, self.redirect_uri_env, None)
|
||||
if not value:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"{self.redirect_uri_env} not configured.",
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_state_manager(self) -> OAuthStateManager:
|
||||
if self._state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="SECRET_KEY not configured for OAuth security.",
|
||||
)
|
||||
self._state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return self._state_manager
|
||||
|
||||
def _get_token_encryption(self) -> TokenEncryption:
|
||||
if self._token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="SECRET_KEY not configured for token encryption.",
|
||||
)
|
||||
self._token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return self._token_encryption
|
||||
|
||||
def _frontend_redirect(
|
||||
self,
|
||||
space_id: int | None,
|
||||
*,
|
||||
success: bool = False,
|
||||
connector_id: int | None = None,
|
||||
error: str | None = None,
|
||||
) -> RedirectResponse:
|
||||
if success and space_id:
|
||||
connector_slug = f"{self.provider_name}-connector"
|
||||
qs = f"success=true&connector={connector_slug}"
|
||||
if connector_id:
|
||||
qs += f"&connectorId={connector_id}"
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?{qs}"
|
||||
)
|
||||
if error and space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error={error}"
|
||||
)
|
||||
if error:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error={error}"
|
||||
)
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard")
|
||||
|
||||
async def fetch_account_info(self, access_token: str) -> dict[str, Any]:
|
||||
"""Override to fetch account/workspace info after token exchange.
|
||||
|
||||
Return dict is merged into connector config; key ``"name"`` is used
|
||||
for the display name and dedup.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def build_connector_config(
|
||||
self,
|
||||
token_json: dict[str, Any],
|
||||
account_info: dict[str, Any],
|
||||
encryption: TokenEncryption,
|
||||
) -> dict[str, Any]:
|
||||
"""Override for custom config shapes. Default: standard encrypted OAuth fields."""
|
||||
access_token = token_json.get("access_token", "")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
cfg: dict[str, Any] = {
|
||||
"access_token": encryption.encrypt_token(access_token),
|
||||
"refresh_token": (
|
||||
encryption.encrypt_token(refresh_token) if refresh_token else None
|
||||
),
|
||||
"token_type": token_json.get("token_type", "Bearer"),
|
||||
"expires_in": token_json.get("expires_in"),
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"scope": token_json.get("scope"),
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
cfg.update(account_info)
|
||||
return cfg
|
||||
|
||||
def get_connector_display_name(self, account_info: dict[str, Any]) -> str:
|
||||
return str(account_info.get("name", self.provider_name.title()))
|
||||
|
||||
async def on_token_refresh_failure(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
connector: SearchSourceConnector,
|
||||
) -> None:
|
||||
try:
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to persist auth_expired flag for connector %s",
|
||||
connector.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _exchange_code(
|
||||
self, code: str, extra_state: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
client_id = self._get_client_id()
|
||||
client_secret = self._get_client_secret()
|
||||
redirect_uri = self._get_redirect_uri()
|
||||
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
body: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
if self.token_auth_method == "basic":
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
else:
|
||||
body["client_id"] = client_id
|
||||
body["client_secret"] = client_secret
|
||||
|
||||
if self.use_pkce:
|
||||
verifier = extra_state.get("code_verifier")
|
||||
if verifier:
|
||||
body["code_verifier"] = verifier
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
self.token_url, data=body, headers=headers, timeout=30.0
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
detail = resp.text
|
||||
try:
|
||||
detail = resp.json().get("error_description", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token exchange failed: {detail}"
|
||||
)
|
||||
|
||||
return resp.json()
|
||||
|
||||
async def refresh_token(
|
||||
self, session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
encryption = self._get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
refresh_tok = connector.config.get("refresh_token")
|
||||
if is_encrypted and refresh_tok:
|
||||
try:
|
||||
refresh_tok = encryption.decrypt_token(refresh_tok)
|
||||
except Exception as e:
|
||||
logger.error("Failed to decrypt refresh token: %s", e)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to decrypt stored refresh token"
|
||||
) from e
|
||||
|
||||
if not refresh_tok:
|
||||
await self.on_token_refresh_failure(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
)
|
||||
|
||||
client_id = self._get_client_id()
|
||||
client_secret = self._get_client_secret()
|
||||
|
||||
headers: dict[str, str] = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
body: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_tok,
|
||||
}
|
||||
|
||||
if self.token_auth_method == "basic":
|
||||
creds = base64.b64encode(f"{client_id}:{client_secret}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
else:
|
||||
body["client_id"] = client_id
|
||||
body["client_secret"] = client_secret
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
self.token_url, data=body, headers=headers, timeout=30.0
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
error_detail = resp.text
|
||||
try:
|
||||
ej = resp.json()
|
||||
error_detail = ej.get("error_description", error_detail)
|
||||
error_code = ej.get("error", "")
|
||||
except Exception:
|
||||
error_code = ""
|
||||
combined = (error_detail + error_code).lower()
|
||||
if any(kw in combined for kw in ("invalid_grant", "expired", "revoked")):
|
||||
await self.on_token_refresh_failure(session, connector)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"{self.provider_name.title()} authentication failed. "
|
||||
"Please re-authenticate.",
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
||||
token_json = resp.json()
|
||||
new_access = token_json.get("access_token")
|
||||
if not new_access:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from refresh"
|
||||
)
|
||||
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(
|
||||
seconds=int(token_json["expires_in"])
|
||||
)
|
||||
|
||||
updated_config = dict(connector.config)
|
||||
updated_config["access_token"] = encryption.encrypt_token(new_access)
|
||||
new_refresh = token_json.get("refresh_token")
|
||||
if new_refresh:
|
||||
updated_config["refresh_token"] = encryption.encrypt_token(new_refresh)
|
||||
updated_config["expires_in"] = token_json.get("expires_in")
|
||||
updated_config["expires_at"] = expires_at.isoformat() if expires_at else None
|
||||
updated_config["scope"] = token_json.get("scope", updated_config.get("scope"))
|
||||
updated_config["_token_encrypted"] = True
|
||||
updated_config.pop("auth_expired", None)
|
||||
|
||||
connector.config = updated_config
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info(
|
||||
"Refreshed %s token for connector %s",
|
||||
self.provider_name,
|
||||
connector.id,
|
||||
)
|
||||
return connector
|
||||
|
||||
def build_router(self) -> APIRouter:
|
||||
router = APIRouter()
|
||||
oauth = self
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/add")
|
||||
async def connect(
|
||||
space_id: int,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
||||
client_id = oauth._get_client_id()
|
||||
state_mgr = oauth._get_state_manager()
|
||||
|
||||
extra_state: dict[str, Any] = {}
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": oauth._get_redirect_uri(),
|
||||
"scope": " ".join(oauth.scopes),
|
||||
}
|
||||
|
||||
if oauth.use_pkce:
|
||||
from app.utils.oauth_security import generate_pkce_pair
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
extra_state["code_verifier"] = verifier
|
||||
auth_params["code_challenge"] = challenge
|
||||
auth_params["code_challenge_method"] = "S256"
|
||||
|
||||
auth_params.update(oauth.extra_auth_params)
|
||||
|
||||
state_encoded = state_mgr.generate_secure_state(
|
||||
space_id, user.id, **extra_state
|
||||
)
|
||||
auth_params["state"] = state_encoded
|
||||
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Generated %s OAuth URL for user %s, space %s",
|
||||
oauth.provider_name,
|
||||
user.id,
|
||||
space_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/reauth")
|
||||
async def reauth(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == oauth.connector_type,
|
||||
)
|
||||
)
|
||||
if not result.scalars().first():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"{oauth.provider_name.title()} connector not found "
|
||||
"or access denied",
|
||||
)
|
||||
|
||||
client_id = oauth._get_client_id()
|
||||
state_mgr = oauth._get_state_manager()
|
||||
|
||||
extra: dict[str, Any] = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
|
||||
auth_params: dict[str, str] = {
|
||||
"client_id": client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": oauth._get_redirect_uri(),
|
||||
"scope": " ".join(oauth.scopes),
|
||||
}
|
||||
|
||||
if oauth.use_pkce:
|
||||
from app.utils.oauth_security import generate_pkce_pair
|
||||
|
||||
verifier, challenge = generate_pkce_pair()
|
||||
extra["code_verifier"] = verifier
|
||||
auth_params["code_challenge"] = challenge
|
||||
auth_params["code_challenge_method"] = "S256"
|
||||
|
||||
auth_params.update(oauth.extra_auth_params)
|
||||
|
||||
state_encoded = state_mgr.generate_secure_state(
|
||||
space_id, user.id, **extra
|
||||
)
|
||||
auth_params["state"] = state_encoded
|
||||
auth_url = f"{oauth.authorize_url}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info(
|
||||
"Initiating %s re-auth for user %s, connector %s",
|
||||
oauth.provider_name,
|
||||
user.id,
|
||||
connector_id,
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
@router.get(f"{oauth.auth_prefix}/connector/callback")
|
||||
async def callback(
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
error_label = f"{oauth.provider_name}_oauth_denied"
|
||||
|
||||
if error:
|
||||
logger.warning("%s OAuth error: %s", oauth.provider_name, error)
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
data = oauth._get_state_manager().validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass
|
||||
return oauth._frontend_redirect(space_id, error=error_label)
|
||||
|
||||
if not code:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing authorization code"
|
||||
)
|
||||
if not state:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing state parameter"
|
||||
)
|
||||
|
||||
state_mgr = oauth._get_state_manager()
|
||||
try:
|
||||
data = state_mgr.validate_state(state)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state parameter: {e!s}"
|
||||
) from e
|
||||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
|
||||
token_json = await oauth._exchange_code(code, data)
|
||||
|
||||
access_token = token_json.get("access_token", "")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"No access token received from {oauth.provider_name.title()}",
|
||||
)
|
||||
|
||||
account_info = await oauth.fetch_account_info(access_token)
|
||||
encryption = oauth._get_token_encryption()
|
||||
connector_config = oauth.build_connector_config(
|
||||
token_json, account_info, encryption
|
||||
)
|
||||
|
||||
display_name = oauth.get_connector_display_name(account_info)
|
||||
|
||||
# --- Re-auth path ---
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == oauth.connector_type,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Connector not found or access denied during re-auth",
|
||||
)
|
||||
|
||||
db_connector.config = connector_config
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info(
|
||||
"Re-authenticated %s connector %s for user %s",
|
||||
oauth.provider_name,
|
||||
db_connector.id,
|
||||
user_id,
|
||||
)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}"
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id, success=True, connector_id=db_connector.id
|
||||
)
|
||||
|
||||
# --- New connector path ---
|
||||
is_dup = await check_duplicate_connector(
|
||||
session,
|
||||
oauth.connector_type,
|
||||
space_id,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
if is_dup:
|
||||
logger.warning(
|
||||
"Duplicate %s connector for user %s (%s)",
|
||||
oauth.provider_name,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id,
|
||||
error=f"duplicate_account&connector={oauth.provider_name}-connector",
|
||||
)
|
||||
|
||||
connector_name = await generate_unique_connector_name(
|
||||
session,
|
||||
oauth.connector_type,
|
||||
space_id,
|
||||
user_id,
|
||||
display_name,
|
||||
)
|
||||
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=oauth.connector_type,
|
||||
is_indexable=oauth.is_indexable,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
session.add(new_connector)
|
||||
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409, detail=f"Database integrity error: {e!s}"
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
"Created %s connector %s for user %s in space %s",
|
||||
oauth.provider_name,
|
||||
new_connector.id,
|
||||
user_id,
|
||||
space_id,
|
||||
)
|
||||
return oauth._frontend_redirect(
|
||||
space_id, success=True, connector_id=new_connector.id
|
||||
)
|
||||
|
||||
return router
|
||||
Loading…
Add table
Add a link
Reference in a new issue