From 81e4a4ada06343e27d7cdcf3e3b757b4c9b11d69 Mon Sep 17 00:00:00 2001 From: Anish Sarkar <104695310+AnishSarkar22@users.noreply.github.com> Date: Sun, 4 Jan 2026 02:38:19 +0530 Subject: [PATCH] feat: database driven refresh tokens for slack oauth connector --- .../app/connectors/slack_history.py | 217 ++++++++++++++++-- .../app/routes/slack_add_connector_route.py | 146 +++++++++++- .../app/schemas/slack_auth_credentials.py | 76 ++++++ .../tasks/connector_indexers/slack_indexer.py | 45 +--- 4 files changed, 426 insertions(+), 58 deletions(-) create mode 100644 surfsense_backend/app/schemas/slack_auth_credentials.py diff --git a/surfsense_backend/app/connectors/slack_history.py b/surfsense_backend/app/connectors/slack_history.py index 36160c30b..6a016394e 100644 --- a/surfsense_backend/app/connectors/slack_history.py +++ b/surfsense_backend/app/connectors/slack_history.py @@ -12,6 +12,14 @@ from typing import Any from slack_sdk import WebClient from slack_sdk.errors import SlackApiError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector +from app.routes.slack_add_connector_route import refresh_slack_token +from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase +from app.utils.oauth_security import TokenEncryption logger = logging.getLogger(__name__) # Added logger @@ -19,25 +27,195 @@ logger = logging.getLogger(__name__) # Added logger class SlackHistory: """Class for retrieving conversation history from Slack channels.""" - def __init__(self, token: str | None = None): + def __init__( + self, + token: str | None = None, + session: AsyncSession | None = None, + connector_id: int | None = None, + credentials: SlackAuthCredentialsBase | None = None, + ): """ Initialize the SlackHistory class. Args: - token: Slack API token (optional, can be set later with set_token) + token: Slack API token (optional, for backward compatibility) + session: Database session for token refresh (optional) + connector_id: Connector ID for token refresh (optional) + credentials: Slack OAuth credentials (optional, will be loaded from DB if not provided) """ - self.client = WebClient(token=token) if token else None + self._session = session + self._connector_id = connector_id + self._credentials = credentials + # For backward compatibility, if token is provided directly, use it + if token: + self.client = WebClient(token=token) + else: + self.client = None + + async def _get_valid_token(self) -> str: + """ + Get valid Slack bot token, refreshing if needed. + + Returns: + Valid bot token + + Raises: + ValueError: If credentials are missing or invalid + Exception: If token refresh fails + """ + # If we have a direct token (backward compatibility), use it + # Check if client was initialized with a token directly (not via credentials) + if self.client and self._session is None and self._connector_id is None: + # This means it was initialized with a direct token, extract it + # WebClient stores token internally, we need to get it from the client + # For backward compatibility, we'll use the client directly + # But we can't easily extract the token, so we'll just use the client + # In this case, we'll skip refresh logic + if self._credentials is None: + # This is the old pattern - just use the client as-is + # We can't extract token easily, so we'll raise an error + # asking to use the new pattern + raise ValueError( + "Cannot refresh token: Please use session and connector_id for auto-refresh support" + ) + + # Load credentials from DB if not provided + if self._credentials is None: + if not self._session or not self._connector_id: + raise ValueError( + "Cannot load credentials: session and connector_id required" + ) + + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise ValueError(f"Connector {self._connector_id} not found") + + config_data = connector.config.copy() + + # Decrypt credentials if they are encrypted + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("bot_token"): + config_data["bot_token"] = token_encryption.decrypt_token( + config_data["bot_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + logger.info( + f"Decrypted Slack credentials for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to decrypt Slack credentials for connector {self._connector_id}: {e!s}" + ) + raise ValueError( + f"Failed to decrypt Slack credentials: {e!s}" + ) from e + + try: + self._credentials = SlackAuthCredentialsBase.from_dict(config_data) + except Exception as e: + raise ValueError(f"Invalid Slack credentials: {e!s}") from e + + # Check if token is expired and refreshable + if self._credentials.is_expired and self._credentials.is_refreshable: + try: + logger.info( + f"Slack token expired for connector {self._connector_id}, refreshing..." + ) + + # Get connector for refresh + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise RuntimeError( + f"Connector {self._connector_id} not found; cannot refresh token." + ) + + # Refresh token + connector = await refresh_slack_token(self._session, connector) + + # Reload credentials after refresh + config_data = connector.config.copy() + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("bot_token"): + config_data["bot_token"] = token_encryption.decrypt_token( + config_data["bot_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + self._credentials = SlackAuthCredentialsBase.from_dict(config_data) + + # Invalidate cached client so it's recreated with new token + self.client = None + + logger.info( + f"Successfully refreshed Slack token for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to refresh Slack token for connector {self._connector_id}: {e!s}" + ) + raise Exception( + f"Failed to refresh Slack OAuth credentials: {e!s}" + ) from e + + return self._credentials.bot_token + + async def _ensure_client(self) -> WebClient: + """ + Ensure Slack client is initialized with valid token. + + Returns: + WebClient instance + """ + # If client was initialized with direct token (backward compatibility), use it + if self.client and (self._session is None or self._connector_id is None): + return self.client + + # Otherwise, initialize with token from credentials (with auto-refresh) + if self.client is None: + token = await self._get_valid_token() + # Skip if it's the placeholder for direct token initialization + if token != "direct_token_initialized": + self.client = WebClient(token=token) + return self.client def set_token(self, token: str) -> None: """ - Set the Slack API token. + Set the Slack API token (for backward compatibility). Args: token: Slack API token """ self.client = WebClient(token=token) - def get_all_channels(self, include_private: bool = True) -> list[dict[str, Any]]: + async def get_all_channels( + self, include_private: bool = True + ) -> list[dict[str, Any]]: """ Fetch all channels that the bot has access to, with rate limit handling. @@ -52,8 +230,7 @@ class SlackHistory: SlackApiError: If there's an unrecoverable error calling the Slack API RuntimeError: For unexpected errors during channel fetching. """ - if not self.client: - raise ValueError("Slack client not initialized. Call set_token() first.") + client = await self._ensure_client() channels_list = [] # Changed from dict to list types = "public_channel" @@ -72,7 +249,7 @@ class SlackHistory: time.sleep(3) current_limit = 1000 # Max limit - api_result = self.client.conversations_list( + api_result = client.conversations_list( types=types, cursor=next_cursor, limit=current_limit ) @@ -129,7 +306,7 @@ class SlackHistory: return channels_list - def get_conversation_history( + async def get_conversation_history( self, channel_id: str, limit: int = 1000, @@ -152,8 +329,7 @@ class SlackHistory: ValueError: If no Slack client has been initialized SlackApiError: If there's an error calling the Slack API """ - if not self.client: - raise ValueError("Slack client not initialized. Call set_token() first.") + client = await self._ensure_client() messages = [] next_cursor = None @@ -177,7 +353,7 @@ class SlackHistory: current_api_call_successful = False result = None # Ensure result is defined try: - result = self.client.conversations_history(**kwargs) + result = client.conversations_history(**kwargs) current_api_call_successful = True except SlackApiError as e_history: if ( @@ -252,7 +428,7 @@ class SlackHistory: except ValueError: return None - def get_history_by_date_range( + async def get_history_by_date_range( self, channel_id: str, start_date: str, end_date: str, limit: int = 1000 ) -> tuple[list[dict[str, Any]], str | None]: """ @@ -282,7 +458,7 @@ class SlackHistory: latest += 86400 # seconds in a day try: - messages = self.get_conversation_history( + messages = await self.get_conversation_history( channel_id=channel_id, limit=limit, oldest=oldest, latest=latest ) return messages, None @@ -291,7 +467,7 @@ class SlackHistory: except ValueError as e: return [], str(e) - def get_user_info(self, user_id: str) -> dict[str, Any]: + async def get_user_info(self, user_id: str) -> dict[str, Any]: """ Get information about a user. @@ -305,8 +481,7 @@ class SlackHistory: ValueError: If no Slack client has been initialized SlackApiError: If there's an error calling the Slack API """ - if not self.client: - raise ValueError("Slack client not initialized. Call set_token() first.") + client = await self._ensure_client() while True: try: @@ -314,7 +489,7 @@ class SlackHistory: # For now, we are only adding Retry-After as per plan. # time.sleep(0.6) # Optional: ~100 req/min if ever needed. - result = self.client.users_info(user=user_id) + result = client.users_info(user=user_id) return result["user"] # Success, return and exit loop implicitly except SlackApiError as e_user_info: @@ -343,7 +518,7 @@ class SlackHistory: ) raise general_error from general_error # Re-raise unexpected errors - def format_message( + async def format_message( self, msg: dict[str, Any], include_user_info: bool = False ) -> dict[str, Any]: """ @@ -369,9 +544,9 @@ class SlackHistory: "is_thread": "thread_ts" in msg, } - if include_user_info and "user" in msg and self.client: + if include_user_info and "user" in msg: try: - user_info = self.get_user_info(msg["user"]) + user_info = await self.get_user_info(msg["user"]) formatted["user_name"] = user_info.get("real_name", "Unknown") formatted["user_email"] = user_info.get("profile", {}).get("email", "") except Exception: diff --git a/surfsense_backend/app/routes/slack_add_connector_route.py b/surfsense_backend/app/routes/slack_add_connector_route.py index 1bbb4f5f1..71a362119 100644 --- a/surfsense_backend/app/routes/slack_add_connector_route.py +++ b/surfsense_backend/app/routes/slack_add_connector_route.py @@ -23,6 +23,7 @@ from app.db import ( User, get_async_session, ) +from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase from app.users import current_active_user from app.utils.oauth_security import OAuthStateManager, TokenEncryption @@ -229,7 +230,7 @@ async def slack_callback( ) # Extract bot token from Slack response - # Slack OAuth v2 returns: { "ok": true, "access_token": "...", "bot": { "bot_user_id": "...", "bot_access_token": "xoxb-..." }, ... } + # Slack OAuth v2 returns: { "ok": true, "access_token": "...", "bot": { "bot_user_id": "...", "bot_access_token": "xoxb-..." }, "refresh_token": "...", ... } bot_token = None if token_json.get("bot") and token_json["bot"].get("bot_access_token"): bot_token = token_json["bot"]["bot_access_token"] @@ -241,6 +242,9 @@ async def slack_callback( status_code=400, detail="No bot token received from Slack" ) + # Extract refresh token if available (for token rotation) + refresh_token = token_json.get("refresh_token") + # Encrypt sensitive tokens before storing token_encryption = get_token_encryption() @@ -251,9 +255,12 @@ async def slack_callback( now_utc = datetime.now(UTC) expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"])) - # Store the encrypted bot token in connector config + # Store the encrypted bot token and refresh token in connector config connector_config = { "bot_token": token_encryption.encrypt_token(bot_token), + "refresh_token": token_encryption.encrypt_token(refresh_token) + if refresh_token + else None, "bot_user_id": token_json.get("bot", {}).get("bot_user_id"), "team_id": token_json.get("team", {}).get("id"), "team_name": token_json.get("team", {}).get("name"), @@ -334,3 +341,138 @@ async def slack_callback( raise HTTPException( status_code=500, detail=f"Failed to complete Slack OAuth: {e!s}" ) from e + + +async def refresh_slack_token( + session: AsyncSession, connector: SearchSourceConnector +) -> SearchSourceConnector: + """ + Refresh the Slack bot token for a connector. + + Args: + session: Database session + connector: Slack connector to refresh + + Returns: + Updated connector object + """ + try: + logger.info(f"Refreshing Slack token for connector {connector.id}") + + credentials = SlackAuthCredentialsBase.from_dict(connector.config) + + # Decrypt tokens if they are encrypted + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_token = credentials.refresh_token + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error(f"Failed to decrypt refresh token: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + + if not refresh_token: + raise HTTPException( + status_code=400, + detail="No refresh token available. Please re-authenticate.", + ) + + # Slack uses oauth.v2.access for token refresh with grant_type=refresh_token + refresh_data = { + "client_id": config.SLACK_CLIENT_ID, + "client_secret": config.SLACK_CLIENT_SECRET, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + data=refresh_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token refresh failed: {error_detail}" + ) + + token_json = token_response.json() + + # Slack OAuth v2 returns success status in the JSON + if not token_json.get("ok", False): + error_msg = token_json.get("error", "Unknown error") + raise HTTPException( + status_code=400, detail=f"Slack OAuth refresh error: {error_msg}" + ) + + # Extract bot token from refresh response + bot_token = None + if token_json.get("bot") and token_json["bot"].get("bot_access_token"): + bot_token = token_json["bot"]["bot_access_token"] + elif token_json.get("access_token"): + bot_token = token_json["access_token"] + else: + raise HTTPException( + status_code=400, detail="No bot token received from Slack refresh" + ) + + # Get new refresh token if provided (Slack may rotate refresh tokens) + new_refresh_token = token_json.get("refresh_token") + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Update credentials object with encrypted tokens + credentials.bot_token = token_encryption.encrypt_token(bot_token) + if new_refresh_token: + credentials.refresh_token = token_encryption.encrypt_token( + new_refresh_token + ) + credentials.expires_in = expires_in + credentials.expires_at = expires_at + credentials.scope = token_json.get("scope") + + # Preserve team info + if not credentials.team_id: + credentials.team_id = connector.config.get("team_id") + if not credentials.team_name: + credentials.team_name = connector.config.get("team_name") + if not credentials.bot_user_id: + credentials.bot_user_id = connector.config.get("bot_user_id") + + # Update connector config with encrypted tokens + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict + await session.commit() + await session.refresh(connector) + + logger.info(f"Successfully refreshed Slack token for connector {connector.id}") + + return connector + except HTTPException: + raise + except Exception as e: + logger.error( + f"Failed to refresh Slack token for connector {connector.id}: {e!s}", + exc_info=True, + ) + raise HTTPException( + status_code=500, detail=f"Failed to refresh Slack token: {e!s}" + ) from e diff --git a/surfsense_backend/app/schemas/slack_auth_credentials.py b/surfsense_backend/app/schemas/slack_auth_credentials.py new file mode 100644 index 000000000..ad6a713ef --- /dev/null +++ b/surfsense_backend/app/schemas/slack_auth_credentials.py @@ -0,0 +1,76 @@ +from datetime import UTC, datetime + +from pydantic import BaseModel, field_validator + + +class SlackAuthCredentialsBase(BaseModel): + bot_token: str + refresh_token: str | None = None + token_type: str = "Bearer" + expires_in: int | None = None + expires_at: datetime | None = None + scope: str | None = None + bot_user_id: str | None = None + team_id: str | None = None + team_name: str | None = None + + @property + def is_expired(self) -> bool: + """Check if the credentials have expired.""" + if self.expires_at is None: + return False # Long-lived token, treat as not expired + return self.expires_at <= datetime.now(UTC) + + @property + def is_refreshable(self) -> bool: + """Check if the credentials can be refreshed.""" + return self.refresh_token is not None + + def to_dict(self) -> dict: + """Convert credentials to dictionary for storage.""" + return { + "bot_token": self.bot_token, + "refresh_token": self.refresh_token, + "token_type": self.token_type, + "expires_in": self.expires_in, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "scope": self.scope, + "bot_user_id": self.bot_user_id, + "team_id": self.team_id, + "team_name": self.team_name, + } + + @classmethod + def from_dict(cls, data: dict) -> "SlackAuthCredentialsBase": + """Create credentials from dictionary.""" + expires_at = None + if data.get("expires_at"): + expires_at = datetime.fromisoformat(data["expires_at"]) + + return cls( + bot_token=data.get("bot_token", ""), + refresh_token=data.get("refresh_token"), + token_type=data.get("token_type", "Bearer"), + expires_in=data.get("expires_in"), + expires_at=expires_at, + scope=data.get("scope"), + bot_user_id=data.get("bot_user_id"), + team_id=data.get("team_id"), + team_name=data.get("team_name"), + ) + + @field_validator("expires_at", mode="before") + @classmethod + def ensure_aware_utc(cls, v): + # Strings like "2025-08-26T14:46:57.367184" + if isinstance(v, str): + # add +00:00 if missing tz info + if v.endswith("Z"): + return datetime.fromisoformat(v.replace("Z", "+00:00")) + dt = datetime.fromisoformat(v) + return dt if dt.tzinfo else dt.replace(tzinfo=UTC) + # datetime objects + if isinstance(v, datetime): + return v if v.tzinfo else v.replace(tzinfo=UTC) + return v + diff --git a/surfsense_backend/app/tasks/connector_indexers/slack_indexer.py b/surfsense_backend/app/tasks/connector_indexers/slack_indexer.py index 4c4191a4e..c7a815634 100644 --- a/surfsense_backend/app/tasks/connector_indexers/slack_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/slack_indexer.py @@ -17,7 +17,6 @@ from app.utils.document_converters import ( generate_content_hash, generate_unique_identifier_hash, ) -from app.utils.oauth_security import TokenEncryption from .base import ( build_document_metadata_markdown, @@ -93,44 +92,20 @@ async def index_slack_messages( f"Connector with ID {connector_id} not found or is not a Slack connector", ) - # Get the Slack token from the connector config - # Support both new OAuth format (bot_token) and old API format (SLACK_BOT_TOKEN) - config_data = connector.config.copy() - slack_token = config_data.get("bot_token") or config_data.get("SLACK_BOT_TOKEN") + # Note: Token handling is now done automatically by SlackHistory + # with auto-refresh support. We just need to pass session and connector_id. - if not slack_token: - await task_logger.log_task_failure( - log_entry, - f"Slack token not found in connector config for connector {connector_id}", - "Missing Slack token", - {"error_type": "MissingToken"}, - ) - return 0, "Slack token not found in connector config" - - # Decrypt token if it's encrypted (OAuth format) - token_encrypted = config_data.get("_token_encrypted", False) - if token_encrypted and config.SECRET_KEY: - try: - token_encryption = TokenEncryption(config.SECRET_KEY) - slack_token = token_encryption.decrypt_token(slack_token) - logger.info(f"Decrypted Slack bot token for connector {connector_id}") - except Exception as e: - await task_logger.log_task_failure( - log_entry, - f"Failed to decrypt Slack token for connector {connector_id}: {e!s}", - "Token decryption failed", - {"error_type": "TokenDecryptionError"}, - ) - return 0, f"Failed to decrypt Slack token: {e!s}" - - # Initialize Slack client + # Initialize Slack client with auto-refresh support await task_logger.log_task_progress( log_entry, f"Initializing Slack client for connector {connector_id}", {"stage": "client_initialization"}, ) - slack_client = SlackHistory(token=slack_token) + # Use the new pattern with session and connector_id for auto-refresh + slack_client = SlackHistory( + session=session, connector_id=connector_id + ) # Handle 'undefined' string from frontend (treat as None) if start_date == "undefined" or start_date == "": @@ -167,7 +142,7 @@ async def index_slack_messages( # Get all channels try: - channels = slack_client.get_all_channels() + channels = await slack_client.get_all_channels() except Exception as e: await task_logger.log_task_failure( log_entry, @@ -216,7 +191,7 @@ async def index_slack_messages( continue # Get messages for this channel - messages, error = slack_client.get_history_by_date_range( + messages, error = await slack_client.get_history_by_date_range( channel_id=channel_id, start_date=start_date_str, end_date=end_date_str, @@ -249,7 +224,7 @@ async def index_slack_messages( ]: continue - formatted_msg = slack_client.format_message( + formatted_msg = await slack_client.format_message( msg, include_user_info=True ) formatted_messages.append(formatted_msg)