diff --git a/surfsense_backend/alembic/versions/110_add_onedrive_connector_enums.py b/surfsense_backend/alembic/versions/110_add_onedrive_connector_enums.py new file mode 100644 index 000000000..699a50ef0 --- /dev/null +++ b/surfsense_backend/alembic/versions/110_add_onedrive_connector_enums.py @@ -0,0 +1,54 @@ +"""Add OneDrive connector enums + +Revision ID: 110 +Revises: 109 +Create Date: 2026-03-28 00:00:00.000000 + +""" + +from collections.abc import Sequence + +from alembic import op + +revision: str = "110" +down_revision: str | None = "109" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_type t + JOIN pg_enum e ON t.oid = e.enumtypid + WHERE t.typname = 'searchsourceconnectortype' AND e.enumlabel = 'ONEDRIVE_CONNECTOR' + ) THEN + ALTER TYPE searchsourceconnectortype ADD VALUE 'ONEDRIVE_CONNECTOR'; + END IF; + END + $$; + """ + ) + + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_type t + JOIN pg_enum e ON t.oid = e.enumtypid + WHERE t.typname = 'documenttype' AND e.enumlabel = 'ONEDRIVE_FILE' + ) THEN + ALTER TYPE documenttype ADD VALUE 'ONEDRIVE_FILE'; + END IF; + END + $$; + """ + ) + + +def downgrade() -> None: + pass diff --git a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py index 429dafc46..d30288390 100644 --- a/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py +++ b/surfsense_backend/app/agents/new_chat/tools/knowledge_base.py @@ -360,6 +360,7 @@ _INTERNAL_METADATA_KEYS: frozenset[str] = frozenset( "event_id", "calendar_id", "google_drive_file_id", + "onedrive_file_id", "page_id", "issue_id", "connector_id", diff --git a/surfsense_backend/app/config/__init__.py b/surfsense_backend/app/config/__init__.py index 186936325..70100bd0a 100644 --- a/surfsense_backend/app/config/__init__.py +++ b/surfsense_backend/app/config/__init__.py @@ -286,6 +286,11 @@ class Config: TEAMS_CLIENT_SECRET = os.getenv("TEAMS_CLIENT_SECRET") TEAMS_REDIRECT_URI = os.getenv("TEAMS_REDIRECT_URI") + # Microsoft OneDrive OAuth + ONEDRIVE_CLIENT_ID = os.getenv("ONEDRIVE_CLIENT_ID") + ONEDRIVE_CLIENT_SECRET = os.getenv("ONEDRIVE_CLIENT_SECRET") + ONEDRIVE_REDIRECT_URI = os.getenv("ONEDRIVE_REDIRECT_URI") + # ClickUp OAuth CLICKUP_CLIENT_ID = os.getenv("CLICKUP_CLIENT_ID") CLICKUP_CLIENT_SECRET = os.getenv("CLICKUP_CLIENT_SECRET") diff --git a/surfsense_backend/app/connectors/onedrive/__init__.py b/surfsense_backend/app/connectors/onedrive/__init__.py new file mode 100644 index 000000000..91b28bd37 --- /dev/null +++ b/surfsense_backend/app/connectors/onedrive/__init__.py @@ -0,0 +1,13 @@ +"""Microsoft OneDrive Connector Module.""" + +from .client import OneDriveClient +from .content_extractor import download_and_extract_content +from .folder_manager import get_file_by_id, get_files_in_folder, list_folder_contents + +__all__ = [ + "OneDriveClient", + "download_and_extract_content", + "get_file_by_id", + "get_files_in_folder", + "list_folder_contents", +] diff --git a/surfsense_backend/app/connectors/onedrive/client.py b/surfsense_backend/app/connectors/onedrive/client.py new file mode 100644 index 000000000..bb9fbb42b --- /dev/null +++ b/surfsense_backend/app/connectors/onedrive/client.py @@ -0,0 +1,276 @@ +"""Microsoft OneDrive API client using Microsoft Graph API v1.0.""" + +import logging +from datetime import UTC, datetime, timedelta +from typing import Any + +import httpx +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm.attributes import flag_modified + +from app.config import config +from app.db import SearchSourceConnector +from app.utils.oauth_security import TokenEncryption + +logger = logging.getLogger(__name__) + +GRAPH_API_BASE = "https://graph.microsoft.com/v1.0" +TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token" + + +class OneDriveClient: + """Client for Microsoft OneDrive via the Graph API.""" + + def __init__(self, session: AsyncSession, connector_id: int): + self._session = session + self._connector_id = connector_id + + async def _get_valid_token(self) -> str: + """Get a valid access token, refreshing if needed.""" + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + if not connector: + raise ValueError(f"Connector {self._connector_id} not found") + + cfg = connector.config or {} + is_encrypted = cfg.get("_token_encrypted", False) + token_encryption = TokenEncryption(config.SECRET_KEY) if config.SECRET_KEY else None + + access_token = cfg.get("access_token", "") + refresh_token = cfg.get("refresh_token") + + if is_encrypted and token_encryption: + if access_token: + access_token = token_encryption.decrypt_token(access_token) + if refresh_token: + refresh_token = token_encryption.decrypt_token(refresh_token) + + expires_at_str = cfg.get("expires_at") + is_expired = False + if expires_at_str: + expires_at = datetime.fromisoformat(expires_at_str) + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + is_expired = expires_at <= datetime.now(UTC) + + if not is_expired and access_token: + return access_token + + if not refresh_token: + cfg["auth_expired"] = True + connector.config = cfg + flag_modified(connector, "config") + await self._session.commit() + raise ValueError("OneDrive token expired and no refresh token available") + + token_data = await self._refresh_token(refresh_token) + + new_access = token_data["access_token"] + new_refresh = token_data.get("refresh_token", refresh_token) + expires_in = token_data.get("expires_in") + + new_expires_at = None + if expires_in: + new_expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in)) + + if token_encryption: + cfg["access_token"] = token_encryption.encrypt_token(new_access) + cfg["refresh_token"] = token_encryption.encrypt_token(new_refresh) + else: + cfg["access_token"] = new_access + cfg["refresh_token"] = new_refresh + + cfg["expires_at"] = new_expires_at.isoformat() if new_expires_at else None + cfg["expires_in"] = expires_in + cfg["_token_encrypted"] = bool(token_encryption) + cfg.pop("auth_expired", None) + + connector.config = cfg + flag_modified(connector, "config") + await self._session.commit() + + return new_access + + async def _refresh_token(self, refresh_token: str) -> dict: + data = { + "client_id": config.ONEDRIVE_CLIENT_ID, + "client_secret": config.ONEDRIVE_CLIENT_SECRET, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "scope": "offline_access User.Read Files.Read.All Files.ReadWrite.All", + } + async with httpx.AsyncClient() as client: + resp = await client.post( + TOKEN_URL, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + if resp.status_code != 200: + error_detail = resp.text + try: + error_json = resp.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise ValueError(f"OneDrive token refresh failed: {error_detail}") + return resp.json() + + async def _request(self, method: str, path: str, **kwargs) -> httpx.Response: + """Make an authenticated request to the Graph API.""" + token = await self._get_valid_token() + headers = {"Authorization": f"Bearer {token}"} + if "headers" in kwargs: + headers.update(kwargs.pop("headers")) + + async with httpx.AsyncClient() as client: + resp = await client.request( + method, + f"{GRAPH_API_BASE}{path}", + headers=headers, + timeout=60.0, + **kwargs, + ) + + if resp.status_code == 401: + result = await self._session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == self._connector_id + ) + ) + connector = result.scalars().first() + if connector: + cfg = connector.config or {} + cfg["auth_expired"] = True + connector.config = cfg + flag_modified(connector, "config") + await self._session.commit() + raise ValueError("OneDrive authentication expired (401)") + + return resp + + async def list_children( + self, item_id: str = "root" + ) -> tuple[list[dict[str, Any]], str | None]: + all_items: list[dict[str, Any]] = [] + url = f"/me/drive/items/{item_id}/children" + params: dict[str, Any] = { + "$top": 200, + "$select": "id,name,size,file,folder,parentReference,lastModifiedDateTime,createdDateTime,webUrl,remoteItem,package", + } + while url: + resp = await self._request("GET", url, params=params) + if resp.status_code != 200: + return [], f"Failed to list children: {resp.status_code} - {resp.text}" + data = resp.json() + all_items.extend(data.get("value", [])) + next_link = data.get("@odata.nextLink") + if next_link: + url = next_link.replace(GRAPH_API_BASE, "") + params = {} + else: + url = "" + return all_items, None + + async def get_item_metadata( + self, item_id: str + ) -> tuple[dict[str, Any] | None, str | None]: + resp = await self._request( + "GET", + f"/me/drive/items/{item_id}", + params={ + "$select": "id,name,size,file,folder,parentReference,lastModifiedDateTime,createdDateTime,webUrl" + }, + ) + if resp.status_code != 200: + return None, f"Failed to get item: {resp.status_code} - {resp.text}" + return resp.json(), None + + async def download_file(self, item_id: str) -> tuple[bytes | None, str | None]: + token = await self._get_valid_token() + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.get( + f"{GRAPH_API_BASE}/me/drive/items/{item_id}/content", + headers={"Authorization": f"Bearer {token}"}, + timeout=120.0, + ) + if resp.status_code != 200: + return None, f"Download failed: {resp.status_code}" + return resp.content, None + + async def download_file_to_disk(self, item_id: str, dest_path: str) -> str | None: + """Stream file content to disk. Returns error message on failure.""" + token = await self._get_valid_token() + async with httpx.AsyncClient(follow_redirects=True) as client: + async with client.stream( + "GET", + f"{GRAPH_API_BASE}/me/drive/items/{item_id}/content", + headers={"Authorization": f"Bearer {token}"}, + timeout=120.0, + ) as resp: + if resp.status_code != 200: + return f"Download failed: {resp.status_code}" + with open(dest_path, "wb") as f: + async for chunk in resp.aiter_bytes(chunk_size=5 * 1024 * 1024): + f.write(chunk) + return None + + async def create_file( + self, + name: str, + parent_id: str | None = None, + content: str | None = None, + mime_type: str | None = None, + ) -> dict[str, Any]: + """Create (upload) a file in OneDrive.""" + folder_path = f"/me/drive/items/{parent_id or 'root'}" + body = (content or "").encode("utf-8") + resp = await self._request( + "PUT", + f"{folder_path}:/{name}:/content", + content=body, + headers={"Content-Type": mime_type or "application/octet-stream"}, + ) + if resp.status_code not in (200, 201): + raise ValueError(f"File creation failed: {resp.status_code} - {resp.text}") + return resp.json() + + async def trash_file(self, item_id: str) -> bool: + """Delete (move to recycle bin) a OneDrive item.""" + resp = await self._request("DELETE", f"/me/drive/items/{item_id}") + if resp.status_code not in (200, 204): + raise ValueError(f"Trash failed: {resp.status_code} - {resp.text}") + return True + + async def get_delta( + self, folder_id: str | None = None, delta_link: str | None = None + ) -> tuple[list[dict[str, Any]], str | None, str | None]: + """Get delta changes. Returns (changes, new_delta_link, error).""" + all_changes: list[dict[str, Any]] = [] + if delta_link: + url = delta_link.replace(GRAPH_API_BASE, "") + elif folder_id: + url = f"/me/drive/items/{folder_id}/delta" + else: + url = "/me/drive/root/delta" + + params: dict[str, Any] = {"$top": 200} + while url: + resp = await self._request("GET", url, params=params) + if resp.status_code != 200: + return [], None, f"Delta failed: {resp.status_code} - {resp.text}" + data = resp.json() + all_changes.extend(data.get("value", [])) + next_link = data.get("@odata.nextLink") + new_delta_link = data.get("@odata.deltaLink") + if next_link: + url = next_link.replace(GRAPH_API_BASE, "") + params = {} + else: + url = "" + return all_changes, new_delta_link, None diff --git a/surfsense_backend/app/connectors/onedrive/content_extractor.py b/surfsense_backend/app/connectors/onedrive/content_extractor.py new file mode 100644 index 000000000..109a8cb15 --- /dev/null +++ b/surfsense_backend/app/connectors/onedrive/content_extractor.py @@ -0,0 +1,169 @@ +"""Content extraction for OneDrive files. + +Reuses the same ETL parsing logic as Google Drive since file parsing is +extension-based, not provider-specific. +""" + +import asyncio +import logging +import os +import tempfile +import threading +import time +from pathlib import Path +from typing import Any + +from .client import OneDriveClient +from .file_types import get_extension_from_mime, should_skip_file + +logger = logging.getLogger(__name__) + + +async def download_and_extract_content( + client: OneDriveClient, + file: dict[str, Any], +) -> tuple[str | None, dict[str, Any], str | None]: + """Download a OneDrive file and extract its content as markdown. + + Returns (markdown_content, onedrive_metadata, error_message). + """ + item_id = file.get("id") + file_name = file.get("name", "Unknown") + + if should_skip_file(file): + return None, {}, "Skipping non-indexable item" + + file_info = file.get("file", {}) + mime_type = file_info.get("mimeType", "") + + logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})") + + metadata: dict[str, Any] = { + "onedrive_file_id": item_id, + "onedrive_file_name": file_name, + "onedrive_mime_type": mime_type, + "source_connector": "onedrive", + } + if "lastModifiedDateTime" in file: + metadata["modified_time"] = file["lastModifiedDateTime"] + if "createdDateTime" in file: + metadata["created_time"] = file["createdDateTime"] + if "size" in file: + metadata["file_size"] = file["size"] + if "webUrl" in file: + metadata["web_url"] = file["webUrl"] + file_hashes = file_info.get("hashes", {}) + if file_hashes.get("sha256Hash"): + metadata["sha256_hash"] = file_hashes["sha256Hash"] + elif file_hashes.get("quickXorHash"): + metadata["quick_xor_hash"] = file_hashes["quickXorHash"] + + temp_file_path = None + try: + extension = Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin" + with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp: + temp_file_path = tmp.name + + error = await client.download_file_to_disk(item_id, temp_file_path) + if error: + return None, metadata, error + + markdown = await _parse_file_to_markdown(temp_file_path, file_name) + return markdown, metadata, None + + except Exception as e: + logger.warning(f"Failed to extract content from {file_name}: {e!s}") + return None, metadata, str(e) + finally: + if temp_file_path and os.path.exists(temp_file_path): + try: + os.unlink(temp_file_path) + except Exception: + pass + + +async def _parse_file_to_markdown(file_path: str, filename: str) -> str: + """Parse a local file to markdown using the configured ETL service. + + Same logic as Google Drive -- file parsing is extension-based. + """ + lower = filename.lower() + + if lower.endswith((".md", ".markdown", ".txt")): + with open(file_path, encoding="utf-8") as f: + return f.read() + + if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")): + from app.config import config as app_config + from litellm import atranscription + + stt_service_type = ( + "local" + if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/") + else "external" + ) + if stt_service_type == "local": + from app.services.stt_service import stt_service + + t0 = time.monotonic() + logger.info(f"[local-stt] START file={filename} thread={threading.current_thread().name}") + result = await asyncio.to_thread(stt_service.transcribe_file, file_path) + logger.info(f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s") + text = result.get("text", "") + else: + with open(file_path, "rb") as audio_file: + kwargs: dict[str, Any] = { + "model": app_config.STT_SERVICE, + "file": audio_file, + "api_key": app_config.STT_SERVICE_API_KEY, + } + if app_config.STT_SERVICE_API_BASE: + kwargs["api_base"] = app_config.STT_SERVICE_API_BASE + resp = await atranscription(**kwargs) + text = resp.get("text", "") + + if not text: + raise ValueError("Transcription returned empty text") + return f"# Transcription of {filename}\n\n{text}" + + from app.config import config as app_config + + if app_config.ETL_SERVICE == "UNSTRUCTURED": + from langchain_unstructured import UnstructuredLoader + + from app.utils.document_converters import convert_document_to_markdown + + loader = UnstructuredLoader( + file_path, + mode="elements", + post_processors=[], + languages=["eng"], + include_orig_elements=False, + include_metadata=False, + strategy="auto", + ) + docs = await loader.aload() + return await convert_document_to_markdown(docs) + + if app_config.ETL_SERVICE == "LLAMACLOUD": + from app.tasks.document_processors.file_processors import ( + parse_with_llamacloud_retry, + ) + + result = await parse_with_llamacloud_retry(file_path=file_path, estimated_pages=50) + markdown_documents = await result.aget_markdown_documents(split_by_page=False) + if not markdown_documents: + raise RuntimeError(f"LlamaCloud returned no documents for {filename}") + return markdown_documents[0].text + + if app_config.ETL_SERVICE == "DOCLING": + from docling.document_converter import DocumentConverter + + converter = DocumentConverter() + t0 = time.monotonic() + logger.info(f"[docling] START file={filename} thread={threading.current_thread().name}") + result = await asyncio.to_thread(converter.convert, file_path) + logger.info(f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s") + return result.document.export_to_markdown() + + raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}") diff --git a/surfsense_backend/app/connectors/onedrive/file_types.py b/surfsense_backend/app/connectors/onedrive/file_types.py new file mode 100644 index 000000000..403fdc337 --- /dev/null +++ b/surfsense_backend/app/connectors/onedrive/file_types.py @@ -0,0 +1,50 @@ +"""File type handlers for Microsoft OneDrive.""" + +ONEDRIVE_FOLDER_FACET = "folder" +ONENOTE_MIME = "application/msonenote" + +SKIP_MIME_TYPES = frozenset( + { + ONENOTE_MIME, + "application/vnd.ms-onenotesection", + "application/vnd.ms-onenotenotebook", + } +) + +MIME_TO_EXTENSION: dict[str, str] = { + "application/pdf": ".pdf", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", + "application/vnd.ms-excel": ".xls", + "application/msword": ".doc", + "application/vnd.ms-powerpoint": ".ppt", + "text/plain": ".txt", + "text/csv": ".csv", + "text/html": ".html", + "text/markdown": ".md", + "application/json": ".json", + "application/xml": ".xml", + "image/png": ".png", + "image/jpeg": ".jpg", +} + + +def get_extension_from_mime(mime_type: str) -> str | None: + return MIME_TO_EXTENSION.get(mime_type) + + +def is_folder(item: dict) -> bool: + return ONEDRIVE_FOLDER_FACET in item + + +def should_skip_file(item: dict) -> bool: + """Skip folders, OneNote files, remote items (shared links), and packages.""" + if is_folder(item): + return True + if "remoteItem" in item: + return True + if "package" in item: + return True + mime = item.get("file", {}).get("mimeType", "") + return mime in SKIP_MIME_TYPES diff --git a/surfsense_backend/app/connectors/onedrive/folder_manager.py b/surfsense_backend/app/connectors/onedrive/folder_manager.py new file mode 100644 index 000000000..ad04e12ff --- /dev/null +++ b/surfsense_backend/app/connectors/onedrive/folder_manager.py @@ -0,0 +1,90 @@ +"""Folder management for Microsoft OneDrive.""" + +import logging +from typing import Any + +from .client import OneDriveClient +from .file_types import is_folder, should_skip_file + +logger = logging.getLogger(__name__) + + +async def list_folder_contents( + client: OneDriveClient, + parent_id: str | None = None, +) -> tuple[list[dict[str, Any]], str | None]: + """List folders and files in a OneDrive folder. + + Returns (items list with folders first, error message). + """ + try: + items, error = await client.list_children(parent_id or "root") + if error: + return [], error + + for item in items: + item["isFolder"] = is_folder(item) + + items.sort(key=lambda x: (not x["isFolder"], x.get("name", "").lower())) + + folder_count = sum(1 for item in items if item["isFolder"]) + file_count = len(items) - folder_count + logger.info( + f"Listed {len(items)} items ({folder_count} folders, {file_count} files) " + + (f"in folder {parent_id}" if parent_id else "in root") + ) + return items, None + + except Exception as e: + logger.error(f"Error listing folder contents: {e!s}", exc_info=True) + return [], f"Error listing folder contents: {e!s}" + + +async def get_files_in_folder( + client: OneDriveClient, + folder_id: str, + include_subfolders: bool = True, +) -> tuple[list[dict[str, Any]], str | None]: + """Get all indexable files in a folder, optionally recursing into subfolders.""" + try: + items, error = await client.list_children(folder_id) + if error: + return [], error + + files: list[dict[str, Any]] = [] + for item in items: + if is_folder(item): + if include_subfolders: + sub_files, sub_error = await get_files_in_folder( + client, item["id"], include_subfolders=True + ) + if sub_error: + logger.warning(f"Error recursing into folder {item.get('name')}: {sub_error}") + continue + files.extend(sub_files) + elif not should_skip_file(item): + files.append(item) + + return files, None + + except Exception as e: + logger.error(f"Error getting files in folder: {e!s}", exc_info=True) + return [], f"Error getting files in folder: {e!s}" + + +async def get_file_by_id( + client: OneDriveClient, + file_id: str, +) -> tuple[dict[str, Any] | None, str | None]: + """Get file metadata by ID.""" + try: + item, error = await client.get_item_metadata(file_id) + if error: + return None, error + if not item: + return None, f"File not found: {file_id}" + return item, None + + except Exception as e: + logger.error(f"Error getting file by ID: {e!s}", exc_info=True) + return None, f"Error getting file by ID: {e!s}" diff --git a/surfsense_backend/app/db.py b/surfsense_backend/app/db.py index 9680a7bfd..a8510ebab 100644 --- a/surfsense_backend/app/db.py +++ b/surfsense_backend/app/db.py @@ -40,6 +40,7 @@ class DocumentType(StrEnum): FILE = "FILE" SLACK_CONNECTOR = "SLACK_CONNECTOR" TEAMS_CONNECTOR = "TEAMS_CONNECTOR" + ONEDRIVE_FILE = "ONEDRIVE_FILE" NOTION_CONNECTOR = "NOTION_CONNECTOR" YOUTUBE_VIDEO = "YOUTUBE_VIDEO" GITHUB_CONNECTOR = "GITHUB_CONNECTOR" @@ -81,6 +82,7 @@ class SearchSourceConnectorType(StrEnum): BAIDU_SEARCH_API = "BAIDU_SEARCH_API" # Baidu AI Search API for Chinese web search SLACK_CONNECTOR = "SLACK_CONNECTOR" TEAMS_CONNECTOR = "TEAMS_CONNECTOR" + ONEDRIVE_CONNECTOR = "ONEDRIVE_CONNECTOR" NOTION_CONNECTOR = "NOTION_CONNECTOR" GITHUB_CONNECTOR = "GITHUB_CONNECTOR" LINEAR_CONNECTOR = "LINEAR_CONNECTOR" diff --git a/surfsense_backend/app/routes/__init__.py b/surfsense_backend/app/routes/__init__.py index 7782c064c..af26e3680 100644 --- a/surfsense_backend/app/routes/__init__.py +++ b/surfsense_backend/app/routes/__init__.py @@ -43,6 +43,7 @@ from .search_spaces_routes import router as search_spaces_router from .slack_add_connector_route import router as slack_add_connector_router from .surfsense_docs_routes import router as surfsense_docs_router from .teams_add_connector_route import router as teams_add_connector_router +from .onedrive_add_connector_route import router as onedrive_add_connector_router from .video_presentations_routes import router as video_presentations_router from .youtube_routes import router as youtube_router @@ -73,6 +74,7 @@ router.include_router(luma_add_connector_router) router.include_router(notion_add_connector_router) router.include_router(slack_add_connector_router) router.include_router(teams_add_connector_router) +router.include_router(onedrive_add_connector_router) router.include_router(discord_add_connector_router) router.include_router(jira_add_connector_router) router.include_router(confluence_add_connector_router) diff --git a/surfsense_backend/app/routes/onedrive_add_connector_route.py b/surfsense_backend/app/routes/onedrive_add_connector_route.py new file mode 100644 index 000000000..0494888d9 --- /dev/null +++ b/surfsense_backend/app/routes/onedrive_add_connector_route.py @@ -0,0 +1,474 @@ +""" +Microsoft OneDrive Connector OAuth Routes. + +Endpoints: +- GET /auth/onedrive/connector/add - Initiate OAuth +- GET /auth/onedrive/connector/callback - Handle OAuth callback +- GET /auth/onedrive/connector/reauth - Re-authenticate existing connector +- GET /connectors/{connector_id}/onedrive/folders - List folder contents +""" + +import logging +from datetime import UTC, datetime, timedelta +from urllib.parse import urlencode +from uuid import UUID + +import httpx +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import RedirectResponse +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm.attributes import flag_modified + +from app.config import config +from app.connectors.onedrive import OneDriveClient, list_folder_contents +from app.db import ( + SearchSourceConnector, + SearchSourceConnectorType, + User, + get_async_session, +) +from app.users import current_active_user +from app.utils.connector_naming import ( + check_duplicate_connector, + extract_identifier_from_credentials, + generate_unique_connector_name, +) +from app.utils.oauth_security import OAuthStateManager, TokenEncryption + +logger = logging.getLogger(__name__) +router = APIRouter() + +AUTHORIZATION_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" +TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token" + +SCOPES = [ + "offline_access", + "User.Read", + "Files.Read.All", + "Files.ReadWrite.All", +] + +_state_manager = None +_token_encryption = None + + +def get_state_manager() -> OAuthStateManager: + global _state_manager + if _state_manager is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for OAuth security") + _state_manager = OAuthStateManager(config.SECRET_KEY) + return _state_manager + + +def get_token_encryption() -> TokenEncryption: + global _token_encryption + if _token_encryption is None: + if not config.SECRET_KEY: + raise ValueError("SECRET_KEY must be set for token encryption") + _token_encryption = TokenEncryption(config.SECRET_KEY) + return _token_encryption + + +@router.get("/auth/onedrive/connector/add") +async def connect_onedrive(space_id: int, user: User = Depends(current_active_user)): + """Initiate OneDrive OAuth flow.""" + try: + if not space_id: + raise HTTPException(status_code=400, detail="space_id is required") + if not config.ONEDRIVE_CLIENT_ID: + raise HTTPException(status_code=500, detail="Microsoft OneDrive OAuth not configured.") + if not config.SECRET_KEY: + raise HTTPException(status_code=500, detail="SECRET_KEY not configured for OAuth security.") + + state_manager = get_state_manager() + state_encoded = state_manager.generate_secure_state(space_id, user.id) + + auth_params = { + "client_id": config.ONEDRIVE_CLIENT_ID, + "response_type": "code", + "redirect_uri": config.ONEDRIVE_REDIRECT_URI, + "response_mode": "query", + "scope": " ".join(SCOPES), + "state": state_encoded, + } + auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}" + + logger.info("Generated OneDrive OAuth URL for user %s, space %s", user.id, space_id) + return {"auth_url": auth_url} + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to initiate OneDrive OAuth: %s", str(e), exc_info=True) + raise HTTPException(status_code=500, detail=f"Failed to initiate OneDrive OAuth: {e!s}") from e + + +@router.get("/auth/onedrive/connector/reauth") +async def reauth_onedrive( + space_id: int, + connector_id: int, + return_url: str | None = None, + user: User = Depends(current_active_user), + session: AsyncSession = Depends(get_async_session), +): + """Re-authenticate an existing OneDrive connector.""" + try: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.user_id == user.id, + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + raise HTTPException(status_code=404, detail="OneDrive connector not found or access denied") + + if not config.SECRET_KEY: + raise HTTPException(status_code=500, detail="SECRET_KEY not configured for OAuth security.") + + state_manager = get_state_manager() + extra: dict = {"connector_id": connector_id} + if return_url and return_url.startswith("/"): + extra["return_url"] = return_url + state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra) + + auth_params = { + "client_id": config.ONEDRIVE_CLIENT_ID, + "response_type": "code", + "redirect_uri": config.ONEDRIVE_REDIRECT_URI, + "response_mode": "query", + "scope": " ".join(SCOPES), + "state": state_encoded, + "prompt": "consent", + } + auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}" + + logger.info("Initiating OneDrive re-auth for user %s, connector %s", user.id, connector_id) + return {"auth_url": auth_url} + + except HTTPException: + raise + except Exception as e: + logger.error("Failed to initiate OneDrive re-auth: %s", str(e), exc_info=True) + raise HTTPException(status_code=500, detail=f"Failed to initiate OneDrive re-auth: {e!s}") from e + + +@router.get("/auth/onedrive/connector/callback") +async def onedrive_callback( + code: str | None = None, + error: str | None = None, + error_description: str | None = None, + state: str | None = None, + session: AsyncSession = Depends(get_async_session), +): + """Handle OneDrive OAuth callback.""" + try: + if error: + error_msg = error_description or error + logger.warning("OneDrive OAuth error: %s", error_msg) + space_id = None + if state: + try: + data = get_state_manager().validate_state(state) + space_id = data.get("space_id") + except Exception: + pass + if space_id: + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=onedrive_oauth_denied" + ) + return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=onedrive_oauth_denied") + + if not code or not state: + raise HTTPException(status_code=400, detail="Missing required OAuth parameters") + + state_manager = get_state_manager() + try: + data = state_manager.validate_state(state) + space_id = data["space_id"] + user_id = UUID(data["user_id"]) + except (HTTPException, ValueError, KeyError) as e: + logger.error("Invalid OAuth state: %s", str(e)) + return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=invalid_state") + + reauth_connector_id = data.get("connector_id") + reauth_return_url = data.get("return_url") + + token_data = { + "client_id": config.ONEDRIVE_CLIENT_ID, + "client_secret": config.ONEDRIVE_CLIENT_SECRET, + "code": code, + "redirect_uri": config.ONEDRIVE_REDIRECT_URI, + "grant_type": "authorization_code", + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, + data=token_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + except Exception: + pass + raise HTTPException(status_code=400, detail=f"Token exchange failed: {error_detail}") + + token_json = token_response.json() + access_token = token_json.get("access_token") + refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException(status_code=400, detail="No access token received from Microsoft") + + token_encryption = get_token_encryption() + + expires_at = None + if token_json.get("expires_in"): + expires_at = datetime.now(UTC) + timedelta(seconds=int(token_json["expires_in"])) + + user_info: dict = {} + try: + async with httpx.AsyncClient() as client: + user_response = await client.get( + "https://graph.microsoft.com/v1.0/me", + headers={"Authorization": f"Bearer {access_token}"}, + timeout=30.0, + ) + if user_response.status_code == 200: + user_data = user_response.json() + user_info = { + "user_email": user_data.get("mail") or user_data.get("userPrincipalName"), + "user_name": user_data.get("displayName"), + } + except Exception as e: + logger.warning("Failed to fetch user info from Graph: %s", str(e)) + + connector_config = { + "access_token": token_encryption.encrypt_token(access_token), + "refresh_token": token_encryption.encrypt_token(refresh_token) if refresh_token else None, + "token_type": token_json.get("token_type", "Bearer"), + "expires_in": token_json.get("expires_in"), + "expires_at": expires_at.isoformat() if expires_at else None, + "scope": token_json.get("scope"), + "user_email": user_info.get("user_email"), + "user_name": user_info.get("user_name"), + "_token_encrypted": True, + } + + # Handle re-authentication + if reauth_connector_id: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == reauth_connector_id, + SearchSourceConnector.user_id == user_id, + SearchSourceConnector.search_space_id == space_id, + SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + ) + ) + db_connector = result.scalars().first() + if not db_connector: + raise HTTPException(status_code=404, detail="Connector not found or access denied during re-auth") + + existing_delta_link = db_connector.config.get("delta_link") + db_connector.config = {**connector_config, "delta_link": existing_delta_link, "auth_expired": False} + flag_modified(db_connector, "config") + await session.commit() + await session.refresh(db_connector) + + logger.info("Re-authenticated OneDrive connector %s for user %s", db_connector.id, user_id) + if reauth_return_url and reauth_return_url.startswith("/"): + return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}") + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=ONEDRIVE_CONNECTOR&connectorId={db_connector.id}" + ) + + # New connector -- check for duplicates + connector_identifier = extract_identifier_from_credentials( + SearchSourceConnectorType.ONEDRIVE_CONNECTOR, connector_config + ) + is_duplicate = await check_duplicate_connector( + session, SearchSourceConnectorType.ONEDRIVE_CONNECTOR, space_id, user_id, connector_identifier, + ) + if is_duplicate: + logger.warning("Duplicate OneDrive connector for user %s, space %s", user_id, space_id) + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=ONEDRIVE_CONNECTOR" + ) + + connector_name = await generate_unique_connector_name( + session, SearchSourceConnectorType.ONEDRIVE_CONNECTOR, space_id, user_id, connector_identifier, + ) + + new_connector = SearchSourceConnector( + name=connector_name, + connector_type=SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + is_indexable=True, + config=connector_config, + search_space_id=space_id, + user_id=user_id, + ) + + try: + session.add(new_connector) + await session.commit() + await session.refresh(new_connector) + logger.info("Successfully created OneDrive connector %s for user %s", new_connector.id, user_id) + return RedirectResponse( + url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=ONEDRIVE_CONNECTOR&connectorId={new_connector.id}" + ) + except IntegrityError as e: + await session.rollback() + logger.error("Database integrity error creating OneDrive connector: %s", str(e)) + return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=connector_creation_failed") + + except HTTPException: + raise + except (IntegrityError, ValueError) as e: + logger.error("OneDrive OAuth callback error: %s", str(e), exc_info=True) + return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=onedrive_auth_error") + + +@router.get("/connectors/{connector_id}/onedrive/folders") +async def list_onedrive_folders( + connector_id: int, + parent_id: str | None = None, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_active_user), +): + """List folders and files in user's OneDrive.""" + connector = None + try: + result = await session.execute( + select(SearchSourceConnector).filter( + SearchSourceConnector.id == connector_id, + SearchSourceConnector.user_id == user.id, + SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR, + ) + ) + connector = result.scalars().first() + if not connector: + raise HTTPException(status_code=404, detail="OneDrive connector not found or access denied") + + onedrive_client = OneDriveClient(session, connector_id) + items, error = await list_folder_contents(onedrive_client, parent_id=parent_id) + + if error: + error_lower = error.lower() + if "401" in error or "authentication expired" in error_lower or "invalid_grant" in error_lower: + try: + if connector and not connector.config.get("auth_expired"): + connector.config = {**connector.config, "auth_expired": True} + flag_modified(connector, "config") + await session.commit() + except Exception: + logger.warning("Failed to persist auth_expired for connector %s", connector_id, exc_info=True) + raise HTTPException(status_code=400, detail="OneDrive authentication expired. Please re-authenticate.") + raise HTTPException(status_code=500, detail=f"Failed to list folder contents: {error}") + + return {"items": items} + + except HTTPException: + raise + except Exception as e: + logger.error("Error listing OneDrive contents: %s", str(e), exc_info=True) + error_lower = str(e).lower() + if "401" in str(e) or "authentication expired" in error_lower: + try: + if connector and not connector.config.get("auth_expired"): + connector.config = {**connector.config, "auth_expired": True} + flag_modified(connector, "config") + await session.commit() + except Exception: + pass + raise HTTPException(status_code=400, detail="OneDrive authentication expired. Please re-authenticate.") from e + raise HTTPException(status_code=500, detail=f"Failed to list OneDrive contents: {e!s}") from e + + +async def refresh_onedrive_token( + session: AsyncSession, connector: SearchSourceConnector +) -> SearchSourceConnector: + """Refresh OneDrive OAuth tokens.""" + logger.info("Refreshing OneDrive OAuth tokens for connector %s", connector.id) + + token_encryption = get_token_encryption() + is_encrypted = connector.config.get("_token_encrypted", False) + refresh_token = connector.config.get("refresh_token") + + if is_encrypted and refresh_token: + try: + refresh_token = token_encryption.decrypt_token(refresh_token) + except Exception as e: + logger.error("Failed to decrypt refresh token: %s", str(e)) + raise HTTPException(status_code=500, detail="Failed to decrypt stored refresh token") from e + + if not refresh_token: + raise HTTPException(status_code=400, detail=f"No refresh token available for connector {connector.id}") + + refresh_data = { + "client_id": config.ONEDRIVE_CLIENT_ID, + "client_secret": config.ONEDRIVE_CLIENT_SECRET, + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "scope": " ".join(SCOPES), + } + + async with httpx.AsyncClient() as client: + token_response = await client.post( + TOKEN_URL, data=refresh_data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, timeout=30.0, + ) + + if token_response.status_code != 200: + error_detail = token_response.text + error_code = "" + try: + error_json = token_response.json() + error_detail = error_json.get("error_description", error_detail) + error_code = error_json.get("error", "") + except Exception: + pass + error_lower = (error_detail + error_code).lower() + if "invalid_grant" in error_lower or "expired" in error_lower or "revoked" in error_lower: + raise HTTPException(status_code=401, detail="OneDrive authentication failed. Please re-authenticate.") + raise HTTPException(status_code=400, detail=f"Token refresh failed: {error_detail}") + + token_json = token_response.json() + access_token = token_json.get("access_token") + new_refresh_token = token_json.get("refresh_token") + + if not access_token: + raise HTTPException(status_code=400, detail="No access token received from Microsoft refresh") + + expires_at = None + expires_in = token_json.get("expires_in") + if expires_in: + expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in)) + + cfg = dict(connector.config) + cfg["access_token"] = token_encryption.encrypt_token(access_token) + if new_refresh_token: + cfg["refresh_token"] = token_encryption.encrypt_token(new_refresh_token) + cfg["expires_in"] = expires_in + cfg["expires_at"] = expires_at.isoformat() if expires_at else None + cfg["scope"] = token_json.get("scope") + cfg["_token_encrypted"] = True + cfg.pop("auth_expired", None) + + connector.config = cfg + flag_modified(connector, "config") + await session.commit() + await session.refresh(connector) + + logger.info("Successfully refreshed OneDrive tokens for connector %s", connector.id) + return connector diff --git a/surfsense_backend/app/routes/search_source_connectors_routes.py b/surfsense_backend/app/routes/search_source_connectors_routes.py index bef2329d8..2183e3677 100644 --- a/surfsense_backend/app/routes/search_source_connectors_routes.py +++ b/surfsense_backend/app/routes/search_source_connectors_routes.py @@ -999,6 +999,53 @@ async def index_connector_content( ) response_message = "Google Drive indexing started in the background." + elif connector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR: + from app.tasks.celery_tasks.connector_tasks import ( + index_onedrive_files_task, + ) + + if drive_items and drive_items.has_items(): + logger.info( + f"Triggering OneDrive indexing for connector {connector_id} into search space {search_space_id}, " + f"folders: {len(drive_items.folders)}, files: {len(drive_items.files)}" + ) + items_dict = drive_items.model_dump() + else: + config = connector.config or {} + selected_folders = config.get("selected_folders", []) + selected_files = config.get("selected_files", []) + if not selected_folders and not selected_files: + raise HTTPException( + status_code=400, + detail="OneDrive indexing requires folders or files to be configured. " + "Please select folders/files to index.", + ) + indexing_options = config.get( + "indexing_options", + { + "max_files_per_folder": 100, + "incremental_sync": True, + "include_subfolders": True, + }, + ) + items_dict = { + "folders": selected_folders, + "files": selected_files, + "indexing_options": indexing_options, + } + logger.info( + f"Triggering OneDrive indexing for connector {connector_id} into search space {search_space_id} " + f"using existing config" + ) + + index_onedrive_files_task.delay( + connector_id, + search_space_id, + str(user.id), + items_dict, + ) + response_message = "OneDrive indexing started in the background." + elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR: from app.tasks.celery_tasks.connector_tasks import ( index_discord_messages_task, @@ -2485,6 +2532,108 @@ async def run_google_drive_indexing( logger.error(f"Failed to update notification: {notif_error!s}") +async def run_onedrive_indexing( + session: AsyncSession, + connector_id: int, + search_space_id: int, + user_id: str, + items_dict: dict, +): + """Runs the OneDrive indexing task for folders and files with notifications.""" + from uuid import UUID + + notification = None + try: + from app.tasks.connector_indexers.onedrive_indexer import index_onedrive_files + + connector_result = await session.execute( + select(SearchSourceConnector).where( + SearchSourceConnector.id == connector_id + ) + ) + connector = connector_result.scalar_one_or_none() + + if connector: + notification = await NotificationService.connector_indexing.notify_google_drive_indexing_started( + session=session, + user_id=UUID(user_id), + connector_id=connector_id, + connector_name=connector.name, + connector_type=connector.connector_type.value, + search_space_id=search_space_id, + folder_count=len(items_dict.get("folders", [])), + file_count=len(items_dict.get("files", [])), + folder_names=[f.get("name", "Unknown") for f in items_dict.get("folders", [])], + file_names=[f.get("name", "Unknown") for f in items_dict.get("files", [])], + ) + + if notification: + await NotificationService.connector_indexing.notify_indexing_progress( + session=session, + notification=notification, + indexed_count=0, + stage="fetching", + ) + + total_indexed, total_skipped, error_message = await index_onedrive_files( + session, + connector_id, + search_space_id, + user_id, + items_dict, + ) + + if error_message: + logger.error( + f"OneDrive indexing completed with errors for connector {connector_id}: {error_message}" + ) + if _is_auth_error(error_message): + await _persist_auth_expired(session, connector_id) + error_message = "OneDrive authentication expired. Please re-authenticate." + else: + if notification: + await session.refresh(notification) + await NotificationService.connector_indexing.notify_indexing_progress( + session=session, + notification=notification, + indexed_count=total_indexed, + stage="storing", + ) + + logger.info( + f"OneDrive indexing successful for connector {connector_id}. Indexed {total_indexed} documents." + ) + await _update_connector_timestamp_by_id(session, connector_id) + await session.commit() + + if notification: + await session.refresh(notification) + await NotificationService.connector_indexing.notify_indexing_completed( + session=session, + notification=notification, + indexed_count=total_indexed, + error_message=error_message, + skipped_count=total_skipped, + ) + + except Exception as e: + logger.error( + f"Critical error in run_onedrive_indexing for connector {connector_id}: {e}", + exc_info=True, + ) + if notification: + try: + await session.refresh(notification) + await NotificationService.connector_indexing.notify_indexing_completed( + session=session, + notification=notification, + indexed_count=0, + error_message=str(e), + ) + except Exception as notif_error: + logger.error(f"Failed to update notification: {notif_error!s}") + + # Add new helper functions for luma indexing async def run_luma_indexing_with_new_session( connector_id: int, diff --git a/surfsense_backend/app/schemas/onedrive_auth_credentials.py b/surfsense_backend/app/schemas/onedrive_auth_credentials.py new file mode 100644 index 000000000..7690a2694 --- /dev/null +++ b/surfsense_backend/app/schemas/onedrive_auth_credentials.py @@ -0,0 +1,71 @@ +"""Microsoft OneDrive OAuth credentials schema.""" + +from datetime import UTC, datetime + +from pydantic import BaseModel, field_validator + + +class OneDriveAuthCredentialsBase(BaseModel): + """Microsoft OneDrive OAuth credentials.""" + + access_token: str + refresh_token: str | None = None + token_type: str = "Bearer" + expires_in: int | None = None + expires_at: datetime | None = None + scope: str | None = None + user_email: str | None = None + user_name: str | None = None + tenant_id: str | None = None + + @property + def is_expired(self) -> bool: + if self.expires_at is None: + return False + return self.expires_at <= datetime.now(UTC) + + @property + def is_refreshable(self) -> bool: + return self.refresh_token is not None + + def to_dict(self) -> dict: + return { + "access_token": self.access_token, + "refresh_token": self.refresh_token, + "token_type": self.token_type, + "expires_in": self.expires_in, + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "scope": self.scope, + "user_email": self.user_email, + "user_name": self.user_name, + "tenant_id": self.tenant_id, + } + + @classmethod + def from_dict(cls, data: dict) -> "OneDriveAuthCredentialsBase": + expires_at = None + if data.get("expires_at"): + expires_at = datetime.fromisoformat(data["expires_at"]) + return cls( + access_token=data.get("access_token", ""), + refresh_token=data.get("refresh_token"), + token_type=data.get("token_type", "Bearer"), + expires_in=data.get("expires_in"), + expires_at=expires_at, + scope=data.get("scope"), + user_email=data.get("user_email"), + user_name=data.get("user_name"), + tenant_id=data.get("tenant_id"), + ) + + @field_validator("expires_at", mode="before") + @classmethod + def ensure_aware_utc(cls, v): + if isinstance(v, str): + if v.endswith("Z"): + return datetime.fromisoformat(v.replace("Z", "+00:00")) + dt = datetime.fromisoformat(v) + return dt if dt.tzinfo else dt.replace(tzinfo=UTC) + if isinstance(v, datetime): + return v if v.tzinfo else v.replace(tzinfo=UTC) + return v diff --git a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py index 9d52add9c..9eccbc798 100644 --- a/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py +++ b/surfsense_backend/app/tasks/celery_tasks/connector_tasks.py @@ -526,6 +526,54 @@ async def _index_google_drive_files( ) +@celery_app.task(name="index_onedrive_files", bind=True) +def index_onedrive_files_task( + self, + connector_id: int, + search_space_id: int, + user_id: str, + items_dict: dict, +): + """Celery task to index OneDrive folders and files.""" + import asyncio + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + loop.run_until_complete( + _index_onedrive_files( + connector_id, + search_space_id, + user_id, + items_dict, + ) + ) + finally: + loop.close() + + +async def _index_onedrive_files( + connector_id: int, + search_space_id: int, + user_id: str, + items_dict: dict, +): + """Index OneDrive folders and files with new session.""" + from app.routes.search_source_connectors_routes import ( + run_onedrive_indexing, + ) + + async with get_celery_session_maker()() as session: + await run_onedrive_indexing( + session, + connector_id, + search_space_id, + user_id, + items_dict, + ) + + @celery_app.task(name="index_discord_messages", bind=True) def index_discord_messages_task( self, diff --git a/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py new file mode 100644 index 000000000..e565f6a6a --- /dev/null +++ b/surfsense_backend/app/tasks/connector_indexers/onedrive_indexer.py @@ -0,0 +1,606 @@ +"""OneDrive indexer using the shared IndexingPipelineService. + +File-level pre-filter (_should_skip_file) handles hash/modifiedDateTime +checks and rename-only detection. download_and_extract_content() +returns markdown which is fed into ConnectorDocument -> pipeline. +""" + +import asyncio +import logging +import time +from collections.abc import Awaitable, Callable + +from sqlalchemy import String, cast, select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.attributes import flag_modified + +from app.config import config +from app.connectors.onedrive import ( + OneDriveClient, + download_and_extract_content, + get_file_by_id, + get_files_in_folder, +) +from app.connectors.onedrive.file_types import should_skip_file as skip_item +from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType +from app.indexing_pipeline.connector_document import ConnectorDocument +from app.indexing_pipeline.document_hashing import compute_identifier_hash +from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService +from app.services.llm_service import get_user_long_context_llm +from app.services.task_logging_service import TaskLoggingService +from app.tasks.connector_indexers.base import ( + check_document_by_unique_identifier, + get_connector_by_id, + update_connector_last_indexed, +) + +HeartbeatCallbackType = Callable[[int], Awaitable[None]] +HEARTBEAT_INTERVAL_SECONDS = 30 + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +async def _should_skip_file( + session: AsyncSession, + file: dict, + search_space_id: int, +) -> tuple[bool, str | None]: + """Pre-filter: detect unchanged / rename-only files.""" + file_id = file.get("id") + file_name = file.get("name", "Unknown") + + if skip_item(file): + return True, "folder/onenote/remote" + if not file_id: + return True, "missing file_id" + + primary_hash = compute_identifier_hash( + DocumentType.ONEDRIVE_FILE.value, file_id, search_space_id + ) + existing = await check_document_by_unique_identifier(session, primary_hash) + + if not existing: + result = await session.execute( + select(Document).where( + Document.search_space_id == search_space_id, + Document.document_type == DocumentType.ONEDRIVE_FILE, + cast(Document.document_metadata["onedrive_file_id"], String) == file_id, + ) + ) + existing = result.scalar_one_or_none() + if existing: + existing.unique_identifier_hash = primary_hash + logger.debug(f"Found OneDrive doc by metadata for file_id: {file_id}") + + if not existing: + return False, None + + incoming_mtime = file.get("lastModifiedDateTime") + meta = existing.document_metadata or {} + stored_mtime = meta.get("modified_time") + + file_info = file.get("file", {}) + file_hashes = file_info.get("hashes", {}) + incoming_hash = file_hashes.get("sha256Hash") or file_hashes.get("quickXorHash") + stored_hash = meta.get("sha256_hash") or meta.get("quick_xor_hash") + + content_unchanged = False + if incoming_hash and stored_hash: + content_unchanged = incoming_hash == stored_hash + elif incoming_hash and not stored_hash: + return False, None + elif not incoming_hash and incoming_mtime and stored_mtime: + content_unchanged = incoming_mtime == stored_mtime + elif not incoming_hash: + return False, None + + if not content_unchanged: + return False, None + + old_name = meta.get("onedrive_file_name") + if old_name and old_name != file_name: + existing.title = file_name + if not existing.document_metadata: + existing.document_metadata = {} + existing.document_metadata["onedrive_file_name"] = file_name + if incoming_mtime: + existing.document_metadata["modified_time"] = incoming_mtime + flag_modified(existing, "document_metadata") + await session.commit() + logger.info(f"Rename-only update: '{old_name}' -> '{file_name}'") + return True, f"File renamed: '{old_name}' -> '{file_name}'" + + if not DocumentStatus.is_state(existing.status, DocumentStatus.READY): + return True, "skipped (previously failed)" + return True, "unchanged" + + +def _build_connector_doc( + file: dict, + markdown: str, + onedrive_metadata: dict, + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, +) -> ConnectorDocument: + file_id = file.get("id", "") + file_name = file.get("name", "Unknown") + + metadata = { + **onedrive_metadata, + "connector_id": connector_id, + "document_type": "OneDrive File", + "connector_type": "OneDrive", + } + + fallback_summary = f"File: {file_name}\n\n{markdown[:4000]}" + + return ConnectorDocument( + title=file_name, + source_markdown=markdown, + unique_id=file_id, + document_type=DocumentType.ONEDRIVE_FILE, + search_space_id=search_space_id, + connector_id=connector_id, + created_by_id=user_id, + should_summarize=enable_summary, + fallback_summary=fallback_summary, + metadata=metadata, + ) + + +async def _download_files_parallel( + onedrive_client: OneDriveClient, + files: list[dict], + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, + max_concurrency: int = 3, + on_heartbeat: HeartbeatCallbackType | None = None, +) -> tuple[list[ConnectorDocument], int]: + """Download and ETL files in parallel. Returns (docs, failed_count).""" + results: list[ConnectorDocument] = [] + sem = asyncio.Semaphore(max_concurrency) + last_heartbeat = time.time() + completed_count = 0 + hb_lock = asyncio.Lock() + + async def _download_one(file: dict) -> ConnectorDocument | None: + nonlocal last_heartbeat, completed_count + async with sem: + markdown, od_metadata, error = await download_and_extract_content( + onedrive_client, file + ) + if error or not markdown: + file_name = file.get("name", "Unknown") + reason = error or "empty content" + logger.warning(f"Download/ETL failed for {file_name}: {reason}") + return None + doc = _build_connector_doc( + file, markdown, od_metadata, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=enable_summary, + ) + async with hb_lock: + completed_count += 1 + if on_heartbeat: + now = time.time() + if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS: + await on_heartbeat(completed_count) + last_heartbeat = now + return doc + + tasks = [_download_one(f) for f in files] + outcomes = await asyncio.gather(*tasks, return_exceptions=True) + + failed = 0 + for outcome in outcomes: + if isinstance(outcome, Exception): + failed += 1 + elif outcome is None: + failed += 1 + else: + results.append(outcome) + + return results, failed + + +async def _download_and_index( + onedrive_client: OneDriveClient, + session: AsyncSession, + files: list[dict], + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, + on_heartbeat: HeartbeatCallbackType | None = None, +) -> tuple[int, int]: + """Parallel download then parallel indexing. Returns (batch_indexed, total_failed).""" + connector_docs, download_failed = await _download_files_parallel( + onedrive_client, files, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=enable_summary, + on_heartbeat=on_heartbeat, + ) + + batch_indexed = 0 + batch_failed = 0 + if connector_docs: + pipeline = IndexingPipelineService(session) + + async def _get_llm(s): + return await get_user_long_context_llm(s, user_id, search_space_id) + + _, batch_indexed, batch_failed = await pipeline.index_batch_parallel( + connector_docs, _get_llm, max_concurrency=3, + on_heartbeat=on_heartbeat, + ) + + return batch_indexed, download_failed + batch_failed + + +async def _remove_document(session: AsyncSession, file_id: str, search_space_id: int): + """Remove a document that was deleted in OneDrive.""" + primary_hash = compute_identifier_hash( + DocumentType.ONEDRIVE_FILE.value, file_id, search_space_id + ) + existing = await check_document_by_unique_identifier(session, primary_hash) + + if not existing: + result = await session.execute( + select(Document).where( + Document.search_space_id == search_space_id, + Document.document_type == DocumentType.ONEDRIVE_FILE, + cast(Document.document_metadata["onedrive_file_id"], String) == file_id, + ) + ) + existing = result.scalar_one_or_none() + + if existing: + await session.delete(existing) + logger.info(f"Removed deleted OneDrive file document: {file_id}") + + +async def _index_selected_files( + onedrive_client: OneDriveClient, + session: AsyncSession, + file_ids: list[tuple[str, str | None]], + *, + connector_id: int, + search_space_id: int, + user_id: str, + enable_summary: bool, + on_heartbeat: HeartbeatCallbackType | None = None, +) -> tuple[int, int, list[str]]: + """Index user-selected files using the parallel pipeline.""" + files_to_download: list[dict] = [] + errors: list[str] = [] + renamed_count = 0 + skipped = 0 + + for file_id, file_name in file_ids: + file, error = await get_file_by_id(onedrive_client, file_id) + if error or not file: + display = file_name or file_id + errors.append(f"File '{display}': {error or 'File not found'}") + continue + + skip, msg = await _should_skip_file(session, file, search_space_id) + if skip: + if msg and "renamed" in msg.lower(): + renamed_count += 1 + else: + skipped += 1 + continue + + files_to_download.append(file) + + batch_indexed, failed = await _download_and_index( + onedrive_client, session, files_to_download, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=enable_summary, + on_heartbeat=on_heartbeat, + ) + + return renamed_count + batch_indexed, skipped, errors + + +# --------------------------------------------------------------------------- +# Scan strategies +# --------------------------------------------------------------------------- + +async def _index_full_scan( + onedrive_client: OneDriveClient, + session: AsyncSession, + connector_id: int, + search_space_id: int, + user_id: str, + folder_id: str, + folder_name: str, + task_logger: TaskLoggingService, + log_entry: object, + max_files: int, + include_subfolders: bool = True, + on_heartbeat_callback: HeartbeatCallbackType | None = None, + enable_summary: bool = True, +) -> tuple[int, int]: + """Full scan indexing of a folder.""" + await task_logger.log_task_progress( + log_entry, + f"Starting full scan of folder: {folder_name}", + {"stage": "full_scan", "folder_id": folder_id, "include_subfolders": include_subfolders}, + ) + + renamed_count = 0 + skipped = 0 + files_to_download: list[dict] = [] + + all_files, error = await get_files_in_folder( + onedrive_client, folder_id, include_subfolders=include_subfolders, + ) + if error: + err_lower = error.lower() + if "401" in error or "authentication expired" in err_lower: + raise Exception(f"OneDrive authentication failed. Please re-authenticate. (Error: {error})") + raise Exception(f"Failed to list OneDrive files: {error}") + + for file in all_files[:max_files]: + skip, msg = await _should_skip_file(session, file, search_space_id) + if skip: + if msg and "renamed" in msg.lower(): + renamed_count += 1 + else: + skipped += 1 + continue + files_to_download.append(file) + + batch_indexed, failed = await _download_and_index( + onedrive_client, session, files_to_download, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=enable_summary, + on_heartbeat=on_heartbeat_callback, + ) + + indexed = renamed_count + batch_indexed + logger.info(f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed") + return indexed, skipped + + +async def _index_with_delta_sync( + onedrive_client: OneDriveClient, + session: AsyncSession, + connector_id: int, + search_space_id: int, + user_id: str, + folder_id: str | None, + delta_link: str, + task_logger: TaskLoggingService, + log_entry: object, + max_files: int, + on_heartbeat_callback: HeartbeatCallbackType | None = None, + enable_summary: bool = True, +) -> tuple[int, int, str | None]: + """Delta sync using OneDrive change tracking. Returns (indexed, skipped, new_delta_link).""" + await task_logger.log_task_progress( + log_entry, "Starting delta sync", + {"stage": "delta_sync"}, + ) + + changes, new_delta_link, error = await onedrive_client.get_delta( + folder_id=folder_id, delta_link=delta_link + ) + if error: + err_lower = error.lower() + if "401" in error or "authentication expired" in err_lower: + raise Exception(f"OneDrive authentication failed. Please re-authenticate. (Error: {error})") + raise Exception(f"Failed to fetch OneDrive changes: {error}") + + if not changes: + logger.info("No changes detected since last sync") + return 0, 0, new_delta_link + + logger.info(f"Processing {len(changes)} delta changes") + + renamed_count = 0 + skipped = 0 + files_to_download: list[dict] = [] + files_processed = 0 + + for change in changes: + if files_processed >= max_files: + break + files_processed += 1 + + if change.get("deleted"): + fid = change.get("id") + if fid: + await _remove_document(session, fid, search_space_id) + continue + + if "folder" in change: + continue + + if not change.get("file"): + continue + + skip, msg = await _should_skip_file(session, change, search_space_id) + if skip: + if msg and "renamed" in msg.lower(): + renamed_count += 1 + else: + skipped += 1 + continue + + files_to_download.append(change) + + batch_indexed, failed = await _download_and_index( + onedrive_client, session, files_to_download, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=enable_summary, + on_heartbeat=on_heartbeat_callback, + ) + + indexed = renamed_count + batch_indexed + logger.info(f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed") + return indexed, skipped, new_delta_link + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + +async def index_onedrive_files( + session: AsyncSession, + connector_id: int, + search_space_id: int, + user_id: str, + items_dict: dict, +) -> tuple[int, int, str | None]: + """Index OneDrive files for a specific connector. + + items_dict format: + { + "folders": [{"id": "...", "name": "..."}, ...], + "files": [{"id": "...", "name": "..."}, ...], + "indexing_options": {"max_files": 500, "include_subfolders": true, "use_delta_sync": true} + } + """ + task_logger = TaskLoggingService(session, search_space_id) + log_entry = await task_logger.log_task_start( + task_name="onedrive_files_indexing", + source="connector_indexing_task", + message=f"Starting OneDrive indexing for connector {connector_id}", + metadata={"connector_id": connector_id, "user_id": str(user_id)}, + ) + + try: + connector = await get_connector_by_id( + session, connector_id, SearchSourceConnectorType.ONEDRIVE_CONNECTOR + ) + if not connector: + error_msg = f"OneDrive connector with ID {connector_id} not found" + await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"}) + return 0, 0, error_msg + + token_encrypted = connector.config.get("_token_encrypted", False) + if token_encrypted and not config.SECRET_KEY: + error_msg = "SECRET_KEY not configured but credentials are encrypted" + await task_logger.log_task_failure(log_entry, error_msg, "Missing SECRET_KEY", {"error_type": "MissingSecretKey"}) + return 0, 0, error_msg + + connector_enable_summary = getattr(connector, "enable_summary", True) + onedrive_client = OneDriveClient(session, connector_id) + + indexing_options = items_dict.get("indexing_options", {}) + max_files = indexing_options.get("max_files", 500) + include_subfolders = indexing_options.get("include_subfolders", True) + use_delta_sync = indexing_options.get("use_delta_sync", True) + + total_indexed = 0 + total_skipped = 0 + + # Index selected individual files + selected_files = items_dict.get("files", []) + if selected_files: + file_tuples = [(f["id"], f.get("name")) for f in selected_files] + indexed, skipped, errors = await _index_selected_files( + onedrive_client, session, file_tuples, + connector_id=connector_id, search_space_id=search_space_id, + user_id=user_id, enable_summary=connector_enable_summary, + ) + total_indexed += indexed + total_skipped += skipped + + # Index selected folders + folders = items_dict.get("folders", []) + for folder in folders: + folder_id = folder.get("id", "root") + folder_name = folder.get("name", "Root") + + folder_delta_links = connector.config.get("folder_delta_links", {}) + delta_link = folder_delta_links.get(folder_id) + can_use_delta = use_delta_sync and delta_link and connector.last_indexed_at + + if can_use_delta: + logger.info(f"Using delta sync for folder {folder_name}") + indexed, skipped, new_delta_link = await _index_with_delta_sync( + onedrive_client, session, connector_id, search_space_id, user_id, + folder_id, delta_link, task_logger, log_entry, max_files, + enable_summary=connector_enable_summary, + ) + total_indexed += indexed + total_skipped += skipped + + if new_delta_link: + await session.refresh(connector) + if "folder_delta_links" not in connector.config: + connector.config["folder_delta_links"] = {} + connector.config["folder_delta_links"][folder_id] = new_delta_link + flag_modified(connector, "config") + + # Reconciliation full scan + ri, rs = await _index_full_scan( + onedrive_client, session, connector_id, search_space_id, user_id, + folder_id, folder_name, task_logger, log_entry, max_files, + include_subfolders, enable_summary=connector_enable_summary, + ) + total_indexed += ri + total_skipped += rs + else: + logger.info(f"Using full scan for folder {folder_name}") + indexed, skipped = await _index_full_scan( + onedrive_client, session, connector_id, search_space_id, user_id, + folder_id, folder_name, task_logger, log_entry, max_files, + include_subfolders, enable_summary=connector_enable_summary, + ) + total_indexed += indexed + total_skipped += skipped + + # Store new delta link for this folder + _, new_delta_link, _ = await onedrive_client.get_delta(folder_id=folder_id) + if new_delta_link: + await session.refresh(connector) + if "folder_delta_links" not in connector.config: + connector.config["folder_delta_links"] = {} + connector.config["folder_delta_links"][folder_id] = new_delta_link + flag_modified(connector, "config") + + if total_indexed > 0 or folders: + await update_connector_last_indexed(session, connector, True) + + await session.commit() + + await task_logger.log_task_success( + log_entry, + f"Successfully completed OneDrive indexing for connector {connector_id}", + {"files_processed": total_indexed, "files_skipped": total_skipped}, + ) + logger.info(f"OneDrive indexing completed: {total_indexed} indexed, {total_skipped} skipped") + return total_indexed, total_skipped, None + + except SQLAlchemyError as db_error: + await session.rollback() + await task_logger.log_task_failure( + log_entry, f"Database error during OneDrive indexing for connector {connector_id}", + str(db_error), {"error_type": "SQLAlchemyError"}, + ) + logger.error(f"Database error: {db_error!s}", exc_info=True) + return 0, 0, f"Database error: {db_error!s}" + except Exception as e: + await session.rollback() + await task_logger.log_task_failure( + log_entry, f"Failed to index OneDrive files for connector {connector_id}", + str(e), {"error_type": type(e).__name__}, + ) + logger.error(f"Failed to index OneDrive files: {e!s}", exc_info=True) + return 0, 0, f"Failed to index OneDrive files: {e!s}" diff --git a/surfsense_backend/app/utils/connector_naming.py b/surfsense_backend/app/utils/connector_naming.py index 9fdec3e79..7c72e0781 100644 --- a/surfsense_backend/app/utils/connector_naming.py +++ b/surfsense_backend/app/utils/connector_naming.py @@ -21,6 +21,7 @@ BASE_NAME_FOR_TYPE = { SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "Google Calendar", SearchSourceConnectorType.SLACK_CONNECTOR: "Slack", SearchSourceConnectorType.TEAMS_CONNECTOR: "Microsoft Teams", + SearchSourceConnectorType.ONEDRIVE_CONNECTOR: "OneDrive", SearchSourceConnectorType.NOTION_CONNECTOR: "Notion", SearchSourceConnectorType.LINEAR_CONNECTOR: "Linear", SearchSourceConnectorType.JIRA_CONNECTOR: "Jira", @@ -61,6 +62,9 @@ def extract_identifier_from_credentials( if connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR: return credentials.get("tenant_name") + if connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR: + return credentials.get("user_email") + if connector_type == SearchSourceConnectorType.NOTION_CONNECTOR: return credentials.get("workspace_name")