diff --git a/surfsense_backend/app/connectors/linear_connector.py b/surfsense_backend/app/connectors/linear_connector.py index 86eeb2e4f..148aa4d0a 100644 --- a/surfsense_backend/app/connectors/linear_connector.py +++ b/surfsense_backend/app/connectors/linear_connector.py @@ -5,33 +5,153 @@ A module for retrieving issues and comments from Linear. Allows fetching issue lists and their comments with date range filtering. """ +import logging from datetime import datetime from typing import Any import requests +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector +from app.routes.linear_add_connector_route import refresh_linear_token +from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase +from app.utils.oauth_security import TokenEncryption + +logger = logging.getLogger(__name__) class LinearConnector: """Class for retrieving issues and comments from Linear.""" - def __init__(self, access_token: str | None = None): + def __init__( + self, + session: AsyncSession, + connector_id: int, + credentials: LinearAuthCredentialsBase | None = None, + ): """ - Initialize the LinearConnector class. + Initialize the LinearConnector class with auto-refresh capability. Args: - access_token: Linear OAuth access token or API key (optional, can be set later with set_token) + session: Database session for updating connector + connector_id: Connector ID for direct updates + credentials: Linear OAuth credentials (optional, will be loaded from DB if not provided) """ - self.access_token = access_token + self._session = session + self._connector_id = connector_id + self._credentials = credentials self.api_url = "https://api.linear.app/graphql" - def set_token(self, access_token: str) -> None: + async def _get_valid_token(self) -> str: """ - Set the Linear OAuth access token or API key. + Get valid Linear access token, refreshing if needed. - Args: - access_token: Linear OAuth access token or API key + Returns: + Valid access token + + Raises: + ValueError: If credentials are missing or invalid + Exception: If token refresh fails """ - self.access_token = access_token + # Load credentials from DB if not provided + if self._credentials is None: + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise ValueError(f"Connector {self._connector_id} not found") + + config_data = connector.config.copy() + + # Decrypt credentials if they are encrypted + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + logger.info( + f"Decrypted Linear credentials for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to decrypt Linear credentials for connector {self._connector_id}: {e!s}" + ) + raise ValueError( + f"Failed to decrypt Linear credentials: {e!s}" + ) from e + + try: + self._credentials = LinearAuthCredentialsBase.from_dict(config_data) + except Exception as e: + raise ValueError(f"Invalid Linear credentials: {e!s}") from e + + # Check if token is expired and refreshable + if self._credentials.is_expired and self._credentials.is_refreshable: + try: + logger.info( + f"Linear token expired for connector {self._connector_id}, refreshing..." + ) + + # Get connector for refresh + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise RuntimeError( + f"Connector {self._connector_id} not found; cannot refresh token." + ) + + # Refresh token + connector = await refresh_linear_token(self._session, connector) + + # Reload credentials after refresh + config_data = connector.config.copy() + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + self._credentials = LinearAuthCredentialsBase.from_dict(config_data) + + logger.info( + f"Successfully refreshed Linear token for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to refresh Linear token for connector {self._connector_id}: {e!s}" + ) + raise Exception( + f"Failed to refresh Linear OAuth credentials: {e!s}" + ) from e + + return self._credentials.access_token def get_headers(self) -> dict[str, str]: """ @@ -43,21 +163,24 @@ class LinearConnector: Raises: ValueError: If no Linear access token has been set """ - if not self.access_token: + # This is a synchronous method, but we need async token refresh + # For now, we'll raise an error if called directly + # All API calls should go through execute_graphql_query which handles async refresh + if not self._credentials or not self._credentials.access_token: raise ValueError( - "Linear access token not initialized. Call set_token() first." + "Linear access token not initialized. Use execute_graphql_query() method." ) return { "Content-Type": "application/json", - "Authorization": f"Bearer {self.access_token}", + "Authorization": f"Bearer {self._credentials.access_token}", } - def execute_graphql_query( + async def execute_graphql_query( self, query: str, variables: dict[str, Any] | None = None ) -> dict[str, Any]: """ - Execute a GraphQL query against the Linear API. + Execute a GraphQL query against the Linear API with automatic token refresh. Args: query: GraphQL query string @@ -70,12 +193,14 @@ class LinearConnector: ValueError: If no Linear access token has been set Exception: If the API request fails """ - if not self.access_token: - raise ValueError( - "Linear access token not initialized. Call set_token() first." - ) + # Get valid token (refreshes if needed) + access_token = await self._get_valid_token() + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + } - headers = self.get_headers() payload = {"query": query} if variables: @@ -90,7 +215,9 @@ class LinearConnector: f"Query failed with status code {response.status_code}: {response.text}" ) - def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]: + async def get_all_issues( + self, include_comments: bool = True + ) -> list[dict[str, Any]]: """ Fetch all issues from Linear. @@ -153,7 +280,7 @@ class LinearConnector: }} """ - result = self.execute_graphql_query(query) + result = await self.execute_graphql_query(query) # Extract issues from the response if ( @@ -165,7 +292,7 @@ class LinearConnector: return [] - def get_issues_by_date_range( + async def get_issues_by_date_range( self, start_date: str, end_date: str, include_comments: bool = True ) -> tuple[list[dict[str, Any]], str | None]: """ @@ -277,7 +404,7 @@ class LinearConnector: # Handle pagination to get all issues while has_next_page: variables = {"after": cursor} if cursor else {} - result = self.execute_graphql_query(query, variables) + result = await self.execute_graphql_query(query, variables) # Check for errors if "errors" in result: @@ -465,37 +592,3 @@ class LinearConnector: return dt.strftime("%Y-%m-%d %H:%M:%S") except ValueError: return iso_date - - -# Example usage (uncomment to use): -""" -if __name__ == "__main__": - # Set your OAuth access token here - access_token = "YOUR_LINEAR_ACCESS_TOKEN" - - linear = LinearConnector(access_token=access_token) - - try: - # Get all issues with comments - issues = linear.get_all_issues() - print(f"Retrieved {len(issues)} issues") - - # Format and print the first issue as markdown - if issues: - issue_md = linear.format_issue_to_markdown(issues[0]) - print("\nSample Issue in Markdown:\n") - print(issue_md) - - # Get issues by date range - start_date = "2023-01-01" - end_date = "2023-01-31" - date_issues, error = linear.get_issues_by_date_range(start_date, end_date) - - if error: - print(f"Error: {error}") - else: - print(f"\nRetrieved {len(date_issues)} issues from {start_date} to {end_date}") - - except Exception as e: - print(f"Error: {e}") -""" diff --git a/surfsense_backend/app/connectors/notion_history.py b/surfsense_backend/app/connectors/notion_history.py index bb518f88c..e38218a6e 100644 --- a/surfsense_backend/app/connectors/notion_history.py +++ b/surfsense_backend/app/connectors/notion_history.py @@ -1,19 +1,167 @@ +import logging + from notion_client import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select + +from app.config import config +from app.db import SearchSourceConnector +from app.routes.notion_add_connector_route import refresh_notion_token +from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase +from app.utils.oauth_security import TokenEncryption + +logger = logging.getLogger(__name__) class NotionHistoryConnector: - def __init__(self, token): + def __init__( + self, + session: AsyncSession, + connector_id: int, + credentials: NotionAuthCredentialsBase | None = None, + ): """ - Initialize the NotionPageFetcher with a token. + Initialize the NotionHistoryConnector with auto-refresh capability. Args: - token (str): Notion OAuth access token or integration token + session: Database session for updating connector + connector_id: Connector ID for direct updates + credentials: Notion OAuth credentials (optional, will be loaded from DB if not provided) """ - self.notion = AsyncClient(auth=token) + self._session = session + self._connector_id = connector_id + self._credentials = credentials + self._notion_client: AsyncClient | None = None + + async def _get_valid_token(self) -> str: + """ + Get valid Notion access token, refreshing if needed. + + Returns: + Valid access token + + Raises: + ValueError: If credentials are missing or invalid + Exception: If token refresh fails + """ + # Load credentials from DB if not provided + if self._credentials is None: + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise ValueError(f"Connector {self._connector_id} not found") + + config_data = connector.config.copy() + + # Decrypt credentials if they are encrypted + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + try: + token_encryption = TokenEncryption(config.SECRET_KEY) + + # Decrypt sensitive fields + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + logger.info( + f"Decrypted Notion credentials for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to decrypt Notion credentials for connector {self._connector_id}: {e!s}" + ) + raise ValueError( + f"Failed to decrypt Notion credentials: {e!s}" + ) from e + + try: + self._credentials = NotionAuthCredentialsBase.from_dict(config_data) + except Exception as e: + raise ValueError(f"Invalid Notion credentials: {e!s}") from e + + # Check if token is expired and refreshable + if self._credentials.is_expired and self._credentials.is_refreshable: + try: + logger.info( + f"Notion token expired for connector {self._connector_id}, refreshing..." + ) + + # Get connector for refresh + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + + if not connector: + raise RuntimeError( + f"Connector {self._connector_id} not found; cannot refresh token." + ) + + # Refresh token + connector = await refresh_notion_token(self._session, connector) + + # Reload credentials after refresh + config_data = connector.config.copy() + token_encrypted = config_data.get("_token_encrypted", False) + if token_encrypted and config.SECRET_KEY: + token_encryption = TokenEncryption(config.SECRET_KEY) + if config_data.get("access_token"): + config_data["access_token"] = token_encryption.decrypt_token( + config_data["access_token"] + ) + if config_data.get("refresh_token"): + config_data["refresh_token"] = token_encryption.decrypt_token( + config_data["refresh_token"] + ) + + self._credentials = NotionAuthCredentialsBase.from_dict(config_data) + + # Invalidate cached client so it's recreated with new token + self._notion_client = None + + logger.info( + f"Successfully refreshed Notion token for connector {self._connector_id}" + ) + except Exception as e: + logger.error( + f"Failed to refresh Notion token for connector {self._connector_id}: {e!s}" + ) + raise Exception( + f"Failed to refresh Notion OAuth credentials: {e!s}" + ) from e + + return self._credentials.access_token + + async def _get_client(self) -> AsyncClient: + """ + Get or create Notion AsyncClient with valid token. + + Returns: + Notion AsyncClient instance + """ + if self._notion_client is None: + token = await self._get_valid_token() + self._notion_client = AsyncClient(auth=token) + return self._notion_client async def close(self): """Close the async client connection.""" - await self.notion.aclose() + if self._notion_client: + await self._notion_client.aclose() + self._notion_client = None async def __aenter__(self): """Async context manager entry.""" @@ -34,6 +182,8 @@ class NotionHistoryConnector: Returns: list: List of dictionaries containing page data """ + notion = await self._get_client() + # Build the filter for the search # Note: Notion API requires specific filter structure search_params = {} @@ -67,7 +217,7 @@ class NotionHistoryConnector: if cursor: search_params["start_cursor"] = cursor - search_results = await self.notion.search(**search_params) + search_results = await notion.search(**search_params) pages.extend(search_results["results"]) has_more = search_results.get("has_more", False) @@ -125,6 +275,8 @@ class NotionHistoryConnector: Returns: list: List of processed blocks from the page """ + notion = await self._get_client() + blocks = [] has_more = True cursor = None @@ -132,11 +284,11 @@ class NotionHistoryConnector: # Paginate through all blocks while has_more: if cursor: - response = await self.notion.blocks.children.list( + response = await notion.blocks.children.list( block_id=page_id, start_cursor=cursor ) else: - response = await self.notion.blocks.children.list(block_id=page_id) + response = await notion.blocks.children.list(block_id=page_id) blocks.extend(response["results"]) has_more = response["has_more"] @@ -162,6 +314,8 @@ class NotionHistoryConnector: Returns: dict: Processed block with content and children """ + notion = await self._get_client() + block_id = block["id"] block_type = block["type"] @@ -174,9 +328,7 @@ class NotionHistoryConnector: if has_children: # Fetch and process child blocks - children_response = await self.notion.blocks.children.list( - block_id=block_id - ) + children_response = await notion.blocks.children.list(block_id=block_id) for child_block in children_response["results"]: child_blocks.append(await self.process_block(child_block)) diff --git a/surfsense_backend/app/routes/linear_add_connector_route.py b/surfsense_backend/app/routes/linear_add_connector_route.py index 9747d4507..7a7fc196a 100644 --- a/surfsense_backend/app/routes/linear_add_connector_route.py +++ b/surfsense_backend/app/routes/linear_add_connector_route.py @@ -23,6 +23,7 @@ from app.db import ( User, get_async_session, ) +from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase from app.users import current_active_user from app.utils.oauth_security import OAuthStateManager, TokenEncryption @@ -328,3 +329,120 @@ async def linear_callback( raise HTTPException( status_code=500, detail=f"Failed to complete Linear OAuth: {e!s}" ) from e + + +async def refresh_linear_token( + session: AsyncSession, connector: SearchSourceConnector +) -> SearchSourceConnector: + """ + Refresh the Linear access token for a connector. + + Args: + session: Database session + connector: Linear connector to refresh + + Returns: + Updated connector object + """ + try: + logger.info(f"Refreshing Linear token for connector {connector.id}") + + credentials = LinearAuthCredentialsBase.from_dict(connector.config) + + # Decrypt tokens if they are encrypted + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_token = credentials.refresh_token + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error(f"Failed to decrypt refresh token: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + + if not refresh_token: + raise HTTPException( + status_code=400, + detail="No refresh token available. Please re-authenticate.", + ) + + auth_header = make_basic_auth_header( + config.LINEAR_CLIENT_ID, config.LINEAR_CLIENT_SECRET + ) + + # Prepare token refresh data + refresh_data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + data=refresh_data, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": auth_header, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token refresh failed: {error_detail}" + ) + + token_json = token_response.json() + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Encrypt new tokens before storing + access_token = token_json.get("access_token") + new_refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Linear refresh" + ) + + # Update credentials object with encrypted tokens + credentials.access_token = token_encryption.encrypt_token(access_token) + if new_refresh_token: + credentials.refresh_token = token_encryption.encrypt_token( + new_refresh_token + ) + credentials.expires_in = expires_in + credentials.expires_at = expires_at + credentials.scope = token_json.get("scope") + + # Update connector config with encrypted tokens + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict + await session.commit() + await session.refresh(connector) + + logger.info(f"Successfully refreshed Linear token for connector {connector.id}") + + return connector + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to refresh Linear token: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to refresh Linear token: {e!s}" + ) from e diff --git a/surfsense_backend/app/routes/notion_add_connector_route.py b/surfsense_backend/app/routes/notion_add_connector_route.py index a36fdf7e8..462ac398c 100644 --- a/surfsense_backend/app/routes/notion_add_connector_route.py +++ b/surfsense_backend/app/routes/notion_add_connector_route.py @@ -5,6 +5,7 @@ Handles OAuth 2.0 authentication flow for Notion connector. """ import logging +from datetime import UTC, datetime, timedelta from uuid import UUID import httpx @@ -22,6 +23,7 @@ from app.db import ( User, get_async_session, ) +from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase from app.users import current_active_user from app.utils.oauth_security import OAuthStateManager, TokenEncryption @@ -230,15 +232,28 @@ async def notion_callback( # Encrypt sensitive tokens before storing token_encryption = get_token_encryption() access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") if not access_token: raise HTTPException( status_code=400, detail="No access token received from Notion" ) - # Notion returns access_token and workspace information - # Store the encrypted access token and workspace info in connector config + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Notion returns access_token, refresh_token (if available), and workspace information + # Store the encrypted tokens and workspace info in connector config connector_config = { "access_token": token_encryption.encrypt_token(access_token), + "refresh_token": token_encryption.encrypt_token(refresh_token) + if refresh_token + else None, + "expires_in": expires_in, + "expires_at": expires_at.isoformat() if expires_at else None, "workspace_id": token_json.get("workspace_id"), "workspace_name": token_json.get("workspace_name"), "workspace_icon": token_json.get("workspace_icon"), @@ -316,3 +331,129 @@ async def notion_callback( raise HTTPException( status_code=500, detail=f"Failed to complete Notion OAuth: {e!s}" ) from e + + +async def refresh_notion_token( + session: AsyncSession, connector: SearchSourceConnector +) -> SearchSourceConnector: + """ + Refresh the Notion access token for a connector. + + Args: + session: Database session + connector: Notion connector to refresh + + Returns: + Updated connector object + """ + try: + logger.info(f"Refreshing Notion token for connector {connector.id}") + + credentials = NotionAuthCredentialsBase.from_dict(connector.config) + + # Decrypt tokens if they are encrypted + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + + refresh_token = credentials.refresh_token + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error(f"Failed to decrypt refresh token: {e!s}") + raise HTTPException( + status_code=500, detail="Failed to decrypt stored refresh token" + ) from e + + if not refresh_token: + raise HTTPException( + status_code=400, + detail="No refresh token available. Please re-authenticate.", + ) + + auth_header = make_basic_auth_header( + config.NOTION_CLIENT_ID, config.NOTION_CLIENT_SECRET + ) + + # Prepare token refresh data + refresh_data = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + json=refresh_data, + headers={ + "Content-Type": "application/json", + "Authorization": auth_header, + }, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise HTTPException( + status_code=400, detail=f"Token refresh failed: {error_detail}" + ) + + token_json = token_response.json() + + # Calculate expiration time (UTC, tz-aware) + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + now_utc = datetime.now(UTC) + expires_at = now_utc + timedelta(seconds=int(expires_in)) + + # Encrypt new tokens before storing + access_token = token_json.get("access_token") + new_refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="No access token received from Notion refresh" + ) + + # Update credentials object with encrypted tokens + credentials.access_token = token_encryption.encrypt_token(access_token) + if new_refresh_token: + credentials.refresh_token = token_encryption.encrypt_token( + new_refresh_token + ) + credentials.expires_in = expires_in + credentials.expires_at = expires_at + + # Preserve workspace info + if not credentials.workspace_id: + credentials.workspace_id = connector.config.get("workspace_id") + if not credentials.workspace_name: + credentials.workspace_name = connector.config.get("workspace_name") + if not credentials.workspace_icon: + credentials.workspace_icon = connector.config.get("workspace_icon") + if not credentials.bot_id: + credentials.bot_id = connector.config.get("bot_id") + + # Update connector config with encrypted tokens + credentials_dict = credentials.to_dict() + credentials_dict["_token_encrypted"] = True + connector.config = credentials_dict + await session.commit() + await session.refresh(connector) + + logger.info(f"Successfully refreshed Notion token for connector {connector.id}") + + return connector + except HTTPException: + raise + except Exception as e: + logger.error(f"Failed to refresh Notion token: {e!s}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to refresh Notion token: {e!s}" + ) from e diff --git a/surfsense_backend/app/schemas/linear_auth_credentials.py b/surfsense_backend/app/schemas/linear_auth_credentials.py new file mode 100644 index 000000000..99e8d9111 --- /dev/null +++ b/surfsense_backend/app/schemas/linear_auth_credentials.py @@ -0,0 +1,66 @@ +from datetime import UTC, datetime + +from pydantic import BaseModel, field_validator + + +class LinearAuthCredentialsBase(BaseModel): + access_token: str + refresh_token: str | None = None + token_type: str = "Bearer" + expires_in: int | None = None + expires_at: datetime | None = None + scope: str | None = None + + @property + def is_expired(self) -> bool: + """Check if the credentials have expired.""" + if self.expires_at is None: + return False + return self.expires_at <= datetime.now(UTC) + + @property + def is_refreshable(self) -> bool: + """Check if the credentials can be refreshed.""" + return self.refresh_token is not None + + def to_dict(self) -> dict: + """Convert credentials to dictionary for storage.""" + return { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "token_type": self.token_type, + "expires_in": self.expires_in, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "scope": self.scope, + } + + @classmethod + def from_dict(cls, data: dict) -> "LinearAuthCredentialsBase": + """Create credentials from dictionary.""" + expires_at = None + if data.get("expires_at"): + expires_at = datetime.fromisoformat(data["expires_at"]) + + return cls( + access_token=data["access_token"], + refresh_token=data.get("refresh_token"), + token_type=data.get("token_type", "Bearer"), + expires_in=data.get("expires_in"), + expires_at=expires_at, + scope=data.get("scope"), + ) + + @field_validator("expires_at", mode="before") + @classmethod + def ensure_aware_utc(cls, v): + # Strings like "2025-08-26T14:46:57.367184" + if isinstance(v, str): + # add +00:00 if missing tz info + if v.endswith("Z"): + return datetime.fromisoformat(v.replace("Z", "+00:00")) + dt = datetime.fromisoformat(v) + return dt if dt.tzinfo else dt.replace(tzinfo=UTC) + # datetime objects + if isinstance(v, datetime): + return v if v.tzinfo else v.replace(tzinfo=UTC) + return v diff --git a/surfsense_backend/app/schemas/notion_auth_credentials.py b/surfsense_backend/app/schemas/notion_auth_credentials.py new file mode 100644 index 000000000..e66afb903 --- /dev/null +++ b/surfsense_backend/app/schemas/notion_auth_credentials.py @@ -0,0 +1,72 @@ +from datetime import UTC, datetime + +from pydantic import BaseModel, field_validator + + +class NotionAuthCredentialsBase(BaseModel): + access_token: str + refresh_token: str | None = None + expires_in: int | None = None + expires_at: datetime | None = None + workspace_id: str | None = None + workspace_name: str | None = None + workspace_icon: str | None = None + bot_id: str | None = None + + @property + def is_expired(self) -> bool: + """Check if the credentials have expired.""" + if self.expires_at is None: + return False # Long-lived token, treat as not expired + return self.expires_at <= datetime.now(UTC) + + @property + def is_refreshable(self) -> bool: + """Check if the credentials can be refreshed.""" + return self.refresh_token is not None + + def to_dict(self) -> dict: + """Convert credentials to dictionary for storage.""" + return { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "expires_in": self.expires_in, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "workspace_id": self.workspace_id, + "workspace_name": self.workspace_name, + "workspace_icon": self.workspace_icon, + "bot_id": self.bot_id, + } + + @classmethod + def from_dict(cls, data: dict) -> "NotionAuthCredentialsBase": + """Create credentials from dictionary.""" + expires_at = None + if data.get("expires_at"): + expires_at = datetime.fromisoformat(data["expires_at"]) + + return cls( + access_token=data["access_token"], + refresh_token=data.get("refresh_token"), + expires_in=data.get("expires_in"), + expires_at=expires_at, + workspace_id=data.get("workspace_id"), + workspace_name=data.get("workspace_name"), + workspace_icon=data.get("workspace_icon"), + bot_id=data.get("bot_id"), + ) + + @field_validator("expires_at", mode="before") + @classmethod + def ensure_aware_utc(cls, v): + # Strings like "2025-08-26T14:46:57.367184" + if isinstance(v, str): + # add +00:00 if missing tz info + if v.endswith("Z"): + return datetime.fromisoformat(v.replace("Z", "+00:00")) + dt = datetime.fromisoformat(v) + return dt if dt.tzinfo else dt.replace(tzinfo=UTC) + # datetime objects + if isinstance(v, datetime): + return v if v.tzinfo else v.replace(tzinfo=UTC) + return v diff --git a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py index 5a2933971..f1bfd42e8 100644 --- a/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/linear_indexer.py @@ -7,6 +7,7 @@ from datetime import datetime from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession +from app.config import config from app.connectors.linear_connector import LinearConnector from app.db import Document, DocumentType, SearchSourceConnectorType from app.services.llm_service import get_user_long_context_llm @@ -91,12 +92,10 @@ async def index_linear_issues( f"Connector with ID {connector_id} not found or is not a Linear connector", ) - # Get the Linear access token from the connector config - # Support both new OAuth format (access_token) and old API key format (LINEAR_API_KEY) - linear_access_token = connector.config.get( - "access_token" - ) or connector.config.get("LINEAR_API_KEY") - if not linear_access_token: + # Check if access_token exists (support both new OAuth format and old API key format) + if not connector.config.get("access_token") and not connector.config.get( + "LINEAR_API_KEY" + ): await task_logger.log_task_failure( log_entry, f"Linear access token not found in connector config for connector {connector_id}", @@ -105,47 +104,16 @@ async def index_linear_issues( ) return 0, "Linear access token not found in connector config" - # Decrypt token if it's encrypted (only when explicitly marked) - from app.config import config - from app.utils.oauth_security import TokenEncryption - - token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted: - # Token is explicitly marked as encrypted, attempt decryption - if not config.SECRET_KEY: - await task_logger.log_task_failure( - log_entry, - f"SECRET_KEY not configured but token is marked as encrypted for connector {connector_id}", - "Missing SECRET_KEY for token decryption", - {"error_type": "MissingSecretKey"}, - ) - return 0, "SECRET_KEY not configured but token is marked as encrypted" - try: - token_encryption = TokenEncryption(config.SECRET_KEY) - linear_access_token = token_encryption.decrypt_token( - linear_access_token - ) - logger.info( - f"Decrypted Linear access token for connector {connector_id}" - ) - except Exception as e: - await task_logger.log_task_failure( - log_entry, - f"Failed to decrypt Linear access token for connector {connector_id}: {e!s}", - "Token decryption failed", - {"error_type": "TokenDecryptionError"}, - ) - return 0, f"Failed to decrypt Linear access token: {e!s}" - # If _token_encrypted is False or not set, treat token as plaintext - - # Initialize Linear client + # Initialize Linear client with internal refresh capability await task_logger.log_task_progress( log_entry, f"Initializing Linear client for connector {connector_id}", {"stage": "client_initialization"}, ) - linear_client = LinearConnector(access_token=linear_access_token) + # Create connector with session and connector_id for internal refresh + # Token refresh will happen automatically when needed + linear_client = LinearConnector(session=session, connector_id=connector_id) # Handle 'undefined' string from frontend (treat as None) if start_date == "undefined" or start_date == "": @@ -172,7 +140,7 @@ async def index_linear_issues( # Get issues within date range try: - issues, error = linear_client.get_issues_by_date_range( + issues, error = await linear_client.get_issues_by_date_range( start_date=start_date_str, end_date=end_date_str, include_comments=True ) diff --git a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py index 80a9e91e0..13923269d 100644 --- a/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py +++ b/surfsense_backend/app/tasks/connector_indexers/notion_indexer.py @@ -7,7 +7,6 @@ from datetime import datetime from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from app.config import config from app.connectors.notion_history import NotionHistoryConnector from app.db import Document, DocumentType, SearchSourceConnectorType from app.services.llm_service import get_user_long_context_llm @@ -18,7 +17,6 @@ from app.utils.document_converters import ( generate_document_summary, generate_unique_identifier_hash, ) -from app.utils.oauth_security import TokenEncryption from .base import ( build_document_metadata_string, @@ -94,12 +92,10 @@ async def index_notion_pages( f"Connector with ID {connector_id} not found or is not a Notion connector", ) - # Get the Notion access token from the connector config - # Support both new OAuth format (access_token) and old integration token format (NOTION_INTEGRATION_TOKEN) - notion_token = connector.config.get("access_token") or connector.config.get( + # Check if access_token exists (support both new OAuth format and old integration token format) + if not connector.config.get("access_token") and not connector.config.get( "NOTION_INTEGRATION_TOKEN" - ) - if not notion_token: + ): await task_logger.log_task_failure( log_entry, f"Notion access token not found in connector config for connector {connector_id}", @@ -108,35 +104,7 @@ async def index_notion_pages( ) return 0, "Notion access token not found in connector config" - # Decrypt token if it's encrypted (only when explicitly marked) - token_encrypted = connector.config.get("_token_encrypted", False) - if token_encrypted: - # Token is explicitly marked as encrypted, attempt decryption - if not config.SECRET_KEY: - await task_logger.log_task_failure( - log_entry, - f"SECRET_KEY not configured but token is marked as encrypted for connector {connector_id}", - "Missing SECRET_KEY for token decryption", - {"error_type": "MissingSecretKey"}, - ) - return 0, "SECRET_KEY not configured but token is marked as encrypted" - try: - token_encryption = TokenEncryption(config.SECRET_KEY) - notion_token = token_encryption.decrypt_token(notion_token) - logger.info( - f"Decrypted Notion access token for connector {connector_id}" - ) - except Exception as e: - await task_logger.log_task_failure( - log_entry, - f"Failed to decrypt Notion access token for connector {connector_id}: {e!s}", - "Token decryption failed", - {"error_type": "TokenDecryptionError"}, - ) - return 0, f"Failed to decrypt Notion access token: {e!s}" - # If _token_encrypted is False or not set, treat token as plaintext - - # Initialize Notion client + # Initialize Notion client with internal refresh capability await task_logger.log_task_progress( log_entry, f"Initializing Notion client for connector {connector_id}", @@ -164,7 +132,11 @@ async def index_notion_pages( "%Y-%m-%dT%H:%M:%SZ" ) - notion_client = NotionHistoryConnector(token=notion_token) + # Create connector with session and connector_id for internal refresh + # Token refresh will happen automatically when needed + notion_client = NotionHistoryConnector( + session=session, connector_id=connector_id + ) logger.info(f"Fetching Notion pages from {start_date_iso} to {end_date_iso}")