mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-10 16:22:38 +02:00
feat: implement auto-refresh capability for Linear and Notion connectors similar to google oauth based ones
- Enhanced LinearConnector and NotionHistoryConnector classes to support automatic token refresh, improving reliability in accessing APIs. - Updated initialization to require session and connector ID, allowing for dynamic credential management. - Introduced new credential schemas for Linear and Notion, encapsulating access and refresh tokens with expiration handling. - Refactored indexers to utilize the new connector structure, ensuring seamless integration with the updated authentication flow. - Improved error handling and logging during token refresh processes for better debugging and user feedback.
This commit is contained in:
parent
1e30bc6484
commit
4f77d171d8
8 changed files with 731 additions and 149 deletions
|
|
@ -5,33 +5,153 @@ A module for retrieving issues and comments from Linear.
|
|||
Allows fetching issue lists and their comments with date range filtering.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector
|
||||
from app.routes.linear_add_connector_route import refresh_linear_token
|
||||
from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LinearConnector:
|
||||
"""Class for retrieving issues and comments from Linear."""
|
||||
|
||||
def __init__(self, access_token: str | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
credentials: LinearAuthCredentialsBase | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the LinearConnector class.
|
||||
Initialize the LinearConnector class with auto-refresh capability.
|
||||
|
||||
Args:
|
||||
access_token: Linear OAuth access token or API key (optional, can be set later with set_token)
|
||||
session: Database session for updating connector
|
||||
connector_id: Connector ID for direct updates
|
||||
credentials: Linear OAuth credentials (optional, will be loaded from DB if not provided)
|
||||
"""
|
||||
self.access_token = access_token
|
||||
self._session = session
|
||||
self._connector_id = connector_id
|
||||
self._credentials = credentials
|
||||
self.api_url = "https://api.linear.app/graphql"
|
||||
|
||||
def set_token(self, access_token: str) -> None:
|
||||
async def _get_valid_token(self) -> str:
|
||||
"""
|
||||
Set the Linear OAuth access token or API key.
|
||||
Get valid Linear access token, refreshing if needed.
|
||||
|
||||
Args:
|
||||
access_token: Linear OAuth access token or API key
|
||||
Returns:
|
||||
Valid access token
|
||||
|
||||
Raises:
|
||||
ValueError: If credentials are missing or invalid
|
||||
Exception: If token refresh fails
|
||||
"""
|
||||
self.access_token = access_token
|
||||
# Load credentials from DB if not provided
|
||||
if self._credentials is None:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
raise ValueError(f"Connector {self._connector_id} not found")
|
||||
|
||||
config_data = connector.config.copy()
|
||||
|
||||
# Decrypt credentials if they are encrypted
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
|
||||
# Decrypt sensitive fields
|
||||
if config_data.get("access_token"):
|
||||
config_data["access_token"] = token_encryption.decrypt_token(
|
||||
config_data["access_token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Decrypted Linear credentials for connector {self._connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to decrypt Linear credentials for connector {self._connector_id}: {e!s}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to decrypt Linear credentials: {e!s}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
self._credentials = LinearAuthCredentialsBase.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid Linear credentials: {e!s}") from e
|
||||
|
||||
# Check if token is expired and refreshable
|
||||
if self._credentials.is_expired and self._credentials.is_refreshable:
|
||||
try:
|
||||
logger.info(
|
||||
f"Linear token expired for connector {self._connector_id}, refreshing..."
|
||||
)
|
||||
|
||||
# Get connector for refresh
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
raise RuntimeError(
|
||||
f"Connector {self._connector_id} not found; cannot refresh token."
|
||||
)
|
||||
|
||||
# Refresh token
|
||||
connector = await refresh_linear_token(self._session, connector)
|
||||
|
||||
# Reload credentials after refresh
|
||||
config_data = connector.config.copy()
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("access_token"):
|
||||
config_data["access_token"] = token_encryption.decrypt_token(
|
||||
config_data["access_token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
|
||||
self._credentials = LinearAuthCredentialsBase.from_dict(config_data)
|
||||
|
||||
logger.info(
|
||||
f"Successfully refreshed Linear token for connector {self._connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to refresh Linear token for connector {self._connector_id}: {e!s}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Failed to refresh Linear OAuth credentials: {e!s}"
|
||||
) from e
|
||||
|
||||
return self._credentials.access_token
|
||||
|
||||
def get_headers(self) -> dict[str, str]:
|
||||
"""
|
||||
|
|
@ -43,21 +163,24 @@ class LinearConnector:
|
|||
Raises:
|
||||
ValueError: If no Linear access token has been set
|
||||
"""
|
||||
if not self.access_token:
|
||||
# This is a synchronous method, but we need async token refresh
|
||||
# For now, we'll raise an error if called directly
|
||||
# All API calls should go through execute_graphql_query which handles async refresh
|
||||
if not self._credentials or not self._credentials.access_token:
|
||||
raise ValueError(
|
||||
"Linear access token not initialized. Call set_token() first."
|
||||
"Linear access token not initialized. Use execute_graphql_query() method."
|
||||
)
|
||||
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
"Authorization": f"Bearer {self._credentials.access_token}",
|
||||
}
|
||||
|
||||
def execute_graphql_query(
|
||||
async def execute_graphql_query(
|
||||
self, query: str, variables: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a GraphQL query against the Linear API.
|
||||
Execute a GraphQL query against the Linear API with automatic token refresh.
|
||||
|
||||
Args:
|
||||
query: GraphQL query string
|
||||
|
|
@ -70,12 +193,14 @@ class LinearConnector:
|
|||
ValueError: If no Linear access token has been set
|
||||
Exception: If the API request fails
|
||||
"""
|
||||
if not self.access_token:
|
||||
raise ValueError(
|
||||
"Linear access token not initialized. Call set_token() first."
|
||||
)
|
||||
# Get valid token (refreshes if needed)
|
||||
access_token = await self._get_valid_token()
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
headers = self.get_headers()
|
||||
payload = {"query": query}
|
||||
|
||||
if variables:
|
||||
|
|
@ -90,7 +215,9 @@ class LinearConnector:
|
|||
f"Query failed with status code {response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
def get_all_issues(self, include_comments: bool = True) -> list[dict[str, Any]]:
|
||||
async def get_all_issues(
|
||||
self, include_comments: bool = True
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch all issues from Linear.
|
||||
|
||||
|
|
@ -153,7 +280,7 @@ class LinearConnector:
|
|||
}}
|
||||
"""
|
||||
|
||||
result = self.execute_graphql_query(query)
|
||||
result = await self.execute_graphql_query(query)
|
||||
|
||||
# Extract issues from the response
|
||||
if (
|
||||
|
|
@ -165,7 +292,7 @@ class LinearConnector:
|
|||
|
||||
return []
|
||||
|
||||
def get_issues_by_date_range(
|
||||
async def get_issues_by_date_range(
|
||||
self, start_date: str, end_date: str, include_comments: bool = True
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""
|
||||
|
|
@ -277,7 +404,7 @@ class LinearConnector:
|
|||
# Handle pagination to get all issues
|
||||
while has_next_page:
|
||||
variables = {"after": cursor} if cursor else {}
|
||||
result = self.execute_graphql_query(query, variables)
|
||||
result = await self.execute_graphql_query(query, variables)
|
||||
|
||||
# Check for errors
|
||||
if "errors" in result:
|
||||
|
|
@ -465,37 +592,3 @@ class LinearConnector:
|
|||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
return iso_date
|
||||
|
||||
|
||||
# Example usage (uncomment to use):
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
# Set your OAuth access token here
|
||||
access_token = "YOUR_LINEAR_ACCESS_TOKEN"
|
||||
|
||||
linear = LinearConnector(access_token=access_token)
|
||||
|
||||
try:
|
||||
# Get all issues with comments
|
||||
issues = linear.get_all_issues()
|
||||
print(f"Retrieved {len(issues)} issues")
|
||||
|
||||
# Format and print the first issue as markdown
|
||||
if issues:
|
||||
issue_md = linear.format_issue_to_markdown(issues[0])
|
||||
print("\nSample Issue in Markdown:\n")
|
||||
print(issue_md)
|
||||
|
||||
# Get issues by date range
|
||||
start_date = "2023-01-01"
|
||||
end_date = "2023-01-31"
|
||||
date_issues, error = linear.get_issues_by_date_range(start_date, end_date)
|
||||
|
||||
if error:
|
||||
print(f"Error: {error}")
|
||||
else:
|
||||
print(f"\nRetrieved {len(date_issues)} issues from {start_date} to {end_date}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,19 +1,167 @@
|
|||
import logging
|
||||
|
||||
from notion_client import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector
|
||||
from app.routes.notion_add_connector_route import refresh_notion_token
|
||||
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotionHistoryConnector:
|
||||
def __init__(self, token):
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
credentials: NotionAuthCredentialsBase | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the NotionPageFetcher with a token.
|
||||
Initialize the NotionHistoryConnector with auto-refresh capability.
|
||||
|
||||
Args:
|
||||
token (str): Notion OAuth access token or integration token
|
||||
session: Database session for updating connector
|
||||
connector_id: Connector ID for direct updates
|
||||
credentials: Notion OAuth credentials (optional, will be loaded from DB if not provided)
|
||||
"""
|
||||
self.notion = AsyncClient(auth=token)
|
||||
self._session = session
|
||||
self._connector_id = connector_id
|
||||
self._credentials = credentials
|
||||
self._notion_client: AsyncClient | None = None
|
||||
|
||||
async def _get_valid_token(self) -> str:
|
||||
"""
|
||||
Get valid Notion access token, refreshing if needed.
|
||||
|
||||
Returns:
|
||||
Valid access token
|
||||
|
||||
Raises:
|
||||
ValueError: If credentials are missing or invalid
|
||||
Exception: If token refresh fails
|
||||
"""
|
||||
# Load credentials from DB if not provided
|
||||
if self._credentials is None:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
raise ValueError(f"Connector {self._connector_id} not found")
|
||||
|
||||
config_data = connector.config.copy()
|
||||
|
||||
# Decrypt credentials if they are encrypted
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
|
||||
# Decrypt sensitive fields
|
||||
if config_data.get("access_token"):
|
||||
config_data["access_token"] = token_encryption.decrypt_token(
|
||||
config_data["access_token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Decrypted Notion credentials for connector {self._connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to decrypt Notion credentials for connector {self._connector_id}: {e!s}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to decrypt Notion credentials: {e!s}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
self._credentials = NotionAuthCredentialsBase.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid Notion credentials: {e!s}") from e
|
||||
|
||||
# Check if token is expired and refreshable
|
||||
if self._credentials.is_expired and self._credentials.is_refreshable:
|
||||
try:
|
||||
logger.info(
|
||||
f"Notion token expired for connector {self._connector_id}, refreshing..."
|
||||
)
|
||||
|
||||
# Get connector for refresh
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
|
||||
if not connector:
|
||||
raise RuntimeError(
|
||||
f"Connector {self._connector_id} not found; cannot refresh token."
|
||||
)
|
||||
|
||||
# Refresh token
|
||||
connector = await refresh_notion_token(self._session, connector)
|
||||
|
||||
# Reload credentials after refresh
|
||||
config_data = connector.config.copy()
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
if config_data.get("access_token"):
|
||||
config_data["access_token"] = token_encryption.decrypt_token(
|
||||
config_data["access_token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
|
||||
self._credentials = NotionAuthCredentialsBase.from_dict(config_data)
|
||||
|
||||
# Invalidate cached client so it's recreated with new token
|
||||
self._notion_client = None
|
||||
|
||||
logger.info(
|
||||
f"Successfully refreshed Notion token for connector {self._connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to refresh Notion token for connector {self._connector_id}: {e!s}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Failed to refresh Notion OAuth credentials: {e!s}"
|
||||
) from e
|
||||
|
||||
return self._credentials.access_token
|
||||
|
||||
async def _get_client(self) -> AsyncClient:
|
||||
"""
|
||||
Get or create Notion AsyncClient with valid token.
|
||||
|
||||
Returns:
|
||||
Notion AsyncClient instance
|
||||
"""
|
||||
if self._notion_client is None:
|
||||
token = await self._get_valid_token()
|
||||
self._notion_client = AsyncClient(auth=token)
|
||||
return self._notion_client
|
||||
|
||||
async def close(self):
|
||||
"""Close the async client connection."""
|
||||
await self.notion.aclose()
|
||||
if self._notion_client:
|
||||
await self._notion_client.aclose()
|
||||
self._notion_client = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
|
|
@ -34,6 +182,8 @@ class NotionHistoryConnector:
|
|||
Returns:
|
||||
list: List of dictionaries containing page data
|
||||
"""
|
||||
notion = await self._get_client()
|
||||
|
||||
# Build the filter for the search
|
||||
# Note: Notion API requires specific filter structure
|
||||
search_params = {}
|
||||
|
|
@ -67,7 +217,7 @@ class NotionHistoryConnector:
|
|||
if cursor:
|
||||
search_params["start_cursor"] = cursor
|
||||
|
||||
search_results = await self.notion.search(**search_params)
|
||||
search_results = await notion.search(**search_params)
|
||||
|
||||
pages.extend(search_results["results"])
|
||||
has_more = search_results.get("has_more", False)
|
||||
|
|
@ -125,6 +275,8 @@ class NotionHistoryConnector:
|
|||
Returns:
|
||||
list: List of processed blocks from the page
|
||||
"""
|
||||
notion = await self._get_client()
|
||||
|
||||
blocks = []
|
||||
has_more = True
|
||||
cursor = None
|
||||
|
|
@ -132,11 +284,11 @@ class NotionHistoryConnector:
|
|||
# Paginate through all blocks
|
||||
while has_more:
|
||||
if cursor:
|
||||
response = await self.notion.blocks.children.list(
|
||||
response = await notion.blocks.children.list(
|
||||
block_id=page_id, start_cursor=cursor
|
||||
)
|
||||
else:
|
||||
response = await self.notion.blocks.children.list(block_id=page_id)
|
||||
response = await notion.blocks.children.list(block_id=page_id)
|
||||
|
||||
blocks.extend(response["results"])
|
||||
has_more = response["has_more"]
|
||||
|
|
@ -162,6 +314,8 @@ class NotionHistoryConnector:
|
|||
Returns:
|
||||
dict: Processed block with content and children
|
||||
"""
|
||||
notion = await self._get_client()
|
||||
|
||||
block_id = block["id"]
|
||||
block_type = block["type"]
|
||||
|
||||
|
|
@ -174,9 +328,7 @@ class NotionHistoryConnector:
|
|||
|
||||
if has_children:
|
||||
# Fetch and process child blocks
|
||||
children_response = await self.notion.blocks.children.list(
|
||||
block_id=block_id
|
||||
)
|
||||
children_response = await notion.blocks.children.list(block_id=block_id)
|
||||
for child_block in children_response["results"]:
|
||||
child_blocks.append(await self.process_block(child_block))
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from app.db import (
|
|||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.linear_auth_credentials import LinearAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
|
|
@ -328,3 +329,120 @@ async def linear_callback(
|
|||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to complete Linear OAuth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def refresh_linear_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
"""
|
||||
Refresh the Linear access token for a connector.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector: Linear connector to refresh
|
||||
|
||||
Returns:
|
||||
Updated connector object
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Refreshing Linear token for connector {connector.id}")
|
||||
|
||||
credentials = LinearAuthCredentialsBase.from_dict(connector.config)
|
||||
|
||||
# Decrypt tokens if they are encrypted
|
||||
token_encryption = get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
refresh_token = credentials.refresh_token
|
||||
if is_encrypted and refresh_token:
|
||||
try:
|
||||
refresh_token = token_encryption.decrypt_token(refresh_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt refresh token: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to decrypt stored refresh token"
|
||||
) from e
|
||||
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
)
|
||||
|
||||
auth_header = make_basic_auth_header(
|
||||
config.LINEAR_CLIENT_ID, config.LINEAR_CLIENT_SECRET
|
||||
)
|
||||
|
||||
# Prepare token refresh data
|
||||
refresh_data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
TOKEN_URL,
|
||||
data=refresh_data,
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Authorization": auth_header,
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
||||
token_json = token_response.json()
|
||||
|
||||
# Calculate expiration time (UTC, tz-aware)
|
||||
expires_at = None
|
||||
expires_in = token_json.get("expires_in")
|
||||
if expires_in:
|
||||
now_utc = datetime.now(UTC)
|
||||
expires_at = now_utc + timedelta(seconds=int(expires_in))
|
||||
|
||||
# Encrypt new tokens before storing
|
||||
access_token = token_json.get("access_token")
|
||||
new_refresh_token = token_json.get("refresh_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from Linear refresh"
|
||||
)
|
||||
|
||||
# Update credentials object with encrypted tokens
|
||||
credentials.access_token = token_encryption.encrypt_token(access_token)
|
||||
if new_refresh_token:
|
||||
credentials.refresh_token = token_encryption.encrypt_token(
|
||||
new_refresh_token
|
||||
)
|
||||
credentials.expires_in = expires_in
|
||||
credentials.expires_at = expires_at
|
||||
credentials.scope = token_json.get("scope")
|
||||
|
||||
# Update connector config with encrypted tokens
|
||||
credentials_dict = credentials.to_dict()
|
||||
credentials_dict["_token_encrypted"] = True
|
||||
connector.config = credentials_dict
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info(f"Successfully refreshed Linear token for connector {connector.id}")
|
||||
|
||||
return connector
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh Linear token: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to refresh Linear token: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Handles OAuth 2.0 authentication flow for Notion connector.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
|
|
@ -22,6 +23,7 @@ from app.db import (
|
|||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas.notion_auth_credentials import NotionAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
|
|
@ -230,15 +232,28 @@ async def notion_callback(
|
|||
# Encrypt sensitive tokens before storing
|
||||
token_encryption = get_token_encryption()
|
||||
access_token = token_json.get("access_token")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from Notion"
|
||||
)
|
||||
|
||||
# Notion returns access_token and workspace information
|
||||
# Store the encrypted access token and workspace info in connector config
|
||||
# Calculate expiration time (UTC, tz-aware)
|
||||
expires_at = None
|
||||
expires_in = token_json.get("expires_in")
|
||||
if expires_in:
|
||||
now_utc = datetime.now(UTC)
|
||||
expires_at = now_utc + timedelta(seconds=int(expires_in))
|
||||
|
||||
# Notion returns access_token, refresh_token (if available), and workspace information
|
||||
# Store the encrypted tokens and workspace info in connector config
|
||||
connector_config = {
|
||||
"access_token": token_encryption.encrypt_token(access_token),
|
||||
"refresh_token": token_encryption.encrypt_token(refresh_token)
|
||||
if refresh_token
|
||||
else None,
|
||||
"expires_in": expires_in,
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"workspace_id": token_json.get("workspace_id"),
|
||||
"workspace_name": token_json.get("workspace_name"),
|
||||
"workspace_icon": token_json.get("workspace_icon"),
|
||||
|
|
@ -316,3 +331,129 @@ async def notion_callback(
|
|||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to complete Notion OAuth: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
async def refresh_notion_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
"""
|
||||
Refresh the Notion access token for a connector.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
connector: Notion connector to refresh
|
||||
|
||||
Returns:
|
||||
Updated connector object
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Refreshing Notion token for connector {connector.id}")
|
||||
|
||||
credentials = NotionAuthCredentialsBase.from_dict(connector.config)
|
||||
|
||||
# Decrypt tokens if they are encrypted
|
||||
token_encryption = get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
refresh_token = credentials.refresh_token
|
||||
if is_encrypted and refresh_token:
|
||||
try:
|
||||
refresh_token = token_encryption.decrypt_token(refresh_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt refresh token: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to decrypt stored refresh token"
|
||||
) from e
|
||||
|
||||
if not refresh_token:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No refresh token available. Please re-authenticate.",
|
||||
)
|
||||
|
||||
auth_header = make_basic_auth_header(
|
||||
config.NOTION_CLIENT_ID, config.NOTION_CLIENT_SECRET
|
||||
)
|
||||
|
||||
# Prepare token refresh data
|
||||
refresh_data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
TOKEN_URL,
|
||||
json=refresh_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": auth_header,
|
||||
},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Token refresh failed: {error_detail}"
|
||||
)
|
||||
|
||||
token_json = token_response.json()
|
||||
|
||||
# Calculate expiration time (UTC, tz-aware)
|
||||
expires_at = None
|
||||
expires_in = token_json.get("expires_in")
|
||||
if expires_in:
|
||||
now_utc = datetime.now(UTC)
|
||||
expires_at = now_utc + timedelta(seconds=int(expires_in))
|
||||
|
||||
# Encrypt new tokens before storing
|
||||
access_token = token_json.get("access_token")
|
||||
new_refresh_token = token_json.get("refresh_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from Notion refresh"
|
||||
)
|
||||
|
||||
# Update credentials object with encrypted tokens
|
||||
credentials.access_token = token_encryption.encrypt_token(access_token)
|
||||
if new_refresh_token:
|
||||
credentials.refresh_token = token_encryption.encrypt_token(
|
||||
new_refresh_token
|
||||
)
|
||||
credentials.expires_in = expires_in
|
||||
credentials.expires_at = expires_at
|
||||
|
||||
# Preserve workspace info
|
||||
if not credentials.workspace_id:
|
||||
credentials.workspace_id = connector.config.get("workspace_id")
|
||||
if not credentials.workspace_name:
|
||||
credentials.workspace_name = connector.config.get("workspace_name")
|
||||
if not credentials.workspace_icon:
|
||||
credentials.workspace_icon = connector.config.get("workspace_icon")
|
||||
if not credentials.bot_id:
|
||||
credentials.bot_id = connector.config.get("bot_id")
|
||||
|
||||
# Update connector config with encrypted tokens
|
||||
credentials_dict = credentials.to_dict()
|
||||
credentials_dict["_token_encrypted"] = True
|
||||
connector.config = credentials_dict
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info(f"Successfully refreshed Notion token for connector {connector.id}")
|
||||
|
||||
return connector
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh Notion token: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to refresh Notion token: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
66
surfsense_backend/app/schemas/linear_auth_credentials.py
Normal file
66
surfsense_backend/app/schemas/linear_auth_credentials.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class LinearAuthCredentialsBase(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str | None = None
|
||||
token_type: str = "Bearer"
|
||||
expires_in: int | None = None
|
||||
expires_at: datetime | None = None
|
||||
scope: str | None = None
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the credentials have expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return self.expires_at <= datetime.now(UTC)
|
||||
|
||||
@property
|
||||
def is_refreshable(self) -> bool:
|
||||
"""Check if the credentials can be refreshed."""
|
||||
return self.refresh_token is not None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert credentials to dictionary for storage."""
|
||||
return {
|
||||
"access_token": self.access_token,
|
||||
"refresh_token": self.refresh_token,
|
||||
"token_type": self.token_type,
|
||||
"expires_in": self.expires_in,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"scope": self.scope,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "LinearAuthCredentialsBase":
|
||||
"""Create credentials from dictionary."""
|
||||
expires_at = None
|
||||
if data.get("expires_at"):
|
||||
expires_at = datetime.fromisoformat(data["expires_at"])
|
||||
|
||||
return cls(
|
||||
access_token=data["access_token"],
|
||||
refresh_token=data.get("refresh_token"),
|
||||
token_type=data.get("token_type", "Bearer"),
|
||||
expires_in=data.get("expires_in"),
|
||||
expires_at=expires_at,
|
||||
scope=data.get("scope"),
|
||||
)
|
||||
|
||||
@field_validator("expires_at", mode="before")
|
||||
@classmethod
|
||||
def ensure_aware_utc(cls, v):
|
||||
# Strings like "2025-08-26T14:46:57.367184"
|
||||
if isinstance(v, str):
|
||||
# add +00:00 if missing tz info
|
||||
if v.endswith("Z"):
|
||||
return datetime.fromisoformat(v.replace("Z", "+00:00"))
|
||||
dt = datetime.fromisoformat(v)
|
||||
return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
|
||||
# datetime objects
|
||||
if isinstance(v, datetime):
|
||||
return v if v.tzinfo else v.replace(tzinfo=UTC)
|
||||
return v
|
||||
72
surfsense_backend/app/schemas/notion_auth_credentials.py
Normal file
72
surfsense_backend/app/schemas/notion_auth_credentials.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class NotionAuthCredentialsBase(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str | None = None
|
||||
expires_in: int | None = None
|
||||
expires_at: datetime | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
workspace_icon: str | None = None
|
||||
bot_id: str | None = None
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if the credentials have expired."""
|
||||
if self.expires_at is None:
|
||||
return False # Long-lived token, treat as not expired
|
||||
return self.expires_at <= datetime.now(UTC)
|
||||
|
||||
@property
|
||||
def is_refreshable(self) -> bool:
|
||||
"""Check if the credentials can be refreshed."""
|
||||
return self.refresh_token is not None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert credentials to dictionary for storage."""
|
||||
return {
|
||||
"access_token": self.access_token,
|
||||
"refresh_token": self.refresh_token,
|
||||
"expires_in": self.expires_in,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"workspace_id": self.workspace_id,
|
||||
"workspace_name": self.workspace_name,
|
||||
"workspace_icon": self.workspace_icon,
|
||||
"bot_id": self.bot_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "NotionAuthCredentialsBase":
|
||||
"""Create credentials from dictionary."""
|
||||
expires_at = None
|
||||
if data.get("expires_at"):
|
||||
expires_at = datetime.fromisoformat(data["expires_at"])
|
||||
|
||||
return cls(
|
||||
access_token=data["access_token"],
|
||||
refresh_token=data.get("refresh_token"),
|
||||
expires_in=data.get("expires_in"),
|
||||
expires_at=expires_at,
|
||||
workspace_id=data.get("workspace_id"),
|
||||
workspace_name=data.get("workspace_name"),
|
||||
workspace_icon=data.get("workspace_icon"),
|
||||
bot_id=data.get("bot_id"),
|
||||
)
|
||||
|
||||
@field_validator("expires_at", mode="before")
|
||||
@classmethod
|
||||
def ensure_aware_utc(cls, v):
|
||||
# Strings like "2025-08-26T14:46:57.367184"
|
||||
if isinstance(v, str):
|
||||
# add +00:00 if missing tz info
|
||||
if v.endswith("Z"):
|
||||
return datetime.fromisoformat(v.replace("Z", "+00:00"))
|
||||
dt = datetime.fromisoformat(v)
|
||||
return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
|
||||
# datetime objects
|
||||
if isinstance(v, datetime):
|
||||
return v if v.tzinfo else v.replace(tzinfo=UTC)
|
||||
return v
|
||||
|
|
@ -7,6 +7,7 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.linear_connector import LinearConnector
|
||||
from app.db import Document, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
|
@ -91,12 +92,10 @@ async def index_linear_issues(
|
|||
f"Connector with ID {connector_id} not found or is not a Linear connector",
|
||||
)
|
||||
|
||||
# Get the Linear access token from the connector config
|
||||
# Support both new OAuth format (access_token) and old API key format (LINEAR_API_KEY)
|
||||
linear_access_token = connector.config.get(
|
||||
"access_token"
|
||||
) or connector.config.get("LINEAR_API_KEY")
|
||||
if not linear_access_token:
|
||||
# Check if access_token exists (support both new OAuth format and old API key format)
|
||||
if not connector.config.get("access_token") and not connector.config.get(
|
||||
"LINEAR_API_KEY"
|
||||
):
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Linear access token not found in connector config for connector {connector_id}",
|
||||
|
|
@ -105,47 +104,16 @@ async def index_linear_issues(
|
|||
)
|
||||
return 0, "Linear access token not found in connector config"
|
||||
|
||||
# Decrypt token if it's encrypted (only when explicitly marked)
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted:
|
||||
# Token is explicitly marked as encrypted, attempt decryption
|
||||
if not config.SECRET_KEY:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"SECRET_KEY not configured but token is marked as encrypted for connector {connector_id}",
|
||||
"Missing SECRET_KEY for token decryption",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return 0, "SECRET_KEY not configured but token is marked as encrypted"
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
linear_access_token = token_encryption.decrypt_token(
|
||||
linear_access_token
|
||||
)
|
||||
logger.info(
|
||||
f"Decrypted Linear access token for connector {connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to decrypt Linear access token for connector {connector_id}: {e!s}",
|
||||
"Token decryption failed",
|
||||
{"error_type": "TokenDecryptionError"},
|
||||
)
|
||||
return 0, f"Failed to decrypt Linear access token: {e!s}"
|
||||
# If _token_encrypted is False or not set, treat token as plaintext
|
||||
|
||||
# Initialize Linear client
|
||||
# Initialize Linear client with internal refresh capability
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing Linear client for connector {connector_id}",
|
||||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
linear_client = LinearConnector(access_token=linear_access_token)
|
||||
# Create connector with session and connector_id for internal refresh
|
||||
# Token refresh will happen automatically when needed
|
||||
linear_client = LinearConnector(session=session, connector_id=connector_id)
|
||||
|
||||
# Handle 'undefined' string from frontend (treat as None)
|
||||
if start_date == "undefined" or start_date == "":
|
||||
|
|
@ -172,7 +140,7 @@ async def index_linear_issues(
|
|||
|
||||
# Get issues within date range
|
||||
try:
|
||||
issues, error = linear_client.get_issues_by_date_range(
|
||||
issues, error = await linear_client.get_issues_by_date_range(
|
||||
start_date=start_date_str, end_date=end_date_str, include_comments=True
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.notion_history import NotionHistoryConnector
|
||||
from app.db import Document, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
|
@ -18,7 +17,6 @@ from app.utils.document_converters import (
|
|||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
from .base import (
|
||||
build_document_metadata_string,
|
||||
|
|
@ -94,12 +92,10 @@ async def index_notion_pages(
|
|||
f"Connector with ID {connector_id} not found or is not a Notion connector",
|
||||
)
|
||||
|
||||
# Get the Notion access token from the connector config
|
||||
# Support both new OAuth format (access_token) and old integration token format (NOTION_INTEGRATION_TOKEN)
|
||||
notion_token = connector.config.get("access_token") or connector.config.get(
|
||||
# Check if access_token exists (support both new OAuth format and old integration token format)
|
||||
if not connector.config.get("access_token") and not connector.config.get(
|
||||
"NOTION_INTEGRATION_TOKEN"
|
||||
)
|
||||
if not notion_token:
|
||||
):
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Notion access token not found in connector config for connector {connector_id}",
|
||||
|
|
@ -108,35 +104,7 @@ async def index_notion_pages(
|
|||
)
|
||||
return 0, "Notion access token not found in connector config"
|
||||
|
||||
# Decrypt token if it's encrypted (only when explicitly marked)
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted:
|
||||
# Token is explicitly marked as encrypted, attempt decryption
|
||||
if not config.SECRET_KEY:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"SECRET_KEY not configured but token is marked as encrypted for connector {connector_id}",
|
||||
"Missing SECRET_KEY for token decryption",
|
||||
{"error_type": "MissingSecretKey"},
|
||||
)
|
||||
return 0, "SECRET_KEY not configured but token is marked as encrypted"
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
notion_token = token_encryption.decrypt_token(notion_token)
|
||||
logger.info(
|
||||
f"Decrypted Notion access token for connector {connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to decrypt Notion access token for connector {connector_id}: {e!s}",
|
||||
"Token decryption failed",
|
||||
{"error_type": "TokenDecryptionError"},
|
||||
)
|
||||
return 0, f"Failed to decrypt Notion access token: {e!s}"
|
||||
# If _token_encrypted is False or not set, treat token as plaintext
|
||||
|
||||
# Initialize Notion client
|
||||
# Initialize Notion client with internal refresh capability
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Initializing Notion client for connector {connector_id}",
|
||||
|
|
@ -164,7 +132,11 @@ async def index_notion_pages(
|
|||
"%Y-%m-%dT%H:%M:%SZ"
|
||||
)
|
||||
|
||||
notion_client = NotionHistoryConnector(token=notion_token)
|
||||
# Create connector with session and connector_id for internal refresh
|
||||
# Token refresh will happen automatically when needed
|
||||
notion_client = NotionHistoryConnector(
|
||||
session=session, connector_id=connector_id
|
||||
)
|
||||
|
||||
logger.info(f"Fetching Notion pages from {start_date_iso} to {end_date_iso}")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue