mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-02 19:55:18 +02:00
feat: implement token encryption and state management for OAuth connectors
- Added encryption for sensitive tokens (access token, refresh token, client secret) in Google Calendar, Google Drive, Gmail, Linear, and Notion connectors to enhance security. - Introduced OAuthStateManager for secure state parameter generation and validation, improving the integrity of OAuth flows. - Updated callback routes to handle state validation and error management, ensuring robust handling of authorization processes. - Enhanced indexers to support decryption of tokens for backward compatibility, maintaining functionality with existing encrypted credentials. - Improved validation for date parameters in connector routes to ensure proper input handling.
This commit is contained in:
parent
ec7599362d
commit
45489423d1
17 changed files with 1140 additions and 152 deletions
|
|
@ -109,7 +109,32 @@ class GoogleCalendarConnector:
|
|||
raise RuntimeError(
|
||||
"GOOGLE_CALENDAR_CONNECTOR connector not found; cannot persist refreshed token."
|
||||
)
|
||||
connector.config = json.loads(self._credentials.to_json())
|
||||
|
||||
# Encrypt sensitive credentials before storing
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
creds_dict = json.loads(self._credentials.to_json())
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
# Encrypt sensitive fields
|
||||
if creds_dict.get("token"):
|
||||
creds_dict["token"] = token_encryption.encrypt_token(
|
||||
creds_dict["token"]
|
||||
)
|
||||
if creds_dict.get("refresh_token"):
|
||||
creds_dict["refresh_token"] = token_encryption.encrypt_token(
|
||||
creds_dict["refresh_token"]
|
||||
)
|
||||
if creds_dict.get("client_secret"):
|
||||
creds_dict["client_secret"] = token_encryption.encrypt_token(
|
||||
creds_dict["client_secret"]
|
||||
)
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
connector.config = creds_dict
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
except Exception as e:
|
||||
|
|
@ -182,6 +207,12 @@ class GoogleCalendarConnector:
|
|||
Tuple containing (events list, error message or None)
|
||||
"""
|
||||
try:
|
||||
# Validate date strings
|
||||
if not start_date or start_date.lower() in ("undefined", "null", "none"):
|
||||
return [], "Invalid start_date: must be a valid date string in YYYY-MM-DD format"
|
||||
if not end_date or end_date.lower() in ("undefined", "null", "none"):
|
||||
return [], "Invalid end_date: must be a valid date string in YYYY-MM-DD format"
|
||||
|
||||
service = await self._get_service()
|
||||
|
||||
# Parse both dates
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Google Drive OAuth credential management."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from google.auth.transport.requests import Request
|
||||
|
|
@ -9,7 +10,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_valid_credentials(
|
||||
|
|
@ -38,7 +43,39 @@ async def get_valid_credentials(
|
|||
if not connector:
|
||||
raise ValueError(f"Connector {connector_id} not found")
|
||||
|
||||
config_data = connector.config
|
||||
config_data = connector.config.copy() # Work with a copy to avoid modifying original
|
||||
|
||||
# Decrypt credentials if they are encrypted
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
|
||||
# Decrypt sensitive fields
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Decrypted Google Drive credentials for connector {connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to decrypt Google Drive credentials for connector {connector_id}: {e!s}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to decrypt Google Drive credentials: {e!s}"
|
||||
) from e
|
||||
|
||||
exp = config_data.get("expiry", "").replace("Z", "")
|
||||
|
||||
if not all(
|
||||
|
|
@ -66,7 +103,29 @@ async def get_valid_credentials(
|
|||
try:
|
||||
credentials.refresh(Request())
|
||||
|
||||
connector.config = json.loads(credentials.to_json())
|
||||
creds_dict = json.loads(credentials.to_json())
|
||||
|
||||
# Encrypt sensitive credentials before storing
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
# Encrypt sensitive fields
|
||||
if creds_dict.get("token"):
|
||||
creds_dict["token"] = token_encryption.encrypt_token(
|
||||
creds_dict["token"]
|
||||
)
|
||||
if creds_dict.get("refresh_token"):
|
||||
creds_dict["refresh_token"] = token_encryption.encrypt_token(
|
||||
creds_dict["refresh_token"]
|
||||
)
|
||||
if creds_dict.get("client_secret"):
|
||||
creds_dict["client_secret"] = token_encryption.encrypt_token(
|
||||
creds_dict["client_secret"]
|
||||
)
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
connector.config = creds_dict
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
|
||||
|
|
|
|||
|
|
@ -172,6 +172,12 @@ class LinearConnector:
|
|||
Returns:
|
||||
Tuple containing (issues list, error message or None)
|
||||
"""
|
||||
# Validate date strings
|
||||
if not start_date or start_date.lower() in ("undefined", "null", "none"):
|
||||
return [], "Invalid start_date: must be a valid date string in YYYY-MM-DD format"
|
||||
if not end_date or end_date.lower() in ("undefined", "null", "none"):
|
||||
return [], "Invalid end_date: must be a valid date string in YYYY-MM-DD format"
|
||||
|
||||
# Convert date strings to ISO format
|
||||
try:
|
||||
# For Linear API: we need to use a more specific format for the filter
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
|
@ -23,6 +22,7 @@ from app.db import (
|
|||
)
|
||||
from app.schemas.airtable_auth_credentials import AirtableAuthCredentialsBase
|
||||
from app.users import current_active_user
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -40,6 +40,30 @@ SCOPES = [
|
|||
"user.email:read",
|
||||
]
|
||||
|
||||
# Initialize security utilities
|
||||
_state_manager = None
|
||||
_token_encryption = None
|
||||
|
||||
|
||||
def get_state_manager() -> OAuthStateManager:
|
||||
"""Get or create OAuth state manager instance."""
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for OAuth security")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def get_token_encryption() -> TokenEncryption:
|
||||
"""Get or create token encryption instance."""
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for token encryption")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
|
||||
def make_basic_auth_header(client_id: str, client_secret: str) -> str:
|
||||
credentials = f"{client_id}:{client_secret}".encode()
|
||||
|
|
@ -90,18 +114,19 @@ async def connect_airtable(space_id: int, user: User = Depends(current_active_us
|
|||
status_code=500, detail="Airtable OAuth not configured."
|
||||
)
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
# Generate PKCE parameters
|
||||
code_verifier, code_challenge = generate_pkce_pair()
|
||||
|
||||
# Generate state parameter
|
||||
state_payload = json.dumps(
|
||||
{
|
||||
"space_id": space_id,
|
||||
"user_id": str(user.id),
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
# Generate secure state parameter with HMAC signature (including code_verifier for PKCE)
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(
|
||||
space_id, user.id, code_verifier=code_verifier
|
||||
)
|
||||
state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode()
|
||||
|
||||
# Build authorization URL
|
||||
auth_params = {
|
||||
|
|
@ -160,11 +185,12 @@ async def airtable_callback(
|
|||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
data = json.loads(decoded_state)
|
||||
state_manager = get_state_manager()
|
||||
data = state_manager.validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass # If state is invalid, we'll redirect without space_id
|
||||
# If state is invalid, we'll redirect without space_id
|
||||
logger.warning("Failed to validate state in error handler")
|
||||
|
||||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
|
|
@ -185,11 +211,13 @@ async def airtable_callback(
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing state parameter"
|
||||
)
|
||||
|
||||
# Decode and parse the state
|
||||
|
||||
# Validate and decode state with signature verification
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
data = json.loads(decoded_state)
|
||||
data = state_manager.validate_state(state)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state parameter: {e!s}"
|
||||
|
|
@ -197,7 +225,12 @@ async def airtable_callback(
|
|||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
code_verifier = data["code_verifier"]
|
||||
code_verifier = data.get("code_verifier")
|
||||
|
||||
if not code_verifier:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing code_verifier in state parameter"
|
||||
)
|
||||
auth_header = make_basic_auth_header(
|
||||
config.AIRTABLE_CLIENT_ID, config.AIRTABLE_CLIENT_SECRET
|
||||
)
|
||||
|
|
@ -236,21 +269,37 @@ async def airtable_callback(
|
|||
|
||||
token_json = token_response.json()
|
||||
|
||||
# Encrypt sensitive tokens before storing
|
||||
token_encryption = get_token_encryption()
|
||||
access_token = token_json.get("access_token")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from Airtable"
|
||||
)
|
||||
|
||||
# Calculate expiration time (UTC, tz-aware)
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
now_utc = datetime.now(UTC)
|
||||
expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"]))
|
||||
|
||||
# Create credentials object
|
||||
# Create credentials object with encrypted tokens
|
||||
credentials = AirtableAuthCredentialsBase(
|
||||
access_token=token_json["access_token"],
|
||||
refresh_token=token_json.get("refresh_token"),
|
||||
access_token=token_encryption.encrypt_token(access_token),
|
||||
refresh_token=token_encryption.encrypt_token(refresh_token)
|
||||
if refresh_token
|
||||
else None,
|
||||
token_type=token_json.get("token_type", "Bearer"),
|
||||
expires_in=token_json.get("expires_in"),
|
||||
expires_at=expires_at,
|
||||
scope=token_json.get("scope"),
|
||||
)
|
||||
|
||||
# Mark that tokens are encrypted for backward compatibility
|
||||
credentials_dict = credentials.to_dict()
|
||||
credentials_dict["_token_encrypted"] = True
|
||||
|
||||
# Check if connector already exists for this search space and user
|
||||
existing_connector_result = await session.execute(
|
||||
|
|
@ -265,7 +314,7 @@ async def airtable_callback(
|
|||
|
||||
if existing_connector:
|
||||
# Update existing connector
|
||||
existing_connector.config = credentials.to_dict()
|
||||
existing_connector.config = credentials_dict
|
||||
existing_connector.name = "Airtable Connector"
|
||||
existing_connector.is_indexable = True
|
||||
logger.info(
|
||||
|
|
@ -277,7 +326,7 @@ async def airtable_callback(
|
|||
name="Airtable Connector",
|
||||
connector_type=SearchSourceConnectorType.AIRTABLE_CONNECTOR,
|
||||
is_indexable=True,
|
||||
config=credentials.to_dict(),
|
||||
config=credentials_dict,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
|
@ -341,6 +390,21 @@ async def refresh_airtable_token(
|
|||
logger.info(f"Refreshing Airtable token for connector {connector.id}")
|
||||
|
||||
credentials = AirtableAuthCredentialsBase.from_dict(connector.config)
|
||||
|
||||
# Decrypt tokens if they are encrypted
|
||||
token_encryption = get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
|
||||
refresh_token = credentials.refresh_token
|
||||
if is_encrypted and refresh_token:
|
||||
try:
|
||||
refresh_token = token_encryption.decrypt_token(refresh_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt refresh token: {e!s}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to decrypt stored refresh token"
|
||||
) from e
|
||||
|
||||
auth_header = make_basic_auth_header(
|
||||
config.AIRTABLE_CLIENT_ID, config.AIRTABLE_CLIENT_SECRET
|
||||
)
|
||||
|
|
@ -348,7 +412,7 @@ async def refresh_airtable_token(
|
|||
# Prepare token refresh data
|
||||
refresh_data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.refresh_token,
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": config.AIRTABLE_CLIENT_ID,
|
||||
"client_secret": config.AIRTABLE_CLIENT_SECRET,
|
||||
}
|
||||
|
|
@ -377,14 +441,27 @@ async def refresh_airtable_token(
|
|||
now_utc = datetime.now(UTC)
|
||||
expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"]))
|
||||
|
||||
# Update credentials object
|
||||
credentials.access_token = token_json["access_token"]
|
||||
# Encrypt new tokens before storing
|
||||
access_token = token_json.get("access_token")
|
||||
new_refresh_token = token_json.get("refresh_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from Airtable refresh"
|
||||
)
|
||||
|
||||
# Update credentials object with encrypted tokens
|
||||
credentials.access_token = token_encryption.encrypt_token(access_token)
|
||||
if new_refresh_token:
|
||||
credentials.refresh_token = token_encryption.encrypt_token(new_refresh_token)
|
||||
credentials.expires_in = token_json.get("expires_in")
|
||||
credentials.expires_at = expires_at
|
||||
credentials.scope = token_json.get("scope")
|
||||
|
||||
# Update connector config
|
||||
connector.config = credentials.to_dict()
|
||||
# Update connector config with encrypted tokens
|
||||
credentials_dict = credentials.to_dict()
|
||||
credentials_dict["_token_encrypted"] = True
|
||||
connector.config = credentials_dict
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import os
|
|||
|
||||
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
|
@ -23,6 +22,7 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -31,6 +31,30 @@ router = APIRouter()
|
|||
SCOPES = ["https://www.googleapis.com/auth/calendar.readonly"]
|
||||
REDIRECT_URI = config.GOOGLE_CALENDAR_REDIRECT_URI
|
||||
|
||||
# Initialize security utilities
|
||||
_state_manager = None
|
||||
_token_encryption = None
|
||||
|
||||
|
||||
def get_state_manager() -> OAuthStateManager:
|
||||
"""Get or create OAuth state manager instance."""
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for OAuth security")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def get_token_encryption() -> TokenEncryption:
|
||||
"""Get or create token encryption instance."""
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for token encryption")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
|
||||
def get_google_flow():
|
||||
try:
|
||||
|
|
@ -59,16 +83,16 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
|
|||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Encode space_id and user_id in state
|
||||
state_payload = json.dumps(
|
||||
{
|
||||
"space_id": space_id,
|
||||
"user_id": str(user.id),
|
||||
}
|
||||
)
|
||||
state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode()
|
||||
# Generate secure state parameter with HMAC signature
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
|
|
@ -86,24 +110,90 @@ async def connect_calendar(space_id: int, user: User = Depends(current_active_us
|
|||
@router.get("/auth/google/calendar/connector/callback")
|
||||
async def calendar_callback(
|
||||
request: Request,
|
||||
code: str,
|
||||
state: str,
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
try:
|
||||
# Decode and parse the state
|
||||
decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
data = json.loads(decoded_state)
|
||||
# Handle OAuth errors (e.g., user denied access)
|
||||
if error:
|
||||
logger.warning(f"Google Calendar OAuth error: {error}")
|
||||
# Try to decode state to get space_id for redirect, but don't fail if it's invalid
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
state_manager = get_state_manager()
|
||||
data = state_manager.validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
# If state is invalid, we'll redirect without space_id
|
||||
logger.warning("Failed to validate state in error handler")
|
||||
|
||||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_calendar_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_calendar_oauth_denied"
|
||||
)
|
||||
|
||||
# Validate required parameters for successful flow
|
||||
if not code:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing authorization code"
|
||||
)
|
||||
if not state:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing state parameter"
|
||||
)
|
||||
|
||||
# Validate and decode state with signature verification
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
data = state_manager.validate_state(state)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state parameter: {e!s}"
|
||||
) from e
|
||||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.GOOGLE_CALENDAR_REDIRECT_URI:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="GOOGLE_CALENDAR_REDIRECT_URI not configured"
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
creds_dict = json.loads(creds.to_json())
|
||||
|
||||
# Encrypt sensitive credentials before storing
|
||||
token_encryption = get_token_encryption()
|
||||
|
||||
# Encrypt sensitive fields: token, refresh_token, client_secret
|
||||
if creds_dict.get("token"):
|
||||
creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"])
|
||||
if creds_dict.get("refresh_token"):
|
||||
creds_dict["refresh_token"] = token_encryption.encrypt_token(
|
||||
creds_dict["refresh_token"]
|
||||
)
|
||||
if creds_dict.get("client_secret"):
|
||||
creds_dict["client_secret"] = token_encryption.encrypt_token(
|
||||
creds_dict["client_secret"]
|
||||
)
|
||||
|
||||
# Mark that credentials are encrypted for backward compatibility
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
try:
|
||||
# Check if a connector with the same type already exists for this search space and user
|
||||
result = await session.execute(
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ Endpoints:
|
|||
- GET /connectors/{connector_id}/google-drive/folders - List user's folders (for index-time selection)
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -37,6 +36,7 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
# Relax token scope validation for Google OAuth
|
||||
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
|
||||
|
|
@ -44,6 +44,30 @@ os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
|
|||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize security utilities
|
||||
_state_manager = None
|
||||
_token_encryption = None
|
||||
|
||||
|
||||
def get_state_manager() -> OAuthStateManager:
|
||||
"""Get or create OAuth state manager instance."""
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for OAuth security")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def get_token_encryption() -> TokenEncryption:
|
||||
"""Get or create token encryption instance."""
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for token encryption")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
# Google Drive OAuth scopes
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.readonly", # Read-only access to Drive
|
||||
|
|
@ -90,16 +114,16 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
|
|||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Encode space_id and user_id in state parameter
|
||||
state_payload = json.dumps(
|
||||
{
|
||||
"space_id": space_id,
|
||||
"user_id": str(user.id),
|
||||
}
|
||||
)
|
||||
state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode()
|
||||
# Generate secure state parameter with HMAC signature
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
|
||||
# Generate authorization URL
|
||||
auth_url, _ = flow.authorization_url(
|
||||
|
|
@ -124,8 +148,9 @@ async def connect_drive(space_id: int, user: User = Depends(current_active_user)
|
|||
@router.get("/auth/google/drive/connector/callback")
|
||||
async def drive_callback(
|
||||
request: Request,
|
||||
code: str,
|
||||
state: str,
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
|
|
@ -133,15 +158,57 @@ async def drive_callback(
|
|||
|
||||
Query params:
|
||||
code: Authorization code from Google
|
||||
error: OAuth error (if user denied access)
|
||||
state: Encoded state with space_id and user_id
|
||||
|
||||
Returns:
|
||||
Redirect to frontend success page
|
||||
"""
|
||||
try:
|
||||
# Decode and parse state
|
||||
decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
data = json.loads(decoded_state)
|
||||
# Handle OAuth errors (e.g., user denied access)
|
||||
if error:
|
||||
logger.warning(f"Google Drive OAuth error: {error}")
|
||||
# Try to decode state to get space_id for redirect, but don't fail if it's invalid
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
state_manager = get_state_manager()
|
||||
data = state_manager.validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
# If state is invalid, we'll redirect without space_id
|
||||
logger.warning("Failed to validate state in error handler")
|
||||
|
||||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_drive_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_drive_oauth_denied"
|
||||
)
|
||||
|
||||
# Validate required parameters for successful flow
|
||||
if not code:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing authorization code"
|
||||
)
|
||||
if not state:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing state parameter"
|
||||
)
|
||||
|
||||
# Validate and decode state with signature verification
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
data = state_manager.validate_state(state)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state parameter: {e!s}"
|
||||
) from e
|
||||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
|
|
@ -150,6 +217,12 @@ async def drive_callback(
|
|||
f"Processing Google Drive callback for user {user_id}, space {space_id}"
|
||||
)
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.GOOGLE_DRIVE_REDIRECT_URI:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="GOOGLE_DRIVE_REDIRECT_URI not configured"
|
||||
)
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
flow = get_google_flow()
|
||||
flow.fetch_token(code=code)
|
||||
|
|
@ -157,6 +230,24 @@ async def drive_callback(
|
|||
creds = flow.credentials
|
||||
creds_dict = json.loads(creds.to_json())
|
||||
|
||||
# Encrypt sensitive credentials before storing
|
||||
token_encryption = get_token_encryption()
|
||||
|
||||
# Encrypt sensitive fields: token, refresh_token, client_secret
|
||||
if creds_dict.get("token"):
|
||||
creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"])
|
||||
if creds_dict.get("refresh_token"):
|
||||
creds_dict["refresh_token"] = token_encryption.encrypt_token(
|
||||
creds_dict["refresh_token"]
|
||||
)
|
||||
if creds_dict.get("client_secret"):
|
||||
creds_dict["client_secret"] = token_encryption.encrypt_token(
|
||||
creds_dict["client_secret"]
|
||||
)
|
||||
|
||||
# Mark that credentials are encrypted for backward compatibility
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
# Check if connector already exists for this space/user
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import os
|
|||
|
||||
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1"
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
|
@ -23,51 +22,90 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Initialize security utilities
|
||||
_state_manager = None
|
||||
_token_encryption = None
|
||||
|
||||
|
||||
def get_state_manager() -> OAuthStateManager:
|
||||
"""Get or create OAuth state manager instance."""
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for OAuth security")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def get_token_encryption() -> TokenEncryption:
|
||||
"""Get or create token encryption instance."""
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for token encryption")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
|
||||
def get_google_flow():
|
||||
"""Create and return a Google OAuth flow for Gmail API."""
|
||||
flow = Flow.from_client_config(
|
||||
{
|
||||
"web": {
|
||||
"client_id": config.GOOGLE_OAUTH_CLIENT_ID,
|
||||
"client_secret": config.GOOGLE_OAUTH_CLIENT_SECRET,
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"redirect_uris": [config.GOOGLE_GMAIL_REDIRECT_URI],
|
||||
}
|
||||
},
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"openid",
|
||||
],
|
||||
)
|
||||
flow.redirect_uri = config.GOOGLE_GMAIL_REDIRECT_URI
|
||||
return flow
|
||||
try:
|
||||
flow = Flow.from_client_config(
|
||||
{
|
||||
"web": {
|
||||
"client_id": config.GOOGLE_OAUTH_CLIENT_ID,
|
||||
"client_secret": config.GOOGLE_OAUTH_CLIENT_SECRET,
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"redirect_uris": [config.GOOGLE_GMAIL_REDIRECT_URI],
|
||||
}
|
||||
},
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"openid",
|
||||
],
|
||||
)
|
||||
flow.redirect_uri = config.GOOGLE_GMAIL_REDIRECT_URI
|
||||
return flow
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create Google flow: {e!s}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.get("/auth/google/gmail/connector/add")
|
||||
async def connect_gmail(space_id: int, user: User = Depends(current_active_user)):
|
||||
"""
|
||||
Initiate Google Gmail OAuth flow.
|
||||
|
||||
Query params:
|
||||
space_id: Search space ID to add connector to
|
||||
|
||||
Returns:
|
||||
JSON with auth_url to redirect user to Google authorization
|
||||
"""
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
|
||||
# Encode space_id and user_id in state
|
||||
state_payload = json.dumps(
|
||||
{
|
||||
"space_id": space_id,
|
||||
"user_id": str(user.id),
|
||||
}
|
||||
)
|
||||
state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode()
|
||||
# Generate secure state parameter with HMAC signature
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
|
||||
auth_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
|
|
@ -75,8 +113,13 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
|
|||
include_granted_scopes="true",
|
||||
state=state_encoded,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Initiating Google Gmail OAuth for user {user.id}, space {space_id}"
|
||||
)
|
||||
return {"auth_url": auth_url}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initiate Google Gmail OAuth: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to initiate Google OAuth: {e!s}"
|
||||
) from e
|
||||
|
|
@ -85,24 +128,103 @@ async def connect_gmail(space_id: int, user: User = Depends(current_active_user)
|
|||
@router.get("/auth/google/gmail/connector/callback")
|
||||
async def gmail_callback(
|
||||
request: Request,
|
||||
code: str,
|
||||
state: str,
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""
|
||||
Handle Google Gmail OAuth callback.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
code: Authorization code from Google (if user granted access)
|
||||
error: Error code from Google (if user denied access or error occurred)
|
||||
state: State parameter containing user/space info
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
Redirect response to frontend
|
||||
"""
|
||||
try:
|
||||
# Decode and parse the state
|
||||
decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
data = json.loads(decoded_state)
|
||||
# Handle OAuth errors (e.g., user denied access)
|
||||
if error:
|
||||
logger.warning(f"Google Gmail OAuth error: {error}")
|
||||
# Try to decode state to get space_id for redirect, but don't fail if it's invalid
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
state_manager = get_state_manager()
|
||||
data = state_manager.validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
# If state is invalid, we'll redirect without space_id
|
||||
logger.warning("Failed to validate state in error handler")
|
||||
|
||||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/new-chat?modal=connectors&tab=all&error=google_gmail_oauth_denied"
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=google_gmail_oauth_denied"
|
||||
)
|
||||
|
||||
# Validate required parameters for successful flow
|
||||
if not code:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing authorization code"
|
||||
)
|
||||
if not state:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing state parameter"
|
||||
)
|
||||
|
||||
# Validate and decode state with signature verification
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
data = state_manager.validate_state(state)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state parameter: {e!s}"
|
||||
) from e
|
||||
|
||||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.GOOGLE_GMAIL_REDIRECT_URI:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="GOOGLE_GMAIL_REDIRECT_URI not configured"
|
||||
)
|
||||
|
||||
flow = get_google_flow()
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
creds = flow.credentials
|
||||
creds_dict = json.loads(creds.to_json())
|
||||
|
||||
# Encrypt sensitive credentials before storing
|
||||
token_encryption = get_token_encryption()
|
||||
|
||||
# Encrypt sensitive fields: token, refresh_token, client_secret
|
||||
if creds_dict.get("token"):
|
||||
creds_dict["token"] = token_encryption.encrypt_token(creds_dict["token"])
|
||||
if creds_dict.get("refresh_token"):
|
||||
creds_dict["refresh_token"] = token_encryption.encrypt_token(
|
||||
creds_dict["refresh_token"]
|
||||
)
|
||||
if creds_dict.get("client_secret"):
|
||||
creds_dict["client_secret"] = token_encryption.encrypt_token(
|
||||
creds_dict["client_secret"]
|
||||
)
|
||||
|
||||
# Mark that credentials are encrypted for backward compatibility
|
||||
creds_dict["_token_encrypted"] = True
|
||||
|
||||
try:
|
||||
# Check if a connector with the same type already exists for this search space and user
|
||||
result = await session.execute(
|
||||
|
|
@ -160,3 +282,6 @@ async def gmail_callback(
|
|||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in Gmail callback: {e!s}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to complete Google Gmail OAuth: {e!s}"
|
||||
) from e
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ Linear Connector OAuth Routes.
|
|||
Handles OAuth 2.0 authentication flow for Linear connector.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
|
@ -26,6 +25,7 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -38,9 +38,35 @@ TOKEN_URL = "https://api.linear.app/oauth/token"
|
|||
# OAuth scopes for Linear
|
||||
SCOPES = ["read", "write"]
|
||||
|
||||
# Initialize security utilities
|
||||
_state_manager = None
|
||||
_token_encryption = None
|
||||
|
||||
|
||||
def get_state_manager() -> OAuthStateManager:
|
||||
"""Get or create OAuth state manager instance."""
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for OAuth security")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def get_token_encryption() -> TokenEncryption:
|
||||
"""Get or create token encryption instance."""
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for token encryption")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
|
||||
def make_basic_auth_header(client_id: str, client_secret: str) -> str:
|
||||
"""Create Basic Auth header for Linear OAuth."""
|
||||
import base64
|
||||
|
||||
credentials = f"{client_id}:{client_secret}".encode()
|
||||
b64 = base64.b64encode(credentials).decode("ascii")
|
||||
return f"Basic {b64}"
|
||||
|
|
@ -67,14 +93,14 @@ async def connect_linear(space_id: int, user: User = Depends(current_active_user
|
|||
status_code=500, detail="Linear OAuth not configured."
|
||||
)
|
||||
|
||||
# Generate state parameter
|
||||
state_payload = json.dumps(
|
||||
{
|
||||
"space_id": space_id,
|
||||
"user_id": str(user.id),
|
||||
}
|
||||
)
|
||||
state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode()
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
|
||||
# Build authorization URL
|
||||
from urllib.parse import urlencode
|
||||
|
|
@ -130,11 +156,12 @@ async def linear_callback(
|
|||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
data = json.loads(decoded_state)
|
||||
state_manager = get_state_manager()
|
||||
data = state_manager.validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass # If state is invalid, we'll redirect without space_id
|
||||
# If state is invalid, we'll redirect without space_id
|
||||
logger.warning("Failed to validate state in error handler")
|
||||
|
||||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
|
|
@ -155,11 +182,13 @@ async def linear_callback(
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing state parameter"
|
||||
)
|
||||
|
||||
# Decode and parse the state
|
||||
|
||||
# Validate and decode state with signature verification
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
data = json.loads(decoded_state)
|
||||
data = state_manager.validate_state(state)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state parameter: {e!s}"
|
||||
|
|
@ -168,6 +197,12 @@ async def linear_callback(
|
|||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
if not config.LINEAR_REDIRECT_URI:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="LINEAR_REDIRECT_URI not configured"
|
||||
)
|
||||
|
||||
# Exchange authorization code for access token
|
||||
auth_header = make_basic_auth_header(
|
||||
config.LINEAR_CLIENT_ID, config.LINEAR_CLIENT_SECRET
|
||||
|
|
@ -176,7 +211,7 @@ async def linear_callback(
|
|||
token_data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": config.LINEAR_REDIRECT_URI,
|
||||
"redirect_uri": config.LINEAR_REDIRECT_URI, # Use stored value, not from request
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
|
@ -203,20 +238,34 @@ async def linear_callback(
|
|||
|
||||
token_json = token_response.json()
|
||||
|
||||
# Encrypt sensitive tokens before storing
|
||||
token_encryption = get_token_encryption()
|
||||
access_token = token_json.get("access_token")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from Linear"
|
||||
)
|
||||
|
||||
# Calculate expiration time (UTC, tz-aware)
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
now_utc = datetime.now(UTC)
|
||||
expires_at = now_utc + timedelta(seconds=int(token_json["expires_in"]))
|
||||
|
||||
# Store the access token and refresh token in connector config
|
||||
# Store the encrypted access token and refresh token in connector config
|
||||
connector_config = {
|
||||
"access_token": token_json["access_token"],
|
||||
"refresh_token": token_json.get("refresh_token"),
|
||||
"access_token": token_encryption.encrypt_token(access_token),
|
||||
"refresh_token": token_encryption.encrypt_token(refresh_token)
|
||||
if refresh_token
|
||||
else None,
|
||||
"token_type": token_json.get("token_type", "Bearer"),
|
||||
"expires_in": token_json.get("expires_in"),
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"scope": token_json.get("scope"),
|
||||
# Mark that tokens are encrypted for backward compatibility
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
# Check if connector already exists for this search space and user
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ Notion Connector OAuth Routes.
|
|||
Handles OAuth 2.0 authentication flow for Notion connector.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
|
@ -25,6 +24,7 @@ from app.db import (
|
|||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -34,9 +34,35 @@ router = APIRouter()
|
|||
AUTHORIZATION_URL = "https://api.notion.com/v1/oauth/authorize"
|
||||
TOKEN_URL = "https://api.notion.com/v1/oauth/token"
|
||||
|
||||
# Initialize security utilities
|
||||
_state_manager = None
|
||||
_token_encryption = None
|
||||
|
||||
|
||||
def get_state_manager() -> OAuthStateManager:
|
||||
"""Get or create OAuth state manager instance."""
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for OAuth security")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def get_token_encryption() -> TokenEncryption:
|
||||
"""Get or create token encryption instance."""
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for token encryption")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
|
||||
def make_basic_auth_header(client_id: str, client_secret: str) -> str:
|
||||
"""Create Basic Auth header for Notion OAuth."""
|
||||
import base64
|
||||
|
||||
credentials = f"{client_id}:{client_secret}".encode()
|
||||
b64 = base64.b64encode(credentials).decode("ascii")
|
||||
return f"Basic {b64}"
|
||||
|
|
@ -63,14 +89,14 @@ async def connect_notion(space_id: int, user: User = Depends(current_active_user
|
|||
status_code=500, detail="Notion OAuth not configured."
|
||||
)
|
||||
|
||||
# Generate state parameter
|
||||
state_payload = json.dumps(
|
||||
{
|
||||
"space_id": space_id,
|
||||
"user_id": str(user.id),
|
||||
}
|
||||
)
|
||||
state_encoded = base64.urlsafe_b64encode(state_payload.encode()).decode()
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="SECRET_KEY not configured for OAuth security."
|
||||
)
|
||||
|
||||
# Generate secure state parameter with HMAC signature
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
|
||||
# Build authorization URL
|
||||
from urllib.parse import urlencode
|
||||
|
|
@ -126,11 +152,12 @@ async def notion_callback(
|
|||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
data = json.loads(decoded_state)
|
||||
state_manager = get_state_manager()
|
||||
data = state_manager.validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass # If state is invalid, we'll redirect without space_id
|
||||
# If state is invalid, we'll redirect without space_id
|
||||
logger.warning("Failed to validate state in error handler")
|
||||
|
||||
# Redirect to frontend with error parameter
|
||||
if space_id:
|
||||
|
|
@ -151,11 +178,13 @@ async def notion_callback(
|
|||
raise HTTPException(
|
||||
status_code=400, detail="Missing state parameter"
|
||||
)
|
||||
|
||||
# Decode and parse the state
|
||||
|
||||
# Validate and decode state with signature verification
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
decoded_state = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
data = json.loads(decoded_state)
|
||||
data = state_manager.validate_state(state)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state parameter: {e!s}"
|
||||
|
|
@ -164,6 +193,14 @@ async def notion_callback(
|
|||
user_id = UUID(data["user_id"])
|
||||
space_id = data["space_id"]
|
||||
|
||||
# Validate redirect URI (security: ensure it matches configured value)
|
||||
# Note: Notion doesn't send redirect_uri in callback, but we validate
|
||||
# that we're using the configured one in token exchange
|
||||
if not config.NOTION_REDIRECT_URI:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="NOTION_REDIRECT_URI not configured"
|
||||
)
|
||||
|
||||
# Exchange authorization code for access token
|
||||
auth_header = make_basic_auth_header(
|
||||
config.NOTION_CLIENT_ID, config.NOTION_CLIENT_SECRET
|
||||
|
|
@ -172,7 +209,7 @@ async def notion_callback(
|
|||
token_data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": config.NOTION_REDIRECT_URI,
|
||||
"redirect_uri": config.NOTION_REDIRECT_URI, # Use stored value, not from request
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
|
|
@ -199,14 +236,24 @@ async def notion_callback(
|
|||
|
||||
token_json = token_response.json()
|
||||
|
||||
# Encrypt sensitive tokens before storing
|
||||
token_encryption = get_token_encryption()
|
||||
access_token = token_json.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No access token received from Notion"
|
||||
)
|
||||
|
||||
# Notion returns access_token and workspace information
|
||||
# Store the access token and workspace info in connector config
|
||||
# Store the encrypted access token and workspace info in connector config
|
||||
connector_config = {
|
||||
"access_token": token_json["access_token"],
|
||||
"access_token": token_encryption.encrypt_token(access_token),
|
||||
"workspace_id": token_json.get("workspace_id"),
|
||||
"workspace_name": token_json.get("workspace_name"),
|
||||
"workspace_icon": token_json.get("workspace_icon"),
|
||||
"bot_id": token_json.get("bot_id"),
|
||||
# Mark that token is encrypted for backward compatibility
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
# Check if connector already exists for this search space and user
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from app.routes.airtable_add_connector_route import refresh_airtable_token
|
|||
from app.schemas.airtable_auth_credentials import AirtableAuthCredentialsBase
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
from app.utils.document_converters import (
|
||||
create_document_chunks,
|
||||
generate_content_hash,
|
||||
|
|
@ -85,7 +86,38 @@ async def index_airtable_records(
|
|||
return 0, f"Connector with ID {connector_id} not found"
|
||||
|
||||
# Create credentials from connector config
|
||||
config_data = connector.config
|
||||
config_data = connector.config.copy() # Work with a copy to avoid modifying original
|
||||
|
||||
# Decrypt tokens if they are encrypted (for backward compatibility)
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
|
||||
# Decrypt access_token
|
||||
if config_data.get("access_token"):
|
||||
if token_encryption.is_encrypted(config_data["access_token"]):
|
||||
config_data["access_token"] = token_encryption.decrypt_token(
|
||||
config_data["access_token"]
|
||||
)
|
||||
logger.info(f"Decrypted Airtable access token for connector {connector_id}")
|
||||
|
||||
# Decrypt refresh_token if present
|
||||
if config_data.get("refresh_token"):
|
||||
if token_encryption.is_encrypted(config_data["refresh_token"]):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
logger.info(f"Decrypted Airtable refresh token for connector {connector_id}")
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to decrypt Airtable tokens for connector {connector_id}: {e!s}",
|
||||
"Token decryption failed",
|
||||
{"error_type": "TokenDecryptionError"},
|
||||
)
|
||||
return 0, f"Failed to decrypt Airtable tokens: {e!s}"
|
||||
|
||||
try:
|
||||
credentials = AirtableAuthCredentialsBase.from_dict(config_data)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -84,15 +84,52 @@ async def index_google_calendar_events(
|
|||
return 0, f"Connector with ID {connector_id} not found"
|
||||
|
||||
# Get the Google Calendar credentials from the connector config
|
||||
exp = connector.config.get("expiry").replace("Z", "")
|
||||
config_data = connector.config
|
||||
|
||||
# Decrypt sensitive credentials if encrypted (for backward compatibility)
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
|
||||
# Decrypt sensitive fields
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Decrypted Google Calendar credentials for connector {connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to decrypt Google Calendar credentials for connector {connector_id}: {e!s}",
|
||||
"Credential decryption failed",
|
||||
{"error_type": "CredentialDecryptionError"},
|
||||
)
|
||||
return 0, f"Failed to decrypt Google Calendar credentials: {e!s}"
|
||||
|
||||
exp = config_data.get("expiry", "").replace("Z", "")
|
||||
credentials = Credentials(
|
||||
token=connector.config.get("token"),
|
||||
refresh_token=connector.config.get("refresh_token"),
|
||||
token_uri=connector.config.get("token_uri"),
|
||||
client_id=connector.config.get("client_id"),
|
||||
client_secret=connector.config.get("client_secret"),
|
||||
scopes=connector.config.get("scopes"),
|
||||
expiry=datetime.fromisoformat(exp),
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
token_uri=config_data.get("token_uri"),
|
||||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes"),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
if (
|
||||
|
|
@ -122,6 +159,12 @@ async def index_google_calendar_events(
|
|||
connector_id=connector_id,
|
||||
)
|
||||
|
||||
# Handle 'undefined' string from frontend (treat as None)
|
||||
if start_date == "undefined" or start_date == "":
|
||||
start_date = None
|
||||
if end_date == "undefined" or end_date == "":
|
||||
end_date = None
|
||||
|
||||
# Calculate date range
|
||||
if start_date is None or end_date is None:
|
||||
# Fall back to calculating dates based on last_indexed_at
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import logging
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.google_drive import (
|
||||
GoogleDriveClient,
|
||||
categorize_change,
|
||||
|
|
@ -22,6 +23,7 @@ from app.tasks.connector_indexers.base import (
|
|||
update_connector_last_indexed,
|
||||
)
|
||||
from app.utils.document_converters import generate_unique_identifier_hash
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -87,6 +89,28 @@ async def index_google_drive_files(
|
|||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
# Check if credentials are encrypted and validate decryption capability
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
# Verify we can decrypt credentials before proceeding
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
# Check if any sensitive fields exist and are encrypted
|
||||
if connector.config.get("token") and token_encryption.is_encrypted(
|
||||
connector.config.get("token")
|
||||
):
|
||||
logger.info(
|
||||
f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization"
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to initialize token decryption for Google Drive connector {connector_id}: {e!s}",
|
||||
"Token decryption initialization failed",
|
||||
{"error_type": "TokenDecryptionError"},
|
||||
)
|
||||
return 0, f"Failed to initialize token decryption: {e!s}"
|
||||
|
||||
drive_client = GoogleDriveClient(session, connector_id)
|
||||
|
||||
if not folder_id:
|
||||
|
|
@ -249,6 +273,28 @@ async def index_google_drive_single_file(
|
|||
{"stage": "client_initialization"},
|
||||
)
|
||||
|
||||
# Check if credentials are encrypted and validate decryption capability
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
# Verify we can decrypt credentials before proceeding
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
# Check if any sensitive fields exist and are encrypted
|
||||
if connector.config.get("token") and token_encryption.is_encrypted(
|
||||
connector.config.get("token")
|
||||
):
|
||||
logger.info(
|
||||
f"Google Drive credentials are encrypted for connector {connector_id}, will decrypt during client initialization"
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to initialize token decryption for Google Drive connector {connector_id}: {e!s}",
|
||||
"Token decryption initialization failed",
|
||||
{"error_type": "TokenDecryptionError"},
|
||||
)
|
||||
return 0, f"Failed to initialize token decryption: {e!s}"
|
||||
|
||||
drive_client = GoogleDriveClient(session, connector_id)
|
||||
|
||||
# Fetch the file metadata
|
||||
|
|
|
|||
|
|
@ -88,9 +88,47 @@ async def index_google_gmail_messages(
|
|||
)
|
||||
return 0, error_msg
|
||||
|
||||
# Create credentials from connector config
|
||||
# Get the Google Gmail credentials from the connector config
|
||||
config_data = connector.config
|
||||
exp = config_data.get("expiry").replace("Z", "")
|
||||
|
||||
# Decrypt sensitive credentials if encrypted (for backward compatibility)
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = config_data.get("_token_encrypted", False)
|
||||
if token_encrypted and config.SECRET_KEY:
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
|
||||
# Decrypt sensitive fields
|
||||
if config_data.get("token"):
|
||||
config_data["token"] = token_encryption.decrypt_token(
|
||||
config_data["token"]
|
||||
)
|
||||
if config_data.get("refresh_token"):
|
||||
config_data["refresh_token"] = token_encryption.decrypt_token(
|
||||
config_data["refresh_token"]
|
||||
)
|
||||
if config_data.get("client_secret"):
|
||||
config_data["client_secret"] = token_encryption.decrypt_token(
|
||||
config_data["client_secret"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Decrypted Google Gmail credentials for connector {connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to decrypt Google Gmail credentials for connector {connector_id}: {e!s}",
|
||||
"Credential decryption failed",
|
||||
{"error_type": "CredentialDecryptionError"},
|
||||
)
|
||||
return 0, f"Failed to decrypt Google Gmail credentials: {e!s}"
|
||||
|
||||
exp = config_data.get("expiry", "")
|
||||
if exp:
|
||||
exp = exp.replace("Z", "")
|
||||
credentials = Credentials(
|
||||
token=config_data.get("token"),
|
||||
refresh_token=config_data.get("refresh_token"),
|
||||
|
|
@ -98,7 +136,7 @@ async def index_google_gmail_messages(
|
|||
client_id=config_data.get("client_id"),
|
||||
client_secret=config_data.get("client_secret"),
|
||||
scopes=config_data.get("scopes", []),
|
||||
expiry=datetime.fromisoformat(exp),
|
||||
expiry=datetime.fromisoformat(exp) if exp else None,
|
||||
)
|
||||
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -103,6 +103,30 @@ async def index_linear_issues(
|
|||
)
|
||||
return 0, "Linear access token not found in connector config"
|
||||
|
||||
# Decrypt token if it's encrypted (for backward compatibility)
|
||||
from app.config import config
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted or (
|
||||
config.SECRET_KEY
|
||||
and TokenEncryption(config.SECRET_KEY).is_encrypted(linear_access_token)
|
||||
):
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
linear_access_token = token_encryption.decrypt_token(linear_access_token)
|
||||
logger.info(
|
||||
f"Decrypted Linear access token for connector {connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to decrypt Linear access token for connector {connector_id}: {e!s}",
|
||||
"Token decryption failed",
|
||||
{"error_type": "TokenDecryptionError"},
|
||||
)
|
||||
return 0, f"Failed to decrypt Linear access token: {e!s}"
|
||||
|
||||
# Initialize Linear client
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
|
|
@ -112,6 +136,12 @@ async def index_linear_issues(
|
|||
|
||||
linear_client = LinearConnector(access_token=linear_access_token)
|
||||
|
||||
# Handle 'undefined' string from frontend (treat as None)
|
||||
if start_date == "undefined" or start_date == "":
|
||||
start_date = None
|
||||
if end_date == "undefined" or end_date == "":
|
||||
end_date = None
|
||||
|
||||
# Calculate date range
|
||||
start_date_str, end_date_str = calculate_date_range(
|
||||
connector, start_date, end_date, default_days_back=365
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from datetime import datetime
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.notion_history import NotionHistoryConnector
|
||||
from app.db import Document, DocumentType, SearchSourceConnectorType
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
|
|
@ -17,6 +18,7 @@ from app.utils.document_converters import (
|
|||
generate_document_summary,
|
||||
generate_unique_identifier_hash,
|
||||
)
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
from .base import (
|
||||
build_document_metadata_string,
|
||||
|
|
@ -103,6 +105,22 @@ async def index_notion_pages(
|
|||
)
|
||||
return 0, "Notion access token not found in connector config"
|
||||
|
||||
# Decrypt token if it's encrypted (for backward compatibility)
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted or (config.SECRET_KEY and TokenEncryption(config.SECRET_KEY).is_encrypted(notion_token)):
|
||||
try:
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
notion_token = token_encryption.decrypt_token(notion_token)
|
||||
logger.info(f"Decrypted Notion access token for connector {connector_id}")
|
||||
except Exception as e:
|
||||
await task_logger.log_task_failure(
|
||||
log_entry,
|
||||
f"Failed to decrypt Notion access token for connector {connector_id}: {e!s}",
|
||||
"Token decryption failed",
|
||||
{"error_type": "TokenDecryptionError"},
|
||||
)
|
||||
return 0, f"Failed to decrypt Notion access token: {e!s}"
|
||||
|
||||
# Initialize Notion client
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
|
|
|
|||
216
surfsense_backend/app/utils/oauth_security.py
Normal file
216
surfsense_backend/app/utils/oauth_security.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""
|
||||
OAuth Security Utilities.
|
||||
|
||||
Provides secure state parameter generation/validation and token encryption
|
||||
for OAuth 2.0 flows.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthStateManager:
|
||||
"""Manages secure OAuth state parameters with HMAC signatures."""
|
||||
|
||||
def __init__(self, secret_key: str, max_age_seconds: int = 600):
|
||||
"""
|
||||
Initialize OAuth state manager.
|
||||
|
||||
Args:
|
||||
secret_key: Secret key for HMAC signing (should be SECRET_KEY from config)
|
||||
max_age_seconds: Maximum age of state parameter in seconds (default 10 minutes)
|
||||
"""
|
||||
if not secret_key:
|
||||
raise ValueError("secret_key is required for OAuth state management")
|
||||
self.secret_key = secret_key
|
||||
self.max_age_seconds = max_age_seconds
|
||||
|
||||
def generate_secure_state(self, space_id: int, user_id: UUID, **extra_fields) -> str:
|
||||
"""
|
||||
Generate cryptographically signed state parameter.
|
||||
|
||||
Args:
|
||||
space_id: The search space ID
|
||||
user_id: The user ID
|
||||
**extra_fields: Additional fields to include in state (e.g., code_verifier for PKCE)
|
||||
|
||||
Returns:
|
||||
Base64-encoded state parameter with HMAC signature
|
||||
"""
|
||||
timestamp = int(time.time())
|
||||
state_payload = {
|
||||
"space_id": space_id,
|
||||
"user_id": str(user_id),
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
# Add any extra fields (e.g., code_verifier for PKCE)
|
||||
state_payload.update(extra_fields)
|
||||
|
||||
# Create signature
|
||||
payload_str = json.dumps(state_payload, sort_keys=True)
|
||||
signature = hmac.new(
|
||||
self.secret_key.encode(),
|
||||
payload_str.encode(),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
# Include signature in state
|
||||
state_payload["signature"] = signature
|
||||
state_encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(state_payload).encode()
|
||||
).decode()
|
||||
|
||||
return state_encoded
|
||||
|
||||
def validate_state(self, state: str) -> dict:
|
||||
"""
|
||||
Validate and decode state parameter with signature verification.
|
||||
|
||||
Args:
|
||||
state: The state parameter from OAuth callback
|
||||
|
||||
Returns:
|
||||
Decoded state data (space_id, user_id, timestamp)
|
||||
|
||||
Raises:
|
||||
HTTPException: If state is invalid, expired, or tampered with
|
||||
"""
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(state.encode()).decode()
|
||||
data = json.loads(decoded)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid state format: {e!s}"
|
||||
) from e
|
||||
|
||||
# Verify signature exists
|
||||
signature = data.pop("signature", None)
|
||||
if not signature:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing state signature"
|
||||
)
|
||||
|
||||
# Verify signature
|
||||
payload_str = json.dumps(data, sort_keys=True)
|
||||
expected_signature = hmac.new(
|
||||
self.secret_key.encode(),
|
||||
payload_str.encode(),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid state signature - possible tampering"
|
||||
)
|
||||
|
||||
# Verify timestamp (prevent replay attacks)
|
||||
timestamp = data.get("timestamp", 0)
|
||||
current_time = time.time()
|
||||
age = current_time - timestamp
|
||||
|
||||
if age < 0:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid state timestamp"
|
||||
)
|
||||
|
||||
if age > self.max_age_seconds:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="State parameter expired. Please try again.",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class TokenEncryption:
|
||||
"""Encrypt/decrypt sensitive OAuth tokens for storage."""
|
||||
|
||||
def __init__(self, secret_key: str):
|
||||
"""
|
||||
Initialize token encryption.
|
||||
|
||||
Args:
|
||||
secret_key: Secret key for encryption (should be SECRET_KEY from config)
|
||||
"""
|
||||
if not secret_key:
|
||||
raise ValueError("secret_key is required for token encryption")
|
||||
# Derive Fernet key from secret using SHA256
|
||||
# Note: In production, consider using HKDF for key derivation
|
||||
key = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(secret_key.encode()).digest()
|
||||
)
|
||||
try:
|
||||
self.cipher = Fernet(key)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to initialize encryption cipher: {e!s}"
|
||||
) from e
|
||||
|
||||
def encrypt_token(self, token: str) -> str:
|
||||
"""
|
||||
Encrypt a token for storage.
|
||||
|
||||
Args:
|
||||
token: Plaintext token to encrypt
|
||||
|
||||
Returns:
|
||||
Encrypted token string
|
||||
"""
|
||||
if not token:
|
||||
return token
|
||||
try:
|
||||
return self.cipher.encrypt(token.encode()).decode()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt token: {e!s}")
|
||||
raise ValueError(f"Token encryption failed: {e!s}") from e
|
||||
|
||||
def decrypt_token(self, encrypted_token: str) -> str:
|
||||
"""
|
||||
Decrypt a stored token.
|
||||
|
||||
Args:
|
||||
encrypted_token: Encrypted token string
|
||||
|
||||
Returns:
|
||||
Decrypted plaintext token
|
||||
"""
|
||||
if not encrypted_token:
|
||||
return encrypted_token
|
||||
try:
|
||||
return self.cipher.decrypt(encrypted_token.encode()).decode()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt token: {e!s}")
|
||||
raise ValueError(f"Token decryption failed: {e!s}") from e
|
||||
|
||||
def is_encrypted(self, token: str) -> bool:
|
||||
"""
|
||||
Check if a token appears to be encrypted.
|
||||
|
||||
Args:
|
||||
token: Token string to check
|
||||
|
||||
Returns:
|
||||
True if token appears encrypted, False otherwise
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
# Encrypted tokens are base64-encoded and have specific format
|
||||
# This is a heuristic check - encrypted tokens are longer and base64-like
|
||||
try:
|
||||
# Try to decode as base64
|
||||
base64.urlsafe_b64decode(token.encode())
|
||||
# If it's base64 and reasonably long, likely encrypted
|
||||
return len(token) > 20
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -514,16 +514,6 @@ def validate_connector_config(
|
|||
"validators": {},
|
||||
},
|
||||
"SLACK_CONNECTOR": {"required": ["SLACK_BOT_TOKEN"], "validators": {}},
|
||||
"NOTION_CONNECTOR": {
|
||||
"required": ["access_token"], # OAuth-based only
|
||||
"optional": [
|
||||
"workspace_id", # OAuth fields
|
||||
"workspace_name",
|
||||
"workspace_icon",
|
||||
"bot_id",
|
||||
],
|
||||
"validators": {},
|
||||
},
|
||||
"GITHUB_CONNECTOR": {
|
||||
"required": ["GITHUB_PAT", "repo_full_names"],
|
||||
"validators": {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue