mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-02 04:12:47 +02:00
feat: database driven refresh tokens for slack oauth connector
This commit is contained in:
parent
0fe94bfcf3
commit
81e4a4ada0
4 changed files with 426 additions and 58 deletions
|
|
@ -12,6 +12,14 @@ from typing import Any
|
||||||
|
|
||||||
from slack_sdk import WebClient
|
from slack_sdk import WebClient
|
||||||
from slack_sdk.errors import SlackApiError
|
from slack_sdk.errors import SlackApiError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
|
||||||
|
from app.config import config
|
||||||
|
from app.db import SearchSourceConnector
|
||||||
|
from app.routes.slack_add_connector_route import refresh_slack_token
|
||||||
|
from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase
|
||||||
|
from app.utils.oauth_security import TokenEncryption
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # Added logger
|
logger = logging.getLogger(__name__) # Added logger
|
||||||
|
|
||||||
|
|
@ -19,25 +27,195 @@ logger = logging.getLogger(__name__) # Added logger
|
||||||
class SlackHistory:
|
class SlackHistory:
|
||||||
"""Class for retrieving conversation history from Slack channels."""
|
"""Class for retrieving conversation history from Slack channels."""
|
||||||
|
|
||||||
def __init__(self, token: str | None = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
token: str | None = None,
|
||||||
|
session: AsyncSession | None = None,
|
||||||
|
connector_id: int | None = None,
|
||||||
|
credentials: SlackAuthCredentialsBase | None = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the SlackHistory class.
|
Initialize the SlackHistory class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: Slack API token (optional, can be set later with set_token)
|
token: Slack API token (optional, for backward compatibility)
|
||||||
|
session: Database session for token refresh (optional)
|
||||||
|
connector_id: Connector ID for token refresh (optional)
|
||||||
|
credentials: Slack OAuth credentials (optional, will be loaded from DB if not provided)
|
||||||
"""
|
"""
|
||||||
self.client = WebClient(token=token) if token else None
|
self._session = session
|
||||||
|
self._connector_id = connector_id
|
||||||
|
self._credentials = credentials
|
||||||
|
# For backward compatibility, if token is provided directly, use it
|
||||||
|
if token:
|
||||||
|
self.client = WebClient(token=token)
|
||||||
|
else:
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
async def _get_valid_token(self) -> str:
|
||||||
|
"""
|
||||||
|
Get valid Slack bot token, refreshing if needed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Valid bot token
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If credentials are missing or invalid
|
||||||
|
Exception: If token refresh fails
|
||||||
|
"""
|
||||||
|
# If we have a direct token (backward compatibility), use it
|
||||||
|
# Check if client was initialized with a token directly (not via credentials)
|
||||||
|
if self.client and self._session is None and self._connector_id is None:
|
||||||
|
# This means it was initialized with a direct token, extract it
|
||||||
|
# WebClient stores token internally, we need to get it from the client
|
||||||
|
# For backward compatibility, we'll use the client directly
|
||||||
|
# But we can't easily extract the token, so we'll just use the client
|
||||||
|
# In this case, we'll skip refresh logic
|
||||||
|
if self._credentials is None:
|
||||||
|
# This is the old pattern - just use the client as-is
|
||||||
|
# We can't extract token easily, so we'll raise an error
|
||||||
|
# asking to use the new pattern
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot refresh token: Please use session and connector_id for auto-refresh support"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load credentials from DB if not provided
|
||||||
|
if self._credentials is None:
|
||||||
|
if not self._session or not self._connector_id:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot load credentials: session and connector_id required"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == self._connector_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
|
||||||
|
if not connector:
|
||||||
|
raise ValueError(f"Connector {self._connector_id} not found")
|
||||||
|
|
||||||
|
config_data = connector.config.copy()
|
||||||
|
|
||||||
|
# Decrypt credentials if they are encrypted
|
||||||
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
|
if token_encrypted and config.SECRET_KEY:
|
||||||
|
try:
|
||||||
|
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||||
|
|
||||||
|
# Decrypt sensitive fields
|
||||||
|
if config_data.get("bot_token"):
|
||||||
|
config_data["bot_token"] = token_encryption.decrypt_token(
|
||||||
|
config_data["bot_token"]
|
||||||
|
)
|
||||||
|
if config_data.get("refresh_token"):
|
||||||
|
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||||
|
config_data["refresh_token"]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Decrypted Slack credentials for connector {self._connector_id}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to decrypt Slack credentials for connector {self._connector_id}: {e!s}"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to decrypt Slack credentials: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._credentials = SlackAuthCredentialsBase.from_dict(config_data)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid Slack credentials: {e!s}") from e
|
||||||
|
|
||||||
|
# Check if token is expired and refreshable
|
||||||
|
if self._credentials.is_expired and self._credentials.is_refreshable:
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Slack token expired for connector {self._connector_id}, refreshing..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get connector for refresh
|
||||||
|
result = await self._session.execute(
|
||||||
|
select(SearchSourceConnector).filter(
|
||||||
|
SearchSourceConnector.id == self._connector_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connector = result.scalars().first()
|
||||||
|
|
||||||
|
if not connector:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Connector {self._connector_id} not found; cannot refresh token."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Refresh token
|
||||||
|
connector = await refresh_slack_token(self._session, connector)
|
||||||
|
|
||||||
|
# Reload credentials after refresh
|
||||||
|
config_data = connector.config.copy()
|
||||||
|
token_encrypted = config_data.get("_token_encrypted", False)
|
||||||
|
if token_encrypted and config.SECRET_KEY:
|
||||||
|
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||||
|
if config_data.get("bot_token"):
|
||||||
|
config_data["bot_token"] = token_encryption.decrypt_token(
|
||||||
|
config_data["bot_token"]
|
||||||
|
)
|
||||||
|
if config_data.get("refresh_token"):
|
||||||
|
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||||
|
config_data["refresh_token"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self._credentials = SlackAuthCredentialsBase.from_dict(config_data)
|
||||||
|
|
||||||
|
# Invalidate cached client so it's recreated with new token
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Successfully refreshed Slack token for connector {self._connector_id}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to refresh Slack token for connector {self._connector_id}: {e!s}"
|
||||||
|
)
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to refresh Slack OAuth credentials: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return self._credentials.bot_token
|
||||||
|
|
||||||
|
async def _ensure_client(self) -> WebClient:
|
||||||
|
"""
|
||||||
|
Ensure Slack client is initialized with valid token.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WebClient instance
|
||||||
|
"""
|
||||||
|
# If client was initialized with direct token (backward compatibility), use it
|
||||||
|
if self.client and (self._session is None or self._connector_id is None):
|
||||||
|
return self.client
|
||||||
|
|
||||||
|
# Otherwise, initialize with token from credentials (with auto-refresh)
|
||||||
|
if self.client is None:
|
||||||
|
token = await self._get_valid_token()
|
||||||
|
# Skip if it's the placeholder for direct token initialization
|
||||||
|
if token != "direct_token_initialized":
|
||||||
|
self.client = WebClient(token=token)
|
||||||
|
return self.client
|
||||||
|
|
||||||
def set_token(self, token: str) -> None:
|
def set_token(self, token: str) -> None:
|
||||||
"""
|
"""
|
||||||
Set the Slack API token.
|
Set the Slack API token (for backward compatibility).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: Slack API token
|
token: Slack API token
|
||||||
"""
|
"""
|
||||||
self.client = WebClient(token=token)
|
self.client = WebClient(token=token)
|
||||||
|
|
||||||
def get_all_channels(self, include_private: bool = True) -> list[dict[str, Any]]:
|
async def get_all_channels(
|
||||||
|
self, include_private: bool = True
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch all channels that the bot has access to, with rate limit handling.
|
Fetch all channels that the bot has access to, with rate limit handling.
|
||||||
|
|
||||||
|
|
@ -52,8 +230,7 @@ class SlackHistory:
|
||||||
SlackApiError: If there's an unrecoverable error calling the Slack API
|
SlackApiError: If there's an unrecoverable error calling the Slack API
|
||||||
RuntimeError: For unexpected errors during channel fetching.
|
RuntimeError: For unexpected errors during channel fetching.
|
||||||
"""
|
"""
|
||||||
if not self.client:
|
client = await self._ensure_client()
|
||||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
|
||||||
|
|
||||||
channels_list = [] # Changed from dict to list
|
channels_list = [] # Changed from dict to list
|
||||||
types = "public_channel"
|
types = "public_channel"
|
||||||
|
|
@ -72,7 +249,7 @@ class SlackHistory:
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
|
|
||||||
current_limit = 1000 # Max limit
|
current_limit = 1000 # Max limit
|
||||||
api_result = self.client.conversations_list(
|
api_result = client.conversations_list(
|
||||||
types=types, cursor=next_cursor, limit=current_limit
|
types=types, cursor=next_cursor, limit=current_limit
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -129,7 +306,7 @@ class SlackHistory:
|
||||||
|
|
||||||
return channels_list
|
return channels_list
|
||||||
|
|
||||||
def get_conversation_history(
|
async def get_conversation_history(
|
||||||
self,
|
self,
|
||||||
channel_id: str,
|
channel_id: str,
|
||||||
limit: int = 1000,
|
limit: int = 1000,
|
||||||
|
|
@ -152,8 +329,7 @@ class SlackHistory:
|
||||||
ValueError: If no Slack client has been initialized
|
ValueError: If no Slack client has been initialized
|
||||||
SlackApiError: If there's an error calling the Slack API
|
SlackApiError: If there's an error calling the Slack API
|
||||||
"""
|
"""
|
||||||
if not self.client:
|
client = await self._ensure_client()
|
||||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
next_cursor = None
|
next_cursor = None
|
||||||
|
|
@ -177,7 +353,7 @@ class SlackHistory:
|
||||||
current_api_call_successful = False
|
current_api_call_successful = False
|
||||||
result = None # Ensure result is defined
|
result = None # Ensure result is defined
|
||||||
try:
|
try:
|
||||||
result = self.client.conversations_history(**kwargs)
|
result = client.conversations_history(**kwargs)
|
||||||
current_api_call_successful = True
|
current_api_call_successful = True
|
||||||
except SlackApiError as e_history:
|
except SlackApiError as e_history:
|
||||||
if (
|
if (
|
||||||
|
|
@ -252,7 +428,7 @@ class SlackHistory:
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_history_by_date_range(
|
async def get_history_by_date_range(
|
||||||
self, channel_id: str, start_date: str, end_date: str, limit: int = 1000
|
self, channel_id: str, start_date: str, end_date: str, limit: int = 1000
|
||||||
) -> tuple[list[dict[str, Any]], str | None]:
|
) -> tuple[list[dict[str, Any]], str | None]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -282,7 +458,7 @@ class SlackHistory:
|
||||||
latest += 86400 # seconds in a day
|
latest += 86400 # seconds in a day
|
||||||
|
|
||||||
try:
|
try:
|
||||||
messages = self.get_conversation_history(
|
messages = await self.get_conversation_history(
|
||||||
channel_id=channel_id, limit=limit, oldest=oldest, latest=latest
|
channel_id=channel_id, limit=limit, oldest=oldest, latest=latest
|
||||||
)
|
)
|
||||||
return messages, None
|
return messages, None
|
||||||
|
|
@ -291,7 +467,7 @@ class SlackHistory:
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return [], str(e)
|
return [], str(e)
|
||||||
|
|
||||||
def get_user_info(self, user_id: str) -> dict[str, Any]:
|
async def get_user_info(self, user_id: str) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Get information about a user.
|
Get information about a user.
|
||||||
|
|
||||||
|
|
@ -305,8 +481,7 @@ class SlackHistory:
|
||||||
ValueError: If no Slack client has been initialized
|
ValueError: If no Slack client has been initialized
|
||||||
SlackApiError: If there's an error calling the Slack API
|
SlackApiError: If there's an error calling the Slack API
|
||||||
"""
|
"""
|
||||||
if not self.client:
|
client = await self._ensure_client()
|
||||||
raise ValueError("Slack client not initialized. Call set_token() first.")
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
@ -314,7 +489,7 @@ class SlackHistory:
|
||||||
# For now, we are only adding Retry-After as per plan.
|
# For now, we are only adding Retry-After as per plan.
|
||||||
# time.sleep(0.6) # Optional: ~100 req/min if ever needed.
|
# time.sleep(0.6) # Optional: ~100 req/min if ever needed.
|
||||||
|
|
||||||
result = self.client.users_info(user=user_id)
|
result = client.users_info(user=user_id)
|
||||||
return result["user"] # Success, return and exit loop implicitly
|
return result["user"] # Success, return and exit loop implicitly
|
||||||
|
|
||||||
except SlackApiError as e_user_info:
|
except SlackApiError as e_user_info:
|
||||||
|
|
@ -343,7 +518,7 @@ class SlackHistory:
|
||||||
)
|
)
|
||||||
raise general_error from general_error # Re-raise unexpected errors
|
raise general_error from general_error # Re-raise unexpected errors
|
||||||
|
|
||||||
def format_message(
|
async def format_message(
|
||||||
self, msg: dict[str, Any], include_user_info: bool = False
|
self, msg: dict[str, Any], include_user_info: bool = False
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -369,9 +544,9 @@ class SlackHistory:
|
||||||
"is_thread": "thread_ts" in msg,
|
"is_thread": "thread_ts" in msg,
|
||||||
}
|
}
|
||||||
|
|
||||||
if include_user_info and "user" in msg and self.client:
|
if include_user_info and "user" in msg:
|
||||||
try:
|
try:
|
||||||
user_info = self.get_user_info(msg["user"])
|
user_info = await self.get_user_info(msg["user"])
|
||||||
formatted["user_name"] = user_info.get("real_name", "Unknown")
|
formatted["user_name"] = user_info.get("real_name", "Unknown")
|
||||||
formatted["user_email"] = user_info.get("profile", {}).get("email", "")
|
formatted["user_email"] = user_info.get("profile", {}).get("email", "")
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from app.db import (
|
||||||
User,
|
User,
|
||||||
get_async_session,
|
get_async_session,
|
||||||
)
|
)
|
||||||
|
from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase
|
||||||
from app.users import current_active_user
|
from app.users import current_active_user
|
||||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||||
|
|
||||||
|
|
@ -229,7 +230,7 @@ async def slack_callback(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract bot token from Slack response
|
# Extract bot token from Slack response
|
||||||
# Slack OAuth v2 returns: { "ok": true, "access_token": "...", "bot": { "bot_user_id": "...", "bot_access_token": "xoxb-..." }, ... }
|
# Slack OAuth v2 returns: { "ok": true, "access_token": "...", "bot": { "bot_user_id": "...", "bot_access_token": "xoxb-..." }, "refresh_token": "...", ... }
|
||||||
bot_token = None
|
bot_token = None
|
||||||
if token_json.get("bot") and token_json["bot"].get("bot_access_token"):
|
if token_json.get("bot") and token_json["bot"].get("bot_access_token"):
|
||||||
bot_token = token_json["bot"]["bot_access_token"]
|
bot_token = token_json["bot"]["bot_access_token"]
|
||||||
|
|
@ -241,6 +242,9 @@ async def slack_callback(
|
||||||
status_code=400, detail="No bot token received from Slack"
|
status_code=400, detail="No bot token received from Slack"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract refresh token if available (for token rotation)
|
||||||
|
refresh_token = token_json.get("refresh_token")
|
||||||
|
|
||||||
# Encrypt sensitive tokens before storing
|
# Encrypt sensitive tokens before storing
|
||||||
token_encryption = get_token_encryption()
|
token_encryption = get_token_encryption()
|
||||||
|
|
||||||
|
|
@ -251,9 +255,12 @@ async def slack_callback(
|
||||||
now_utc = datetime.now(UTC)
|
now_utc = datetime.now(UTC)
|
||||||
expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"]))
|
expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"]))
|
||||||
|
|
||||||
# Store the encrypted bot token in connector config
|
# Store the encrypted bot token and refresh token in connector config
|
||||||
connector_config = {
|
connector_config = {
|
||||||
"bot_token": token_encryption.encrypt_token(bot_token),
|
"bot_token": token_encryption.encrypt_token(bot_token),
|
||||||
|
"refresh_token": token_encryption.encrypt_token(refresh_token)
|
||||||
|
if refresh_token
|
||||||
|
else None,
|
||||||
"bot_user_id": token_json.get("bot", {}).get("bot_user_id"),
|
"bot_user_id": token_json.get("bot", {}).get("bot_user_id"),
|
||||||
"team_id": token_json.get("team", {}).get("id"),
|
"team_id": token_json.get("team", {}).get("id"),
|
||||||
"team_name": token_json.get("team", {}).get("name"),
|
"team_name": token_json.get("team", {}).get("name"),
|
||||||
|
|
@ -334,3 +341,138 @@ async def slack_callback(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"Failed to complete Slack OAuth: {e!s}"
|
status_code=500, detail=f"Failed to complete Slack OAuth: {e!s}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_slack_token(
|
||||||
|
session: AsyncSession, connector: SearchSourceConnector
|
||||||
|
) -> SearchSourceConnector:
|
||||||
|
"""
|
||||||
|
Refresh the Slack bot token for a connector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
connector: Slack connector to refresh
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated connector object
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"Refreshing Slack token for connector {connector.id}")
|
||||||
|
|
||||||
|
credentials = SlackAuthCredentialsBase.from_dict(connector.config)
|
||||||
|
|
||||||
|
# Decrypt tokens if they are encrypted
|
||||||
|
token_encryption = get_token_encryption()
|
||||||
|
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||||
|
|
||||||
|
refresh_token = credentials.refresh_token
|
||||||
|
if is_encrypted and refresh_token:
|
||||||
|
try:
|
||||||
|
refresh_token = token_encryption.decrypt_token(refresh_token)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to decrypt refresh token: {e!s}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Failed to decrypt stored refresh token"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if not refresh_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="No refresh token available. Please re-authenticate.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Slack uses oauth.v2.access for token refresh with grant_type=refresh_token
|
||||||
|
refresh_data = {
|
||||||
|
"client_id": config.SLACK_CLIENT_ID,
|
||||||
|
"client_secret": config.SLACK_CLIENT_SECRET,
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
token_response = await client.post(
|
||||||
|
TOKEN_URL,
|
||||||
|
data=refresh_data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if token_response.status_code != 200:
|
||||||
|
error_detail = token_response.text
|
||||||
|
try:
|
||||||
|
error_json = token_response.json()
|
||||||
|
error_detail = error_json.get("error", error_detail)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||||
|
)
|
||||||
|
|
||||||
|
token_json = token_response.json()
|
||||||
|
|
||||||
|
# Slack OAuth v2 returns success status in the JSON
|
||||||
|
if not token_json.get("ok", False):
|
||||||
|
error_msg = token_json.get("error", "Unknown error")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"Slack OAuth refresh error: {error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract bot token from refresh response
|
||||||
|
bot_token = None
|
||||||
|
if token_json.get("bot") and token_json["bot"].get("bot_access_token"):
|
||||||
|
bot_token = token_json["bot"]["bot_access_token"]
|
||||||
|
elif token_json.get("access_token"):
|
||||||
|
bot_token = token_json["access_token"]
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="No bot token received from Slack refresh"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get new refresh token if provided (Slack may rotate refresh tokens)
|
||||||
|
new_refresh_token = token_json.get("refresh_token")
|
||||||
|
|
||||||
|
# Calculate expiration time (UTC, tz-aware)
|
||||||
|
expires_at = None
|
||||||
|
expires_in = token_json.get("expires_in")
|
||||||
|
if expires_in:
|
||||||
|
now_utc = datetime.now(UTC)
|
||||||
|
expires_at = now_utc + timedelta(seconds=int(expires_in))
|
||||||
|
|
||||||
|
# Update credentials object with encrypted tokens
|
||||||
|
credentials.bot_token = token_encryption.encrypt_token(bot_token)
|
||||||
|
if new_refresh_token:
|
||||||
|
credentials.refresh_token = token_encryption.encrypt_token(
|
||||||
|
new_refresh_token
|
||||||
|
)
|
||||||
|
credentials.expires_in = expires_in
|
||||||
|
credentials.expires_at = expires_at
|
||||||
|
credentials.scope = token_json.get("scope")
|
||||||
|
|
||||||
|
# Preserve team info
|
||||||
|
if not credentials.team_id:
|
||||||
|
credentials.team_id = connector.config.get("team_id")
|
||||||
|
if not credentials.team_name:
|
||||||
|
credentials.team_name = connector.config.get("team_name")
|
||||||
|
if not credentials.bot_user_id:
|
||||||
|
credentials.bot_user_id = connector.config.get("bot_user_id")
|
||||||
|
|
||||||
|
# Update connector config with encrypted tokens
|
||||||
|
credentials_dict = credentials.to_dict()
|
||||||
|
credentials_dict["_token_encrypted"] = True
|
||||||
|
connector.config = credentials_dict
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(connector)
|
||||||
|
|
||||||
|
logger.info(f"Successfully refreshed Slack token for connector {connector.id}")
|
||||||
|
|
||||||
|
return connector
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to refresh Slack token for connector {connector.id}: {e!s}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to refresh Slack token: {e!s}"
|
||||||
|
) from e
|
||||||
|
|
|
||||||
76
surfsense_backend/app/schemas/slack_auth_credentials.py
Normal file
76
surfsense_backend/app/schemas/slack_auth_credentials.py
Normal file
|
|
@ -0,0 +1,76 @@
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
|
|
||||||
|
class SlackAuthCredentialsBase(BaseModel):
|
||||||
|
bot_token: str
|
||||||
|
refresh_token: str | None = None
|
||||||
|
token_type: str = "Bearer"
|
||||||
|
expires_in: int | None = None
|
||||||
|
expires_at: datetime | None = None
|
||||||
|
scope: str | None = None
|
||||||
|
bot_user_id: str | None = None
|
||||||
|
team_id: str | None = None
|
||||||
|
team_name: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""Check if the credentials have expired."""
|
||||||
|
if self.expires_at is None:
|
||||||
|
return False # Long-lived token, treat as not expired
|
||||||
|
return self.expires_at <= datetime.now(UTC)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_refreshable(self) -> bool:
|
||||||
|
"""Check if the credentials can be refreshed."""
|
||||||
|
return self.refresh_token is not None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert credentials to dictionary for storage."""
|
||||||
|
return {
|
||||||
|
"bot_token": self.bot_token,
|
||||||
|
"refresh_token": self.refresh_token,
|
||||||
|
"token_type": self.token_type,
|
||||||
|
"expires_in": self.expires_in,
|
||||||
|
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||||
|
"scope": self.scope,
|
||||||
|
"bot_user_id": self.bot_user_id,
|
||||||
|
"team_id": self.team_id,
|
||||||
|
"team_name": self.team_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "SlackAuthCredentialsBase":
|
||||||
|
"""Create credentials from dictionary."""
|
||||||
|
expires_at = None
|
||||||
|
if data.get("expires_at"):
|
||||||
|
expires_at = datetime.fromisoformat(data["expires_at"])
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
bot_token=data.get("bot_token", ""),
|
||||||
|
refresh_token=data.get("refresh_token"),
|
||||||
|
token_type=data.get("token_type", "Bearer"),
|
||||||
|
expires_in=data.get("expires_in"),
|
||||||
|
expires_at=expires_at,
|
||||||
|
scope=data.get("scope"),
|
||||||
|
bot_user_id=data.get("bot_user_id"),
|
||||||
|
team_id=data.get("team_id"),
|
||||||
|
team_name=data.get("team_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("expires_at", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def ensure_aware_utc(cls, v):
|
||||||
|
# Strings like "2025-08-26T14:46:57.367184"
|
||||||
|
if isinstance(v, str):
|
||||||
|
# add +00:00 if missing tz info
|
||||||
|
if v.endswith("Z"):
|
||||||
|
return datetime.fromisoformat(v.replace("Z", "+00:00"))
|
||||||
|
dt = datetime.fromisoformat(v)
|
||||||
|
return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
|
||||||
|
# datetime objects
|
||||||
|
if isinstance(v, datetime):
|
||||||
|
return v if v.tzinfo else v.replace(tzinfo=UTC)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
@ -17,7 +17,6 @@ from app.utils.document_converters import (
|
||||||
generate_content_hash,
|
generate_content_hash,
|
||||||
generate_unique_identifier_hash,
|
generate_unique_identifier_hash,
|
||||||
)
|
)
|
||||||
from app.utils.oauth_security import TokenEncryption
|
|
||||||
|
|
||||||
from .base import (
|
from .base import (
|
||||||
build_document_metadata_markdown,
|
build_document_metadata_markdown,
|
||||||
|
|
@ -93,44 +92,20 @@ async def index_slack_messages(
|
||||||
f"Connector with ID {connector_id} not found or is not a Slack connector",
|
f"Connector with ID {connector_id} not found or is not a Slack connector",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the Slack token from the connector config
|
# Note: Token handling is now done automatically by SlackHistory
|
||||||
# Support both new OAuth format (bot_token) and old API format (SLACK_BOT_TOKEN)
|
# with auto-refresh support. We just need to pass session and connector_id.
|
||||||
config_data = connector.config.copy()
|
|
||||||
slack_token = config_data.get("bot_token") or config_data.get("SLACK_BOT_TOKEN")
|
|
||||||
|
|
||||||
if not slack_token:
|
# Initialize Slack client with auto-refresh support
|
||||||
await task_logger.log_task_failure(
|
|
||||||
log_entry,
|
|
||||||
f"Slack token not found in connector config for connector {connector_id}",
|
|
||||||
"Missing Slack token",
|
|
||||||
{"error_type": "MissingToken"},
|
|
||||||
)
|
|
||||||
return 0, "Slack token not found in connector config"
|
|
||||||
|
|
||||||
# Decrypt token if it's encrypted (OAuth format)
|
|
||||||
token_encrypted = config_data.get("_token_encrypted", False)
|
|
||||||
if token_encrypted and config.SECRET_KEY:
|
|
||||||
try:
|
|
||||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
|
||||||
slack_token = token_encryption.decrypt_token(slack_token)
|
|
||||||
logger.info(f"Decrypted Slack bot token for connector {connector_id}")
|
|
||||||
except Exception as e:
|
|
||||||
await task_logger.log_task_failure(
|
|
||||||
log_entry,
|
|
||||||
f"Failed to decrypt Slack token for connector {connector_id}: {e!s}",
|
|
||||||
"Token decryption failed",
|
|
||||||
{"error_type": "TokenDecryptionError"},
|
|
||||||
)
|
|
||||||
return 0, f"Failed to decrypt Slack token: {e!s}"
|
|
||||||
|
|
||||||
# Initialize Slack client
|
|
||||||
await task_logger.log_task_progress(
|
await task_logger.log_task_progress(
|
||||||
log_entry,
|
log_entry,
|
||||||
f"Initializing Slack client for connector {connector_id}",
|
f"Initializing Slack client for connector {connector_id}",
|
||||||
{"stage": "client_initialization"},
|
{"stage": "client_initialization"},
|
||||||
)
|
)
|
||||||
|
|
||||||
slack_client = SlackHistory(token=slack_token)
|
# Use the new pattern with session and connector_id for auto-refresh
|
||||||
|
slack_client = SlackHistory(
|
||||||
|
session=session, connector_id=connector_id
|
||||||
|
)
|
||||||
|
|
||||||
# Handle 'undefined' string from frontend (treat as None)
|
# Handle 'undefined' string from frontend (treat as None)
|
||||||
if start_date == "undefined" or start_date == "":
|
if start_date == "undefined" or start_date == "":
|
||||||
|
|
@ -167,7 +142,7 @@ async def index_slack_messages(
|
||||||
|
|
||||||
# Get all channels
|
# Get all channels
|
||||||
try:
|
try:
|
||||||
channels = slack_client.get_all_channels()
|
channels = await slack_client.get_all_channels()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await task_logger.log_task_failure(
|
await task_logger.log_task_failure(
|
||||||
log_entry,
|
log_entry,
|
||||||
|
|
@ -216,7 +191,7 @@ async def index_slack_messages(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get messages for this channel
|
# Get messages for this channel
|
||||||
messages, error = slack_client.get_history_by_date_range(
|
messages, error = await slack_client.get_history_by_date_range(
|
||||||
channel_id=channel_id,
|
channel_id=channel_id,
|
||||||
start_date=start_date_str,
|
start_date=start_date_str,
|
||||||
end_date=end_date_str,
|
end_date=end_date_str,
|
||||||
|
|
@ -249,7 +224,7 @@ async def index_slack_messages(
|
||||||
]:
|
]:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
formatted_msg = slack_client.format_message(
|
formatted_msg = await slack_client.format_message(
|
||||||
msg, include_user_info=True
|
msg, include_user_info=True
|
||||||
)
|
)
|
||||||
formatted_messages.append(formatted_msg)
|
formatted_messages.append(formatted_msg)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue