feat: database driven refresh tokens for slack oauth connector

This commit is contained in:
Anish Sarkar 2026-01-04 02:38:19 +05:30
parent 0fe94bfcf3
commit 81e4a4ada0
4 changed files with 426 additions and 58 deletions

View file

@ -12,6 +12,14 @@ from typing import Any
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.db import SearchSourceConnector
from app.routes.slack_add_connector_route import refresh_slack_token
from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase
from app.utils.oauth_security import TokenEncryption
logger = logging.getLogger(__name__) # Added logger
@ -19,25 +27,195 @@ logger = logging.getLogger(__name__) # Added logger
class SlackHistory:
"""Class for retrieving conversation history from Slack channels."""
def __init__(self, token: str | None = None):
def __init__(
self,
token: str | None = None,
session: AsyncSession | None = None,
connector_id: int | None = None,
credentials: SlackAuthCredentialsBase | None = None,
):
"""
Initialize the SlackHistory class.
Args:
token: Slack API token (optional, can be set later with set_token)
token: Slack API token (optional, for backward compatibility)
session: Database session for token refresh (optional)
connector_id: Connector ID for token refresh (optional)
credentials: Slack OAuth credentials (optional, will be loaded from DB if not provided)
"""
self.client = WebClient(token=token) if token else None
self._session = session
self._connector_id = connector_id
self._credentials = credentials
# For backward compatibility, if token is provided directly, use it
if token:
self.client = WebClient(token=token)
else:
self.client = None
async def _get_valid_token(self) -> str:
"""
Get valid Slack bot token, refreshing if needed.
Returns:
Valid bot token
Raises:
ValueError: If credentials are missing or invalid
Exception: If token refresh fails
"""
# If we have a direct token (backward compatibility), use it
# Check if client was initialized with a token directly (not via credentials)
if self.client and self._session is None and self._connector_id is None:
# This means it was initialized with a direct token, extract it
# WebClient stores token internally, we need to get it from the client
# For backward compatibility, we'll use the client directly
# But we can't easily extract the token, so we'll just use the client
# In this case, we'll skip refresh logic
if self._credentials is None:
# This is the old pattern - just use the client as-is
# We can't extract token easily, so we'll raise an error
# asking to use the new pattern
raise ValueError(
"Cannot refresh token: Please use session and connector_id for auto-refresh support"
)
# Load credentials from DB if not provided
if self._credentials is None:
if not self._session or not self._connector_id:
raise ValueError(
"Cannot load credentials: session and connector_id required"
)
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == self._connector_id
)
)
connector = result.scalars().first()
if not connector:
raise ValueError(f"Connector {self._connector_id} not found")
config_data = connector.config.copy()
# Decrypt credentials if they are encrypted
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
try:
token_encryption = TokenEncryption(config.SECRET_KEY)
# Decrypt sensitive fields
if config_data.get("bot_token"):
config_data["bot_token"] = token_encryption.decrypt_token(
config_data["bot_token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
logger.info(
f"Decrypted Slack credentials for connector {self._connector_id}"
)
except Exception as e:
logger.error(
f"Failed to decrypt Slack credentials for connector {self._connector_id}: {e!s}"
)
raise ValueError(
f"Failed to decrypt Slack credentials: {e!s}"
) from e
try:
self._credentials = SlackAuthCredentialsBase.from_dict(config_data)
except Exception as e:
raise ValueError(f"Invalid Slack credentials: {e!s}") from e
# Check if token is expired and refreshable
if self._credentials.is_expired and self._credentials.is_refreshable:
try:
logger.info(
f"Slack token expired for connector {self._connector_id}, refreshing..."
)
# Get connector for refresh
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == self._connector_id
)
)
connector = result.scalars().first()
if not connector:
raise RuntimeError(
f"Connector {self._connector_id} not found; cannot refresh token."
)
# Refresh token
connector = await refresh_slack_token(self._session, connector)
# Reload credentials after refresh
config_data = connector.config.copy()
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("bot_token"):
config_data["bot_token"] = token_encryption.decrypt_token(
config_data["bot_token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
self._credentials = SlackAuthCredentialsBase.from_dict(config_data)
# Invalidate cached client so it's recreated with new token
self.client = None
logger.info(
f"Successfully refreshed Slack token for connector {self._connector_id}"
)
except Exception as e:
logger.error(
f"Failed to refresh Slack token for connector {self._connector_id}: {e!s}"
)
raise Exception(
f"Failed to refresh Slack OAuth credentials: {e!s}"
) from e
return self._credentials.bot_token
async def _ensure_client(self) -> WebClient:
"""
Ensure Slack client is initialized with valid token.
Returns:
WebClient instance
"""
# If client was initialized with direct token (backward compatibility), use it
if self.client and (self._session is None or self._connector_id is None):
return self.client
# Otherwise, initialize with token from credentials (with auto-refresh)
if self.client is None:
token = await self._get_valid_token()
# Skip if it's the placeholder for direct token initialization
if token != "direct_token_initialized":
self.client = WebClient(token=token)
return self.client
def set_token(self, token: str) -> None:
"""
Set the Slack API token.
Set the Slack API token (for backward compatibility).
Args:
token: Slack API token
"""
self.client = WebClient(token=token)
def get_all_channels(self, include_private: bool = True) -> list[dict[str, Any]]:
async def get_all_channels(
self, include_private: bool = True
) -> list[dict[str, Any]]:
"""
Fetch all channels that the bot has access to, with rate limit handling.
@ -52,8 +230,7 @@ class SlackHistory:
SlackApiError: If there's an unrecoverable error calling the Slack API
RuntimeError: For unexpected errors during channel fetching.
"""
if not self.client:
raise ValueError("Slack client not initialized. Call set_token() first.")
client = await self._ensure_client()
channels_list = [] # Changed from dict to list
types = "public_channel"
@ -72,7 +249,7 @@ class SlackHistory:
time.sleep(3)
current_limit = 1000 # Max limit
api_result = self.client.conversations_list(
api_result = client.conversations_list(
types=types, cursor=next_cursor, limit=current_limit
)
@ -129,7 +306,7 @@ class SlackHistory:
return channels_list
def get_conversation_history(
async def get_conversation_history(
self,
channel_id: str,
limit: int = 1000,
@ -152,8 +329,7 @@ class SlackHistory:
ValueError: If no Slack client has been initialized
SlackApiError: If there's an error calling the Slack API
"""
if not self.client:
raise ValueError("Slack client not initialized. Call set_token() first.")
client = await self._ensure_client()
messages = []
next_cursor = None
@ -177,7 +353,7 @@ class SlackHistory:
current_api_call_successful = False
result = None # Ensure result is defined
try:
result = self.client.conversations_history(**kwargs)
result = client.conversations_history(**kwargs)
current_api_call_successful = True
except SlackApiError as e_history:
if (
@ -252,7 +428,7 @@ class SlackHistory:
except ValueError:
return None
def get_history_by_date_range(
async def get_history_by_date_range(
self, channel_id: str, start_date: str, end_date: str, limit: int = 1000
) -> tuple[list[dict[str, Any]], str | None]:
"""
@ -282,7 +458,7 @@ class SlackHistory:
latest += 86400 # seconds in a day
try:
messages = self.get_conversation_history(
messages = await self.get_conversation_history(
channel_id=channel_id, limit=limit, oldest=oldest, latest=latest
)
return messages, None
@ -291,7 +467,7 @@ class SlackHistory:
except ValueError as e:
return [], str(e)
def get_user_info(self, user_id: str) -> dict[str, Any]:
async def get_user_info(self, user_id: str) -> dict[str, Any]:
"""
Get information about a user.
@ -305,8 +481,7 @@ class SlackHistory:
ValueError: If no Slack client has been initialized
SlackApiError: If there's an error calling the Slack API
"""
if not self.client:
raise ValueError("Slack client not initialized. Call set_token() first.")
client = await self._ensure_client()
while True:
try:
@ -314,7 +489,7 @@ class SlackHistory:
# For now, we are only adding Retry-After as per plan.
# time.sleep(0.6) # Optional: ~100 req/min if ever needed.
result = self.client.users_info(user=user_id)
result = client.users_info(user=user_id)
return result["user"] # Success, return and exit loop implicitly
except SlackApiError as e_user_info:
@ -343,7 +518,7 @@ class SlackHistory:
)
raise general_error from general_error # Re-raise unexpected errors
def format_message(
async def format_message(
self, msg: dict[str, Any], include_user_info: bool = False
) -> dict[str, Any]:
"""
@ -369,9 +544,9 @@ class SlackHistory:
"is_thread": "thread_ts" in msg,
}
if include_user_info and "user" in msg and self.client:
if include_user_info and "user" in msg:
try:
user_info = self.get_user_info(msg["user"])
user_info = await self.get_user_info(msg["user"])
formatted["user_name"] = user_info.get("real_name", "Unknown")
formatted["user_email"] = user_info.get("profile", {}).get("email", "")
except Exception:

View file

@ -23,6 +23,7 @@ from app.db import (
User,
get_async_session,
)
from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase
from app.users import current_active_user
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
@ -229,7 +230,7 @@ async def slack_callback(
)
# Extract bot token from Slack response
# Slack OAuth v2 returns: { "ok": true, "access_token": "...", "bot": { "bot_user_id": "...", "bot_access_token": "xoxb-..." }, ... }
# Slack OAuth v2 returns: { "ok": true, "access_token": "...", "bot": { "bot_user_id": "...", "bot_access_token": "xoxb-..." }, "refresh_token": "...", ... }
bot_token = None
if token_json.get("bot") and token_json["bot"].get("bot_access_token"):
bot_token = token_json["bot"]["bot_access_token"]
@ -241,6 +242,9 @@ async def slack_callback(
status_code=400, detail="No bot token received from Slack"
)
# Extract refresh token if available (for token rotation)
refresh_token = token_json.get("refresh_token")
# Encrypt sensitive tokens before storing
token_encryption = get_token_encryption()
@ -251,9 +255,12 @@ async def slack_callback(
now_utc = datetime.now(UTC)
expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"]))
# Store the encrypted bot token in connector config
# Store the encrypted bot token and refresh token in connector config
connector_config = {
"bot_token": token_encryption.encrypt_token(bot_token),
"refresh_token": token_encryption.encrypt_token(refresh_token)
if refresh_token
else None,
"bot_user_id": token_json.get("bot", {}).get("bot_user_id"),
"team_id": token_json.get("team", {}).get("id"),
"team_name": token_json.get("team", {}).get("name"),
@ -334,3 +341,138 @@ async def slack_callback(
raise HTTPException(
status_code=500, detail=f"Failed to complete Slack OAuth: {e!s}"
) from e
async def refresh_slack_token(
session: AsyncSession, connector: SearchSourceConnector
) -> SearchSourceConnector:
"""
Refresh the Slack bot token for a connector.
Args:
session: Database session
connector: Slack connector to refresh
Returns:
Updated connector object
"""
try:
logger.info(f"Refreshing Slack token for connector {connector.id}")
credentials = SlackAuthCredentialsBase.from_dict(connector.config)
# Decrypt tokens if they are encrypted
token_encryption = get_token_encryption()
is_encrypted = connector.config.get("_token_encrypted", False)
refresh_token = credentials.refresh_token
if is_encrypted and refresh_token:
try:
refresh_token = token_encryption.decrypt_token(refresh_token)
except Exception as e:
logger.error(f"Failed to decrypt refresh token: {e!s}")
raise HTTPException(
status_code=500, detail="Failed to decrypt stored refresh token"
) from e
if not refresh_token:
raise HTTPException(
status_code=400,
detail="No refresh token available. Please re-authenticate.",
)
# Slack uses oauth.v2.access for token refresh with grant_type=refresh_token
refresh_data = {
"client_id": config.SLACK_CLIENT_ID,
"client_secret": config.SLACK_CLIENT_SECRET,
"grant_type": "refresh_token",
"refresh_token": refresh_token,
}
async with httpx.AsyncClient() as client:
token_response = await client.post(
TOKEN_URL,
data=refresh_data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=30.0,
)
if token_response.status_code != 200:
error_detail = token_response.text
try:
error_json = token_response.json()
error_detail = error_json.get("error", error_detail)
except Exception:
pass
raise HTTPException(
status_code=400, detail=f"Token refresh failed: {error_detail}"
)
token_json = token_response.json()
# Slack OAuth v2 returns success status in the JSON
if not token_json.get("ok", False):
error_msg = token_json.get("error", "Unknown error")
raise HTTPException(
status_code=400, detail=f"Slack OAuth refresh error: {error_msg}"
)
# Extract bot token from refresh response
bot_token = None
if token_json.get("bot") and token_json["bot"].get("bot_access_token"):
bot_token = token_json["bot"]["bot_access_token"]
elif token_json.get("access_token"):
bot_token = token_json["access_token"]
else:
raise HTTPException(
status_code=400, detail="No bot token received from Slack refresh"
)
# Get new refresh token if provided (Slack may rotate refresh tokens)
new_refresh_token = token_json.get("refresh_token")
# Calculate expiration time (UTC, tz-aware)
expires_at = None
expires_in = token_json.get("expires_in")
if expires_in:
now_utc = datetime.now(UTC)
expires_at = now_utc + timedelta(seconds=int(expires_in))
# Update credentials object with encrypted tokens
credentials.bot_token = token_encryption.encrypt_token(bot_token)
if new_refresh_token:
credentials.refresh_token = token_encryption.encrypt_token(
new_refresh_token
)
credentials.expires_in = expires_in
credentials.expires_at = expires_at
credentials.scope = token_json.get("scope")
# Preserve team info
if not credentials.team_id:
credentials.team_id = connector.config.get("team_id")
if not credentials.team_name:
credentials.team_name = connector.config.get("team_name")
if not credentials.bot_user_id:
credentials.bot_user_id = connector.config.get("bot_user_id")
# Update connector config with encrypted tokens
credentials_dict = credentials.to_dict()
credentials_dict["_token_encrypted"] = True
connector.config = credentials_dict
await session.commit()
await session.refresh(connector)
logger.info(f"Successfully refreshed Slack token for connector {connector.id}")
return connector
except HTTPException:
raise
except Exception as e:
logger.error(
f"Failed to refresh Slack token for connector {connector.id}: {e!s}",
exc_info=True,
)
raise HTTPException(
status_code=500, detail=f"Failed to refresh Slack token: {e!s}"
) from e

View file

@ -0,0 +1,76 @@
from datetime import UTC, datetime
from pydantic import BaseModel, field_validator
class SlackAuthCredentialsBase(BaseModel):
bot_token: str
refresh_token: str | None = None
token_type: str = "Bearer"
expires_in: int | None = None
expires_at: datetime | None = None
scope: str | None = None
bot_user_id: str | None = None
team_id: str | None = None
team_name: str | None = None
@property
def is_expired(self) -> bool:
"""Check if the credentials have expired."""
if self.expires_at is None:
return False # Long-lived token, treat as not expired
return self.expires_at <= datetime.now(UTC)
@property
def is_refreshable(self) -> bool:
"""Check if the credentials can be refreshed."""
return self.refresh_token is not None
def to_dict(self) -> dict:
"""Convert credentials to dictionary for storage."""
return {
"bot_token": self.bot_token,
"refresh_token": self.refresh_token,
"token_type": self.token_type,
"expires_in": self.expires_in,
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
"scope": self.scope,
"bot_user_id": self.bot_user_id,
"team_id": self.team_id,
"team_name": self.team_name,
}
@classmethod
def from_dict(cls, data: dict) -> "SlackAuthCredentialsBase":
"""Create credentials from dictionary."""
expires_at = None
if data.get("expires_at"):
expires_at = datetime.fromisoformat(data["expires_at"])
return cls(
bot_token=data.get("bot_token", ""),
refresh_token=data.get("refresh_token"),
token_type=data.get("token_type", "Bearer"),
expires_in=data.get("expires_in"),
expires_at=expires_at,
scope=data.get("scope"),
bot_user_id=data.get("bot_user_id"),
team_id=data.get("team_id"),
team_name=data.get("team_name"),
)
@field_validator("expires_at", mode="before")
@classmethod
def ensure_aware_utc(cls, v):
# Strings like "2025-08-26T14:46:57.367184"
if isinstance(v, str):
# add +00:00 if missing tz info
if v.endswith("Z"):
return datetime.fromisoformat(v.replace("Z", "+00:00"))
dt = datetime.fromisoformat(v)
return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
# datetime objects
if isinstance(v, datetime):
return v if v.tzinfo else v.replace(tzinfo=UTC)
return v

View file

@ -17,7 +17,6 @@ from app.utils.document_converters import (
generate_content_hash,
generate_unique_identifier_hash,
)
from app.utils.oauth_security import TokenEncryption
from .base import (
build_document_metadata_markdown,
@ -93,44 +92,20 @@ async def index_slack_messages(
f"Connector with ID {connector_id} not found or is not a Slack connector",
)
# Get the Slack token from the connector config
# Support both new OAuth format (bot_token) and old API format (SLACK_BOT_TOKEN)
config_data = connector.config.copy()
slack_token = config_data.get("bot_token") or config_data.get("SLACK_BOT_TOKEN")
# Note: Token handling is now done automatically by SlackHistory
# with auto-refresh support. We just need to pass session and connector_id.
if not slack_token:
await task_logger.log_task_failure(
log_entry,
f"Slack token not found in connector config for connector {connector_id}",
"Missing Slack token",
{"error_type": "MissingToken"},
)
return 0, "Slack token not found in connector config"
# Decrypt token if it's encrypted (OAuth format)
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
try:
token_encryption = TokenEncryption(config.SECRET_KEY)
slack_token = token_encryption.decrypt_token(slack_token)
logger.info(f"Decrypted Slack bot token for connector {connector_id}")
except Exception as e:
await task_logger.log_task_failure(
log_entry,
f"Failed to decrypt Slack token for connector {connector_id}: {e!s}",
"Token decryption failed",
{"error_type": "TokenDecryptionError"},
)
return 0, f"Failed to decrypt Slack token: {e!s}"
# Initialize Slack client
# Initialize Slack client with auto-refresh support
await task_logger.log_task_progress(
log_entry,
f"Initializing Slack client for connector {connector_id}",
{"stage": "client_initialization"},
)
slack_client = SlackHistory(token=slack_token)
# Use the new pattern with session and connector_id for auto-refresh
slack_client = SlackHistory(
session=session, connector_id=connector_id
)
# Handle 'undefined' string from frontend (treat as None)
if start_date == "undefined" or start_date == "":
@ -167,7 +142,7 @@ async def index_slack_messages(
# Get all channels
try:
channels = slack_client.get_all_channels()
channels = await slack_client.get_all_channels()
except Exception as e:
await task_logger.log_task_failure(
log_entry,
@ -216,7 +191,7 @@ async def index_slack_messages(
continue
# Get messages for this channel
messages, error = slack_client.get_history_by_date_range(
messages, error = await slack_client.get_history_by_date_range(
channel_id=channel_id,
start_date=start_date_str,
end_date=end_date_str,
@ -249,7 +224,7 @@ async def index_slack_messages(
]:
continue
formatted_msg = slack_client.format_message(
formatted_msg = await slack_client.format_message(
msg, include_user_info=True
)
formatted_messages.append(formatted_msg)