SurfSense/surfsense_backend/app/connectors/notion_history.py
Anish Sarkar 4f77d171d8 feat: implement auto-refresh capability for Linear and Notion connectors similar to google oauth based ones
- Enhanced LinearConnector and NotionHistoryConnector classes to support automatic token refresh, improving reliability in accessing APIs.
- Updated initialization to require session and connector ID, allowing for dynamic credential management.
- Introduced new credential schemas for Linear and Notion, encapsulating access and refresh tokens with expiration handling.
- Refactored indexers to utilize the new connector structure, ensuring seamless integration with the updated authentication flow.
- Improved error handling and logging during token refresh processes for better debugging and user feedback.
2026-01-03 04:24:00 +05:30

387 lines
13 KiB
Python

import logging
from notion_client import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.config import config
from app.db import SearchSourceConnector
from app.routes.notion_add_connector_route import refresh_notion_token
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
from app.utils.oauth_security import TokenEncryption
logger = logging.getLogger(__name__)
class NotionHistoryConnector:
def __init__(
self,
session: AsyncSession,
connector_id: int,
credentials: NotionAuthCredentialsBase | None = None,
):
"""
Initialize the NotionHistoryConnector with auto-refresh capability.
Args:
session: Database session for updating connector
connector_id: Connector ID for direct updates
credentials: Notion OAuth credentials (optional, will be loaded from DB if not provided)
"""
self._session = session
self._connector_id = connector_id
self._credentials = credentials
self._notion_client: AsyncClient | None = None
async def _get_valid_token(self) -> str:
"""
Get valid Notion access token, refreshing if needed.
Returns:
Valid access token
Raises:
ValueError: If credentials are missing or invalid
Exception: If token refresh fails
"""
# Load credentials from DB if not provided
if self._credentials is None:
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == self._connector_id
)
)
connector = result.scalars().first()
if not connector:
raise ValueError(f"Connector {self._connector_id} not found")
config_data = connector.config.copy()
# Decrypt credentials if they are encrypted
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
try:
token_encryption = TokenEncryption(config.SECRET_KEY)
# Decrypt sensitive fields
if config_data.get("access_token"):
config_data["access_token"] = token_encryption.decrypt_token(
config_data["access_token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
logger.info(
f"Decrypted Notion credentials for connector {self._connector_id}"
)
except Exception as e:
logger.error(
f"Failed to decrypt Notion credentials for connector {self._connector_id}: {e!s}"
)
raise ValueError(
f"Failed to decrypt Notion credentials: {e!s}"
) from e
try:
self._credentials = NotionAuthCredentialsBase.from_dict(config_data)
except Exception as e:
raise ValueError(f"Invalid Notion credentials: {e!s}") from e
# Check if token is expired and refreshable
if self._credentials.is_expired and self._credentials.is_refreshable:
try:
logger.info(
f"Notion token expired for connector {self._connector_id}, refreshing..."
)
# Get connector for refresh
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == self._connector_id
)
)
connector = result.scalars().first()
if not connector:
raise RuntimeError(
f"Connector {self._connector_id} not found; cannot refresh token."
)
# Refresh token
connector = await refresh_notion_token(self._session, connector)
# Reload credentials after refresh
config_data = connector.config.copy()
token_encrypted = config_data.get("_token_encrypted", False)
if token_encrypted and config.SECRET_KEY:
token_encryption = TokenEncryption(config.SECRET_KEY)
if config_data.get("access_token"):
config_data["access_token"] = token_encryption.decrypt_token(
config_data["access_token"]
)
if config_data.get("refresh_token"):
config_data["refresh_token"] = token_encryption.decrypt_token(
config_data["refresh_token"]
)
self._credentials = NotionAuthCredentialsBase.from_dict(config_data)
# Invalidate cached client so it's recreated with new token
self._notion_client = None
logger.info(
f"Successfully refreshed Notion token for connector {self._connector_id}"
)
except Exception as e:
logger.error(
f"Failed to refresh Notion token for connector {self._connector_id}: {e!s}"
)
raise Exception(
f"Failed to refresh Notion OAuth credentials: {e!s}"
) from e
return self._credentials.access_token
async def _get_client(self) -> AsyncClient:
"""
Get or create Notion AsyncClient with valid token.
Returns:
Notion AsyncClient instance
"""
if self._notion_client is None:
token = await self._get_valid_token()
self._notion_client = AsyncClient(auth=token)
return self._notion_client
async def close(self):
"""Close the async client connection."""
if self._notion_client:
await self._notion_client.aclose()
self._notion_client = None
async def __aenter__(self):
"""Async context manager entry."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
await self.close()
async def get_all_pages(self, start_date=None, end_date=None):
"""
Fetches all pages shared with your integration and their content.
Args:
start_date (str, optional): ISO 8601 date string (e.g., "2023-01-01T00:00:00Z")
end_date (str, optional): ISO 8601 date string (e.g., "2023-12-31T23:59:59Z")
Returns:
list: List of dictionaries containing page data
"""
notion = await self._get_client()
# Build the filter for the search
# Note: Notion API requires specific filter structure
search_params = {}
# Filter for pages only (not databases)
search_params["filter"] = {"value": "page", "property": "object"}
# Add date filters if provided
if start_date or end_date:
date_filter = {}
if start_date:
date_filter["on_or_after"] = start_date
if end_date:
date_filter["on_or_before"] = end_date
# Add the date filter to the search params
if date_filter:
search_params["sort"] = {
"direction": "descending",
"timestamp": "last_edited_time",
}
# Paginate through all pages the integration has access to
pages = []
has_more = True
cursor = None
while has_more:
if cursor:
search_params["start_cursor"] = cursor
search_results = await notion.search(**search_params)
pages.extend(search_results["results"])
has_more = search_results.get("has_more", False)
if has_more:
cursor = search_results.get("next_cursor")
all_page_data = []
for page in pages:
page_id = page["id"]
# Get detailed page information
page_content = await self.get_page_content(page_id)
all_page_data.append(
{
"page_id": page_id,
"title": self.get_page_title(page),
"content": page_content,
}
)
return all_page_data
def get_page_title(self, page):
"""
Extracts the title from a page object.
Args:
page (dict): Notion page object
Returns:
str: Page title or a fallback string
"""
# Title can be in different properties depending on the page type
if "properties" in page:
# Try to find a title property
for _prop_name, prop_data in page["properties"].items():
if prop_data["type"] == "title" and len(prop_data["title"]) > 0:
return " ".join(
[text_obj["plain_text"] for text_obj in prop_data["title"]]
)
# If no title found, return the page ID as fallback
return f"Untitled page ({page['id']})"
async def get_page_content(self, page_id):
"""
Fetches the content (blocks) of a specific page.
Args:
page_id (str): The ID of the page to fetch
Returns:
list: List of processed blocks from the page
"""
notion = await self._get_client()
blocks = []
has_more = True
cursor = None
# Paginate through all blocks
while has_more:
if cursor:
response = await notion.blocks.children.list(
block_id=page_id, start_cursor=cursor
)
else:
response = await notion.blocks.children.list(block_id=page_id)
blocks.extend(response["results"])
has_more = response["has_more"]
if has_more:
cursor = response["next_cursor"]
# Process nested blocks recursively
processed_blocks = []
for block in blocks:
processed_block = await self.process_block(block)
processed_blocks.append(processed_block)
return processed_blocks
async def process_block(self, block):
"""
Processes a block and recursively fetches any child blocks.
Args:
block (dict): The block to process
Returns:
dict: Processed block with content and children
"""
notion = await self._get_client()
block_id = block["id"]
block_type = block["type"]
# Extract block content based on its type
content = self.extract_block_content(block)
# Check if block has children
has_children = block.get("has_children", False)
child_blocks = []
if has_children:
# Fetch and process child blocks
children_response = await notion.blocks.children.list(block_id=block_id)
for child_block in children_response["results"]:
child_blocks.append(await self.process_block(child_block))
return {
"id": block_id,
"type": block_type,
"content": content,
"children": child_blocks,
}
def extract_block_content(self, block):
"""
Extracts the content from a block based on its type.
Args:
block (dict): The block to extract content from
Returns:
str: Extracted content as a string
"""
block_type = block["type"]
# Different block types have different structures
if block_type in block and "rich_text" in block[block_type]:
return "".join(
[text_obj["plain_text"] for text_obj in block[block_type]["rich_text"]]
)
elif block_type == "image":
# Instead of returning the raw URL which may contain sensitive AWS credentials,
# return a placeholder or reference to the image
if "file" in block["image"]:
# For Notion-hosted images (which use AWS S3 pre-signed URLs)
return "[Notion Image]"
elif "external" in block["image"]:
# For external images, we can return a sanitized reference
url = block["image"]["external"]["url"]
# Only return the domain part of external URLs to avoid potential sensitive parameters
try:
from urllib.parse import urlparse
parsed_url = urlparse(url)
return f"[External Image from {parsed_url.netloc}]"
except Exception:
return "[External Image]"
elif block_type == "code":
language = block["code"]["language"]
code_text = "".join(
[text_obj["plain_text"] for text_obj in block["code"]["rich_text"]]
)
return f"```{language}\n{code_text}\n```"
elif block_type == "equation":
return block["equation"]["expression"]
# Add more block types as needed
# Return empty string for unsupported block types
return ""