diff --git a/surfsense_backend/.env.example b/surfsense_backend/.env.example index 91a0cb42f..d2c667178 100644 --- a/surfsense_backend/.env.example +++ b/surfsense_backend/.env.example @@ -38,13 +38,33 @@ GOOGLE_OAUTH_CLIENT_SECRET=GOCSV GOOGLE_CALENDAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/calendar/connector/callback GOOGLE_GMAIL_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/gmail/connector/callback GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback -GOOGLE_DRIVE_REDIRECT_URI=http://localhost:8000/api/v1/auth/google/drive/connector/callback -# Airtable OAuth for Aitable Connector +# OAuth for Aitable Connector AIRTABLE_CLIENT_ID=your_airtable_client_id AIRTABLE_CLIENT_SECRET=your_airtable_client_secret AIRTABLE_REDIRECT_URI=http://localhost:8000/api/v1/auth/airtable/connector/callback +# Discord OAuth Configuration +DISCORD_CLIENT_ID=your_discord_client_id_here +DISCORD_CLIENT_SECRET=your_discord_client_secret_here +DISCORD_REDIRECT_URI=http://localhost:8000/api/v1/auth/discord/connector/callback +DISCORD_BOT_TOKEN=your_bot_token_from_developer_portal + +# OAuth for Linear Connector +LINEAR_CLIENT_ID=your_linear_client_id +LINEAR_CLIENT_SECRET=your_linear_client_secret +LINEAR_REDIRECT_URI=http://localhost:8000/api/v1/auth/linear/connector/callback + +# OAuth for Notion Connector +NOTION_CLIENT_ID=your_notion_client_id +NOTION_CLIENT_SECRET=your_notion_client_secret +NOTION_REDIRECT_URI=http://localhost:8000/api/v1/auth/notion/connector/callback + +# OAuth for Slack connector +SLACK_CLIENT_ID=1234567890.1234567890123 +SLACK_CLIENT_SECRET=abcdefghijklmnopqrstuvwxyz1234567890 +SLACK_REDIRECT_URI=http://localhost:8000/api/v1/auth/slack/connector/callback + # Embedding Model # Examples: # # Get sentence transformers embeddings diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 9c503fb18..f65a94cc0 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -90,6 +90,27 @@ class Config: AIRTABLE_CLIENT_SECRET = os.getenv("AIRTABLE_CLIENT_SECRET") AIRTABLE_REDIRECT_URI = os.getenv("AIRTABLE_REDIRECT_URI") + # Notion OAuth + NOTION_CLIENT_ID = os.getenv("NOTION_CLIENT_ID") + NOTION_CLIENT_SECRET = os.getenv("NOTION_CLIENT_SECRET") + NOTION_REDIRECT_URI = os.getenv("NOTION_REDIRECT_URI") + + # Linear OAuth + LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID") + LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET") + LINEAR_REDIRECT_URI = os.getenv("LINEAR_REDIRECT_URI") + + # Slack OAuth + SLACK_CLIENT_ID = os.getenv("SLACK_CLIENT_ID") + SLACK_CLIENT_SECRET = os.getenv("SLACK_CLIENT_SECRET") + SLACK_REDIRECT_URI = os.getenv("SLACK_REDIRECT_URI") + + # Discord OAuth + DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID") + DISCORD_CLIENT_SECRET = os.getenv("DISCORD_CLIENT_SECRET") + DISCORD_REDIRECT_URI = os.getenv("DISCORD_REDIRECT_URI") + DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN") + # LLM instances are now managed per-user through the LLMConfig system # Legacy environment variables removed in favor of user-specific configurations diff --git a/surfsense_backend/app/connectors/discord_connector.py b/surfsense_backend/app/connectors/discord_connector.py index 506b463a5..1e12cb9a4 100644 --- a/surfsense_backend/app/connectors/discord_connector.py +++ b/surfsense_backend/app/connectors/discord_connector.py @@ -3,7 +3,7 @@ Discord Connector A module for interacting with Discord's HTTP API to retrieve guilds, channels, and message history. -Requires a Discord bot token. +Supports both direct bot token and OAuth-based authentication with token refresh. """ import asyncio @@ -12,6 +12,14 @@ import logging import discord from discord.ext import commands +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector +from app.routes.discord_add_connector_route import refresh_discord_token +from app.schemas.discord_auth_credentials import DiscordAuthCredentialsBase +from app.utils.oauth_security import TokenEncryption logger = logging.getLogger(__name__) @@ -19,12 +27,21 @@ logger = logging.getLogger(__name__) class DiscordConnector(commands.Bot): """Class for retrieving guild, channel, and message history from Discord.""" - def __init__(self, token: str | None = None): + def __init__( + self, + token: str | None = None, + session: AsyncSession | None = None, + connector_id: int | None = None, + credentials: DiscordAuthCredentialsBase | None = None, + ): """ - Initialize the DiscordConnector with a bot token. + Initialize the DiscordConnector with a bot token or OAuth credentials. Args: - token (str): The Discord bot token. + token: Discord bot token (optional, for backward compatibility) + session: Database session for token refresh (optional) + connector_id: Connector ID for token refresh (optional) + credentials: Discord OAuth credentials (optional, will be loaded from DB if not provided) """ intents = discord.Intents.default() intents.guilds = True # Required to fetch guilds and channels @@ -34,7 +51,14 @@ class DiscordConnector(commands.Bot): super().__init__( command_prefix="!", intents=intents ) # command_prefix is required but not strictly used here - self.token = token + self._session = session + self._connector_id = connector_id + self._credentials = credentials + # For backward compatibility, if token is provided directly, use it + if token: + self.token = token + else: + self.token = None self._bot_task = None # Holds the async bot task self._is_running = False # Flag to track if the bot is running @@ -57,12 +81,143 @@ class DiscordConnector(commands.Bot): async def on_resumed(): logger.debug("Bot resumed connection to Discord gateway.") + async def _get_valid_token(self) -> str: + """ + Get valid Discord bot token, refreshing if needed. + + Returns: + Valid bot token + + Raises: + ValueError: If credentials are missing or invalid + Exception: If token refresh fails + """ + # If we have a direct token (backward compatibility), use it + if ( + self.token + and self._session is None + and self._connector_id is None + and self._credentials is None + ): + # This means it was initialized with a direct token, use it + return self.token + + # Load credentials from DB if not provided + if self._credentials is None: + if not self._session or not self._connector_id: + raise ValueError( + "Cannot load credentials: session and connector_id required" + ) + + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise ValueError(f"Connector {self._connector_id} not found") + + config_data = connector.config.copy() + + # Decrypt credentials if they are encrypted + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("bot_token"): + config_data["bot_token"] = token_encryption.decrypt_token( + config_data["bot_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + logger.info( + f"Decrypted Discord credentials for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to decrypt Discord credentials for connector {self._connector_id}: {e!s}" + ) + raise ValueError( + f"Failed to decrypt Discord credentials: {e!s}" + ) from e + + try: + self._credentials = DiscordAuthCredentialsBase.from_dict(config_data) + except Exception as e: + raise ValueError(f"Invalid Discord credentials: {e!s}") from e + + # Check if token is expired and refreshable + if self._credentials.is_expired and self._credentials.is_refreshable: + try: + logger.info( + f"Discord token expired for connector {self._connector_id}, refreshing..." + ) + + # Get connector for refresh + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise RuntimeError( + f"Connector {self._connector_id} not found; cannot refresh token." + ) + + # Refresh token + connector = await refresh_discord_token(self._session, connector) + + # Reload credentials after refresh + config_data = connector.config.copy() + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("bot_token"): + config_data["bot_token"] = token_encryption.decrypt_token( + config_data["bot_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + self._credentials = DiscordAuthCredentialsBase.from_dict(config_data) + + logger.info( + f"Successfully refreshed Discord token for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to refresh Discord token for connector {self._connector_id}: {e!s}" + ) + raise Exception( + f"Failed to refresh Discord OAuth credentials: {e!s}" + ) from e + + return self._credentials.bot_token + async def start_bot(self): """Starts the bot to connect to Discord.""" logger.info("Starting Discord bot...") + # Get valid token (with auto-refresh if using OAuth) if not self.token: - raise ValueError("Discord bot token not set. Call set_token(token) first.") + # Try to get token from credentials + try: + self.token = await self._get_valid_token() + except ValueError as e: + raise ValueError( + f"Discord bot token not set. {e!s} Please authenticate via OAuth or provide a token." + ) from e try: if self._is_running: @@ -107,7 +262,7 @@ class DiscordConnector(commands.Bot): def set_token(self, token: str) -> None: """ - Set the discord bot token. + Set the discord bot token (for backward compatibility). Args: token (str): The Discord bot token. diff --git a/surfsense_backend/app/connectors/google_calendar_connector.py b/surfsense_backend/app/connectors/google_calendar_connector.py index 164d230e0..6d389ddd5 100644 --- a/surfsense_backend/app/connectors/google_calendar_connector.py +++ b/surfsense_backend/app/connectors/google_calendar_connector.py @@ -109,7 +109,36 @@ class GoogleCalendarConnector: raise RuntimeError( "GOOGLE_CALENDAR_CONNECTOR connector not found; cannot persist refreshed token." ) - connector.config = json.loads(self._credentials.to_json()) + + # Encrypt sensitive credentials before storing + from app.config import config + from app.utils.oauth_security import TokenEncryption + + creds_dict = json.loads(self._credentials.to_json()) + token_encrypted = connector.config.get("_token_encrypted", False) + + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + # Encrypt sensitive fields + if creds_dict.get("token"): + creds_dict["token"] = token_encryption.encrypt_token( + creds_dict["token"] + ) + if creds_dict.get("refresh_token"): + creds_dict["refresh_token"] = ( + token_encryption.encrypt_token( + creds_dict["refresh_token"] + ) + ) + if creds_dict.get("client_secret"): + creds_dict["client_secret"] = ( + token_encryption.encrypt_token( + creds_dict["client_secret"] + ) + ) + creds_dict["_token_encrypted"] = True + + connector.config = creds_dict flag_modified(connector, "config") await self._session.commit() except Exception as e: @@ -182,6 +211,18 @@ class GoogleCalendarConnector: Tuple containing (events list, error message or None) """ try: + # Validate date strings + if not start_date or start_date.lower() in ("undefined", "null", "none"): + return ( + [], + "Invalid start_date: must be a valid date string in YYYY-MM-DD format", + ) + if not end_date or end_date.lower() in ("undefined", "null", "none"): + return ( + [], + "Invalid end_date: must be a valid date string in YYYY-MM-DD format", + ) + service = await self._get_service() # Parse both dates diff --git a/surfsense_backend/app/connectors/google_drive/credentials.py b/surfsense_backend/app/connectors/google_drive/credentials.py index f88486468..7e6335f6d 100644 --- a/surfsense_backend/app/connectors/google_drive/credentials.py +++ b/surfsense_backend/app/connectors/google_drive/credentials.py @@ -1,6 +1,7 @@ """Google Drive OAuth credential management.""" import json +import logging from datetime import datetime from google.auth.transport.requests import Request @@ -9,7 +10,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm.attributes import flag_modified +from app.config import config from app.db import SearchSourceConnector +from app.utils.oauth_security import TokenEncryption + +logger = logging.getLogger(__name__) async def get_valid_credentials( @@ -38,7 +43,41 @@ async def get_valid_credentials( if not connector: raise ValueError(f"Connector {connector_id} not found") - config_data = connector.config + config_data = ( + connector.config.copy() + ) # Work with a copy to avoid modifying original + + # Decrypt credentials if they are encrypted + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + if config_data.get("client_secret"): + config_data["client_secret"] = token_encryption.decrypt_token( + config_data["client_secret"] + ) + + logger.info( + f"Decrypted Google Drive credentials for connector {connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to decrypt Google Drive credentials for connector {connector_id}: {e!s}" + ) + raise ValueError( + f"Failed to decrypt Google Drive credentials: {e!s}" + ) from e + exp = config_data.get("expiry", "").replace("Z", "") if not all( @@ -66,7 +105,29 @@ async def get_valid_credentials( try: credentials.refresh(Request()) - connector.config = json.loads(credentials.to_json()) + creds_dict = json.loads(credentials.to_json()) + + # Encrypt sensitive credentials before storing + token_encrypted = connector.config.get("_token_encrypted", False) + + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + # Encrypt sensitive fields + if creds_dict.get("token"): + creds_dict["token"] = token_encryption.encrypt_token( + creds_dict["token"] + ) + if creds_dict.get("refresh_token"): + creds_dict["refresh_token"] = token_encryption.encrypt_token( + creds_dict["refresh_token"] + ) + if creds_dict.get("client_secret"): + creds_dict["client_secret"] = token_encryption.encrypt_token( + creds_dict["client_secret"] + ) + creds_dict["_token_encrypted"] = True + + connector.config = creds_dict flag_modified(connector, "config") await session.commit() diff --git a/surfsense_backend/app/connectors/linear_connector.py b/surfsense_backend/app/connectors/linear_connector.py index b4c54fda3..148aa4d0a 100644 --- a/surfsense_backend/app/connectors/linear_connector.py +++ b/surfsense_backend/app/connectors/linear_connector.py @@ -5,33 +5,153 @@ A module for retrieving issues and comments from Linear. Allows fetching issue lists and their comments with date range filtering. """ +import logging from datetime import datetime from typing import Any import requests +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector +from app.routes.linear_add_connector_route import refresh_linear_token +from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase +from app.utils.oauth_security import TokenEncryption + +logger = logging.getLogger(__name__) class LinearConnector: """Class for retrieving issues and comments from Linear.""" - def __init__(self, token: str | None = None): + def __init__( + self, + session: AsyncSession, + connector_id: int, + credentials: LinearAuthCredentialsBase | None = None, + ): """ - Initialize the LinearConnector class. + Initialize the LinearConnector class with auto-refresh capability. Args: - token: Linear API token (optional, can be set later with set_token) + session: Database session for updating connector + connector_id: Connector ID for direct updates + credentials: Linear OAuth credentials (optional, will be loaded from DB if not provided) """ - self.token = token + self._session = session + self._connector_id = connector_id + self._credentials = credentials self.api_url = "https://api.linear.app/graphql" - def set_token(self, token: str) -> None: + async def _get_valid_token(self) -> str: """ - Set the Linear API token. + Get valid Linear access token, refreshing if needed. - Args: - token: Linear API token + Returns: + Valid access token + + Raises: + ValueError: If credentials are missing or invalid + Exception: If token refresh fails """ - self.token = token + # Load credentials from DB if not provided + if self._credentials is None: + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise ValueError(f"Connector {self._connector_id} not found") + + config_data = connector.config.copy() + + # Decrypt credentials if they are encrypted + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + logger.info( + f"Decrypted Linear credentials for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to decrypt Linear credentials for connector {self._connector_id}: {e!s}" + ) + raise ValueError( + f"Failed to decrypt Linear credentials: {e!s}" + ) from e + + try: + self._credentials = LinearAuthCredentialsBase.from_dict(config_data) + except Exception as e: + raise ValueError(f"Invalid Linear credentials: {e!s}") from e + + # Check if token is expired and refreshable + if self._credentials.is_expired and self._credentials.is_refreshable: + try: + logger.info( + f"Linear token expired for connector {self._connector_id}, refreshing..." + ) + + # Get connector for refresh + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise RuntimeError( + f"Connector {self._connector_id} not found; cannot refresh token." + ) + + # Refresh token + connector = await refresh_linear_token(self._session, connector) + + # Reload credentials after refresh + config_data = connector.config.copy() + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + self._credentials = LinearAuthCredentialsBase.from_dict(config_data) + + logger.info( + f"Successfully refreshed Linear token for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to refresh Linear token for connector {self._connector_id}: {e!s}" + ) + raise Exception( + f"Failed to refresh Linear OAuth credentials: {e!s}" + ) from e + + return self._credentials.access_token def get_headers(self) -> dict[str, str]: """ @@ -41,18 +161,26 @@ class LinearConnector: Dictionary of headers Raises: - ValueError: If no Linear token has been set + ValueError: If no Linear access token has been set """ - if not self.token: - raise ValueError("Linear token not initialized. Call set_token() first.") + # This is a synchronous method, but we need async token refresh + # For now, we'll raise an error if called directly + # All API calls should go through execute_graphql_query which handles async refresh + if not self._credentials or not self._credentials.access_token: + raise ValueError( + "Linear access token not initialized. Use execute_graphql_query() method." + ) - return {"Content-Type": "application/json", "Authorization": self.token} + return { + "Content-Type": "application/json", + "Authorization": f"Bearer {self._credentials.access_token}", + } - def execute_graphql_query( + async def execute_graphql_query( self, query: str, variables: dict[str, Any] | None = None ) -> dict[str, Any]: """ - Execute a GraphQL query against the Linear API. + Execute a GraphQL query against the Linear API with automatic token refresh. Args: query: GraphQL query string @@ -62,13 +190,17 @@ class LinearConnector: Response data from the API Raises: - ValueError: If no Linear token has been set + ValueError: If no Linear access token has been set Exception: If the API request fails """ - if not self.token: - raise ValueError("Linear token not initialized. Call set_token() first.") + # Get valid token (refreshes if needed) + access_token = await self._get_valid_token() + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } - headers = self.get_headers() payload = {"query": query} if variables: @@ -83,7 +215,9 @@ class LinearConnector: f"Query failed with status code {response.status_code}: {response.text}" ) - def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]: + async def get_all_issues( + self, include_comments: bool = True + ) -> list[dict[str, Any]]: """ Fetch all issues from Linear. @@ -94,7 +228,7 @@ class LinearConnector: List of issue objects Raises: - ValueError: If no Linear token has been set + ValueError: If no Linear access token has been set Exception: If the API request fails """ comments_query = "" @@ -146,7 +280,7 @@ class LinearConnector: }} """ - result = self.execute_graphql_query(query) + result = await self.execute_graphql_query(query) # Extract issues from the response if ( @@ -158,7 +292,7 @@ class LinearConnector: return [] - def get_issues_by_date_range( + async def get_issues_by_date_range( self, start_date: str, end_date: str, include_comments: bool = True ) -> tuple[list[dict[str, Any]], str | None]: """ @@ -172,6 +306,18 @@ class LinearConnector: Returns: Tuple containing (issues list, error message or None) """ + # Validate date strings + if not start_date or start_date.lower() in ("undefined", "null", "none"): + return ( + [], + "Invalid start_date: must be a valid date string in YYYY-MM-DD format", + ) + if not end_date or end_date.lower() in ("undefined", "null", "none"): + return ( + [], + "Invalid end_date: must be a valid date string in YYYY-MM-DD format", + ) + # Convert date strings to ISO format try: # For Linear API: we need to use a more specific format for the filter @@ -258,7 +404,7 @@ class LinearConnector: # Handle pagination to get all issues while has_next_page: variables = {"after": cursor} if cursor else {} - result = self.execute_graphql_query(query, variables) + result = await self.execute_graphql_query(query, variables) # Check for errors if "errors" in result: @@ -446,37 +592,3 @@ class LinearConnector: return dt.strftime("%Y-%m-%d %H:%M:%S") except ValueError: return iso_date - - -# Example usage (uncomment to use): -""" -if __name__ == "__main__": - # Set your token here - token = "YOUR_LINEAR_API_KEY" - - linear = LinearConnector(token) - - try: - # Get all issues with comments - issues = linear.get_all_issues() - print(f"Retrieved {len(issues)} issues") - - # Format and print the first issue as markdown - if issues: - issue_md = linear.format_issue_to_markdown(issues[0]) - print("\nSample Issue in Markdown:\n") - print(issue_md) - - # Get issues by date range - start_date = "2023-01-01" - end_date = "2023-01-31" - date_issues, error = linear.get_issues_by_date_range(start_date, end_date) - - if error: - print(f"Error: {error}") - else: - print(f"\nRetrieved {len(date_issues)} issues from {start_date} to {end_date}") - - except Exception as e: - print(f"Error: {e}") -""" diff --git a/surfsense_backend/app/connectors/notion_history.py b/surfsense_backend/app/connectors/notion_history.py index 81f6642f1..e38218a6e 100644 --- a/surfsense_backend/app/connectors/notion_history.py +++ b/surfsense_backend/app/connectors/notion_history.py @@ -1,19 +1,167 @@ +import logging + from notion_client import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector +from app.routes.notion_add_connector_route import refresh_notion_token +from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase +from app.utils.oauth_security import TokenEncryption + +logger = logging.getLogger(__name__) class NotionHistoryConnector: - def __init__(self, token): + def __init__( + self, + session: AsyncSession, + connector_id: int, + credentials: NotionAuthCredentialsBase | None = None, + ): """ - Initialize the NotionPageFetcher with a token. + Initialize the NotionHistoryConnector with auto-refresh capability. Args: - token (str): Notion integration token + session: Database session for updating connector + connector_id: Connector ID for direct updates + credentials: Notion OAuth credentials (optional, will be loaded from DB if not provided) """ - self.notion = AsyncClient(auth=token) + self._session = session + self._connector_id = connector_id + self._credentials = credentials + self._notion_client: AsyncClient | None = None + + async def _get_valid_token(self) -> str: + """ + Get valid Notion access token, refreshing if needed. + + Returns: + Valid access token + + Raises: + ValueError: If credentials are missing or invalid + Exception: If token refresh fails + """ + # Load credentials from DB if not provided + if self._credentials is None: + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise ValueError(f"Connector {self._connector_id} not found") + + config_data = connector.config.copy() + + # Decrypt credentials if they are encrypted + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + logger.info( + f"Decrypted Notion credentials for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to decrypt Notion credentials for connector {self._connector_id}: {e!s}" + ) + raise ValueError( + f"Failed to decrypt Notion credentials: {e!s}" + ) from e + + try: + self._credentials = NotionAuthCredentialsBase.from_dict(config_data) + except Exception as e: + raise ValueError(f"Invalid Notion credentials: {e!s}") from e + + # Check if token is expired and refreshable + if self._credentials.is_expired and self._credentials.is_refreshable: + try: + logger.info( + f"Notion token expired for connector {self._connector_id}, refreshing..." + ) + + # Get connector for refresh + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise RuntimeError( + f"Connector {self._connector_id} not found; cannot refresh token." + ) + + # Refresh token + connector = await refresh_notion_token(self._session, connector) + + # Reload credentials after refresh + config_data = connector.config.copy() + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + self._credentials = NotionAuthCredentialsBase.from_dict(config_data) + + # Invalidate cached client so it's recreated with new token + self._notion_client = None + + logger.info( + f"Successfully refreshed Notion token for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to refresh Notion token for connector {self._connector_id}: {e!s}" + ) + raise Exception( + f"Failed to refresh Notion OAuth credentials: {e!s}" + ) from e + + return self._credentials.access_token + + async def _get_client(self) -> AsyncClient: + """ + Get or create Notion AsyncClient with valid token. + + Returns: + Notion AsyncClient instance + """ + if self._notion_client is None: + token = await self._get_valid_token() + self._notion_client = AsyncClient(auth=token) + return self._notion_client async def close(self): """Close the async client connection.""" - await self.notion.aclose() + if self._notion_client: + await self._notion_client.aclose() + self._notion_client = None async def __aenter__(self): """Async context manager entry.""" @@ -34,6 +182,8 @@ class NotionHistoryConnector: Returns: list: List of dictionaries containing page data """ + notion = await self._get_client() + # Build the filter for the search # Note: Notion API requires specific filter structure search_params = {} @@ -67,7 +217,7 @@ class NotionHistoryConnector: if cursor: search_params["start_cursor"] = cursor - search_results = await self.notion.search(**search_params) + search_results = await notion.search(**search_params) pages.extend(search_results["results"]) has_more = search_results.get("has_more", False) @@ -125,6 +275,8 @@ class NotionHistoryConnector: Returns: list: List of processed blocks from the page """ + notion = await self._get_client() + blocks = [] has_more = True cursor = None @@ -132,11 +284,11 @@ class NotionHistoryConnector: # Paginate through all blocks while has_more: if cursor: - response = await self.notion.blocks.children.list( + response = await notion.blocks.children.list( block_id=page_id, start_cursor=cursor ) else: - response = await self.notion.blocks.children.list(block_id=page_id) + response = await notion.blocks.children.list(block_id=page_id) blocks.extend(response["results"]) has_more = response["has_more"] @@ -162,6 +314,8 @@ class NotionHistoryConnector: Returns: dict: Processed block with content and children """ + notion = await self._get_client() + block_id = block["id"] block_type = block["type"] @@ -174,9 +328,7 @@ class NotionHistoryConnector: if has_children: # Fetch and process child blocks - children_response = await self.notion.blocks.children.list( - block_id=block_id - ) + children_response = await notion.blocks.children.list(block_id=block_id) for child_block in children_response["results"]: child_blocks.append(await self.process_block(child_block)) diff --git a/surfsense_backend/app/connectors/slack_history.py b/surfsense_backend/app/connectors/slack_history.py index 36160c30b..dbf43bb24 100644 --- a/surfsense_backend/app/connectors/slack_history.py +++ b/surfsense_backend/app/connectors/slack_history.py @@ -12,6 +12,14 @@ from typing import Any from slack_sdk import WebClient from slack_sdk.errors import SlackApiError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector +from app.routes.slack_add_connector_route import refresh_slack_token +from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase +from app.utils.oauth_security import TokenEncryption logger = logging.getLogger(__name__) # Added logger @@ -19,25 +27,199 @@ logger = logging.getLogger(__name__) # Added logger class SlackHistory: """Class for retrieving conversation history from Slack channels.""" - def __init__(self, token: str | None = None): + def __init__( + self, + token: str | None = None, + session: AsyncSession | None = None, + connector_id: int | None = None, + credentials: SlackAuthCredentialsBase | None = None, + ): """ Initialize the SlackHistory class. Args: - token: Slack API token (optional, can be set later with set_token) + token: Slack API token (optional, for backward compatibility) + session: Database session for token refresh (optional) + connector_id: Connector ID for token refresh (optional) + credentials: Slack OAuth credentials (optional, will be loaded from DB if not provided) """ - self.client = WebClient(token=token) if token else None + self._session = session + self._connector_id = connector_id + self._credentials = credentials + # For backward compatibility, if token is provided directly, use it + if token: + self.client = WebClient(token=token) + else: + self.client = None + + async def _get_valid_token(self) -> str: + """ + Get valid Slack bot token, refreshing if needed. + + Returns: + Valid bot token + + Raises: + ValueError: If credentials are missing or invalid + Exception: If token refresh fails + """ + # If we have a direct token (backward compatibility), use it + # Check if client was initialized with a token directly (not via credentials) + if ( + self.client + and self._session is None + and self._connector_id is None + and self._credentials is None + ): + # This means it was initialized with a direct token, extract it + # WebClient stores token internally, we need to get it from the client + # For backward compatibility, we'll use the client directly + # But we can't easily extract the token, so we'll just use the client + # In this case, we'll skip refresh logic + # This is the old pattern - just use the client as-is + # We can't extract token easily, so we'll raise an error + # asking to use the new pattern + raise ValueError( + "Cannot refresh token: Please use session and connector_id for auto-refresh support" + ) + + # Load credentials from DB if not provided + if self._credentials is None: + if not self._session or not self._connector_id: + raise ValueError( + "Cannot load credentials: session and connector_id required" + ) + + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise ValueError(f"Connector {self._connector_id} not found") + + config_data = connector.config.copy() + + # Decrypt credentials if they are encrypted + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("bot_token"): + config_data["bot_token"] = token_encryption.decrypt_token( + config_data["bot_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + logger.info( + f"Decrypted Slack credentials for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to decrypt Slack credentials for connector {self._connector_id}: {e!s}" + ) + raise ValueError( + f"Failed to decrypt Slack credentials: {e!s}" + ) from e + + try: + self._credentials = SlackAuthCredentialsBase.from_dict(config_data) + except Exception as e: + raise ValueError(f"Invalid Slack credentials: {e!s}") from e + + # Check if token is expired and refreshable + if self._credentials.is_expired and self._credentials.is_refreshable: + try: + logger.info( + f"Slack token expired for connector {self._connector_id}, refreshing..." + ) + + # Get connector for refresh + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise RuntimeError( + f"Connector {self._connector_id} not found; cannot refresh token." + ) + + # Refresh token + connector = await refresh_slack_token(self._session, connector) + + # Reload credentials after refresh + config_data = connector.config.copy() + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("bot_token"): + config_data["bot_token"] = token_encryption.decrypt_token( + config_data["bot_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + self._credentials = SlackAuthCredentialsBase.from_dict(config_data) + + # Invalidate cached client so it's recreated with new token + self.client = None + + logger.info( + f"Successfully refreshed Slack token for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to refresh Slack token for connector {self._connector_id}: {e!s}" + ) + raise Exception( + f"Failed to refresh Slack OAuth credentials: {e!s}" + ) from e + + return self._credentials.bot_token + + async def _ensure_client(self) -> WebClient: + """ + Ensure Slack client is initialized with valid token. + + Returns: + WebClient instance + """ + # If client was initialized with direct token (backward compatibility), use it + if self.client and (self._session is None or self._connector_id is None): + return self.client + + # Otherwise, initialize with token from credentials (with auto-refresh) + if self.client is None: + token = await self._get_valid_token() + # Skip if it's the placeholder for direct token initialization + if token != "direct_token_initialized": + self.client = WebClient(token=token) + return self.client def set_token(self, token: str) -> None: """ - Set the Slack API token. + Set the Slack API token (for backward compatibility). Args: token: Slack API token """ self.client = WebClient(token=token) - def get_all_channels(self, include_private: bool = True) -> list[dict[str, Any]]: + async def get_all_channels( + self, include_private: bool = True + ) -> list[dict[str, Any]]: """ Fetch all channels that the bot has access to, with rate limit handling. @@ -52,8 +234,7 @@ class SlackHistory: SlackApiError: If there's an unrecoverable error calling the Slack API RuntimeError: For unexpected errors during channel fetching. """ - if not self.client: - raise ValueError("Slack client not initialized. Call set_token() first.") + client = await self._ensure_client() channels_list = [] # Changed from dict to list types = "public_channel" @@ -72,7 +253,7 @@ class SlackHistory: time.sleep(3) current_limit = 1000 # Max limit - api_result = self.client.conversations_list( + api_result = client.conversations_list( types=types, cursor=next_cursor, limit=current_limit ) @@ -129,7 +310,7 @@ class SlackHistory: return channels_list - def get_conversation_history( + async def get_conversation_history( self, channel_id: str, limit: int = 1000, @@ -152,8 +333,7 @@ class SlackHistory: ValueError: If no Slack client has been initialized SlackApiError: If there's an error calling the Slack API """ - if not self.client: - raise ValueError("Slack client not initialized. Call set_token() first.") + client = await self._ensure_client() messages = [] next_cursor = None @@ -177,7 +357,7 @@ class SlackHistory: current_api_call_successful = False result = None # Ensure result is defined try: - result = self.client.conversations_history(**kwargs) + result = client.conversations_history(**kwargs) current_api_call_successful = True except SlackApiError as e_history: if ( @@ -252,7 +432,7 @@ class SlackHistory: except ValueError: return None - def get_history_by_date_range( + async def get_history_by_date_range( self, channel_id: str, start_date: str, end_date: str, limit: int = 1000 ) -> tuple[list[dict[str, Any]], str | None]: """ @@ -282,7 +462,7 @@ class SlackHistory: latest += 86400 # seconds in a day try: - messages = self.get_conversation_history( + messages = await self.get_conversation_history( channel_id=channel_id, limit=limit, oldest=oldest, latest=latest ) return messages, None @@ -291,7 +471,7 @@ class SlackHistory: except ValueError as e: return [], str(e) - def get_user_info(self, user_id: str) -> dict[str, Any]: + async def get_user_info(self, user_id: str) -> dict[str, Any]: """ Get information about a user. @@ -305,8 +485,7 @@ class SlackHistory: ValueError: If no Slack client has been initialized SlackApiError: If there's an error calling the Slack API """ - if not self.client: - raise ValueError("Slack client not initialized. Call set_token() first.") + client = await self._ensure_client() while True: try: @@ -314,7 +493,7 @@ class SlackHistory: # For now, we are only adding Retry-After as per plan. # time.sleep(0.6) # Optional: ~100 req/min if ever needed. - result = self.client.users_info(user=user_id) + result = client.users_info(user=user_id) return result["user"] # Success, return and exit loop implicitly except SlackApiError as e_user_info: @@ -343,7 +522,7 @@ class SlackHistory: ) raise general_error from general_error # Re-raise unexpected errors - def format_message( + async def format_message( self, msg: dict[str, Any], include_user_info: bool = False ) -> dict[str, Any]: """ @@ -369,9 +548,9 @@ class SlackHistory: "is_thread": "thread_ts" in msg, } - if include_user_info and "user" in msg and self.client: + if include_user_info and "user" in msg: try: - user_info = self.get_user_info(msg["user"]) + user_info = await self.get_user_info(msg["user"]) formatted["user_name"] = user_info.get("real_name", "Unknown") formatted["user_email"] = user_info.get("profile", {}).get("email", "") except Exception: diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 3c18650ae..b35d743e0 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -15,15 +15,19 @@ from .google_drive_add_connector_route import ( from .google_gmail_add_connector_route import ( router as google_gmail_add_connector_router, ) +from .linear_add_connector_route import router as linear_add_connector_router from .logs_routes import router as logs_router from .luma_add_connector_route import router as luma_add_connector_router from .new_chat_routes import router as new_chat_router from .new_llm_config_routes import router as new_llm_config_router from .notes_routes import router as notes_router +from .notion_add_connector_route import router as notion_add_connector_router from .podcasts_routes import router as podcasts_router from .rbac_routes import router as rbac_router from .search_source_connectors_routes import router as search_source_connectors_router from .search_spaces_routes import router as search_spaces_router +from .slack_add_connector_route import router as slack_add_connector_router +from .discord_add_connector_route import router as discord_add_connector_router router = APIRouter() @@ -39,7 +43,11 @@ router.include_router(google_calendar_add_connector_router) router.include_router(google_gmail_add_connector_router) router.include_router(google_drive_add_connector_router) router.include_router(airtable_add_connector_router) +router.include_router(linear_add_connector_router) router.include_router(luma_add_connector_router) +router.include_router(notion_add_connector_router) +router.include_router(slack_add_connector_router) +router.include_router(discord_add_connector_router) router.include_router(new_llm_config_router) # LLM configs with prompt configuration router.include_router(logs_router) router.include_router(circleback_webhook_router) # Circleback meeting webhooks diff --git a/surfsense_backend/app/routes/airtable_add_connector_route.py b/surfsense_backend/app/routes/airtable_add_connector_route.py index 3bcbe4dc0..9284d89e8 100644 --- a/surfsense_backend/app/routes/airtable_add_connector_route.py +++ b/surfsense_backend/app/routes/airtable_add_connector_route.py @@ -1,6 +1,5 @@ import base64 import hashlib -import json import logging import secrets from datetime import UTC, datetime, timedelta @@ -23,6 +22,7 @@ from app.db import ( ) from app.schemas.airtable_auth_credentials import AirtableAuthCredentialsBase from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption logger = logging.getLogger(__name__) @@ -40,6 +40,30 @@ SCOPES = [ "user.email:read", ] +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + def make_basic_auth_header(client_id: str, client_secret: str) -> str: credentials = f"{client_id}:{client_secret}".encode() @@ -90,18 +114,19 @@ async def connect_airtable(space_id: int, user: User = Depends(current_active_us status_code=500, detail="Airtable OAuth not configured." ) + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + # Generate PKCE parameters code_verifier, code_challenge = generate_pkce_pair() - # Generate state parameter - state_payload = json.dumps( - { - "space_id": space_id, - "user_id": str(user.id), - "code_verifier": code_verifier, - } + # Generate secure state parameter with HMAC signature (including code_verifier for PKCE) + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state( + space_id, user.id, code_verifier=code_verifier ) - state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode() # Build authorization URL auth_params = { @@ -134,8 +159,9 @@ async def connect_airtable(space_id: int, user: User = Depends(current_active_us @router.get("/auth/airtable/connector/callback") async def airtable_callback( request: Request, - code: str, - state: str, + code: str | None = None, + error: str | None = None, + state: str | None = None, session: AsyncSession = Depends(get_async_session), ): """ @@ -143,7 +169,8 @@ async def airtable_callback( Args: request: FastAPI request object - code: Authorization code from Airtable + code: Authorization code from Airtable (if user granted access) + error: Error code from Airtable (if user denied access or error occurred) state: State parameter containing user/space info session: Database session @@ -151,10 +178,42 @@ async def airtable_callback( Redirect response to frontend """ try: - # Decode and parse the state + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Airtable OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=airtable_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=airtable_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() try: - decoded_state = base64.urlsafe_b64decode(state.encode()).decode() - data = json.loads(decoded_state) + data = state_manager.validate_state(state) + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=400, detail=f"Invalid state parameter: {e!s}" @@ -162,7 +221,12 @@ async def airtable_callback( user_id = UUID(data["user_id"]) space_id = data["space_id"] - code_verifier = data["code_verifier"] + code_verifier = data.get("code_verifier") + + if not code_verifier: + raise HTTPException( + status_code=400, detail="Missing code_verifier in state parameter" + ) auth_header = make_basic_auth_header( config.AIRTABLE_CLIENT_ID, config.AIRTABLE_CLIENT_SECRET ) @@ -201,22 +265,38 @@ async def airtable_callback( token_json = token_response.json() + # Encrypt sensitive tokens before storing + token_encryption = get_token_encryption() + access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Airtable" + ) + # Calculate expiration time (UTC, tz-aware) expires_at = None if token_json.get("expires_in"): now_utc = datetime.now(UTC) expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"])) - # Create credentials object + # Create credentials object with encrypted tokens credentials = AirtableAuthCredentialsBase( - access_token=token_json["access_token"], - refresh_token=token_json.get("refresh_token"), + access_token=token_encryption.encrypt_token(access_token), + refresh_token=token_encryption.encrypt_token(refresh_token) + if refresh_token + else None, token_type=token_json.get("token_type", "Bearer"), expires_in=token_json.get("expires_in"), expires_at=expires_at, scope=token_json.get("scope"), ) + # Mark that tokens are encrypted for backward compatibility + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + # Check if connector already exists for this search space and user existing_connector_result = await session.execute( select(SearchSourceConnector).filter( @@ -230,7 +310,7 @@ async def airtable_callback( if existing_connector: # Update existing connector - existing_connector.config = credentials.to_dict() + existing_connector.config = credentials_dict existing_connector.name = "Airtable Connector" existing_connector.is_indexable = True logger.info( @@ -242,7 +322,7 @@ async def airtable_callback( name="Airtable Connector", connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR, is_indexable=True, - config=credentials.to_dict(), + config=credentials_dict, search_space_id=space_id, user_id=user_id, ) @@ -306,6 +386,21 @@ async def refresh_airtable_token( logger.info(f"Refreshing Airtable token for connector {connector.id}") credentials = AirtableAuthCredentialsBase.from_dict(connector.config) + + # Decrypt tokens if they are encrypted + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_token = credentials.refresh_token + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error(f"Failed to decrypt refresh token: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + auth_header = make_basic_auth_header( config.AIRTABLE_CLIENT_ID, config.AIRTABLE_CLIENT_SECRET ) @@ -313,7 +408,7 @@ async def refresh_airtable_token( # Prepare token refresh data refresh_data = { "grant_type": "refresh_token", - "refresh_token": credentials.refresh_token, + "refresh_token": refresh_token, "client_id": config.AIRTABLE_CLIENT_ID, "client_secret": config.AIRTABLE_CLIENT_SECRET, } @@ -342,14 +437,29 @@ async def refresh_airtable_token( now_utc = datetime.now(UTC) expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"])) - # Update credentials object - credentials.access_token = token_json["access_token"] + # Encrypt new tokens before storing + access_token = token_json.get("access_token") + new_refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Airtable refresh" + ) + + # Update credentials object with encrypted tokens + credentials.access_token = token_encryption.encrypt_token(access_token) + if new_refresh_token: + credentials.refresh_token = token_encryption.encrypt_token( + new_refresh_token + ) credentials.expires_in = token_json.get("expires_in") credentials.expires_at = expires_at credentials.scope = token_json.get("scope") - # Update connector config - connector.config = credentials.to_dict() + # Update connector config with encrypted tokens + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict await session.commit() await session.refresh(connector) diff --git a/surfsense_backend/app/routes/discord_add_connector_route.py b/surfsense_backend/app/routes/discord_add_connector_route.py new file mode 100644 index 000000000..70a0046a3 --- /dev/null +++ b/surfsense_backend/app/routes/discord_add_connector_route.py @@ -0,0 +1,509 @@ +""" +Discord Connector OAuth Routes. + +Handles OAuth 2.0 authentication flow for Discord connector. +""" + +import logging +from datetime import UTC, datetime, timedelta +from uuid import UUID + +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import RedirectResponse +from pydantic import ValidationError +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + User, + get_async_session, +) +from app.schemas.discord_auth_credentials import DiscordAuthCredentialsBase +from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption + +logger = logging.getLogger(__name__) + +router = APIRouter() + +# Discord OAuth endpoints +AUTHORIZATION_URL = "https://discord.com/api/oauth2/authorize" +TOKEN_URL = "https://discord.com/api/oauth2/token" + +# OAuth scopes for Discord (Bot Token) +SCOPES = [ + "bot", # Basic bot scope + "guilds", # Access to guild information + "guilds.members.read", # Read member information +] + +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + + +@router.get("/auth/discord/connector/add") +async def connect_discord(space_id: int, user: User = Depends(current_active_user)): + """ + Initiate Discord OAuth flow. + + Args: + space_id: The search space ID + user: Current authenticated user + + Returns: + Authorization URL for redirect + """ + try: + if not space_id: + raise HTTPException(status_code=400, detail="space_id is required") + + if not config.DISCORD_CLIENT_ID: + raise HTTPException(status_code=500, detail="Discord OAuth not configured.") + + if not config.DISCORD_BOT_TOKEN: + raise HTTPException( + status_code=500, + detail="Discord bot token not configured. Please set DISCORD_BOT_TOKEN in backend configuration.", + ) + + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) + + # Build authorization URL + from urllib.parse import urlencode + + auth_params = { + "client_id": config.DISCORD_CLIENT_ID, + "scope": " ".join(SCOPES), # Discord uses space-separated scopes + "redirect_uri": config.DISCORD_REDIRECT_URI, + "response_type": "code", + "state": state_encoded, + } + + auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}" + + logger.info(f"Generated Discord OAuth URL for user {user.id}, space {space_id}") + return {"auth_url": auth_url} + + except Exception as e: + logger.error(f"Failed to initiate Discord OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to initiate Discord OAuth: {e!s}" + ) from e + + +@router.get("/auth/discord/connector/callback") +async def discord_callback( + request: Request, + code: str | None = None, + error: str | None = None, + state: str | None = None, + session: AsyncSession = Depends(get_async_session), +): + """ + Handle Discord OAuth callback. + + Args: + request: FastAPI request object + code: Authorization code from Discord (if user granted access) + error: Error code from Discord (if user denied access or error occurred) + state: State parameter containing user/space info + session: Database session + + Returns: + Redirect response to frontend + """ + try: + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Discord OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=discord_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=discord_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e + + user_id = UUID(data["user_id"]) + space_id = data["space_id"] + + # Validate redirect URI (security: ensure it matches configured value) + if not config.DISCORD_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="DISCORD_REDIRECT_URI not configured" + ) + + # Exchange authorization code for access token + token_data = { + "client_id": config.DISCORD_CLIENT_ID, + "client_secret": config.DISCORD_CLIENT_SECRET, + "grant_type": "authorization_code", + "code": code, + "redirect_uri": config.DISCORD_REDIRECT_URI, + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + data=token_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_json.get("error", error_detail)) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token exchange failed: {error_detail}" + ) + + token_json = token_response.json() + + # Log OAuth response for debugging (without sensitive data) + logger.info(f"Discord OAuth response received. Keys: {list(token_json.keys())}") + + # Discord OAuth with 'bot' scope returns access_token (user token), not bot token + # The bot token must come from backend config (DISCORD_BOT_TOKEN) + # OAuth is used to authorize bot installation to servers, not to get bot token + if not config.DISCORD_BOT_TOKEN: + raise HTTPException( + status_code=500, + detail="Discord bot token not configured. Please set DISCORD_BOT_TOKEN in backend configuration.", + ) + + # Use the bot token from backend config (not the OAuth access_token) + bot_token = config.DISCORD_BOT_TOKEN + + # Extract OAuth access_token and refresh_token (for reference, not used for bot operations) + oauth_access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") + + # Encrypt sensitive tokens before storing + token_encryption = get_token_encryption() + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + if token_json.get("expires_in"): + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"])) + + # Extract guild info from OAuth response if available + guild_id = None + guild_name = None + if token_json.get("guild"): + guild_id = token_json["guild"].get("id") + guild_name = token_json["guild"].get("name") + + # Store the bot token from config and OAuth metadata + connector_config = { + "bot_token": token_encryption.encrypt_token(bot_token), # Use bot token from config + "oauth_access_token": token_encryption.encrypt_token(oauth_access_token) + if oauth_access_token + else None, # Store OAuth token for reference + "refresh_token": token_encryption.encrypt_token(refresh_token) + if refresh_token + else None, + "token_type": token_json.get("token_type", "Bearer"), + "expires_in": token_json.get("expires_in"), + "expires_at": expires_at.isoformat() if expires_at else None, + "scope": token_json.get("scope"), + "guild_id": guild_id, + "guild_name": guild_name, + # Mark that tokens are encrypted for backward compatibility + "_token_encrypted": True, + } + + # Check if connector already exists for this search space and user + existing_connector_result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.DISCORD_CONNECTOR, + ) + ) + existing_connector = existing_connector_result.scalars().first() + + if existing_connector: + # Update existing connector + existing_connector.config = connector_config + existing_connector.name = "Discord Connector" + existing_connector.is_indexable = True + logger.info( + f"Updated existing Discord connector for user {user_id} in space {space_id}" + ) + else: + # Create new connector + new_connector = SearchSourceConnector( + name="Discord Connector", + connector_type=SearchSourceConnectorType.DISCORD_CONNECTOR, + is_indexable=True, + config=connector_config, + search_space_id=space_id, + user_id=user_id, + ) + session.add(new_connector) + logger.info( + f"Created new Discord connector for user {user_id} in space {space_id}" + ) + + try: + await session.commit() + logger.info(f"Successfully saved Discord connector for user {user_id}") + + # Redirect to the frontend with success params + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=discord-connector" + ) + + except ValidationError as e: + await session.rollback() + raise HTTPException( + status_code=422, detail=f"Validation error: {e!s}" + ) from e + except IntegrityError as e: + await session.rollback() + raise HTTPException( + status_code=409, + detail=f"Integrity error: A connector with this type already exists. {e!s}", + ) from e + except Exception as e: + logger.error(f"Failed to create search source connector: {e!s}") + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"Failed to create search source connector: {e!s}", + ) from e + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to complete Discord OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to complete Discord OAuth: {e!s}" + ) from e + + +async def refresh_discord_token( + session: AsyncSession, connector: SearchSourceConnector +) -> SearchSourceConnector: + """ + Refresh the Discord OAuth tokens for a connector. + + Note: Bot tokens from config don't expire, but OAuth access tokens might. + This function refreshes OAuth tokens if needed, but always uses bot token from config. + + Args: + session: Database session + connector: Discord connector to refresh + + Returns: + Updated connector object + """ + try: + logger.info(f"Refreshing Discord OAuth tokens for connector {connector.id}") + + # Bot token always comes from config, not from OAuth + if not config.DISCORD_BOT_TOKEN: + raise HTTPException( + status_code=500, + detail="Discord bot token not configured. Please set DISCORD_BOT_TOKEN in backend configuration.", + ) + + credentials = DiscordAuthCredentialsBase.from_dict(connector.config) + + # Decrypt tokens if they are encrypted + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_token = credentials.refresh_token + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error(f"Failed to decrypt refresh token: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + + # If no refresh token, bot token from config is still valid (bot tokens don't expire) + # Just update the bot token from config in case it was changed + if not refresh_token: + logger.info( + f"No refresh token available for connector {connector.id}. Using bot token from config." + ) + # Update bot token from config (in case it was changed) + credentials.bot_token = token_encryption.encrypt_token(config.DISCORD_BOT_TOKEN) + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict + await session.commit() + await session.refresh(connector) + return connector + + # Discord uses oauth2/token for token refresh with grant_type=refresh_token + refresh_data = { + "client_id": config.DISCORD_CLIENT_ID, + "client_secret": config.DISCORD_CLIENT_SECRET, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + data=refresh_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_json.get("error", error_detail)) + except Exception: + pass + # If refresh fails, bot token from config is still valid + logger.warning( + f"OAuth token refresh failed for connector {connector.id}: {error_detail}. " + "Using bot token from config." + ) + # Update bot token from config + credentials.bot_token = token_encryption.encrypt_token(config.DISCORD_BOT_TOKEN) + credentials.refresh_token = None # Clear invalid refresh token + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict + await session.commit() + await session.refresh(connector) + return connector + + token_json = token_response.json() + + # Extract OAuth access token from refresh response (for reference) + oauth_access_token = token_json.get("access_token") + + # Get new refresh token if provided (Discord may rotate refresh tokens) + new_refresh_token = token_json.get("refresh_token") + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Always use bot token from config (bot tokens don't expire) + credentials.bot_token = token_encryption.encrypt_token(config.DISCORD_BOT_TOKEN) + + # Update OAuth tokens if available + if oauth_access_token: + # Store OAuth access token for reference + connector.config["oauth_access_token"] = token_encryption.encrypt_token( + oauth_access_token + ) + if new_refresh_token: + credentials.refresh_token = token_encryption.encrypt_token( + new_refresh_token + ) + credentials.expires_in = expires_in + credentials.expires_at = expires_at + credentials.scope = token_json.get("scope") + + # Preserve guild info if present + if not credentials.guild_id: + credentials.guild_id = connector.config.get("guild_id") + if not credentials.guild_name: + credentials.guild_name = connector.config.get("guild_name") + if not credentials.bot_user_id: + credentials.bot_user_id = connector.config.get("bot_user_id") + + # Update connector config with encrypted tokens + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict + await session.commit() + await session.refresh(connector) + + logger.info(f"Successfully refreshed Discord OAuth tokens for connector {connector.id}") + + return connector + except HTTPException: + raise + except Exception as e: + logger.error( + f"Failed to refresh Discord tokens for connector {connector.id}: {e!s}", + exc_info=True, + ) + raise HTTPException( + status_code=500, detail=f"Failed to refresh Discord tokens: {e!s}" + ) from e + diff --git a/surfsense_backend/app/routes/google_calendar_add_connector_route.py b/surfsense_backend/app/routes/google_calendar_add_connector_route.py index 8bb685450..6c6ae4e40 100644 --- a/surfsense_backend/app/routes/google_calendar_add_connector_route.py +++ b/surfsense_backend/app/routes/google_calendar_add_connector_route.py @@ -2,7 +2,6 @@ import os os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" -import base64 import json import logging from uuid import UUID @@ -23,6 +22,7 @@ from app.db import ( get_async_session, ) from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption logger = logging.getLogger(__name__) @@ -31,6 +31,30 @@ router = APIRouter() SCOPES = ["https://www.googleapis.com/auth/calendar.readonly"] REDIRECT_URI = config.GOOGLE_CALENDAR_REDIRECT_URI +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + def get_google_flow(): try: @@ -59,16 +83,16 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us if not space_id: raise HTTPException(status_code=400, detail="space_id is required") + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + flow = get_google_flow() - # Encode space_id and user_id in state - state_payload = json.dumps( - { - "space_id": space_id, - "user_id": str(user.id), - } - ) - state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode() + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) auth_url, _ = flow.authorization_url( access_type="offline", @@ -86,24 +110,86 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us @router.get("/auth/google/calendar/connector/callback") async def calendar_callback( request: Request, - code: str, - state: str, + code: str | None = None, + error: str | None = None, + state: str | None = None, session: AsyncSession = Depends(get_async_session), ): try: - # Decode and parse the state - decoded_state = base64.urlsafe_b64decode(state.encode()).decode() - data = json.loads(decoded_state) + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Google Calendar OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_calendar_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_calendar_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e user_id = UUID(data["user_id"]) space_id = data["space_id"] + # Validate redirect URI (security: ensure it matches configured value) + if not config.GOOGLE_CALENDAR_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="GOOGLE_CALENDAR_REDIRECT_URI not configured" + ) + flow = get_google_flow() flow.fetch_token(code=code) creds = flow.credentials creds_dict = json.loads(creds.to_json()) + # Encrypt sensitive credentials before storing + token_encryption = get_token_encryption() + + # Encrypt sensitive fields: token, refresh_token, client_secret + if creds_dict.get("token"): + creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"]) + if creds_dict.get("refresh_token"): + creds_dict["refresh_token"] = token_encryption.encrypt_token( + creds_dict["refresh_token"] + ) + if creds_dict.get("client_secret"): + creds_dict["client_secret"] = token_encryption.encrypt_token( + creds_dict["client_secret"] + ) + + # Mark that credentials are encrypted for backward compatibility + creds_dict["_token_encrypted"] = True + try: # Check if a connector with the same type already exists for this search space and user result = await session.execute( diff --git a/surfsense_backend/app/routes/google_drive_add_connector_route.py b/surfsense_backend/app/routes/google_drive_add_connector_route.py index 52461319b..6caf3f204 100644 --- a/surfsense_backend/app/routes/google_drive_add_connector_route.py +++ b/surfsense_backend/app/routes/google_drive_add_connector_route.py @@ -10,7 +10,6 @@ Endpoints: - GET /connectors/{connector_id}/google-drive/folders - List user's folders (for index-time selection) """ -import base64 import json import logging import os @@ -37,6 +36,7 @@ from app.db import ( get_async_session, ) from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption # Relax token scope validation for Google OAuth os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" @@ -44,6 +44,31 @@ os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" logger = logging.getLogger(__name__) router = APIRouter() +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + + # Google Drive OAuth scopes SCOPES = [ "https://www.googleapis.com/auth/drive.readonly", # Read-only access to Drive @@ -90,16 +115,16 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user) if not space_id: raise HTTPException(status_code=400, detail="space_id is required") + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + flow = get_google_flow() - # Encode space_id and user_id in state parameter - state_payload = json.dumps( - { - "space_id": space_id, - "user_id": str(user.id), - } - ) - state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode() + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) # Generate authorization URL auth_url, _ = flow.authorization_url( @@ -124,8 +149,9 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user) @router.get("/auth/google/drive/connector/callback") async def drive_callback( request: Request, - code: str, - state: str, + code: str | None = None, + error: str | None = None, + state: str | None = None, session: AsyncSession = Depends(get_async_session), ): """ @@ -133,15 +159,53 @@ async def drive_callback( Query params: code: Authorization code from Google + error: OAuth error (if user denied access) state: Encoded state with space_id and user_id Returns: Redirect to frontend success page """ try: - # Decode and parse state - decoded_state = base64.urlsafe_b64decode(state.encode()).decode() - data = json.loads(decoded_state) + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Google Drive OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_drive_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_drive_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e user_id = UUID(data["user_id"]) space_id = data["space_id"] @@ -150,6 +214,12 @@ async def drive_callback( f"Processing Google Drive callback for user {user_id}, space {space_id}" ) + # Validate redirect URI (security: ensure it matches configured value) + if not config.GOOGLE_DRIVE_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="GOOGLE_DRIVE_REDIRECT_URI not configured" + ) + # Exchange authorization code for tokens flow = get_google_flow() flow.fetch_token(code=code) @@ -157,6 +227,24 @@ async def drive_callback( creds = flow.credentials creds_dict = json.loads(creds.to_json()) + # Encrypt sensitive credentials before storing + token_encryption = get_token_encryption() + + # Encrypt sensitive fields: token, refresh_token, client_secret + if creds_dict.get("token"): + creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"]) + if creds_dict.get("refresh_token"): + creds_dict["refresh_token"] = token_encryption.encrypt_token( + creds_dict["refresh_token"] + ) + if creds_dict.get("client_secret"): + creds_dict["client_secret"] = token_encryption.encrypt_token( + creds_dict["client_secret"] + ) + + # Mark that credentials are encrypted for backward compatibility + creds_dict["_token_encrypted"] = True + # Check if connector already exists for this space/user result = await session.execute( select(SearchSourceConnector).filter( diff --git a/surfsense_backend/app/routes/google_gmail_add_connector_route.py b/surfsense_backend/app/routes/google_gmail_add_connector_route.py index 21fcf2c38..20a51c1a1 100644 --- a/surfsense_backend/app/routes/google_gmail_add_connector_route.py +++ b/surfsense_backend/app/routes/google_gmail_add_connector_route.py @@ -2,7 +2,6 @@ import os os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" -import base64 import json import logging from uuid import UUID @@ -23,51 +22,90 @@ from app.db import ( get_async_session, ) from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption logger = logging.getLogger(__name__) router = APIRouter() +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + def get_google_flow(): """Create and return a Google OAuth flow for Gmail API.""" - flow = Flow.from_client_config( - { - "web": { - "client_id": config.GOOGLE_OAUTH_CLIENT_ID, - "client_secret": config.GOOGLE_OAUTH_CLIENT_SECRET, - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "https://oauth2.googleapis.com/token", - "redirect_uris": [config.GOOGLE_GMAIL_REDIRECT_URI], - } - }, - scopes=[ - "https://www.googleapis.com/auth/gmail.readonly", - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", - "openid", - ], - ) - flow.redirect_uri = config.GOOGLE_GMAIL_REDIRECT_URI - return flow + try: + flow = Flow.from_client_config( + { + "web": { + "client_id": config.GOOGLE_OAUTH_CLIENT_ID, + "client_secret": config.GOOGLE_OAUTH_CLIENT_SECRET, + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "redirect_uris": [config.GOOGLE_GMAIL_REDIRECT_URI], + } + }, + scopes=[ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "openid", + ], + ) + flow.redirect_uri = config.GOOGLE_GMAIL_REDIRECT_URI + return flow + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to create Google flow: {e!s}" + ) from e @router.get("/auth/google/gmail/connector/add") async def connect_gmail(space_id: int, user: User = Depends(current_active_user)): + """ + Initiate Google Gmail OAuth flow. + + Query params: + space_id: Search space ID to add connector to + + Returns: + JSON with auth_url to redirect user to Google authorization + """ try: if not space_id: raise HTTPException(status_code=400, detail="space_id is required") + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + flow = get_google_flow() - # Encode space_id and user_id in state - state_payload = json.dumps( - { - "space_id": space_id, - "user_id": str(user.id), - } - ) - state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode() + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) auth_url, _ = flow.authorization_url( access_type="offline", @@ -75,8 +113,13 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user) include_granted_scopes="true", state=state_encoded, ) + + logger.info( + f"Initiating Google Gmail OAuth for user {user.id}, space {space_id}" + ) return {"auth_url": auth_url} except Exception as e: + logger.error(f"Failed to initiate Google Gmail OAuth: {e!s}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to initiate Google OAuth: {e!s}" ) from e @@ -85,24 +128,99 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user) @router.get("/auth/google/gmail/connector/callback") async def gmail_callback( request: Request, - code: str, - state: str, + code: str | None = None, + error: str | None = None, + state: str | None = None, session: AsyncSession = Depends(get_async_session), ): + """ + Handle Google Gmail OAuth callback. + + Args: + request: FastAPI request object + code: Authorization code from Google (if user granted access) + error: Error code from Google (if user denied access or error occurred) + state: State parameter containing user/space info + session: Database session + + Returns: + Redirect response to frontend + """ try: - # Decode and parse the state - decoded_state = base64.urlsafe_b64decode(state.encode()).decode() - data = json.loads(decoded_state) + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Google Gmail OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_gmail_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_gmail_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e user_id = UUID(data["user_id"]) space_id = data["space_id"] + # Validate redirect URI (security: ensure it matches configured value) + if not config.GOOGLE_GMAIL_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="GOOGLE_GMAIL_REDIRECT_URI not configured" + ) + flow = get_google_flow() flow.fetch_token(code=code) creds = flow.credentials creds_dict = json.loads(creds.to_json()) + # Encrypt sensitive credentials before storing + token_encryption = get_token_encryption() + + # Encrypt sensitive fields: token, refresh_token, client_secret + if creds_dict.get("token"): + creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"]) + if creds_dict.get("refresh_token"): + creds_dict["refresh_token"] = token_encryption.encrypt_token( + creds_dict["refresh_token"] + ) + if creds_dict.get("client_secret"): + creds_dict["client_secret"] = token_encryption.encrypt_token( + creds_dict["client_secret"] + ) + + # Mark that credentials are encrypted for backward compatibility + creds_dict["_token_encrypted"] = True + try: # Check if a connector with the same type already exists for this search space and user result = await session.execute( @@ -160,3 +278,6 @@ async def gmail_callback( raise except Exception as e: logger.error(f"Unexpected error in Gmail callback: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to complete Google Gmail OAuth: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/linear_add_connector_route.py b/surfsense_backend/app/routes/linear_add_connector_route.py new file mode 100644 index 000000000..7a7fc196a --- /dev/null +++ b/surfsense_backend/app/routes/linear_add_connector_route.py @@ -0,0 +1,448 @@ +""" +Linear Connector OAuth Routes. + +Handles OAuth 2.0 authentication flow for Linear connector. +""" + +import logging +from datetime import UTC, datetime, timedelta +from uuid import UUID + +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import RedirectResponse +from pydantic import ValidationError +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + User, + get_async_session, +) +from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase +from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption + +logger = logging.getLogger(__name__) + +router = APIRouter() + +# Linear OAuth endpoints +AUTHORIZATION_URL = "https://linear.app/oauth/authorize" +TOKEN_URL = "https://api.linear.app/oauth/token" + +# OAuth scopes for Linear +SCOPES = ["read", "write"] + +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + + +def make_basic_auth_header(client_id: str, client_secret: str) -> str: + """Create Basic Auth header for Linear OAuth.""" + import base64 + + credentials = f"{client_id}:{client_secret}".encode() + b64 = base64.b64encode(credentials).decode("ascii") + return f"Basic {b64}" + + +@router.get("/auth/linear/connector/add") +async def connect_linear(space_id: int, user: User = Depends(current_active_user)): + """ + Initiate Linear OAuth flow. + + Args: + space_id: The search space ID + user: Current authenticated user + + Returns: + Authorization URL for redirect + """ + try: + if not space_id: + raise HTTPException(status_code=400, detail="space_id is required") + + if not config.LINEAR_CLIENT_ID: + raise HTTPException(status_code=500, detail="Linear OAuth not configured.") + + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) + + # Build authorization URL + from urllib.parse import urlencode + + auth_params = { + "client_id": config.LINEAR_CLIENT_ID, + "response_type": "code", + "redirect_uri": config.LINEAR_REDIRECT_URI, + "scope": " ".join(SCOPES), + "state": state_encoded, + } + + auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}" + + logger.info(f"Generated Linear OAuth URL for user {user.id}, space {space_id}") + return {"auth_url": auth_url} + + except Exception as e: + logger.error(f"Failed to initiate Linear OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to initiate Linear OAuth: {e!s}" + ) from e + + +@router.get("/auth/linear/connector/callback") +async def linear_callback( + request: Request, + code: str | None = None, + error: str | None = None, + state: str | None = None, + session: AsyncSession = Depends(get_async_session), +): + """ + Handle Linear OAuth callback. + + Args: + request: FastAPI request object + code: Authorization code from Linear (if user granted access) + error: Error code from Linear (if user denied access or error occurred) + state: State parameter containing user/space info + session: Database session + + Returns: + Redirect response to frontend + """ + try: + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Linear OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=linear_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=linear_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e + + user_id = UUID(data["user_id"]) + space_id = data["space_id"] + + # Validate redirect URI (security: ensure it matches configured value) + if not config.LINEAR_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="LINEAR_REDIRECT_URI not configured" + ) + + # Exchange authorization code for access token + auth_header = make_basic_auth_header( + config.LINEAR_CLIENT_ID, config.LINEAR_CLIENT_SECRET + ) + + token_data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": config.LINEAR_REDIRECT_URI, # Use stored value, not from request + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + data=token_data, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": auth_header, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token exchange failed: {error_detail}" + ) + + token_json = token_response.json() + + # Encrypt sensitive tokens before storing + token_encryption = get_token_encryption() + access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Linear" + ) + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + if token_json.get("expires_in"): + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"])) + + # Store the encrypted access token and refresh token in connector config + connector_config = { + "access_token": token_encryption.encrypt_token(access_token), + "refresh_token": token_encryption.encrypt_token(refresh_token) + if refresh_token + else None, + "token_type": token_json.get("token_type", "Bearer"), + "expires_in": token_json.get("expires_in"), + "expires_at": expires_at.isoformat() if expires_at else None, + "scope": token_json.get("scope"), + # Mark that tokens are encrypted for backward compatibility + "_token_encrypted": True, + } + + # Check if connector already exists for this search space and user + existing_connector_result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.LINEAR_CONNECTOR, + ) + ) + existing_connector = existing_connector_result.scalars().first() + + if existing_connector: + # Update existing connector + existing_connector.config = connector_config + existing_connector.name = "Linear Connector" + existing_connector.is_indexable = True + logger.info( + f"Updated existing Linear connector for user {user_id} in space {space_id}" + ) + else: + # Create new connector + new_connector = SearchSourceConnector( + name="Linear Connector", + connector_type=SearchSourceConnectorType.LINEAR_CONNECTOR, + is_indexable=True, + config=connector_config, + search_space_id=space_id, + user_id=user_id, + ) + session.add(new_connector) + logger.info( + f"Created new Linear connector for user {user_id} in space {space_id}" + ) + + try: + await session.commit() + logger.info(f"Successfully saved Linear connector for user {user_id}") + + # Redirect to the frontend with success params + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=linear-connector" + ) + + except ValidationError as e: + await session.rollback() + raise HTTPException( + status_code=422, detail=f"Validation error: {e!s}" + ) from e + except IntegrityError as e: + await session.rollback() + raise HTTPException( + status_code=409, + detail=f"Integrity error: A connector with this type already exists. {e!s}", + ) from e + except Exception as e: + logger.error(f"Failed to create search source connector: {e!s}") + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"Failed to create search source connector: {e!s}", + ) from e + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to complete Linear OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to complete Linear OAuth: {e!s}" + ) from e + + +async def refresh_linear_token( + session: AsyncSession, connector: SearchSourceConnector +) -> SearchSourceConnector: + """ + Refresh the Linear access token for a connector. + + Args: + session: Database session + connector: Linear connector to refresh + + Returns: + Updated connector object + """ + try: + logger.info(f"Refreshing Linear token for connector {connector.id}") + + credentials = LinearAuthCredentialsBase.from_dict(connector.config) + + # Decrypt tokens if they are encrypted + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_token = credentials.refresh_token + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error(f"Failed to decrypt refresh token: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + + if not refresh_token: + raise HTTPException( + status_code=400, + detail="No refresh token available. Please re-authenticate.", + ) + + auth_header = make_basic_auth_header( + config.LINEAR_CLIENT_ID, config.LINEAR_CLIENT_SECRET + ) + + # Prepare token refresh data + refresh_data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + data=refresh_data, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": auth_header, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token refresh failed: {error_detail}" + ) + + token_json = token_response.json() + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Encrypt new tokens before storing + access_token = token_json.get("access_token") + new_refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Linear refresh" + ) + + # Update credentials object with encrypted tokens + credentials.access_token = token_encryption.encrypt_token(access_token) + if new_refresh_token: + credentials.refresh_token = token_encryption.encrypt_token( + new_refresh_token + ) + credentials.expires_in = expires_in + credentials.expires_at = expires_at + credentials.scope = token_json.get("scope") + + # Update connector config with encrypted tokens + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict + await session.commit() + await session.refresh(connector) + + logger.info(f"Successfully refreshed Linear token for connector {connector.id}") + + return connector + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to refresh Linear token: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to refresh Linear token: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/notion_add_connector_route.py b/surfsense_backend/app/routes/notion_add_connector_route.py new file mode 100644 index 000000000..462ac398c --- /dev/null +++ b/surfsense_backend/app/routes/notion_add_connector_route.py @@ -0,0 +1,459 @@ +""" +Notion Connector OAuth Routes. + +Handles OAuth 2.0 authentication flow for Notion connector. +""" + +import logging +from datetime import UTC, datetime, timedelta +from uuid import UUID + +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import RedirectResponse +from pydantic import ValidationError +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + User, + get_async_session, +) +from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase +from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption + +logger = logging.getLogger(__name__) + +router = APIRouter() + +# Notion OAuth endpoints +AUTHORIZATION_URL = "https://api.notion.com/v1/oauth/authorize" +TOKEN_URL = "https://api.notion.com/v1/oauth/token" + +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + + +def make_basic_auth_header(client_id: str, client_secret: str) -> str: + """Create Basic Auth header for Notion OAuth.""" + import base64 + + credentials = f"{client_id}:{client_secret}".encode() + b64 = base64.b64encode(credentials).decode("ascii") + return f"Basic {b64}" + + +@router.get("/auth/notion/connector/add") +async def connect_notion(space_id: int, user: User = Depends(current_active_user)): + """ + Initiate Notion OAuth flow. + + Args: + space_id: The search space ID + user: Current authenticated user + + Returns: + Authorization URL for redirect + """ + try: + if not space_id: + raise HTTPException(status_code=400, detail="space_id is required") + + if not config.NOTION_CLIENT_ID: + raise HTTPException(status_code=500, detail="Notion OAuth not configured.") + + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) + + # Build authorization URL + from urllib.parse import urlencode + + auth_params = { + "client_id": config.NOTION_CLIENT_ID, + "response_type": "code", + "owner": "user", # Allows both admins and members to authorize + "redirect_uri": config.NOTION_REDIRECT_URI, + "state": state_encoded, + } + + auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}" + + logger.info(f"Generated Notion OAuth URL for user {user.id}, space {space_id}") + return {"auth_url": auth_url} + + except Exception as e: + logger.error(f"Failed to initiate Notion OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to initiate Notion OAuth: {e!s}" + ) from e + + +@router.get("/auth/notion/connector/callback") +async def notion_callback( + request: Request, + code: str | None = None, + error: str | None = None, + state: str | None = None, + session: AsyncSession = Depends(get_async_session), +): + """ + Handle Notion OAuth callback. + + Args: + request: FastAPI request object + code: Authorization code from Notion (if user granted access) + error: Error code from Notion (if user denied access or error occurred) + state: State parameter containing user/space info + session: Database session + + Returns: + Redirect response to frontend + """ + try: + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Notion OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=notion_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=notion_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e + + user_id = UUID(data["user_id"]) + space_id = data["space_id"] + + # Validate redirect URI (security: ensure it matches configured value) + # Note: Notion doesn't send redirect_uri in callback, but we validate + # that we're using the configured one in token exchange + if not config.NOTION_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="NOTION_REDIRECT_URI not configured" + ) + + # Exchange authorization code for access token + auth_header = make_basic_auth_header( + config.NOTION_CLIENT_ID, config.NOTION_CLIENT_SECRET + ) + + token_data = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": config.NOTION_REDIRECT_URI, # Use stored value, not from request + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + json=token_data, + headers={ + "Content-Type": "application/json", + "Authorization": auth_header, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token exchange failed: {error_detail}" + ) + + token_json = token_response.json() + + # Encrypt sensitive tokens before storing + token_encryption = get_token_encryption() + access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Notion" + ) + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Notion returns access_token, refresh_token (if available), and workspace information + # Store the encrypted tokens and workspace info in connector config + connector_config = { + "access_token": token_encryption.encrypt_token(access_token), + "refresh_token": token_encryption.encrypt_token(refresh_token) + if refresh_token + else None, + "expires_in": expires_in, + "expires_at": expires_at.isoformat() if expires_at else None, + "workspace_id": token_json.get("workspace_id"), + "workspace_name": token_json.get("workspace_name"), + "workspace_icon": token_json.get("workspace_icon"), + "bot_id": token_json.get("bot_id"), + # Mark that token is encrypted for backward compatibility + "_token_encrypted": True, + } + + # Check if connector already exists for this search space and user + existing_connector_result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.NOTION_CONNECTOR, + ) + ) + existing_connector = existing_connector_result.scalars().first() + + if existing_connector: + # Update existing connector + existing_connector.config = connector_config + existing_connector.name = "Notion Connector" + existing_connector.is_indexable = True + logger.info( + f"Updated existing Notion connector for user {user_id} in space {space_id}" + ) + else: + # Create new connector + new_connector = SearchSourceConnector( + name="Notion Connector", + connector_type=SearchSourceConnectorType.NOTION_CONNECTOR, + is_indexable=True, + config=connector_config, + search_space_id=space_id, + user_id=user_id, + ) + session.add(new_connector) + logger.info( + f"Created new Notion connector for user {user_id} in space {space_id}" + ) + + try: + await session.commit() + logger.info(f"Successfully saved Notion connector for user {user_id}") + + # Redirect to the frontend with success params + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=notion-connector" + ) + + except ValidationError as e: + await session.rollback() + raise HTTPException( + status_code=422, detail=f"Validation error: {e!s}" + ) from e + except IntegrityError as e: + await session.rollback() + raise HTTPException( + status_code=409, + detail=f"Integrity error: A connector with this type already exists. {e!s}", + ) from e + except Exception as e: + logger.error(f"Failed to create search source connector: {e!s}") + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"Failed to create search source connector: {e!s}", + ) from e + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to complete Notion OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to complete Notion OAuth: {e!s}" + ) from e + + +async def refresh_notion_token( + session: AsyncSession, connector: SearchSourceConnector +) -> SearchSourceConnector: + """ + Refresh the Notion access token for a connector. + + Args: + session: Database session + connector: Notion connector to refresh + + Returns: + Updated connector object + """ + try: + logger.info(f"Refreshing Notion token for connector {connector.id}") + + credentials = NotionAuthCredentialsBase.from_dict(connector.config) + + # Decrypt tokens if they are encrypted + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_token = credentials.refresh_token + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error(f"Failed to decrypt refresh token: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + + if not refresh_token: + raise HTTPException( + status_code=400, + detail="No refresh token available. Please re-authenticate.", + ) + + auth_header = make_basic_auth_header( + config.NOTION_CLIENT_ID, config.NOTION_CLIENT_SECRET + ) + + # Prepare token refresh data + refresh_data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + json=refresh_data, + headers={ + "Content-Type": "application/json", + "Authorization": auth_header, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token refresh failed: {error_detail}" + ) + + token_json = token_response.json() + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Encrypt new tokens before storing + access_token = token_json.get("access_token") + new_refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Notion refresh" + ) + + # Update credentials object with encrypted tokens + credentials.access_token = token_encryption.encrypt_token(access_token) + if new_refresh_token: + credentials.refresh_token = token_encryption.encrypt_token( + new_refresh_token + ) + credentials.expires_in = expires_in + credentials.expires_at = expires_at + + # Preserve workspace info + if not credentials.workspace_id: + credentials.workspace_id = connector.config.get("workspace_id") + if not credentials.workspace_name: + credentials.workspace_name = connector.config.get("workspace_name") + if not credentials.workspace_icon: + credentials.workspace_icon = connector.config.get("workspace_icon") + if not credentials.bot_id: + credentials.bot_id = connector.config.get("bot_id") + + # Update connector config with encrypted tokens + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict + await session.commit() + await session.refresh(connector) + + logger.info(f"Successfully refreshed Notion token for connector {connector.id}") + + return connector + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to refresh Notion token: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to refresh Notion token: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/slack_add_connector_route.py b/surfsense_backend/app/routes/slack_add_connector_route.py new file mode 100644 index 000000000..71a362119 --- /dev/null +++ b/surfsense_backend/app/routes/slack_add_connector_route.py @@ -0,0 +1,478 @@ +""" +Slack Connector OAuth Routes. + +Handles OAuth 2.0 authentication flow for Slack connector. +""" + +import logging +from datetime import UTC, datetime, timedelta +from uuid import UUID + +import httpx +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import RedirectResponse +from pydantic import ValidationError +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + User, + get_async_session, +) +from app.schemas.slack_auth_credentials import SlackAuthCredentialsBase +from app.users import current_active_user +from app.utils.oauth_security import OAuthStateManager, TokenEncryption + +logger = logging.getLogger(__name__) + +router = APIRouter() + +# Slack OAuth endpoints +AUTHORIZATION_URL = "https://slack.com/oauth/v2/authorize" +TOKEN_URL = "https://slack.com/api/oauth.v2.access" + +# OAuth scopes for Slack (Bot Token) +SCOPES = [ + "channels:history", # Read messages in public channels + "channels:read", # View basic information about public channels + "groups:history", # Read messages in private channels + "groups:read", # View basic information about private channels + "im:history", # Read messages in direct messages + "mpim:history", # Read messages in group direct messages + "users:read", # Read user information +] + +# Initialize security utilities +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + """Get or create OAuth state manager instance.""" + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + """Get or create token encryption instance.""" + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + + +@router.get("/auth/slack/connector/add") +async def connect_slack(space_id: int, user: User = Depends(current_active_user)): + """ + Initiate Slack OAuth flow. + + Args: + space_id: The search space ID + user: Current authenticated user + + Returns: + Authorization URL for redirect + """ + try: + if not space_id: + raise HTTPException(status_code=400, detail="space_id is required") + + if not config.SLACK_CLIENT_ID: + raise HTTPException(status_code=500, detail="Slack OAuth not configured.") + + if not config.SECRET_KEY: + raise HTTPException( + status_code=500, detail="SECRET_KEY not configured for OAuth security." + ) + + # Generate secure state parameter with HMAC signature + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) + + # Build authorization URL + from urllib.parse import urlencode + + auth_params = { + "client_id": config.SLACK_CLIENT_ID, + "scope": ",".join(SCOPES), + "redirect_uri": config.SLACK_REDIRECT_URI, + "state": state_encoded, + } + + auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}" + + logger.info(f"Generated Slack OAuth URL for user {user.id}, space {space_id}") + return {"auth_url": auth_url} + + except Exception as e: + logger.error(f"Failed to initiate Slack OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to initiate Slack OAuth: {e!s}" + ) from e + + +@router.get("/auth/slack/connector/callback") +async def slack_callback( + request: Request, + code: str | None = None, + error: str | None = None, + state: str | None = None, + session: AsyncSession = Depends(get_async_session), +): + """ + Handle Slack OAuth callback. + + Args: + request: FastAPI request object + code: Authorization code from Slack (if user granted access) + error: Error code from Slack (if user denied access or error occurred) + state: State parameter containing user/space info + session: Database session + + Returns: + Redirect response to frontend + """ + try: + # Handle OAuth errors (e.g., user denied access) + if error: + logger.warning(f"Slack OAuth error: {error}") + # Try to decode state to get space_id for redirect, but don't fail if it's invalid + space_id = None + if state: + try: + state_manager = get_state_manager() + data = state_manager.validate_state(state) + space_id = data.get("space_id") + except Exception: + # If state is invalid, we'll redirect without space_id + logger.warning("Failed to validate state in error handler") + + # Redirect to frontend with error parameter + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=slack_oauth_denied" + ) + else: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=slack_oauth_denied" + ) + + # Validate required parameters for successful flow + if not code: + raise HTTPException(status_code=400, detail="Missing authorization code") + if not state: + raise HTTPException(status_code=400, detail="Missing state parameter") + + # Validate and decode state with signature verification + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state parameter: {e!s}" + ) from e + + user_id = UUID(data["user_id"]) + space_id = data["space_id"] + + # Validate redirect URI (security: ensure it matches configured value) + if not config.SLACK_REDIRECT_URI: + raise HTTPException( + status_code=500, detail="SLACK_REDIRECT_URI not configured" + ) + + # Exchange authorization code for access token + token_data = { + "client_id": config.SLACK_CLIENT_ID, + "client_secret": config.SLACK_CLIENT_SECRET, + "code": code, + "redirect_uri": config.SLACK_REDIRECT_URI, + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + data=token_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token exchange failed: {error_detail}" + ) + + token_json = token_response.json() + + # Slack OAuth v2 returns success status in the JSON + if not token_json.get("ok", False): + error_msg = token_json.get("error", "Unknown error") + raise HTTPException( + status_code=400, detail=f"Slack OAuth error: {error_msg}" + ) + + # Extract bot token from Slack response + # Slack OAuth v2 returns: { "ok": true, "access_token": "...", "bot": { "bot_user_id": "...", "bot_access_token": "xoxb-..." }, "refresh_token": "...", ... } + bot_token = None + if token_json.get("bot") and token_json["bot"].get("bot_access_token"): + bot_token = token_json["bot"]["bot_access_token"] + elif token_json.get("access_token"): + # Fallback to access_token if bot token not available + bot_token = token_json["access_token"] + else: + raise HTTPException( + status_code=400, detail="No bot token received from Slack" + ) + + # Extract refresh token if available (for token rotation) + refresh_token = token_json.get("refresh_token") + + # Encrypt sensitive tokens before storing + token_encryption = get_token_encryption() + + # Calculate expiration time (UTC, tz-aware) + # Slack tokens don't expire by default, but we'll store expiration info if provided + expires_at = None + if token_json.get("expires_in"): + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"])) + + # Store the encrypted bot token and refresh token in connector config + connector_config = { + "bot_token": token_encryption.encrypt_token(bot_token), + "refresh_token": token_encryption.encrypt_token(refresh_token) + if refresh_token + else None, + "bot_user_id": token_json.get("bot", {}).get("bot_user_id"), + "team_id": token_json.get("team", {}).get("id"), + "team_name": token_json.get("team", {}).get("name"), + "token_type": token_json.get("token_type", "Bearer"), + "expires_in": token_json.get("expires_in"), + "expires_at": expires_at.isoformat() if expires_at else None, + "scope": token_json.get("scope"), + # Mark that tokens are encrypted for backward compatibility + "_token_encrypted": True, + } + + # Check if connector already exists for this search space and user + existing_connector_result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.connector_type + == SearchSourceConnectorType.SLACK_CONNECTOR, + ) + ) + existing_connector = existing_connector_result.scalars().first() + + if existing_connector: + # Update existing connector + existing_connector.config = connector_config + existing_connector.name = "Slack Connector" + existing_connector.is_indexable = True + logger.info( + f"Updated existing Slack connector for user {user_id} in space {space_id}" + ) + else: + # Create new connector + new_connector = SearchSourceConnector( + name="Slack Connector", + connector_type=SearchSourceConnectorType.SLACK_CONNECTOR, + is_indexable=True, + config=connector_config, + search_space_id=space_id, + user_id=user_id, + ) + session.add(new_connector) + logger.info( + f"Created new Slack connector for user {user_id} in space {space_id}" + ) + + try: + await session.commit() + logger.info(f"Successfully saved Slack connector for user {user_id}") + + # Redirect to the frontend with success params + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&success=true&connector=slack-connector" + ) + + except ValidationError as e: + await session.rollback() + raise HTTPException( + status_code=422, detail=f"Validation error: {e!s}" + ) from e + except IntegrityError as e: + await session.rollback() + raise HTTPException( + status_code=409, + detail=f"Integrity error: A connector with this type already exists. {e!s}", + ) from e + except Exception as e: + logger.error(f"Failed to create search source connector: {e!s}") + await session.rollback() + raise HTTPException( + status_code=500, + detail=f"Failed to create search source connector: {e!s}", + ) from e + + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to complete Slack OAuth: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to complete Slack OAuth: {e!s}" + ) from e + + +async def refresh_slack_token( + session: AsyncSession, connector: SearchSourceConnector +) -> SearchSourceConnector: + """ + Refresh the Slack bot token for a connector. + + Args: + session: Database session + connector: Slack connector to refresh + + Returns: + Updated connector object + """ + try: + logger.info(f"Refreshing Slack token for connector {connector.id}") + + credentials = SlackAuthCredentialsBase.from_dict(connector.config) + + # Decrypt tokens if they are encrypted + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_token = credentials.refresh_token + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error(f"Failed to decrypt refresh token: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + + if not refresh_token: + raise HTTPException( + status_code=400, + detail="No refresh token available. Please re-authenticate.", + ) + + # Slack uses oauth.v2.access for token refresh with grant_type=refresh_token + refresh_data = { + "client_id": config.SLACK_CLIENT_ID, + "client_secret": config.SLACK_CLIENT_SECRET, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + data=refresh_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token refresh failed: {error_detail}" + ) + + token_json = token_response.json() + + # Slack OAuth v2 returns success status in the JSON + if not token_json.get("ok", False): + error_msg = token_json.get("error", "Unknown error") + raise HTTPException( + status_code=400, detail=f"Slack OAuth refresh error: {error_msg}" + ) + + # Extract bot token from refresh response + bot_token = None + if token_json.get("bot") and token_json["bot"].get("bot_access_token"): + bot_token = token_json["bot"]["bot_access_token"] + elif token_json.get("access_token"): + bot_token = token_json["access_token"] + else: + raise HTTPException( + status_code=400, detail="No bot token received from Slack refresh" + ) + + # Get new refresh token if provided (Slack may rotate refresh tokens) + new_refresh_token = token_json.get("refresh_token") + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Update credentials object with encrypted tokens + credentials.bot_token = token_encryption.encrypt_token(bot_token) + if new_refresh_token: + credentials.refresh_token = token_encryption.encrypt_token( + new_refresh_token + ) + credentials.expires_in = expires_in + credentials.expires_at = expires_at + credentials.scope = token_json.get("scope") + + # Preserve team info + if not credentials.team_id: + credentials.team_id = connector.config.get("team_id") + if not credentials.team_name: + credentials.team_name = connector.config.get("team_name") + if not credentials.bot_user_id: + credentials.bot_user_id = connector.config.get("bot_user_id") + + # Update connector config with encrypted tokens + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict + await session.commit() + await session.refresh(connector) + + logger.info(f"Successfully refreshed Slack token for connector {connector.id}") + + return connector + except HTTPException: + raise + except Exception as e: + logger.error( + f"Failed to refresh Slack token for connector {connector.id}: {e!s}", + exc_info=True, + ) + raise HTTPException( + status_code=500, detail=f"Failed to refresh Slack token: {e!s}" + ) from e diff --git a/surfsense_backend/app/schemas/discord_auth_credentials.py b/surfsense_backend/app/schemas/discord_auth_credentials.py new file mode 100644 index 000000000..0c18a7554 --- /dev/null +++ b/surfsense_backend/app/schemas/discord_auth_credentials.py @@ -0,0 +1,76 @@ +from datetime import UTC, datetime + +from pydantic import BaseModel, field_validator + + +class DiscordAuthCredentialsBase(BaseModel): + bot_token: str + refresh_token: str | None = None + token_type: str = "Bearer" + expires_in: int | None = None + expires_at: datetime | None = None + scope: str | None = None + bot_user_id: str | None = None + guild_id: str | None = None + guild_name: str | None = None + + @property + def is_expired(self) -> bool: + """Check if the credentials have expired.""" + if self.expires_at is None: + return False # Long-lived token, treat as not expired + return self.expires_at <= datetime.now(UTC) + + @property + def is_refreshable(self) -> bool: + """Check if the credentials can be refreshed.""" + return self.refresh_token is not None + + def to_dict(self) -> dict: + """Convert credentials to dictionary for storage.""" + return { + "bot_token": self.bot_token, + "refresh_token": self.refresh_token, + "token_type": self.token_type, + "expires_in": self.expires_in, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "scope": self.scope, + "bot_user_id": self.bot_user_id, + "guild_id": self.guild_id, + "guild_name": self.guild_name, + } + + @classmethod + def from_dict(cls, data: dict) -> "DiscordAuthCredentialsBase": + """Create credentials from dictionary.""" + expires_at = None + if data.get("expires_at"): + expires_at = datetime.fromisoformat(data["expires_at"]) + + return cls( + bot_token=data.get("bot_token", ""), + refresh_token=data.get("refresh_token"), + token_type=data.get("token_type", "Bearer"), + expires_in=data.get("expires_in"), + expires_at=expires_at, + scope=data.get("scope"), + bot_user_id=data.get("bot_user_id"), + guild_id=data.get("guild_id"), + guild_name=data.get("guild_name"), + ) + + @field_validator("expires_at", mode="before") + @classmethod + def ensure_aware_utc(cls, v): + # Strings like "2025-08-26T14:46:57.367184" + if isinstance(v, str): + # add +00:00 if missing tz info + if v.endswith("Z"): + return datetime.fromisoformat(v.replace("Z", "+00:00")) + dt = datetime.fromisoformat(v) + return dt if dt.tzinfo else dt.replace(tzinfo=UTC) + # datetime objects + if isinstance(v, datetime): + return v if v.tzinfo else v.replace(tzinfo=UTC) + return v + diff --git a/surfsense_backend/app/schemas/linear_auth_credentials.py b/surfsense_backend/app/schemas/linear_auth_credentials.py new file mode 100644 index 000000000..99e8d9111 --- /dev/null +++ b/surfsense_backend/app/schemas/linear_auth_credentials.py @@ -0,0 +1,66 @@ +from datetime import UTC, datetime + +from pydantic import BaseModel, field_validator + + +class LinearAuthCredentialsBase(BaseModel): + access_token: str + refresh_token: str | None = None + token_type: str = "Bearer" + expires_in: int | None = None + expires_at: datetime | None = None + scope: str | None = None + + @property + def is_expired(self) -> bool: + """Check if the credentials have expired.""" + if self.expires_at is None: + return False + return self.expires_at <= datetime.now(UTC) + + @property + def is_refreshable(self) -> bool: + """Check if the credentials can be refreshed.""" + return self.refresh_token is not None + + def to_dict(self) -> dict: + """Convert credentials to dictionary for storage.""" + return { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "token_type": self.token_type, + "expires_in": self.expires_in, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "scope": self.scope, + } + + @classmethod + def from_dict(cls, data: dict) -> "LinearAuthCredentialsBase": + """Create credentials from dictionary.""" + expires_at = None + if data.get("expires_at"): + expires_at = datetime.fromisoformat(data["expires_at"]) + + return cls( + access_token=data["access_token"], + refresh_token=data.get("refresh_token"), + token_type=data.get("token_type", "Bearer"), + expires_in=data.get("expires_in"), + expires_at=expires_at, + scope=data.get("scope"), + ) + + @field_validator("expires_at", mode="before") + @classmethod + def ensure_aware_utc(cls, v): + # Strings like "2025-08-26T14:46:57.367184" + if isinstance(v, str): + # add +00:00 if missing tz info + if v.endswith("Z"): + return datetime.fromisoformat(v.replace("Z", "+00:00")) + dt = datetime.fromisoformat(v) + return dt if dt.tzinfo else dt.replace(tzinfo=UTC) + # datetime objects + if isinstance(v, datetime): + return v if v.tzinfo else v.replace(tzinfo=UTC) + return v diff --git a/surfsense_backend/app/schemas/notion_auth_credentials.py b/surfsense_backend/app/schemas/notion_auth_credentials.py new file mode 100644 index 000000000..e66afb903 --- /dev/null +++ b/surfsense_backend/app/schemas/notion_auth_credentials.py @@ -0,0 +1,72 @@ +from datetime import UTC, datetime + +from pydantic import BaseModel, field_validator + + +class NotionAuthCredentialsBase(BaseModel): + access_token: str + refresh_token: str | None = None + expires_in: int | None = None + expires_at: datetime | None = None + workspace_id: str | None = None + workspace_name: str | None = None + workspace_icon: str | None = None + bot_id: str | None = None + + @property + def is_expired(self) -> bool: + """Check if the credentials have expired.""" + if self.expires_at is None: + return False # Long-lived token, treat as not expired + return self.expires_at <= datetime.now(UTC) + + @property + def is_refreshable(self) -> bool: + """Check if the credentials can be refreshed.""" + return self.refresh_token is not None + + def to_dict(self) -> dict: + """Convert credentials to dictionary for storage.""" + return { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "expires_in": self.expires_in, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "workspace_id": self.workspace_id, + "workspace_name": self.workspace_name, + "workspace_icon": self.workspace_icon, + "bot_id": self.bot_id, + } + + @classmethod + def from_dict(cls, data: dict) -> "NotionAuthCredentialsBase": + """Create credentials from dictionary.""" + expires_at = None + if data.get("expires_at"): + expires_at = datetime.fromisoformat(data["expires_at"]) + + return cls( + access_token=data["access_token"], + refresh_token=data.get("refresh_token"), + expires_in=data.get("expires_in"), + expires_at=expires_at, + workspace_id=data.get("workspace_id"), + workspace_name=data.get("workspace_name"), + workspace_icon=data.get("workspace_icon"), + bot_id=data.get("bot_id"), + ) + + @field_validator("expires_at", mode="before") + @classmethod + def ensure_aware_utc(cls, v): + # Strings like "2025-08-26T14:46:57.367184" + if isinstance(v, str): + # add +00:00 if missing tz info + if v.endswith("Z"): + return datetime.fromisoformat(v.replace("Z", "+00:00")) + dt = datetime.fromisoformat(v) + return dt if dt.tzinfo else dt.replace(tzinfo=UTC) + # datetime objects + if isinstance(v, datetime): + return v if v.tzinfo else v.replace(tzinfo=UTC) + return v diff --git a/surfsense_backend/app/schemas/search_source_connector.py b/surfsense_backend/app/schemas/search_source_connector.py index 1e8a7a38d..dbe4dce1f 100644 --- a/surfsense_backend/app/schemas/search_source_connector.py +++ b/surfsense_backend/app/schemas/search_source_connector.py @@ -30,7 +30,12 @@ class SearchSourceConnectorBase(BaseModel): @model_validator(mode="after") def validate_periodic_indexing(self): - """Validate that periodic indexing configuration is consistent.""" + """Validate that periodic indexing configuration is consistent. + + Supported frequencies: Any positive integer (in minutes). + Common values: 5, 15, 60 (1 hour), 360 (6 hours), 720 (12 hours), 1440 (daily), etc. + The schedule checker will handle any frequency >= 1 minute. + """ if self.periodic_indexing_enabled: if not self.is_indexable: raise ValueError( diff --git a/surfsense_backend/app/schemas/slack_auth_credentials.py b/surfsense_backend/app/schemas/slack_auth_credentials.py new file mode 100644 index 000000000..5148a0985 --- /dev/null +++ b/surfsense_backend/app/schemas/slack_auth_credentials.py @@ -0,0 +1,75 @@ +from datetime import UTC, datetime + +from pydantic import BaseModel, field_validator + + +class SlackAuthCredentialsBase(BaseModel): + bot_token: str + refresh_token: str | None = None + token_type: str = "Bearer" + expires_in: int | None = None + expires_at: datetime | None = None + scope: str | None = None + bot_user_id: str | None = None + team_id: str | None = None + team_name: str | None = None + + @property + def is_expired(self) -> bool: + """Check if the credentials have expired.""" + if self.expires_at is None: + return False # Long-lived token, treat as not expired + return self.expires_at <= datetime.now(UTC) + + @property + def is_refreshable(self) -> bool: + """Check if the credentials can be refreshed.""" + return self.refresh_token is not None + + def to_dict(self) -> dict: + """Convert credentials to dictionary for storage.""" + return { + "bot_token": self.bot_token, + "refresh_token": self.refresh_token, + "token_type": self.token_type, + "expires_in": self.expires_in, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "scope": self.scope, + "bot_user_id": self.bot_user_id, + "team_id": self.team_id, + "team_name": self.team_name, + } + + @classmethod + def from_dict(cls, data: dict) -> "SlackAuthCredentialsBase": + """Create credentials from dictionary.""" + expires_at = None + if data.get("expires_at"): + expires_at = datetime.fromisoformat(data["expires_at"]) + + return cls( + bot_token=data.get("bot_token", ""), + refresh_token=data.get("refresh_token"), + token_type=data.get("token_type", "Bearer"), + expires_in=data.get("expires_in"), + expires_at=expires_at, + scope=data.get("scope"), + bot_user_id=data.get("bot_user_id"), + team_id=data.get("team_id"), + team_name=data.get("team_name"), + ) + + @field_validator("expires_at", mode="before") + @classmethod + def ensure_aware_utc(cls, v): + # Strings like "2025-08-26T14:46:57.367184" + if isinstance(v, str): + # add +00:00 if missing tz info + if v.endswith("Z"): + return datetime.fromisoformat(v.replace("Z", "+00:00")) + dt = datetime.fromisoformat(v) + return dt if dt.tzinfo else dt.replace(tzinfo=UTC) + # datetime objects + if isinstance(v, datetime): + return v if v.tzinfo else v.replace(tzinfo=UTC) + return v diff --git a/surfsense_backend/app/tasks/chat/stream_new_chat.py b/surfsense_backend/app/tasks/chat/stream_new_chat.py index 69b75e5c4..3b87c33f1 100644 --- a/surfsense_backend/app/tasks/chat/stream_new_chat.py +++ b/surfsense_backend/app/tasks/chat/stream_new_chat.py @@ -270,7 +270,8 @@ async def stream_new_chat( # Track if we just finished a tool (text flows silently after tools) just_finished_tool: bool = False # Track write_todos calls to show "Creating plan" vs "Updating plan" - write_todos_call_count: int = 0 + # Disabled for now + # write_todos_call_count: int = 0 def next_thinking_step_id() -> str: nonlocal thinking_step_counter @@ -479,60 +480,60 @@ async def stream_new_chat( status="in_progress", items=last_active_step_items, ) - elif tool_name == "write_todos": - # Track write_todos calls for better messaging - write_todos_call_count += 1 - todos = ( - tool_input.get("todos", []) - if isinstance(tool_input, dict) - else [] - ) - todo_count = len(todos) if isinstance(todos, list) else 0 + # elif tool_name == "write_todos": # Disabled for now + # # Track write_todos calls for better messaging + # write_todos_call_count += 1 + # todos = ( + # tool_input.get("todos", []) + # if isinstance(tool_input, dict) + # else [] + # ) + # todo_count = len(todos) if isinstance(todos, list) else 0 - if write_todos_call_count == 1: - # First call - creating the plan - last_active_step_title = "Creating plan" - last_active_step_items = [f"Defining {todo_count} tasks..."] - else: - # Subsequent calls - updating the plan - # Try to provide context about what's being updated - in_progress_count = ( - sum( - 1 - for t in todos - if isinstance(t, dict) - and t.get("status") == "in_progress" - ) - if isinstance(todos, list) - else 0 - ) - completed_count = ( - sum( - 1 - for t in todos - if isinstance(t, dict) - and t.get("status") == "completed" - ) - if isinstance(todos, list) - else 0 - ) + # if write_todos_call_count == 1: + # # First call - creating the plan + # last_active_step_title = "Creating plan" + # last_active_step_items = [f"Defining {todo_count} tasks..."] + # else: + # # Subsequent calls - updating the plan + # # Try to provide context about what's being updated + # in_progress_count = ( + # sum( + # 1 + # for t in todos + # if isinstance(t, dict) + # and t.get("status") == "in_progress" + # ) + # if isinstance(todos, list) + # else 0 + # ) + # completed_count = ( + # sum( + # 1 + # for t in todos + # if isinstance(t, dict) + # and t.get("status") == "completed" + # ) + # if isinstance(todos, list) + # else 0 + # ) - last_active_step_title = "Updating progress" - last_active_step_items = ( - [ - f"Progress: {completed_count}/{todo_count} completed", - f"In progress: {in_progress_count} tasks", - ] - if completed_count > 0 - else [f"Working on {todo_count} tasks"] - ) + # last_active_step_title = "Updating progress" + # last_active_step_items = ( + # [ + # f"Progress: {completed_count}/{todo_count} completed", + # f"In progress: {in_progress_count} tasks", + # ] + # if completed_count > 0 + # else [f"Working on {todo_count} tasks"] + # ) - yield streaming_service.format_thinking_step( - step_id=tool_step_id, - title=last_active_step_title, - status="in_progress", - items=last_active_step_items, - ) + # yield streaming_service.format_thinking_step( + # step_id=tool_step_id, + # title=last_active_step_title, + # status="in_progress", + # items=last_active_step_items, + # ) elif tool_name == "generate_podcast": podcast_title = ( tool_input.get("podcast_title", "SurfSense Podcast") @@ -596,10 +597,12 @@ async def stream_new_chat( raw_output = event.get("data", {}).get("output", "") # Handle deepagents' write_todos Command object specially - if tool_name == "write_todos" and hasattr(raw_output, "update"): - # deepagents returns a Command object - extract todos directly - tool_output = extract_todos_from_deepagents(raw_output) - elif hasattr(raw_output, "content"): + # Disabled for now + # if tool_name == "write_todos" and hasattr(raw_output, "update"): + # # deepagents returns a Command object - extract todos directly + # tool_output = extract_todos_from_deepagents(raw_output) + # elif hasattr(raw_output, "content"): + if hasattr(raw_output, "content"): # It's a ToolMessage object - extract the content content = raw_output.content # If content is a string that looks like JSON, try to parse it @@ -758,63 +761,63 @@ async def stream_new_chat( status="completed", items=completed_items, ) - elif tool_name == "write_todos": - # Build completion items for planning/updating - if isinstance(tool_output, dict): - todos = tool_output.get("todos", []) - todo_count = len(todos) if isinstance(todos, list) else 0 - completed_count = ( - sum( - 1 - for t in todos - if isinstance(t, dict) - and t.get("status") == "completed" - ) - if isinstance(todos, list) - else 0 - ) - in_progress_count = ( - sum( - 1 - for t in todos - if isinstance(t, dict) - and t.get("status") == "in_progress" - ) - if isinstance(todos, list) - else 0 - ) + # elif tool_name == "write_todos": # Disabled for now + # # Build completion items for planning/updating + # if isinstance(tool_output, dict): + # todos = tool_output.get("todos", []) + # todo_count = len(todos) if isinstance(todos, list) else 0 + # completed_count = ( + # sum( + # 1 + # for t in todos + # if isinstance(t, dict) + # and t.get("status") == "completed" + # ) + # if isinstance(todos, list) + # else 0 + # ) + # in_progress_count = ( + # sum( + # 1 + # for t in todos + # if isinstance(t, dict) + # and t.get("status") == "in_progress" + # ) + # if isinstance(todos, list) + # else 0 + # ) - # Use context-aware completion message - if last_active_step_title == "Creating plan": - completed_items = [f"Created {todo_count} tasks"] - else: - # Updating progress - show stats - completed_items = [ - f"Progress: {completed_count}/{todo_count} completed", - ] - if in_progress_count > 0: - # Find the currently in-progress task name - in_progress_task = next( - ( - t.get("content", "")[:40] - for t in todos - if isinstance(t, dict) - and t.get("status") == "in_progress" - ), - None, - ) - if in_progress_task: - completed_items.append( - f"Current: {in_progress_task}..." - ) - else: - completed_items = ["Plan updated"] - yield streaming_service.format_thinking_step( - step_id=original_step_id, - title=last_active_step_title, - status="completed", - items=completed_items, - ) + # # Use context-aware completion message + # if last_active_step_title == "Creating plan": + # completed_items = [f"Created {todo_count} tasks"] + # else: + # # Updating progress - show stats + # completed_items = [ + # f"Progress: {completed_count}/{todo_count} completed", + # ] + # if in_progress_count > 0: + # # Find the currently in-progress task name + # in_progress_task = next( + # ( + # t.get("content", "")[:40] + # for t in todos + # if isinstance(t, dict) + # and t.get("status") == "in_progress" + # ), + # None, + # ) + # if in_progress_task: + # completed_items.append( + # f"Current: {in_progress_task}..." + # ) + # else: + # completed_items = ["Plan updated"] + # yield streaming_service.format_thinking_step( + # step_id=original_step_id, + # title=last_active_step_title, + # status="completed", + # items=completed_items, + # ) elif tool_name == "ls": # Build completion items showing file names found if isinstance(tool_output, dict): @@ -992,27 +995,27 @@ async def stream_new_chat( yield streaming_service.format_terminal_info( "Knowledge base search completed", "success" ) - elif tool_name == "write_todos": - # Stream the full write_todos result so frontend can render the Plan component - yield streaming_service.format_tool_output_available( - tool_call_id, - tool_output - if isinstance(tool_output, dict) - else {"result": tool_output}, - ) - # Send terminal message with plan info - if isinstance(tool_output, dict): - todos = tool_output.get("todos", []) - todo_count = len(todos) if isinstance(todos, list) else 0 - yield streaming_service.format_terminal_info( - f"Plan created ({todo_count} tasks)", - "success", - ) - else: - yield streaming_service.format_terminal_info( - "Plan created", - "success", - ) + # elif tool_name == "write_todos": # Disabled for now + # # Stream the full write_todos result so frontend can render the Plan component + # yield streaming_service.format_tool_output_available( + # tool_call_id, + # tool_output + # if isinstance(tool_output, dict) + # else {"result": tool_output}, + # ) + # # Send terminal message with plan info + # if isinstance(tool_output, dict): + # todos = tool_output.get("todos", []) + # todo_count = len(todos) if isinstance(todos, list) else 0 + # yield streaming_service.format_terminal_info( + # f"Plan created ({todo_count} tasks)", + # "success", + # ) + # else: + # yield streaming_service.format_terminal_info( + # "Plan created", + # "success", + # ) else: # Default handling for other tools yield streaming_service.format_tool_output_available( diff --git a/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py b/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py index cea2a0529..3ea6dccc9 100644 --- a/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/airtable_indexer.py @@ -18,6 +18,7 @@ from app.utils.document_converters import ( generate_document_summary, generate_unique_identifier_hash, ) +from app.utils.oauth_security import TokenEncryption from .base import ( calculate_date_range, @@ -85,7 +86,52 @@ async def index_airtable_records( return 0, f"Connector with ID {connector_id} not found" # Create credentials from connector config - config_data = connector.config + config_data = ( + connector.config.copy() + ) # Work with a copy to avoid modifying original + + # Decrypt tokens if they are encrypted (only when explicitly marked) + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted: + # Tokens are explicitly marked as encrypted, attempt decryption + if not config.SECRET_KEY: + await task_logger.log_task_failure( + log_entry, + f"SECRET_KEY not configured but tokens are marked as encrypted for connector {connector_id}", + "Missing SECRET_KEY for token decryption", + {"error_type": "MissingSecretKey"}, + ) + return 0, "SECRET_KEY not configured but tokens are marked as encrypted" + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt access_token + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + logger.info( + f"Decrypted Airtable access token for connector {connector_id}" + ) + + # Decrypt refresh_token if present + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + logger.info( + f"Decrypted Airtable refresh token for connector {connector_id}" + ) + except Exception as e: + await task_logger.log_task_failure( + log_entry, + f"Failed to decrypt Airtable tokens for connector {connector_id}: {e!s}", + "Token decryption failed", + {"error_type": "TokenDecryptionError"}, + ) + return 0, f"Failed to decrypt Airtable tokens: {e!s}" + # If _token_encrypted is False or not set, treat tokens as plaintext + try: credentials = AirtableAuthCredentialsBase.from_dict(config_data) except Exception as e: diff --git a/surfsense_backend/app/tasks/connector_indexers/discord_indexer.py b/surfsense_backend/app/tasks/connector_indexers/discord_indexer.py index 9391be788..b3de1f4b5 100644 --- a/surfsense_backend/app/tasks/connector_indexers/discord_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/discord_indexer.py @@ -8,6 +8,7 @@ from datetime import UTC, datetime, timedelta from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from app.config import config from app.connectors.discord_connector import DiscordConnector from app.db import Document, DocumentType, SearchSourceConnectorType from app.services.llm_service import get_user_long_context_llm @@ -69,6 +70,12 @@ async def index_discord_messages( ) try: + # Normalize date parameters - handle 'undefined' strings from frontend + if start_date and (start_date.lower() == "undefined" or start_date.strip() == ""): + start_date = None + if end_date and (end_date.lower() == "undefined" or end_date.strip() == ""): + end_date = None + # Get the connector await task_logger.log_task_progress( log_entry, @@ -92,27 +99,54 @@ async def index_discord_messages( f"Connector with ID {connector_id} not found or is not a Discord connector", ) - # Get the Discord token from the connector config - discord_token = connector.config.get("DISCORD_BOT_TOKEN") - if not discord_token: - await task_logger.log_task_failure( - log_entry, - f"Discord token not found in connector config for connector {connector_id}", - "Missing Discord token", - {"error_type": "MissingToken"}, - ) - return 0, "Discord token not found in connector config" - logger.info(f"Starting Discord indexing for connector {connector_id}") - # Initialize Discord client + # Initialize Discord client with OAuth credentials support await task_logger.log_task_progress( log_entry, f"Initializing Discord client for connector {connector_id}", {"stage": "client_initialization"}, ) - discord_client = DiscordConnector(token=discord_token) + # Check if using OAuth (has bot_token in config) or legacy (has DISCORD_BOT_TOKEN) + has_oauth = connector.config.get("bot_token") is not None + has_legacy = connector.config.get("DISCORD_BOT_TOKEN") is not None + + if has_oauth: + # Use OAuth credentials with auto-refresh + discord_client = DiscordConnector( + session=session, connector_id=connector_id + ) + elif has_legacy: + # Backward compatibility: use legacy token format + discord_token = connector.config.get("DISCORD_BOT_TOKEN") + + # Decrypt token if it's encrypted (legacy tokens might be encrypted) + token_encrypted = connector.config.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY and discord_token: + try: + from app.utils.oauth_security import TokenEncryption + token_encryption = TokenEncryption(config.SECRET_KEY) + discord_token = token_encryption.decrypt_token(discord_token) + logger.info( + f"Decrypted legacy Discord token for connector {connector_id}" + ) + except Exception as e: + logger.warning( + f"Failed to decrypt legacy Discord token for connector {connector_id}: {e!s}. " + "Trying to use token as-is (might be unencrypted)." + ) + # Continue with token as-is - might be unencrypted legacy token + + discord_client = DiscordConnector(token=discord_token) + else: + await task_logger.log_task_failure( + log_entry, + f"Discord credentials not found in connector config for connector {connector_id}", + "Missing Discord credentials", + {"error_type": "MissingCredentials"}, + ) + return 0, "Discord credentials not found in connector config" # Calculate date range if start_date is None or end_date is None: @@ -135,32 +169,63 @@ async def index_discord_messages( if start_date is None: start_date_iso = calculated_start_date.isoformat() else: - # Convert YYYY-MM-DD to ISO format + # Validate and convert YYYY-MM-DD to ISO format + try: + start_date_iso = ( + datetime.strptime(start_date, "%Y-%m-%d") + .replace(tzinfo=UTC) + .isoformat() + ) + except ValueError as e: + logger.warning( + f"Invalid start_date format '{start_date}', using calculated start date: {e!s}" + ) + start_date_iso = calculated_start_date.isoformat() + + if end_date is None: + end_date_iso = calculated_end_date.isoformat() + else: + # Validate and convert YYYY-MM-DD to ISO format + try: + end_date_iso = ( + datetime.strptime(end_date, "%Y-%m-%d") + .replace(tzinfo=UTC) + .isoformat() + ) + except ValueError as e: + logger.warning( + f"Invalid end_date format '{end_date}', using calculated end date: {e!s}" + ) + end_date_iso = calculated_end_date.isoformat() + else: + # Convert provided dates to ISO format for Discord API + try: start_date_iso = ( datetime.strptime(start_date, "%Y-%m-%d") .replace(tzinfo=UTC) .isoformat() ) - - if end_date is None: - end_date_iso = calculated_end_date.isoformat() - else: - # Convert YYYY-MM-DD to ISO format - end_date_iso = ( - datetime.strptime(end_date, "%Y-%m-%d") - .replace(tzinfo=UTC) - .isoformat() + except ValueError as e: + await task_logger.log_task_failure( + log_entry, + f"Invalid start_date format: {start_date}", + f"Date parsing error: {e!s}", + {"error_type": "InvalidDateFormat", "start_date": start_date}, ) - else: - # Convert provided dates to ISO format for Discord API - start_date_iso = ( - datetime.strptime(start_date, "%Y-%m-%d") - .replace(tzinfo=UTC) - .isoformat() - ) - end_date_iso = ( - datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=UTC).isoformat() - ) + return 0, f"Invalid start_date format: {start_date}. Expected YYYY-MM-DD format." + + try: + end_date_iso = ( + datetime.strptime(end_date, "%Y-%m-%d").replace(tzinfo=UTC).isoformat() + ) + except ValueError as e: + await task_logger.log_task_failure( + log_entry, + f"Invalid end_date format: {end_date}", + f"Date parsing error: {e!s}", + {"error_type": "InvalidDateFormat", "end_date": end_date}, + ) + return 0, f"Invalid end_date format: {end_date}. Expected YYYY-MM-DD format." logger.info( f"Indexing Discord messages from {start_date_iso} to {end_date_iso}" diff --git a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py index a5d2bc73a..499f01d66 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_calendar_indexer.py @@ -8,7 +8,6 @@ from google.oauth2.credentials import Credentials from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from app.config import config from app.connectors.google_calendar_connector import GoogleCalendarConnector from app.db import Document, DocumentType, SearchSourceConnectorType from app.services.llm_service import get_user_long_context_llm @@ -84,15 +83,52 @@ async def index_google_calendar_events( return 0, f"Connector with ID {connector_id} not found" # Get the Google Calendar credentials from the connector config - exp = connector.config.get("expiry").replace("Z", "") + config_data = connector.config + + # Decrypt sensitive credentials if encrypted (for backward compatibility) + from app.config import config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + if config_data.get("client_secret"): + config_data["client_secret"] = token_encryption.decrypt_token( + config_data["client_secret"] + ) + + logger.info( + f"Decrypted Google Calendar credentials for connector {connector_id}" + ) + except Exception as e: + await task_logger.log_task_failure( + log_entry, + f"Failed to decrypt Google Calendar credentials for connector {connector_id}: {e!s}", + "Credential decryption failed", + {"error_type": "CredentialDecryptionError"}, + ) + return 0, f"Failed to decrypt Google Calendar credentials: {e!s}" + + exp = config_data.get("expiry", "").replace("Z", "") credentials = Credentials( - token=connector.config.get("token"), - refresh_token=connector.config.get("refresh_token"), - token_uri=connector.config.get("token_uri"), - client_id=connector.config.get("client_id"), - client_secret=connector.config.get("client_secret"), - scopes=connector.config.get("scopes"), - expiry=datetime.fromisoformat(exp), + token=config_data.get("token"), + refresh_token=config_data.get("refresh_token"), + token_uri=config_data.get("token_uri"), + client_id=config_data.get("client_id"), + client_secret=config_data.get("client_secret"), + scopes=config_data.get("scopes"), + expiry=datetime.fromisoformat(exp) if exp else None, ) if ( @@ -122,6 +158,12 @@ async def index_google_calendar_events( connector_id=connector_id, ) + # Handle 'undefined' string from frontend (treat as None) + if start_date == "undefined" or start_date == "": + start_date = None + if end_date == "undefined" or end_date == "": + end_date = None + # Calculate date range if start_date is None or end_date is None: # Fall back to calculating dates based on last_indexed_at diff --git a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py index 343d44072..9eeb46fc8 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_drive_indexer.py @@ -5,6 +5,7 @@ import logging from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from app.config import config from app.connectors.google_drive import ( GoogleDriveClient, categorize_change, @@ -87,6 +88,26 @@ async def index_google_drive_files( {"stage": "client_initialization"}, ) + # Check if credentials are encrypted (only when explicitly marked) + token_encrypted = connector.config.get("_token_encrypted", False) + if token_encrypted: + # Credentials are explicitly marked as encrypted, will be decrypted during client initialization + if not config.SECRET_KEY: + await task_logger.log_task_failure( + log_entry, + f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}", + "Missing SECRET_KEY for token decryption", + {"error_type": "MissingSecretKey"}, + ) + return ( + 0, + "SECRET_KEY not configured but credentials are marked as encrypted", + ) + logger.info( + f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization" + ) + # If _token_encrypted is False or not set, treat credentials as plaintext + drive_client = GoogleDriveClient(session, connector_id) if not folder_id: @@ -249,6 +270,26 @@ async def index_google_drive_single_file( {"stage": "client_initialization"}, ) + # Check if credentials are encrypted (only when explicitly marked) + token_encrypted = connector.config.get("_token_encrypted", False) + if token_encrypted: + # Credentials are explicitly marked as encrypted, will be decrypted during client initialization + if not config.SECRET_KEY: + await task_logger.log_task_failure( + log_entry, + f"SECRET_KEY not configured but credentials are marked as encrypted for connector {connector_id}", + "Missing SECRET_KEY for token decryption", + {"error_type": "MissingSecretKey"}, + ) + return ( + 0, + "SECRET_KEY not configured but credentials are marked as encrypted", + ) + logger.info( + f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization" + ) + # If _token_encrypted is False or not set, treat credentials as plaintext + drive_client = GoogleDriveClient(session, connector_id) # Fetch the file metadata diff --git a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py index d350411e1..e10297057 100644 --- a/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/google_gmail_indexer.py @@ -8,7 +8,6 @@ from google.oauth2.credentials import Credentials from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from app.config import config from app.connectors.google_gmail_connector import GoogleGmailConnector from app.db import ( Document, @@ -88,9 +87,47 @@ async def index_google_gmail_messages( ) return 0, error_msg - # Create credentials from connector config + # Get the Google Gmail credentials from the connector config config_data = connector.config - exp = config_data.get("expiry").replace("Z", "") + + # Decrypt sensitive credentials if encrypted (for backward compatibility) + from app.config import config + from app.utils.oauth_security import TokenEncryption + + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("token"): + config_data["token"] = token_encryption.decrypt_token( + config_data["token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + if config_data.get("client_secret"): + config_data["client_secret"] = token_encryption.decrypt_token( + config_data["client_secret"] + ) + + logger.info( + f"Decrypted Google Gmail credentials for connector {connector_id}" + ) + except Exception as e: + await task_logger.log_task_failure( + log_entry, + f"Failed to decrypt Google Gmail credentials for connector {connector_id}: {e!s}", + "Credential decryption failed", + {"error_type": "CredentialDecryptionError"}, + ) + return 0, f"Failed to decrypt Google Gmail credentials: {e!s}" + + exp = config_data.get("expiry", "") + if exp: + exp = exp.replace("Z", "") credentials = Credentials( token=config_data.get("token"), refresh_token=config_data.get("refresh_token"), @@ -98,7 +135,7 @@ async def index_google_gmail_messages( client_id=config_data.get("client_id"), client_secret=config_data.get("client_secret"), scopes=config_data.get("scopes", []), - expiry=datetime.fromisoformat(exp), + expiry=datetime.fromisoformat(exp) if exp else None, ) if ( diff --git a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py index afc9ffd3b..f1bfd42e8 100644 --- a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py @@ -92,25 +92,34 @@ async def index_linear_issues( f"Connector with ID {connector_id} not found or is not a Linear connector", ) - # Get the Linear token from the connector config - linear_token = connector.config.get("LINEAR_API_KEY") - if not linear_token: + # Check if access_token exists (support both new OAuth format and old API key format) + if not connector.config.get("access_token") and not connector.config.get( + "LINEAR_API_KEY" + ): await task_logger.log_task_failure( log_entry, - f"Linear API token not found in connector config for connector {connector_id}", - "Missing Linear token", + f"Linear access token not found in connector config for connector {connector_id}", + "Missing Linear access token", {"error_type": "MissingToken"}, ) - return 0, "Linear API token not found in connector config" + return 0, "Linear access token not found in connector config" - # Initialize Linear client + # Initialize Linear client with internal refresh capability await task_logger.log_task_progress( log_entry, f"Initializing Linear client for connector {connector_id}", {"stage": "client_initialization"}, ) - linear_client = LinearConnector(token=linear_token) + # Create connector with session and connector_id for internal refresh + # Token refresh will happen automatically when needed + linear_client = LinearConnector(session=session, connector_id=connector_id) + + # Handle 'undefined' string from frontend (treat as None) + if start_date == "undefined" or start_date == "": + start_date = None + if end_date == "undefined" or end_date == "": + end_date = None # Calculate date range start_date_str, end_date_str = calculate_date_range( @@ -131,7 +140,7 @@ async def index_linear_issues( # Get issues within date range try: - issues, error = linear_client.get_issues_by_date_range( + issues, error = await linear_client.get_issues_by_date_range( start_date=start_date_str, end_date=end_date_str, include_comments=True ) diff --git a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py index 332d3e39d..13923269d 100644 --- a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py @@ -2,7 +2,7 @@ Notion connector indexer. """ -from datetime import datetime, timedelta +from datetime import datetime from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession @@ -20,6 +20,7 @@ from app.utils.document_converters import ( from .base import ( build_document_metadata_string, + calculate_date_range, check_document_by_unique_identifier, get_connector_by_id, get_current_timestamp, @@ -91,18 +92,19 @@ async def index_notion_pages( f"Connector with ID {connector_id} not found or is not a Notion connector", ) - # Get the Notion token from the connector config - notion_token = connector.config.get("NOTION_INTEGRATION_TOKEN") - if not notion_token: + # Check if access_token exists (support both new OAuth format and old integration token format) + if not connector.config.get("access_token") and not connector.config.get( + "NOTION_INTEGRATION_TOKEN" + ): await task_logger.log_task_failure( log_entry, - f"Notion integration token not found in connector config for connector {connector_id}", - "Missing Notion token", + f"Notion access token not found in connector config for connector {connector_id}", + "Missing Notion access token", {"error_type": "MissingToken"}, ) - return 0, "Notion integration token not found in connector config" + return 0, "Notion access token not found in connector config" - # Initialize Notion client + # Initialize Notion client with internal refresh capability await task_logger.log_task_progress( log_entry, f"Initializing Notion client for connector {connector_id}", @@ -111,40 +113,30 @@ async def index_notion_pages( logger.info(f"Initializing Notion client for connector {connector_id}") - # Calculate date range - if start_date is None or end_date is None: - # Fall back to calculating dates - calculated_end_date = datetime.now() - calculated_start_date = calculated_end_date - timedelta( - days=365 - ) # Check for last 1 year of pages + # Handle 'undefined' string from frontend (treat as None) + if start_date == "undefined" or start_date == "": + start_date = None + if end_date == "undefined" or end_date == "": + end_date = None - # Use calculated dates if not provided - if start_date is None: - start_date_iso = calculated_start_date.strftime("%Y-%m-%dT%H:%M:%SZ") - else: - # Convert YYYY-MM-DD to ISO format - start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) + # Calculate date range using the shared utility function + start_date_str, end_date_str = calculate_date_range( + connector, start_date, end_date, default_days_back=365 + ) - if end_date is None: - end_date_iso = calculated_end_date.strftime("%Y-%m-%dT%H:%M:%SZ") - else: - # Convert YYYY-MM-DD to ISO format - end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) - else: - # Convert provided dates to ISO format for Notion API - start_date_iso = datetime.strptime(start_date, "%Y-%m-%d").strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) - end_date_iso = datetime.strptime(end_date, "%Y-%m-%d").strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) + # Convert YYYY-MM-DD to ISO format for Notion API + start_date_iso = datetime.strptime(start_date_str, "%Y-%m-%d").strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + end_date_iso = datetime.strptime(end_date_str, "%Y-%m-%d").strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) - notion_client = NotionHistoryConnector(token=notion_token) + # Create connector with session and connector_id for internal refresh + # Token refresh will happen automatically when needed + notion_client = NotionHistoryConnector( + session=session, connector_id=connector_id + ) logger.info(f"Fetching Notion pages from {start_date_iso} to {end_date_iso}") diff --git a/surfsense_backend/app/tasks/connector_indexers/slack_indexer.py b/surfsense_backend/app/tasks/connector_indexers/slack_indexer.py index 5119aba2e..dad64ad27 100644 --- a/surfsense_backend/app/tasks/connector_indexers/slack_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/slack_indexer.py @@ -92,25 +92,24 @@ async def index_slack_messages( f"Connector with ID {connector_id} not found or is not a Slack connector", ) - # Get the Slack token from the connector config - slack_token = connector.config.get("SLACK_BOT_TOKEN") - if not slack_token: - await task_logger.log_task_failure( - log_entry, - f"Slack token not found in connector config for connector {connector_id}", - "Missing Slack token", - {"error_type": "MissingToken"}, - ) - return 0, "Slack token not found in connector config" + # Note: Token handling is now done automatically by SlackHistory + # with auto-refresh support. We just need to pass session and connector_id. - # Initialize Slack client + # Initialize Slack client with auto-refresh support await task_logger.log_task_progress( log_entry, f"Initializing Slack client for connector {connector_id}", {"stage": "client_initialization"}, ) - slack_client = SlackHistory(token=slack_token) + # Use the new pattern with session and connector_id for auto-refresh + slack_client = SlackHistory(session=session, connector_id=connector_id) + + # Handle 'undefined' string from frontend (treat as None) + if start_date == "undefined" or start_date == "": + start_date = None + if end_date == "undefined" or end_date == "": + end_date = None # Calculate date range await task_logger.log_task_progress( @@ -141,7 +140,7 @@ async def index_slack_messages( # Get all channels try: - channels = slack_client.get_all_channels() + channels = await slack_client.get_all_channels() except Exception as e: await task_logger.log_task_failure( log_entry, @@ -190,7 +189,7 @@ async def index_slack_messages( continue # Get messages for this channel - messages, error = slack_client.get_history_by_date_range( + messages, error = await slack_client.get_history_by_date_range( channel_id=channel_id, start_date=start_date_str, end_date=end_date_str, @@ -223,7 +222,7 @@ async def index_slack_messages( ]: continue - formatted_msg = slack_client.format_message( + formatted_msg = await slack_client.format_message( msg, include_user_info=True ) formatted_messages.append(formatted_msg) diff --git a/surfsense_backend/app/utils/oauth_security.py b/surfsense_backend/app/utils/oauth_security.py new file mode 100644 index 000000000..5135cdef4 --- /dev/null +++ b/surfsense_backend/app/utils/oauth_security.py @@ -0,0 +1,210 @@ +""" +OAuth Security Utilities. + +Provides secure state parameter generation/validation and token encryption +for OAuth 2.0 flows. +""" + +import base64 +import hashlib +import hmac +import json +import logging +import time +from uuid import UUID + +from cryptography.fernet import Fernet +from fastapi import HTTPException + +logger = logging.getLogger(__name__) + + +class OAuthStateManager: + """Manages secure OAuth state parameters with HMAC signatures.""" + + def __init__(self, secret_key: str, max_age_seconds: int = 600): + """ + Initialize OAuth state manager. + + Args: + secret_key: Secret key for HMAC signing (should be SECRET_KEY from config) + max_age_seconds: Maximum age of state parameter in seconds (default 10 minutes) + """ + if not secret_key: + raise ValueError("secret_key is required for OAuth state management") + self.secret_key = secret_key + self.max_age_seconds = max_age_seconds + + def generate_secure_state( + self, space_id: int, user_id: UUID, **extra_fields + ) -> str: + """ + Generate cryptographically signed state parameter. + + Args: + space_id: The search space ID + user_id: The user ID + **extra_fields: Additional fields to include in state (e.g., code_verifier for PKCE) + + Returns: + Base64-encoded state parameter with HMAC signature + """ + timestamp = int(time.time()) + state_payload = { + "space_id": space_id, + "user_id": str(user_id), + "timestamp": timestamp, + } + + # Add any extra fields (e.g., code_verifier for PKCE) + state_payload.update(extra_fields) + + # Create signature + payload_str = json.dumps(state_payload, sort_keys=True) + signature = hmac.new( + self.secret_key.encode(), + payload_str.encode(), + hashlib.sha256, + ).hexdigest() + + # Include signature in state + state_payload["signature"] = signature + state_encoded = base64.urlsafe_b64encode( + json.dumps(state_payload).encode() + ).decode() + + return state_encoded + + def validate_state(self, state: str) -> dict: + """ + Validate and decode state parameter with signature verification. + + Args: + state: The state parameter from OAuth callback + + Returns: + Decoded state data (space_id, user_id, timestamp) + + Raises: + HTTPException: If state is invalid, expired, or tampered with + """ + try: + decoded = base64.urlsafe_b64decode(state.encode()).decode() + data = json.loads(decoded) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Invalid state format: {e!s}" + ) from e + + # Verify signature exists + signature = data.pop("signature", None) + if not signature: + raise HTTPException(status_code=400, detail="Missing state signature") + + # Verify signature + payload_str = json.dumps(data, sort_keys=True) + expected_signature = hmac.new( + self.secret_key.encode(), + payload_str.encode(), + hashlib.sha256, + ).hexdigest() + + if not hmac.compare_digest(signature, expected_signature): + raise HTTPException( + status_code=400, detail="Invalid state signature - possible tampering" + ) + + # Verify timestamp (prevent replay attacks) + timestamp = data.get("timestamp", 0) + current_time = time.time() + age = current_time - timestamp + + if age < 0: + raise HTTPException(status_code=400, detail="Invalid state timestamp") + + if age > self.max_age_seconds: + raise HTTPException( + status_code=400, + detail="State parameter expired. Please try again.", + ) + + return data + + +class TokenEncryption: + """Encrypt/decrypt sensitive OAuth tokens for storage.""" + + def __init__(self, secret_key: str): + """ + Initialize token encryption. + + Args: + secret_key: Secret key for encryption (should be SECRET_KEY from config) + """ + if not secret_key: + raise ValueError("secret_key is required for token encryption") + # Derive Fernet key from secret using SHA256 + # Note: In production, consider using HKDF for key derivation + key = base64.urlsafe_b64encode(hashlib.sha256(secret_key.encode()).digest()) + try: + self.cipher = Fernet(key) + except Exception as e: + raise ValueError(f"Failed to initialize encryption cipher: {e!s}") from e + + def encrypt_token(self, token: str) -> str: + """ + Encrypt a token for storage. + + Args: + token: Plaintext token to encrypt + + Returns: + Encrypted token string + """ + if not token: + return token + try: + return self.cipher.encrypt(token.encode()).decode() + except Exception as e: + logger.error(f"Failed to encrypt token: {e!s}") + raise ValueError(f"Token encryption failed: {e!s}") from e + + def decrypt_token(self, encrypted_token: str) -> str: + """ + Decrypt a stored token. + + Args: + encrypted_token: Encrypted token string + + Returns: + Decrypted plaintext token + """ + if not encrypted_token: + return encrypted_token + try: + return self.cipher.decrypt(encrypted_token.encode()).decode() + except Exception as e: + logger.error(f"Failed to decrypt token: {e!s}") + raise ValueError(f"Token decryption failed: {e!s}") from e + + def is_encrypted(self, token: str) -> bool: + """ + Check if a token appears to be encrypted. + + Args: + token: Token string to check + + Returns: + True if token appears encrypted, False otherwise + """ + if not token: + return False + # Encrypted tokens are base64-encoded and have specific format + # This is a heuristic check - encrypted tokens are longer and base64-like + try: + # Try to decode as base64 + base64.urlsafe_b64decode(token.encode()) + # If it's base64 and reasonably long, likely encrypted + return len(token) > 20 + except Exception: + return False diff --git a/surfsense_backend/app/utils/validators.py b/surfsense_backend/app/utils/validators.py index 6b69fb3e1..f1620c0e5 100644 --- a/surfsense_backend/app/utils/validators.py +++ b/surfsense_backend/app/utils/validators.py @@ -513,11 +513,22 @@ def validate_connector_config( ], "validators": {}, }, - "SLACK_CONNECTOR": {"required": ["SLACK_BOT_TOKEN"], "validators": {}}, - "NOTION_CONNECTOR": { - "required": ["NOTION_INTEGRATION_TOKEN"], - "validators": {}, - }, + # "SLACK_CONNECTOR": { + # "required": [], # OAuth uses bot_token (encrypted), legacy uses SLACK_BOT_TOKEN + # "optional": [ + # "bot_token", + # "SLACK_BOT_TOKEN", + # "bot_user_id", + # "team_id", + # "team_name", + # "token_type", + # "expires_in", + # "expires_at", + # "scope", + # "_token_encrypted", + # ], + # "validators": {}, + # }, "GITHUB_CONNECTOR": { "required": ["GITHUB_PAT", "repo_full_names"], "validators": { @@ -526,8 +537,7 @@ def validate_connector_config( ) }, }, - "LINEAR_CONNECTOR": {"required": ["LINEAR_API_KEY"], "validators": {}}, - "DISCORD_CONNECTOR": {"required": ["DISCORD_BOT_TOKEN"], "validators": {}}, + # "DISCORD_CONNECTOR": {"required": ["DISCORD_BOT_TOKEN"], "validators": {}}, "JIRA_CONNECTOR": { "required": ["JIRA_EMAIL", "JIRA_API_TOKEN", "JIRA_BASE_URL"], "validators": { diff --git a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx index 35a096497..b1abd647f 100644 --- a/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx +++ b/surfsense_web/app/dashboard/[search_space_id]/new-chat/[[...chat_id]]/page.tsx @@ -20,7 +20,7 @@ import { } from "@/atoms/chat/mentioned-documents.atom"; import { clearPlanOwnerRegistry, - extractWriteTodosFromContent, + // extractWriteTodosFromContent, hydratePlanStateAtom, } from "@/atoms/chat/plan-state.atom"; import { Thread } from "@/components/assistant-ui/thread"; @@ -30,7 +30,7 @@ import { DisplayImageToolUI } from "@/components/tool-ui/display-image"; import { GeneratePodcastToolUI } from "@/components/tool-ui/generate-podcast"; import { LinkPreviewToolUI } from "@/components/tool-ui/link-preview"; import { ScrapeWebpageToolUI } from "@/components/tool-ui/scrape-webpage"; -import { WriteTodosToolUI } from "@/components/tool-ui/write-todos"; +// import { WriteTodosToolUI } from "@/components/tool-ui/write-todos"; import { getBearerToken } from "@/lib/auth-utils"; import { createAttachmentAdapter, extractAttachmentContent } from "@/lib/chat/attachment-adapter"; import { @@ -199,7 +199,7 @@ const TOOLS_WITH_UI = new Set([ "link_preview", "display_image", "scrape_webpage", - "write_todos", + // "write_todos", // Disabled for now ]); /** @@ -291,10 +291,11 @@ export default function NewChatPage() { restoredThinkingSteps.set(`msg-${msg.id}`, steps); } // Hydrate write_todos plan state from persisted tool calls - const writeTodosCalls = extractWriteTodosFromContent(msg.content); - for (const todoData of writeTodosCalls) { - hydratePlanState(todoData); - } + // Disabled for now + // const writeTodosCalls = extractWriteTodosFromContent(msg.content); + // for (const todoData of writeTodosCalls) { + // hydratePlanState(todoData); + // } } if (msg.role === "user") { const docs = extractMentionedDocuments(msg.content); @@ -911,7 +912,7 @@ export default function NewChatPage() { - + {/* Disabled for now */}
{ )} - + {/* YouTube Crawler View - shown when adding YouTube videos */} {isYouTubeView && searchSpaceId ? ( @@ -272,7 +272,7 @@ export const ConnectorIndicator: FC = () => { {/* Content */}
-
+
= ({ id, title, @@ -86,13 +125,13 @@ export const ConnectorCard: FC = ({ // Show last indexed date for connected connectors if (lastIndexedAt) { return ( - - Last indexed: {format(new Date(lastIndexedAt), "MMM d, yyyy")} + + Last indexed: {formatLastIndexedDate(lastIndexedAt)} ); } // Fallback for connected but never indexed - return Never indexed; + return Never indexed; } return description; @@ -113,9 +152,9 @@ export const ConnectorCard: FC = ({
{title}
-
{getStatusContent()}
+
{getStatusContent()}
{isConnected && documentCount !== undefined && ( -

+

{formatDocumentCount(documentCount)}

)} @@ -130,12 +169,10 @@ export const ConnectorCard: FC = ({ !isConnected && "shadow-xs" )} onClick={isConnected ? onManage : onConnect} - disabled={isConnecting || isIndexing} + disabled={isConnecting} > {isConnecting ? ( - ) : isIndexing ? ( - "Syncing..." ) : isConnected ? ( "Manage" ) : id === "youtube-crawler" ? ( diff --git a/surfsense_web/components/assistant-ui/connector-popup/components/connector-dialog-header.tsx b/surfsense_web/components/assistant-ui/connector-popup/components/connector-dialog-header.tsx index a18c79a1f..34e1ae2e9 100644 --- a/surfsense_web/components/assistant-ui/connector-popup/components/connector-dialog-header.tsx +++ b/surfsense_web/components/assistant-ui/connector-popup/components/connector-dialog-header.tsx @@ -24,20 +24,20 @@ export const ConnectorDialogHeader: FC = ({ return (
- + Connectors - + Search across all your apps and data in one place. -
+
= ({
- + = ({ + {/* Back button - only show if not from OAuth */} + {!isFromOAuth && ( + + )} {/* Success header */}
@@ -187,15 +193,7 @@ export const IndexingConfigurationView: FC = ({
{/* Fixed Footer - Action buttons */} -
- +
@@ -163,9 +199,8 @@ export const ActiveConnectorsTab: FC = ({ size="sm" className="h-8 text-[11px] px-3 rounded-lg font-medium bg-white text-slate-700 hover:bg-slate-50 border-0 shadow-xs dark:bg-secondary dark:text-secondary-foreground dark:hover:bg-secondary/80" onClick={onManage ? () => onManage(connector) : undefined} - disabled={isIndexing} > - {isIndexing ? "Syncing..." : "Manage"} + Manage
); diff --git a/surfsense_web/components/assistant-ui/document-upload-popup.tsx b/surfsense_web/components/assistant-ui/document-upload-popup.tsx index d1fa208d2..da3b820e5 100644 --- a/surfsense_web/components/assistant-ui/document-upload-popup.tsx +++ b/surfsense_web/components/assistant-ui/document-upload-popup.tsx @@ -1,5 +1,6 @@ "use client"; +import { Upload } from "lucide-react"; import { useAtomValue } from "jotai"; import { useRouter } from "next/navigation"; import { @@ -85,6 +86,7 @@ const DocumentUploadPopupContent: FC<{ }> = ({ isOpen, onOpenChange }) => { const searchSpaceId = useAtomValue(activeSearchSpaceIdAtom); const router = useRouter(); + const [isAccordionExpanded, setIsAccordionExpanded] = useState(false); if (!searchSpaceId) return null; @@ -95,16 +97,40 @@ const DocumentUploadPopupContent: FC<{ return ( - + Upload Document -
-
-
- + + {/* Fixed Header */} +
+ {/* Upload header */} +
+
+ +
+
+

Upload Documents

+

+ Upload and sync your documents to your search space +

- {/* Bottom fade shadow */} -
+
+ + {/* Scrollable Content */} +
+
+
+ +
+
+ {/* Bottom fade shadow - only show when scrolling */} + {isAccordionExpanded && ( +
+ )}
diff --git a/surfsense_web/components/editConnector/types.ts b/surfsense_web/components/editConnector/types.ts index e17a3513a..43fab23e0 100644 --- a/surfsense_web/components/editConnector/types.ts +++ b/surfsense_web/components/editConnector/types.ts @@ -36,7 +36,6 @@ export const editConnectorSchema = z.object({ SEARXNG_LANGUAGE: z.string().optional(), SEARXNG_SAFESEARCH: z.string().optional(), SEARXNG_VERIFY_SSL: z.string().optional(), - LINEAR_API_KEY: z.string().optional(), LINKUP_API_KEY: z.string().optional(), DISCORD_BOT_TOKEN: z.string().optional(), CONFLUENCE_BASE_URL: z.string().optional(), diff --git a/surfsense_web/components/new-chat/document-mention-picker.tsx b/surfsense_web/components/new-chat/document-mention-picker.tsx index 2d5d46267..7a9e7aaa5 100644 --- a/surfsense_web/components/new-chat/document-mention-picker.tsx +++ b/surfsense_web/components/new-chat/document-mention-picker.tsx @@ -12,7 +12,7 @@ import { useState, } from "react"; import { getConnectorIcon } from "@/contracts/enums/connectorIcons"; -import type { Document } from "@/contracts/types/document.types"; +import type { Document, GetDocumentsResponse } from "@/contracts/types/document.types"; import { documentsApiService } from "@/lib/apis/documents-api.service"; import { cacheKeys } from "@/lib/query-client/cache-keys"; import { cn } from "@/lib/utils"; @@ -31,6 +31,8 @@ interface DocumentMentionPickerProps { externalSearch?: string; } +const PAGE_SIZE = 20; + function useDebounced(value: T, delay = 300) { const [debounced, setDebounced] = useState(value); useEffect(() => { @@ -52,12 +54,29 @@ export const DocumentMentionPicker = forwardRef< const debouncedSearch = useDebounced(search, 150); const [highlightedIndex, setHighlightedIndex] = useState(0); const itemRefs = useRef>(new Map()); + const scrollContainerRef = useRef(null); + // State for pagination + const [accumulatedDocuments, setAccumulatedDocuments] = useState([]); + const [currentPage, setCurrentPage] = useState(0); + const [hasMore, setHasMore] = useState(false); + const [isLoadingMore, setIsLoadingMore] = useState(false); + + // Reset pagination when search or search space changes + // biome-ignore lint/correctness/useExhaustiveDependencies: intentionally reset pagination when search/space changes + useEffect(() => { + setAccumulatedDocuments([]); + setCurrentPage(0); + setHasMore(false); + setHighlightedIndex(0); + }, [debouncedSearch, searchSpaceId]); + + // Query params for initial fetch (page 0) const fetchQueryParams = useMemo( () => ({ search_space_id: searchSpaceId, page: 0, - page_size: 20, + page_size: PAGE_SIZE, }), [searchSpaceId] ); @@ -66,31 +85,97 @@ export const DocumentMentionPicker = forwardRef< return { search_space_id: searchSpaceId, page: 0, - page_size: 20, + page_size: PAGE_SIZE, title: debouncedSearch, }; }, [debouncedSearch, searchSpaceId]); - // Use query for fetching documents + // Use query for fetching first page of documents const { data: documents, isLoading: isDocumentsLoading } = useQuery({ queryKey: cacheKeys.documents.withQueryParams(fetchQueryParams), queryFn: () => documentsApiService.getDocuments({ queryParams: fetchQueryParams }), staleTime: 3 * 60 * 1000, - enabled: !!searchSpaceId && !debouncedSearch.trim(), + enabled: !!searchSpaceId && !debouncedSearch.trim() && currentPage === 0, }); - // Searching + // Searching - first page const { data: searchedDocuments, isLoading: isSearchedDocumentsLoading } = useQuery({ queryKey: cacheKeys.documents.withQueryParams(searchQueryParams), queryFn: () => documentsApiService.searchDocuments({ queryParams: searchQueryParams }), staleTime: 3 * 60 * 1000, - enabled: !!searchSpaceId && !!debouncedSearch.trim(), + enabled: !!searchSpaceId && !!debouncedSearch.trim() && currentPage === 0, }); - const actualDocuments = debouncedSearch.trim() - ? searchedDocuments?.items || [] - : documents?.items || []; - const actualLoading = debouncedSearch.trim() ? isSearchedDocumentsLoading : isDocumentsLoading; + // Update accumulated documents when first page loads + useEffect(() => { + if (currentPage === 0) { + if (debouncedSearch.trim()) { + if (searchedDocuments) { + setAccumulatedDocuments(searchedDocuments.items); + setHasMore(searchedDocuments.has_more); + } + } else { + if (documents) { + setAccumulatedDocuments(documents.items); + setHasMore(documents.has_more); + } + } + } + }, [documents, searchedDocuments, debouncedSearch, currentPage]); + + // Function to load next page + const loadNextPage = useCallback(async () => { + if (isLoadingMore || !hasMore) return; + + const nextPage = currentPage + 1; + setIsLoadingMore(true); + + try { + let response: GetDocumentsResponse; + if (debouncedSearch.trim()) { + const queryParams = { + search_space_id: searchSpaceId, + page: nextPage, + page_size: PAGE_SIZE, + title: debouncedSearch, + }; + response = await documentsApiService.searchDocuments({ queryParams }); + } else { + const queryParams = { + search_space_id: searchSpaceId, + page: nextPage, + page_size: PAGE_SIZE, + }; + response = await documentsApiService.getDocuments({ queryParams }); + } + + setAccumulatedDocuments((prev) => [...prev, ...response.items]); + setHasMore(response.has_more); + setCurrentPage(nextPage); + } catch (error) { + console.error("Failed to load next page:", error); + } finally { + setIsLoadingMore(false); + } + }, [currentPage, hasMore, isLoadingMore, debouncedSearch, searchSpaceId]); + + // Infinite scroll handler + const handleScroll = useCallback( + (e: React.UIEvent) => { + const target = e.currentTarget; + const scrollBottom = target.scrollHeight - target.scrollTop - target.clientHeight; + + // Load more when within 50px of bottom + if (scrollBottom < 50 && hasMore && !isLoadingMore) { + loadNextPage(); + } + }, + [hasMore, isLoadingMore, loadNextPage] + ); + + const actualDocuments = accumulatedDocuments; + const actualLoading = + (debouncedSearch.trim() ? isSearchedDocumentsLoading : isDocumentsLoading) && currentPage === 0; // Track already selected document IDs const selectedIds = useMemo( @@ -184,8 +269,12 @@ export const DocumentMentionPicker = forwardRef< role="listbox" tabIndex={-1} > - {/* Document List - Shows max 3 items on mobile, 5 items on desktop */} -
+ {/* Document List - Shows max 5 items on mobile, 7-8 items on desktop */} +
{actualLoading ? (
@@ -235,6 +324,12 @@ export const DocumentMentionPicker = forwardRef< ); })} + {/* Loading indicator for additional pages */} + {isLoadingMore && ( +
+
+
+ )}
)}
diff --git a/surfsense_web/components/sources/DocumentUploadTab.tsx b/surfsense_web/components/sources/DocumentUploadTab.tsx index 5280ea850..0b7f7b51f 100644 --- a/surfsense_web/components/sources/DocumentUploadTab.tsx +++ b/surfsense_web/components/sources/DocumentUploadTab.tsx @@ -31,6 +31,7 @@ import { GridPattern } from "./GridPattern"; interface DocumentUploadTabProps { searchSpaceId: string; onSuccess?: () => void; + onAccordionStateChange?: (isExpanded: boolean) => void; } const audioFileTypes = { @@ -109,11 +110,16 @@ const FILE_TYPE_CONFIG: Record> = { const cardClass = "border border-border bg-slate-400/5 dark:bg-white/5"; -export function DocumentUploadTab({ searchSpaceId, onSuccess }: DocumentUploadTabProps) { +export function DocumentUploadTab({ + searchSpaceId, + onSuccess, + onAccordionStateChange, +}: DocumentUploadTabProps) { const t = useTranslations("upload_documents"); const router = useRouter(); const [files, setFiles] = useState([]); const [uploadProgress, setUploadProgress] = useState(0); + const [accordionValue, setAccordionValue] = useState(""); const [uploadDocumentMutation] = useAtom(uploadDocumentMutationAtom); const { mutate: uploadDocuments, isPending: isUploading } = uploadDocumentMutation; const fileInputRef = useRef(null); @@ -154,6 +160,15 @@ export function DocumentUploadTab({ searchSpaceId, onSuccess }: DocumentUploadTa const totalFileSize = files.reduce((total, file) => total + file.size, 0); + // Track accordion state changes + const handleAccordionChange = useCallback( + (value: string) => { + setAccordionValue(value); + onAccordionStateChange?.(value === "supported-file-types"); + }, + [onAccordionStateChange] + ); + const handleUpload = async () => { setUploadProgress(0); trackDocumentUploadStarted(Number(searchSpaceId), files.length, totalFileSize); @@ -190,11 +205,13 @@ export function DocumentUploadTab({ searchSpaceId, onSuccess }: DocumentUploadTa initial={{ opacity: 0, y: 20 }} animate={{ opacity: 1, y: 0 }} transition={{ duration: 0.3 }} - className="space-y-3 sm:space-y-6 max-w-4xl mx-auto" + className="space-y-3 sm:space-y-6 max-w-4xl mx-auto pt-0" > - - - {t("file_size_limit")} + + + + {t("file_size_limit")} + @@ -366,11 +383,13 @@ export function DocumentUploadTab({ searchSpaceId, onSuccess }: DocumentUploadTa - -
+ +
diff --git a/surfsense_web/hooks/use-connector-edit-page.ts b/surfsense_web/hooks/use-connector-edit-page.ts index 8fa690d04..3beb80247 100644 --- a/surfsense_web/hooks/use-connector-edit-page.ts +++ b/surfsense_web/hooks/use-connector-edit-page.ts @@ -86,7 +86,6 @@ export function useConnectorEditPage(connectorId: number, searchSpaceId: string) SEARXNG_LANGUAGE: "", SEARXNG_SAFESEARCH: "", SEARXNG_VERIFY_SSL: "", - LINEAR_API_KEY: "", DISCORD_BOT_TOKEN: "", CONFLUENCE_BASE_URL: "", CONFLUENCE_EMAIL: "", @@ -134,7 +133,6 @@ export function useConnectorEditPage(connectorId: number, searchSpaceId: string) config.SEARXNG_VERIFY_SSL !== undefined && config.SEARXNG_VERIFY_SSL !== null ? String(config.SEARXNG_VERIFY_SSL) : "", - LINEAR_API_KEY: config.LINEAR_API_KEY || "", LINKUP_API_KEY: config.LINKUP_API_KEY || "", DISCORD_BOT_TOKEN: config.DISCORD_BOT_TOKEN || "", CONFLUENCE_BASE_URL: config.CONFLUENCE_BASE_URL || "", @@ -384,16 +382,6 @@ export function useConnectorEditPage(connectorId: number, searchSpaceId: string) break; } - case "LINEAR_CONNECTOR": - if (formData.LINEAR_API_KEY !== originalConfig.LINEAR_API_KEY) { - if (!formData.LINEAR_API_KEY) { - toast.error("Linear API Key cannot be empty."); - setIsSaving(false); - return; - } - newConfig = { LINEAR_API_KEY: formData.LINEAR_API_KEY }; - } - break; case "LINKUP_API": if (formData.LINKUP_API_KEY !== originalConfig.LINKUP_API_KEY) { if (!formData.LINKUP_API_KEY) { @@ -599,8 +587,6 @@ export function useConnectorEditPage(connectorId: number, searchSpaceId: string) "SEARXNG_VERIFY_SSL", verifyValue === null ? "" : String(verifyValue) ); - } else if (connector.connector_type === "LINEAR_CONNECTOR") { - editForm.setValue("LINEAR_API_KEY", newlySavedConfig.LINEAR_API_KEY || ""); } else if (connector.connector_type === "LINKUP_API") { editForm.setValue("LINKUP_API_KEY", newlySavedConfig.LINKUP_API_KEY || ""); } else if (connector.connector_type === "DISCORD_CONNECTOR") {