mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-25 00:36:31 +02:00
feat: implement Microsoft OneDrive connector with OAuth support and indexing capabilities
This commit is contained in:
parent
64be61b627
commit
5bddde60cb
16 changed files with 2014 additions and 0 deletions
|
|
@ -0,0 +1,54 @@
|
|||
"""Add OneDrive connector enums
|
||||
|
||||
Revision ID: 110
|
||||
Revises: 109
|
||||
Create Date: 2026-03-28 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "110"
|
||||
down_revision: str | None = "109"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = 'searchsourceconnectortype' AND e.enumlabel = 'ONEDRIVE_CONNECTOR'
|
||||
) THEN
|
||||
ALTER TYPE searchsourceconnectortype ADD VALUE 'ONEDRIVE_CONNECTOR';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = 'documenttype' AND e.enumlabel = 'ONEDRIVE_FILE'
|
||||
) THEN
|
||||
ALTER TYPE documenttype ADD VALUE 'ONEDRIVE_FILE';
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
|
|
@ -360,6 +360,7 @@ _INTERNAL_METADATA_KEYS: frozenset[str] = frozenset(
|
|||
"event_id",
|
||||
"calendar_id",
|
||||
"google_drive_file_id",
|
||||
"onedrive_file_id",
|
||||
"page_id",
|
||||
"issue_id",
|
||||
"connector_id",
|
||||
|
|
|
|||
|
|
@ -286,6 +286,11 @@ class Config:
|
|||
TEAMS_CLIENT_SECRET = os.getenv("TEAMS_CLIENT_SECRET")
|
||||
TEAMS_REDIRECT_URI = os.getenv("TEAMS_REDIRECT_URI")
|
||||
|
||||
# Microsoft OneDrive OAuth
|
||||
ONEDRIVE_CLIENT_ID = os.getenv("ONEDRIVE_CLIENT_ID")
|
||||
ONEDRIVE_CLIENT_SECRET = os.getenv("ONEDRIVE_CLIENT_SECRET")
|
||||
ONEDRIVE_REDIRECT_URI = os.getenv("ONEDRIVE_REDIRECT_URI")
|
||||
|
||||
# ClickUp OAuth
|
||||
CLICKUP_CLIENT_ID = os.getenv("CLICKUP_CLIENT_ID")
|
||||
CLICKUP_CLIENT_SECRET = os.getenv("CLICKUP_CLIENT_SECRET")
|
||||
|
|
|
|||
13
surfsense_backend/app/connectors/onedrive/__init__.py
Normal file
13
surfsense_backend/app/connectors/onedrive/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""Microsoft OneDrive Connector Module."""
|
||||
|
||||
from .client import OneDriveClient
|
||||
from .content_extractor import download_and_extract_content
|
||||
from .folder_manager import get_file_by_id, get_files_in_folder, list_folder_contents
|
||||
|
||||
__all__ = [
|
||||
"OneDriveClient",
|
||||
"download_and_extract_content",
|
||||
"get_file_by_id",
|
||||
"get_files_in_folder",
|
||||
"list_folder_contents",
|
||||
]
|
||||
276
surfsense_backend/app/connectors/onedrive/client.py
Normal file
276
surfsense_backend/app/connectors/onedrive/client.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
"""Microsoft OneDrive API client using Microsoft Graph API v1.0."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.db import SearchSourceConnector
|
||||
from app.utils.oauth_security import TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
|
||||
TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
|
||||
|
||||
class OneDriveClient:
|
||||
"""Client for Microsoft OneDrive via the Graph API."""
|
||||
|
||||
def __init__(self, session: AsyncSession, connector_id: int):
|
||||
self._session = session
|
||||
self._connector_id = connector_id
|
||||
|
||||
async def _get_valid_token(self) -> str:
|
||||
"""Get a valid access token, refreshing if needed."""
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise ValueError(f"Connector {self._connector_id} not found")
|
||||
|
||||
cfg = connector.config or {}
|
||||
is_encrypted = cfg.get("_token_encrypted", False)
|
||||
token_encryption = TokenEncryption(config.SECRET_KEY) if config.SECRET_KEY else None
|
||||
|
||||
access_token = cfg.get("access_token", "")
|
||||
refresh_token = cfg.get("refresh_token")
|
||||
|
||||
if is_encrypted and token_encryption:
|
||||
if access_token:
|
||||
access_token = token_encryption.decrypt_token(access_token)
|
||||
if refresh_token:
|
||||
refresh_token = token_encryption.decrypt_token(refresh_token)
|
||||
|
||||
expires_at_str = cfg.get("expires_at")
|
||||
is_expired = False
|
||||
if expires_at_str:
|
||||
expires_at = datetime.fromisoformat(expires_at_str)
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
is_expired = expires_at <= datetime.now(UTC)
|
||||
|
||||
if not is_expired and access_token:
|
||||
return access_token
|
||||
|
||||
if not refresh_token:
|
||||
cfg["auth_expired"] = True
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
raise ValueError("OneDrive token expired and no refresh token available")
|
||||
|
||||
token_data = await self._refresh_token(refresh_token)
|
||||
|
||||
new_access = token_data["access_token"]
|
||||
new_refresh = token_data.get("refresh_token", refresh_token)
|
||||
expires_in = token_data.get("expires_in")
|
||||
|
||||
new_expires_at = None
|
||||
if expires_in:
|
||||
new_expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in))
|
||||
|
||||
if token_encryption:
|
||||
cfg["access_token"] = token_encryption.encrypt_token(new_access)
|
||||
cfg["refresh_token"] = token_encryption.encrypt_token(new_refresh)
|
||||
else:
|
||||
cfg["access_token"] = new_access
|
||||
cfg["refresh_token"] = new_refresh
|
||||
|
||||
cfg["expires_at"] = new_expires_at.isoformat() if new_expires_at else None
|
||||
cfg["expires_in"] = expires_in
|
||||
cfg["_token_encrypted"] = bool(token_encryption)
|
||||
cfg.pop("auth_expired", None)
|
||||
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
|
||||
return new_access
|
||||
|
||||
async def _refresh_token(self, refresh_token: str) -> dict:
|
||||
data = {
|
||||
"client_id": config.ONEDRIVE_CLIENT_ID,
|
||||
"client_secret": config.ONEDRIVE_CLIENT_SECRET,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"scope": "offline_access User.Read Files.Read.All Files.ReadWrite.All",
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
TOKEN_URL,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=30.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
error_detail = resp.text
|
||||
try:
|
||||
error_json = resp.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise ValueError(f"OneDrive token refresh failed: {error_detail}")
|
||||
return resp.json()
|
||||
|
||||
async def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
|
||||
"""Make an authenticated request to the Graph API."""
|
||||
token = await self._get_valid_token()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
if "headers" in kwargs:
|
||||
headers.update(kwargs.pop("headers"))
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.request(
|
||||
method,
|
||||
f"{GRAPH_API_BASE}{path}",
|
||||
headers=headers,
|
||||
timeout=60.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if resp.status_code == 401:
|
||||
result = await self._session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == self._connector_id
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if connector:
|
||||
cfg = connector.config or {}
|
||||
cfg["auth_expired"] = True
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await self._session.commit()
|
||||
raise ValueError("OneDrive authentication expired (401)")
|
||||
|
||||
return resp
|
||||
|
||||
async def list_children(
|
||||
self, item_id: str = "root"
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
all_items: list[dict[str, Any]] = []
|
||||
url = f"/me/drive/items/{item_id}/children"
|
||||
params: dict[str, Any] = {
|
||||
"$top": 200,
|
||||
"$select": "id,name,size,file,folder,parentReference,lastModifiedDateTime,createdDateTime,webUrl,remoteItem,package",
|
||||
}
|
||||
while url:
|
||||
resp = await self._request("GET", url, params=params)
|
||||
if resp.status_code != 200:
|
||||
return [], f"Failed to list children: {resp.status_code} - {resp.text}"
|
||||
data = resp.json()
|
||||
all_items.extend(data.get("value", []))
|
||||
next_link = data.get("@odata.nextLink")
|
||||
if next_link:
|
||||
url = next_link.replace(GRAPH_API_BASE, "")
|
||||
params = {}
|
||||
else:
|
||||
url = ""
|
||||
return all_items, None
|
||||
|
||||
async def get_item_metadata(
|
||||
self, item_id: str
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
resp = await self._request(
|
||||
"GET",
|
||||
f"/me/drive/items/{item_id}",
|
||||
params={
|
||||
"$select": "id,name,size,file,folder,parentReference,lastModifiedDateTime,createdDateTime,webUrl"
|
||||
},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return None, f"Failed to get item: {resp.status_code} - {resp.text}"
|
||||
return resp.json(), None
|
||||
|
||||
async def download_file(self, item_id: str) -> tuple[bytes | None, str | None]:
|
||||
token = await self._get_valid_token()
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
resp = await client.get(
|
||||
f"{GRAPH_API_BASE}/me/drive/items/{item_id}/content",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=120.0,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return None, f"Download failed: {resp.status_code}"
|
||||
return resp.content, None
|
||||
|
||||
async def download_file_to_disk(self, item_id: str, dest_path: str) -> str | None:
|
||||
"""Stream file content to disk. Returns error message on failure."""
|
||||
token = await self._get_valid_token()
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
async with client.stream(
|
||||
"GET",
|
||||
f"{GRAPH_API_BASE}/me/drive/items/{item_id}/content",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=120.0,
|
||||
) as resp:
|
||||
if resp.status_code != 200:
|
||||
return f"Download failed: {resp.status_code}"
|
||||
with open(dest_path, "wb") as f:
|
||||
async for chunk in resp.aiter_bytes(chunk_size=5 * 1024 * 1024):
|
||||
f.write(chunk)
|
||||
return None
|
||||
|
||||
async def create_file(
|
||||
self,
|
||||
name: str,
|
||||
parent_id: str | None = None,
|
||||
content: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create (upload) a file in OneDrive."""
|
||||
folder_path = f"/me/drive/items/{parent_id or 'root'}"
|
||||
body = (content or "").encode("utf-8")
|
||||
resp = await self._request(
|
||||
"PUT",
|
||||
f"{folder_path}:/{name}:/content",
|
||||
content=body,
|
||||
headers={"Content-Type": mime_type or "application/octet-stream"},
|
||||
)
|
||||
if resp.status_code not in (200, 201):
|
||||
raise ValueError(f"File creation failed: {resp.status_code} - {resp.text}")
|
||||
return resp.json()
|
||||
|
||||
async def trash_file(self, item_id: str) -> bool:
|
||||
"""Delete (move to recycle bin) a OneDrive item."""
|
||||
resp = await self._request("DELETE", f"/me/drive/items/{item_id}")
|
||||
if resp.status_code not in (200, 204):
|
||||
raise ValueError(f"Trash failed: {resp.status_code} - {resp.text}")
|
||||
return True
|
||||
|
||||
async def get_delta(
|
||||
self, folder_id: str | None = None, delta_link: str | None = None
|
||||
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
||||
"""Get delta changes. Returns (changes, new_delta_link, error)."""
|
||||
all_changes: list[dict[str, Any]] = []
|
||||
if delta_link:
|
||||
url = delta_link.replace(GRAPH_API_BASE, "")
|
||||
elif folder_id:
|
||||
url = f"/me/drive/items/{folder_id}/delta"
|
||||
else:
|
||||
url = "/me/drive/root/delta"
|
||||
|
||||
params: dict[str, Any] = {"$top": 200}
|
||||
while url:
|
||||
resp = await self._request("GET", url, params=params)
|
||||
if resp.status_code != 200:
|
||||
return [], None, f"Delta failed: {resp.status_code} - {resp.text}"
|
||||
data = resp.json()
|
||||
all_changes.extend(data.get("value", []))
|
||||
next_link = data.get("@odata.nextLink")
|
||||
new_delta_link = data.get("@odata.deltaLink")
|
||||
if next_link:
|
||||
url = next_link.replace(GRAPH_API_BASE, "")
|
||||
params = {}
|
||||
else:
|
||||
url = ""
|
||||
return all_changes, new_delta_link, None
|
||||
169
surfsense_backend/app/connectors/onedrive/content_extractor.py
Normal file
169
surfsense_backend/app/connectors/onedrive/content_extractor.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
"""Content extraction for OneDrive files.
|
||||
|
||||
Reuses the same ETL parsing logic as Google Drive since file parsing is
|
||||
extension-based, not provider-specific.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .client import OneDriveClient
|
||||
from .file_types import get_extension_from_mime, should_skip_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def download_and_extract_content(
|
||||
client: OneDriveClient,
|
||||
file: dict[str, Any],
|
||||
) -> tuple[str | None, dict[str, Any], str | None]:
|
||||
"""Download a OneDrive file and extract its content as markdown.
|
||||
|
||||
Returns (markdown_content, onedrive_metadata, error_message).
|
||||
"""
|
||||
item_id = file.get("id")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
if should_skip_file(file):
|
||||
return None, {}, "Skipping non-indexable item"
|
||||
|
||||
file_info = file.get("file", {})
|
||||
mime_type = file_info.get("mimeType", "")
|
||||
|
||||
logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})")
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"onedrive_file_id": item_id,
|
||||
"onedrive_file_name": file_name,
|
||||
"onedrive_mime_type": mime_type,
|
||||
"source_connector": "onedrive",
|
||||
}
|
||||
if "lastModifiedDateTime" in file:
|
||||
metadata["modified_time"] = file["lastModifiedDateTime"]
|
||||
if "createdDateTime" in file:
|
||||
metadata["created_time"] = file["createdDateTime"]
|
||||
if "size" in file:
|
||||
metadata["file_size"] = file["size"]
|
||||
if "webUrl" in file:
|
||||
metadata["web_url"] = file["webUrl"]
|
||||
file_hashes = file_info.get("hashes", {})
|
||||
if file_hashes.get("sha256Hash"):
|
||||
metadata["sha256_hash"] = file_hashes["sha256Hash"]
|
||||
elif file_hashes.get("quickXorHash"):
|
||||
metadata["quick_xor_hash"] = file_hashes["quickXorHash"]
|
||||
|
||||
temp_file_path = None
|
||||
try:
|
||||
extension = Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin"
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
|
||||
temp_file_path = tmp.name
|
||||
|
||||
error = await client.download_file_to_disk(item_id, temp_file_path)
|
||||
if error:
|
||||
return None, metadata, error
|
||||
|
||||
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
|
||||
return markdown, metadata, None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract content from {file_name}: {e!s}")
|
||||
return None, metadata, str(e)
|
||||
finally:
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.unlink(temp_file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
||||
"""Parse a local file to markdown using the configured ETL service.
|
||||
|
||||
Same logic as Google Drive -- file parsing is extension-based.
|
||||
"""
|
||||
lower = filename.lower()
|
||||
|
||||
if lower.endswith((".md", ".markdown", ".txt")):
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
|
||||
from app.config import config as app_config
|
||||
from litellm import atranscription
|
||||
|
||||
stt_service_type = (
|
||||
"local"
|
||||
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
|
||||
else "external"
|
||||
)
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[local-stt] START file={filename} thread={threading.current_thread().name}")
|
||||
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
|
||||
logger.info(f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
|
||||
text = result.get("text", "")
|
||||
else:
|
||||
with open(file_path, "rb") as audio_file:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": app_config.STT_SERVICE,
|
||||
"file": audio_file,
|
||||
"api_key": app_config.STT_SERVICE_API_KEY,
|
||||
}
|
||||
if app_config.STT_SERVICE_API_BASE:
|
||||
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
|
||||
resp = await atranscription(**kwargs)
|
||||
text = resp.get("text", "")
|
||||
|
||||
if not text:
|
||||
raise ValueError("Transcription returned empty text")
|
||||
return f"# Transcription of {filename}\n\n{text}"
|
||||
|
||||
from app.config import config as app_config
|
||||
|
||||
if app_config.ETL_SERVICE == "UNSTRUCTURED":
|
||||
from langchain_unstructured import UnstructuredLoader
|
||||
|
||||
from app.utils.document_converters import convert_document_to_markdown
|
||||
|
||||
loader = UnstructuredLoader(
|
||||
file_path,
|
||||
mode="elements",
|
||||
post_processors=[],
|
||||
languages=["eng"],
|
||||
include_orig_elements=False,
|
||||
include_metadata=False,
|
||||
strategy="auto",
|
||||
)
|
||||
docs = await loader.aload()
|
||||
return await convert_document_to_markdown(docs)
|
||||
|
||||
if app_config.ETL_SERVICE == "LLAMACLOUD":
|
||||
from app.tasks.document_processors.file_processors import (
|
||||
parse_with_llamacloud_retry,
|
||||
)
|
||||
|
||||
result = await parse_with_llamacloud_retry(file_path=file_path, estimated_pages=50)
|
||||
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
|
||||
if not markdown_documents:
|
||||
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
|
||||
return markdown_documents[0].text
|
||||
|
||||
if app_config.ETL_SERVICE == "DOCLING":
|
||||
from docling.document_converter import DocumentConverter
|
||||
|
||||
converter = DocumentConverter()
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[docling] START file={filename} thread={threading.current_thread().name}")
|
||||
result = await asyncio.to_thread(converter.convert, file_path)
|
||||
logger.info(f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
|
||||
return result.document.export_to_markdown()
|
||||
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
||||
50
surfsense_backend/app/connectors/onedrive/file_types.py
Normal file
50
surfsense_backend/app/connectors/onedrive/file_types.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""File type handlers for Microsoft OneDrive."""
|
||||
|
||||
ONEDRIVE_FOLDER_FACET = "folder"
|
||||
ONENOTE_MIME = "application/msonenote"
|
||||
|
||||
SKIP_MIME_TYPES = frozenset(
|
||||
{
|
||||
ONENOTE_MIME,
|
||||
"application/vnd.ms-onenotesection",
|
||||
"application/vnd.ms-onenotenotebook",
|
||||
}
|
||||
)
|
||||
|
||||
MIME_TO_EXTENSION: dict[str, str] = {
|
||||
"application/pdf": ".pdf",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
|
||||
"application/vnd.ms-excel": ".xls",
|
||||
"application/msword": ".doc",
|
||||
"application/vnd.ms-powerpoint": ".ppt",
|
||||
"text/plain": ".txt",
|
||||
"text/csv": ".csv",
|
||||
"text/html": ".html",
|
||||
"text/markdown": ".md",
|
||||
"application/json": ".json",
|
||||
"application/xml": ".xml",
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
}
|
||||
|
||||
|
||||
def get_extension_from_mime(mime_type: str) -> str | None:
|
||||
return MIME_TO_EXTENSION.get(mime_type)
|
||||
|
||||
|
||||
def is_folder(item: dict) -> bool:
|
||||
return ONEDRIVE_FOLDER_FACET in item
|
||||
|
||||
|
||||
def should_skip_file(item: dict) -> bool:
|
||||
"""Skip folders, OneNote files, remote items (shared links), and packages."""
|
||||
if is_folder(item):
|
||||
return True
|
||||
if "remoteItem" in item:
|
||||
return True
|
||||
if "package" in item:
|
||||
return True
|
||||
mime = item.get("file", {}).get("mimeType", "")
|
||||
return mime in SKIP_MIME_TYPES
|
||||
90
surfsense_backend/app/connectors/onedrive/folder_manager.py
Normal file
90
surfsense_backend/app/connectors/onedrive/folder_manager.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""Folder management for Microsoft OneDrive."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .client import OneDriveClient
|
||||
from .file_types import is_folder, should_skip_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def list_folder_contents(
|
||||
client: OneDriveClient,
|
||||
parent_id: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""List folders and files in a OneDrive folder.
|
||||
|
||||
Returns (items list with folders first, error message).
|
||||
"""
|
||||
try:
|
||||
items, error = await client.list_children(parent_id or "root")
|
||||
if error:
|
||||
return [], error
|
||||
|
||||
for item in items:
|
||||
item["isFolder"] = is_folder(item)
|
||||
|
||||
items.sort(key=lambda x: (not x["isFolder"], x.get("name", "").lower()))
|
||||
|
||||
folder_count = sum(1 for item in items if item["isFolder"])
|
||||
file_count = len(items) - folder_count
|
||||
logger.info(
|
||||
f"Listed {len(items)} items ({folder_count} folders, {file_count} files) "
|
||||
+ (f"in folder {parent_id}" if parent_id else "in root")
|
||||
)
|
||||
return items, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing folder contents: {e!s}", exc_info=True)
|
||||
return [], f"Error listing folder contents: {e!s}"
|
||||
|
||||
|
||||
async def get_files_in_folder(
|
||||
client: OneDriveClient,
|
||||
folder_id: str,
|
||||
include_subfolders: bool = True,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Get all indexable files in a folder, optionally recursing into subfolders."""
|
||||
try:
|
||||
items, error = await client.list_children(folder_id)
|
||||
if error:
|
||||
return [], error
|
||||
|
||||
files: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
if is_folder(item):
|
||||
if include_subfolders:
|
||||
sub_files, sub_error = await get_files_in_folder(
|
||||
client, item["id"], include_subfolders=True
|
||||
)
|
||||
if sub_error:
|
||||
logger.warning(f"Error recursing into folder {item.get('name')}: {sub_error}")
|
||||
continue
|
||||
files.extend(sub_files)
|
||||
elif not should_skip_file(item):
|
||||
files.append(item)
|
||||
|
||||
return files, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting files in folder: {e!s}", exc_info=True)
|
||||
return [], f"Error getting files in folder: {e!s}"
|
||||
|
||||
|
||||
async def get_file_by_id(
|
||||
client: OneDriveClient,
|
||||
file_id: str,
|
||||
) -> tuple[dict[str, Any] | None, str | None]:
|
||||
"""Get file metadata by ID."""
|
||||
try:
|
||||
item, error = await client.get_item_metadata(file_id)
|
||||
if error:
|
||||
return None, error
|
||||
if not item:
|
||||
return None, f"File not found: {file_id}"
|
||||
return item, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting file by ID: {e!s}", exc_info=True)
|
||||
return None, f"Error getting file by ID: {e!s}"
|
||||
|
|
@ -40,6 +40,7 @@ class DocumentType(StrEnum):
|
|||
FILE = "FILE"
|
||||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||||
TEAMS_CONNECTOR = "TEAMS_CONNECTOR"
|
||||
ONEDRIVE_FILE = "ONEDRIVE_FILE"
|
||||
NOTION_CONNECTOR = "NOTION_CONNECTOR"
|
||||
YOUTUBE_VIDEO = "YOUTUBE_VIDEO"
|
||||
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
||||
|
|
@ -81,6 +82,7 @@ class SearchSourceConnectorType(StrEnum):
|
|||
BAIDU_SEARCH_API = "BAIDU_SEARCH_API" # Baidu AI Search API for Chinese web search
|
||||
SLACK_CONNECTOR = "SLACK_CONNECTOR"
|
||||
TEAMS_CONNECTOR = "TEAMS_CONNECTOR"
|
||||
ONEDRIVE_CONNECTOR = "ONEDRIVE_CONNECTOR"
|
||||
NOTION_CONNECTOR = "NOTION_CONNECTOR"
|
||||
GITHUB_CONNECTOR = "GITHUB_CONNECTOR"
|
||||
LINEAR_CONNECTOR = "LINEAR_CONNECTOR"
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from .search_spaces_routes import router as search_spaces_router
|
|||
from .slack_add_connector_route import router as slack_add_connector_router
|
||||
from .surfsense_docs_routes import router as surfsense_docs_router
|
||||
from .teams_add_connector_route import router as teams_add_connector_router
|
||||
from .onedrive_add_connector_route import router as onedrive_add_connector_router
|
||||
from .video_presentations_routes import router as video_presentations_router
|
||||
from .youtube_routes import router as youtube_router
|
||||
|
||||
|
|
@ -73,6 +74,7 @@ router.include_router(luma_add_connector_router)
|
|||
router.include_router(notion_add_connector_router)
|
||||
router.include_router(slack_add_connector_router)
|
||||
router.include_router(teams_add_connector_router)
|
||||
router.include_router(onedrive_add_connector_router)
|
||||
router.include_router(discord_add_connector_router)
|
||||
router.include_router(jira_add_connector_router)
|
||||
router.include_router(confluence_add_connector_router)
|
||||
|
|
|
|||
474
surfsense_backend/app/routes/onedrive_add_connector_route.py
Normal file
474
surfsense_backend/app/routes/onedrive_add_connector_route.py
Normal file
|
|
@ -0,0 +1,474 @@
|
|||
"""
|
||||
Microsoft OneDrive Connector OAuth Routes.
|
||||
|
||||
Endpoints:
|
||||
- GET /auth/onedrive/connector/add - Initiate OAuth
|
||||
- GET /auth/onedrive/connector/callback - Handle OAuth callback
|
||||
- GET /auth/onedrive/connector/reauth - Re-authenticate existing connector
|
||||
- GET /connectors/{connector_id}/onedrive/folders - List folder contents
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.onedrive import OneDriveClient, list_folder_contents
|
||||
from app.db import (
|
||||
SearchSourceConnector,
|
||||
SearchSourceConnectorType,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.connector_naming import (
|
||||
check_duplicate_connector,
|
||||
extract_identifier_from_credentials,
|
||||
generate_unique_connector_name,
|
||||
)
|
||||
from app.utils.oauth_security import OAuthStateManager, TokenEncryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
AUTHORIZATION_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
|
||||
TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
||||
|
||||
SCOPES = [
|
||||
"offline_access",
|
||||
"User.Read",
|
||||
"Files.Read.All",
|
||||
"Files.ReadWrite.All",
|
||||
]
|
||||
|
||||
_state_manager = None
|
||||
_token_encryption = None
|
||||
|
||||
|
||||
def get_state_manager() -> OAuthStateManager:
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for OAuth security")
|
||||
_state_manager = OAuthStateManager(config.SECRET_KEY)
|
||||
return _state_manager
|
||||
|
||||
|
||||
def get_token_encryption() -> TokenEncryption:
|
||||
global _token_encryption
|
||||
if _token_encryption is None:
|
||||
if not config.SECRET_KEY:
|
||||
raise ValueError("SECRET_KEY must be set for token encryption")
|
||||
_token_encryption = TokenEncryption(config.SECRET_KEY)
|
||||
return _token_encryption
|
||||
|
||||
|
||||
@router.get("/auth/onedrive/connector/add")
|
||||
async def connect_onedrive(space_id: int, user: User = Depends(current_active_user)):
|
||||
"""Initiate OneDrive OAuth flow."""
|
||||
try:
|
||||
if not space_id:
|
||||
raise HTTPException(status_code=400, detail="space_id is required")
|
||||
if not config.ONEDRIVE_CLIENT_ID:
|
||||
raise HTTPException(status_code=500, detail="Microsoft OneDrive OAuth not configured.")
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(status_code=500, detail="SECRET_KEY not configured for OAuth security.")
|
||||
|
||||
state_manager = get_state_manager()
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id)
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.ONEDRIVE_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": config.ONEDRIVE_REDIRECT_URI,
|
||||
"response_mode": "query",
|
||||
"scope": " ".join(SCOPES),
|
||||
"state": state_encoded,
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info("Generated OneDrive OAuth URL for user %s, space %s", user.id, space_id)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to initiate OneDrive OAuth: %s", str(e), exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to initiate OneDrive OAuth: {e!s}") from e
|
||||
|
||||
|
||||
@router.get("/auth/onedrive/connector/reauth")
|
||||
async def reauth_onedrive(
|
||||
space_id: int,
|
||||
connector_id: int,
|
||||
return_url: str | None = None,
|
||||
user: User = Depends(current_active_user),
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Re-authenticate an existing OneDrive connector."""
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(status_code=404, detail="OneDrive connector not found or access denied")
|
||||
|
||||
if not config.SECRET_KEY:
|
||||
raise HTTPException(status_code=500, detail="SECRET_KEY not configured for OAuth security.")
|
||||
|
||||
state_manager = get_state_manager()
|
||||
extra: dict = {"connector_id": connector_id}
|
||||
if return_url and return_url.startswith("/"):
|
||||
extra["return_url"] = return_url
|
||||
state_encoded = state_manager.generate_secure_state(space_id, user.id, **extra)
|
||||
|
||||
auth_params = {
|
||||
"client_id": config.ONEDRIVE_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": config.ONEDRIVE_REDIRECT_URI,
|
||||
"response_mode": "query",
|
||||
"scope": " ".join(SCOPES),
|
||||
"state": state_encoded,
|
||||
"prompt": "consent",
|
||||
}
|
||||
auth_url = f"{AUTHORIZATION_URL}?{urlencode(auth_params)}"
|
||||
|
||||
logger.info("Initiating OneDrive re-auth for user %s, connector %s", user.id, connector_id)
|
||||
return {"auth_url": auth_url}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to initiate OneDrive re-auth: %s", str(e), exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to initiate OneDrive re-auth: {e!s}") from e
|
||||
|
||||
|
||||
@router.get("/auth/onedrive/connector/callback")
|
||||
async def onedrive_callback(
|
||||
code: str | None = None,
|
||||
error: str | None = None,
|
||||
error_description: str | None = None,
|
||||
state: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Handle OneDrive OAuth callback."""
|
||||
try:
|
||||
if error:
|
||||
error_msg = error_description or error
|
||||
logger.warning("OneDrive OAuth error: %s", error_msg)
|
||||
space_id = None
|
||||
if state:
|
||||
try:
|
||||
data = get_state_manager().validate_state(state)
|
||||
space_id = data.get("space_id")
|
||||
except Exception:
|
||||
pass
|
||||
if space_id:
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=onedrive_oauth_denied"
|
||||
)
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=onedrive_oauth_denied")
|
||||
|
||||
if not code or not state:
|
||||
raise HTTPException(status_code=400, detail="Missing required OAuth parameters")
|
||||
|
||||
state_manager = get_state_manager()
|
||||
try:
|
||||
data = state_manager.validate_state(state)
|
||||
space_id = data["space_id"]
|
||||
user_id = UUID(data["user_id"])
|
||||
except (HTTPException, ValueError, KeyError) as e:
|
||||
logger.error("Invalid OAuth state: %s", str(e))
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=invalid_state")
|
||||
|
||||
reauth_connector_id = data.get("connector_id")
|
||||
reauth_return_url = data.get("return_url")
|
||||
|
||||
token_data = {
|
||||
"client_id": config.ONEDRIVE_CLIENT_ID,
|
||||
"client_secret": config.ONEDRIVE_CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": config.ONEDRIVE_REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
TOKEN_URL,
|
||||
data=token_data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=400, detail=f"Token exchange failed: {error_detail}")
|
||||
|
||||
token_json = token_response.json()
|
||||
access_token = token_json.get("access_token")
|
||||
refresh_token = token_json.get("refresh_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=400, detail="No access token received from Microsoft")
|
||||
|
||||
token_encryption = get_token_encryption()
|
||||
|
||||
expires_at = None
|
||||
if token_json.get("expires_in"):
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=int(token_json["expires_in"]))
|
||||
|
||||
user_info: dict = {}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
user_response = await client.get(
|
||||
"https://graph.microsoft.com/v1.0/me",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
timeout=30.0,
|
||||
)
|
||||
if user_response.status_code == 200:
|
||||
user_data = user_response.json()
|
||||
user_info = {
|
||||
"user_email": user_data.get("mail") or user_data.get("userPrincipalName"),
|
||||
"user_name": user_data.get("displayName"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch user info from Graph: %s", str(e))
|
||||
|
||||
connector_config = {
|
||||
"access_token": token_encryption.encrypt_token(access_token),
|
||||
"refresh_token": token_encryption.encrypt_token(refresh_token) if refresh_token else None,
|
||||
"token_type": token_json.get("token_type", "Bearer"),
|
||||
"expires_in": token_json.get("expires_in"),
|
||||
"expires_at": expires_at.isoformat() if expires_at else None,
|
||||
"scope": token_json.get("scope"),
|
||||
"user_email": user_info.get("user_email"),
|
||||
"user_name": user_info.get("user_name"),
|
||||
"_token_encrypted": True,
|
||||
}
|
||||
|
||||
# Handle re-authentication
|
||||
if reauth_connector_id:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == reauth_connector_id,
|
||||
SearchSourceConnector.user_id == user_id,
|
||||
SearchSourceConnector.search_space_id == space_id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
db_connector = result.scalars().first()
|
||||
if not db_connector:
|
||||
raise HTTPException(status_code=404, detail="Connector not found or access denied during re-auth")
|
||||
|
||||
existing_delta_link = db_connector.config.get("delta_link")
|
||||
db_connector.config = {**connector_config, "delta_link": existing_delta_link, "auth_expired": False}
|
||||
flag_modified(db_connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(db_connector)
|
||||
|
||||
logger.info("Re-authenticated OneDrive connector %s for user %s", db_connector.id, user_id)
|
||||
if reauth_return_url and reauth_return_url.startswith("/"):
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}{reauth_return_url}")
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=ONEDRIVE_CONNECTOR&connectorId={db_connector.id}"
|
||||
)
|
||||
|
||||
# New connector -- check for duplicates
|
||||
connector_identifier = extract_identifier_from_credentials(
|
||||
SearchSourceConnectorType.ONEDRIVE_CONNECTOR, connector_config
|
||||
)
|
||||
is_duplicate = await check_duplicate_connector(
|
||||
session, SearchSourceConnectorType.ONEDRIVE_CONNECTOR, space_id, user_id, connector_identifier,
|
||||
)
|
||||
if is_duplicate:
|
||||
logger.warning("Duplicate OneDrive connector for user %s, space %s", user_id, space_id)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?error=duplicate_account&connector=ONEDRIVE_CONNECTOR"
|
||||
)
|
||||
|
||||
connector_name = await generate_unique_connector_name(
|
||||
session, SearchSourceConnectorType.ONEDRIVE_CONNECTOR, space_id, user_id, connector_identifier,
|
||||
)
|
||||
|
||||
new_connector = SearchSourceConnector(
|
||||
name=connector_name,
|
||||
connector_type=SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
is_indexable=True,
|
||||
config=connector_config,
|
||||
search_space_id=space_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
try:
|
||||
session.add(new_connector)
|
||||
await session.commit()
|
||||
await session.refresh(new_connector)
|
||||
logger.info("Successfully created OneDrive connector %s for user %s", new_connector.id, user_id)
|
||||
return RedirectResponse(
|
||||
url=f"{config.NEXT_FRONTEND_URL}/dashboard/{space_id}/connectors/callback?success=true&connector=ONEDRIVE_CONNECTOR&connectorId={new_connector.id}"
|
||||
)
|
||||
except IntegrityError as e:
|
||||
await session.rollback()
|
||||
logger.error("Database integrity error creating OneDrive connector: %s", str(e))
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=connector_creation_failed")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except (IntegrityError, ValueError) as e:
|
||||
logger.error("OneDrive OAuth callback error: %s", str(e), exc_info=True)
|
||||
return RedirectResponse(url=f"{config.NEXT_FRONTEND_URL}/dashboard?error=onedrive_auth_error")
|
||||
|
||||
|
||||
@router.get("/connectors/{connector_id}/onedrive/folders")
|
||||
async def list_onedrive_folders(
|
||||
connector_id: int,
|
||||
parent_id: str | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""List folders and files in user's OneDrive."""
|
||||
connector = None
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(SearchSourceConnector).filter(
|
||||
SearchSourceConnector.id == connector_id,
|
||||
SearchSourceConnector.user_id == user.id,
|
||||
SearchSourceConnector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR,
|
||||
)
|
||||
)
|
||||
connector = result.scalars().first()
|
||||
if not connector:
|
||||
raise HTTPException(status_code=404, detail="OneDrive connector not found or access denied")
|
||||
|
||||
onedrive_client = OneDriveClient(session, connector_id)
|
||||
items, error = await list_folder_contents(onedrive_client, parent_id=parent_id)
|
||||
|
||||
if error:
|
||||
error_lower = error.lower()
|
||||
if "401" in error or "authentication expired" in error_lower or "invalid_grant" in error_lower:
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
except Exception:
|
||||
logger.warning("Failed to persist auth_expired for connector %s", connector_id, exc_info=True)
|
||||
raise HTTPException(status_code=400, detail="OneDrive authentication expired. Please re-authenticate.")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list folder contents: {error}")
|
||||
|
||||
return {"items": items}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error listing OneDrive contents: %s", str(e), exc_info=True)
|
||||
error_lower = str(e).lower()
|
||||
if "401" in str(e) or "authentication expired" in error_lower:
|
||||
try:
|
||||
if connector and not connector.config.get("auth_expired"):
|
||||
connector.config = {**connector.config, "auth_expired": True}
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=400, detail="OneDrive authentication expired. Please re-authenticate.") from e
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list OneDrive contents: {e!s}") from e
|
||||
|
||||
|
||||
async def refresh_onedrive_token(
|
||||
session: AsyncSession, connector: SearchSourceConnector
|
||||
) -> SearchSourceConnector:
|
||||
"""Refresh OneDrive OAuth tokens."""
|
||||
logger.info("Refreshing OneDrive OAuth tokens for connector %s", connector.id)
|
||||
|
||||
token_encryption = get_token_encryption()
|
||||
is_encrypted = connector.config.get("_token_encrypted", False)
|
||||
refresh_token = connector.config.get("refresh_token")
|
||||
|
||||
if is_encrypted and refresh_token:
|
||||
try:
|
||||
refresh_token = token_encryption.decrypt_token(refresh_token)
|
||||
except Exception as e:
|
||||
logger.error("Failed to decrypt refresh token: %s", str(e))
|
||||
raise HTTPException(status_code=500, detail="Failed to decrypt stored refresh token") from e
|
||||
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=400, detail=f"No refresh token available for connector {connector.id}")
|
||||
|
||||
refresh_data = {
|
||||
"client_id": config.ONEDRIVE_CLIENT_ID,
|
||||
"client_secret": config.ONEDRIVE_CLIENT_SECRET,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"scope": " ".join(SCOPES),
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(
|
||||
TOKEN_URL, data=refresh_data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"}, timeout=30.0,
|
||||
)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
error_detail = token_response.text
|
||||
error_code = ""
|
||||
try:
|
||||
error_json = token_response.json()
|
||||
error_detail = error_json.get("error_description", error_detail)
|
||||
error_code = error_json.get("error", "")
|
||||
except Exception:
|
||||
pass
|
||||
error_lower = (error_detail + error_code).lower()
|
||||
if "invalid_grant" in error_lower or "expired" in error_lower or "revoked" in error_lower:
|
||||
raise HTTPException(status_code=401, detail="OneDrive authentication failed. Please re-authenticate.")
|
||||
raise HTTPException(status_code=400, detail=f"Token refresh failed: {error_detail}")
|
||||
|
||||
token_json = token_response.json()
|
||||
access_token = token_json.get("access_token")
|
||||
new_refresh_token = token_json.get("refresh_token")
|
||||
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=400, detail="No access token received from Microsoft refresh")
|
||||
|
||||
expires_at = None
|
||||
expires_in = token_json.get("expires_in")
|
||||
if expires_in:
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in))
|
||||
|
||||
cfg = dict(connector.config)
|
||||
cfg["access_token"] = token_encryption.encrypt_token(access_token)
|
||||
if new_refresh_token:
|
||||
cfg["refresh_token"] = token_encryption.encrypt_token(new_refresh_token)
|
||||
cfg["expires_in"] = expires_in
|
||||
cfg["expires_at"] = expires_at.isoformat() if expires_at else None
|
||||
cfg["scope"] = token_json.get("scope")
|
||||
cfg["_token_encrypted"] = True
|
||||
cfg.pop("auth_expired", None)
|
||||
|
||||
connector.config = cfg
|
||||
flag_modified(connector, "config")
|
||||
await session.commit()
|
||||
await session.refresh(connector)
|
||||
|
||||
logger.info("Successfully refreshed OneDrive tokens for connector %s", connector.id)
|
||||
return connector
|
||||
|
|
@ -999,6 +999,53 @@ async def index_connector_content(
|
|||
)
|
||||
response_message = "Google Drive indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_onedrive_files_task,
|
||||
)
|
||||
|
||||
if drive_items and drive_items.has_items():
|
||||
logger.info(
|
||||
f"Triggering OneDrive indexing for connector {connector_id} into search space {search_space_id}, "
|
||||
f"folders: {len(drive_items.folders)}, files: {len(drive_items.files)}"
|
||||
)
|
||||
items_dict = drive_items.model_dump()
|
||||
else:
|
||||
config = connector.config or {}
|
||||
selected_folders = config.get("selected_folders", [])
|
||||
selected_files = config.get("selected_files", [])
|
||||
if not selected_folders and not selected_files:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="OneDrive indexing requires folders or files to be configured. "
|
||||
"Please select folders/files to index.",
|
||||
)
|
||||
indexing_options = config.get(
|
||||
"indexing_options",
|
||||
{
|
||||
"max_files_per_folder": 100,
|
||||
"incremental_sync": True,
|
||||
"include_subfolders": True,
|
||||
},
|
||||
)
|
||||
items_dict = {
|
||||
"folders": selected_folders,
|
||||
"files": selected_files,
|
||||
"indexing_options": indexing_options,
|
||||
}
|
||||
logger.info(
|
||||
f"Triggering OneDrive indexing for connector {connector_id} into search space {search_space_id} "
|
||||
f"using existing config"
|
||||
)
|
||||
|
||||
index_onedrive_files_task.delay(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
str(user.id),
|
||||
items_dict,
|
||||
)
|
||||
response_message = "OneDrive indexing started in the background."
|
||||
|
||||
elif connector.connector_type == SearchSourceConnectorType.DISCORD_CONNECTOR:
|
||||
from app.tasks.celery_tasks.connector_tasks import (
|
||||
index_discord_messages_task,
|
||||
|
|
@ -2485,6 +2532,108 @@ async def run_google_drive_indexing(
|
|||
logger.error(f"Failed to update notification: {notif_error!s}")
|
||||
|
||||
|
||||
async def run_onedrive_indexing(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
):
|
||||
"""Runs the OneDrive indexing task for folders and files with notifications."""
|
||||
from uuid import UUID
|
||||
|
||||
notification = None
|
||||
try:
|
||||
from app.tasks.connector_indexers.onedrive_indexer import index_onedrive_files
|
||||
|
||||
connector_result = await session.execute(
|
||||
select(SearchSourceConnector).where(
|
||||
SearchSourceConnector.id == connector_id
|
||||
)
|
||||
)
|
||||
connector = connector_result.scalar_one_or_none()
|
||||
|
||||
if connector:
|
||||
notification = await NotificationService.connector_indexing.notify_google_drive_indexing_started(
|
||||
session=session,
|
||||
user_id=UUID(user_id),
|
||||
connector_id=connector_id,
|
||||
connector_name=connector.name,
|
||||
connector_type=connector.connector_type.value,
|
||||
search_space_id=search_space_id,
|
||||
folder_count=len(items_dict.get("folders", [])),
|
||||
file_count=len(items_dict.get("files", [])),
|
||||
folder_names=[f.get("name", "Unknown") for f in items_dict.get("folders", [])],
|
||||
file_names=[f.get("name", "Unknown") for f in items_dict.get("files", [])],
|
||||
)
|
||||
|
||||
if notification:
|
||||
await NotificationService.connector_indexing.notify_indexing_progress(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=0,
|
||||
stage="fetching",
|
||||
)
|
||||
|
||||
total_indexed, total_skipped, error_message = await index_onedrive_files(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
|
||||
if error_message:
|
||||
logger.error(
|
||||
f"OneDrive indexing completed with errors for connector {connector_id}: {error_message}"
|
||||
)
|
||||
if _is_auth_error(error_message):
|
||||
await _persist_auth_expired(session, connector_id)
|
||||
error_message = "OneDrive authentication expired. Please re-authenticate."
|
||||
else:
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_progress(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=total_indexed,
|
||||
stage="storing",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"OneDrive indexing successful for connector {connector_id}. Indexed {total_indexed} documents."
|
||||
)
|
||||
await _update_connector_timestamp_by_id(session, connector_id)
|
||||
await session.commit()
|
||||
|
||||
if notification:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=total_indexed,
|
||||
error_message=error_message,
|
||||
skipped_count=total_skipped,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Critical error in run_onedrive_indexing for connector {connector_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if notification:
|
||||
try:
|
||||
await session.refresh(notification)
|
||||
await NotificationService.connector_indexing.notify_indexing_completed(
|
||||
session=session,
|
||||
notification=notification,
|
||||
indexed_count=0,
|
||||
error_message=str(e),
|
||||
)
|
||||
except Exception as notif_error:
|
||||
logger.error(f"Failed to update notification: {notif_error!s}")
|
||||
|
||||
|
||||
# Add new helper functions for luma indexing
|
||||
async def run_luma_indexing_with_new_session(
|
||||
connector_id: int,
|
||||
|
|
|
|||
71
surfsense_backend/app/schemas/onedrive_auth_credentials.py
Normal file
71
surfsense_backend/app/schemas/onedrive_auth_credentials.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""Microsoft OneDrive OAuth credentials schema."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class OneDriveAuthCredentialsBase(BaseModel):
|
||||
"""Microsoft OneDrive OAuth credentials."""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str | None = None
|
||||
token_type: str = "Bearer"
|
||||
expires_in: int | None = None
|
||||
expires_at: datetime | None = None
|
||||
scope: str | None = None
|
||||
user_email: str | None = None
|
||||
user_name: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return self.expires_at <= datetime.now(UTC)
|
||||
|
||||
@property
|
||||
def is_refreshable(self) -> bool:
|
||||
return self.refresh_token is not None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"access_token": self.access_token,
|
||||
"refresh_token": self.refresh_token,
|
||||
"token_type": self.token_type,
|
||||
"expires_in": self.expires_in,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"scope": self.scope,
|
||||
"user_email": self.user_email,
|
||||
"user_name": self.user_name,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "OneDriveAuthCredentialsBase":
|
||||
expires_at = None
|
||||
if data.get("expires_at"):
|
||||
expires_at = datetime.fromisoformat(data["expires_at"])
|
||||
return cls(
|
||||
access_token=data.get("access_token", ""),
|
||||
refresh_token=data.get("refresh_token"),
|
||||
token_type=data.get("token_type", "Bearer"),
|
||||
expires_in=data.get("expires_in"),
|
||||
expires_at=expires_at,
|
||||
scope=data.get("scope"),
|
||||
user_email=data.get("user_email"),
|
||||
user_name=data.get("user_name"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
)
|
||||
|
||||
@field_validator("expires_at", mode="before")
|
||||
@classmethod
|
||||
def ensure_aware_utc(cls, v):
|
||||
if isinstance(v, str):
|
||||
if v.endswith("Z"):
|
||||
return datetime.fromisoformat(v.replace("Z", "+00:00"))
|
||||
dt = datetime.fromisoformat(v)
|
||||
return dt if dt.tzinfo else dt.replace(tzinfo=UTC)
|
||||
if isinstance(v, datetime):
|
||||
return v if v.tzinfo else v.replace(tzinfo=UTC)
|
||||
return v
|
||||
|
|
@ -526,6 +526,54 @@ async def _index_google_drive_files(
|
|||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_onedrive_files", bind=True)
|
||||
def index_onedrive_files_task(
|
||||
self,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
):
|
||||
"""Celery task to index OneDrive folders and files."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(
|
||||
_index_onedrive_files(
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def _index_onedrive_files(
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
):
|
||||
"""Index OneDrive folders and files with new session."""
|
||||
from app.routes.search_source_connectors_routes import (
|
||||
run_onedrive_indexing,
|
||||
)
|
||||
|
||||
async with get_celery_session_maker()() as session:
|
||||
await run_onedrive_indexing(
|
||||
session,
|
||||
connector_id,
|
||||
search_space_id,
|
||||
user_id,
|
||||
items_dict,
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(name="index_discord_messages", bind=True)
|
||||
def index_discord_messages_task(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,606 @@
|
|||
"""OneDrive indexer using the shared IndexingPipelineService.
|
||||
|
||||
File-level pre-filter (_should_skip_file) handles hash/modifiedDateTime
|
||||
checks and rename-only detection. download_and_extract_content()
|
||||
returns markdown which is fed into ConnectorDocument -> pipeline.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from sqlalchemy import String, cast, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.config import config
|
||||
from app.connectors.onedrive import (
|
||||
OneDriveClient,
|
||||
download_and_extract_content,
|
||||
get_file_by_id,
|
||||
get_files_in_folder,
|
||||
)
|
||||
from app.connectors.onedrive.file_types import should_skip_file as skip_item
|
||||
from app.db import Document, DocumentStatus, DocumentType, SearchSourceConnectorType
|
||||
from app.indexing_pipeline.connector_document import ConnectorDocument
|
||||
from app.indexing_pipeline.document_hashing import compute_identifier_hash
|
||||
from app.indexing_pipeline.indexing_pipeline_service import IndexingPipelineService
|
||||
from app.services.llm_service import get_user_long_context_llm
|
||||
from app.services.task_logging_service import TaskLoggingService
|
||||
from app.tasks.connector_indexers.base import (
|
||||
check_document_by_unique_identifier,
|
||||
get_connector_by_id,
|
||||
update_connector_last_indexed,
|
||||
)
|
||||
|
||||
HeartbeatCallbackType = Callable[[int], Awaitable[None]]
|
||||
HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _should_skip_file(
|
||||
session: AsyncSession,
|
||||
file: dict,
|
||||
search_space_id: int,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Pre-filter: detect unchanged / rename-only files."""
|
||||
file_id = file.get("id")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
if skip_item(file):
|
||||
return True, "folder/onenote/remote"
|
||||
if not file_id:
|
||||
return True, "missing file_id"
|
||||
|
||||
primary_hash = compute_identifier_hash(
|
||||
DocumentType.ONEDRIVE_FILE.value, file_id, search_space_id
|
||||
)
|
||||
existing = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
if not existing:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||
cast(Document.document_metadata["onedrive_file_id"], String) == file_id,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing:
|
||||
existing.unique_identifier_hash = primary_hash
|
||||
logger.debug(f"Found OneDrive doc by metadata for file_id: {file_id}")
|
||||
|
||||
if not existing:
|
||||
return False, None
|
||||
|
||||
incoming_mtime = file.get("lastModifiedDateTime")
|
||||
meta = existing.document_metadata or {}
|
||||
stored_mtime = meta.get("modified_time")
|
||||
|
||||
file_info = file.get("file", {})
|
||||
file_hashes = file_info.get("hashes", {})
|
||||
incoming_hash = file_hashes.get("sha256Hash") or file_hashes.get("quickXorHash")
|
||||
stored_hash = meta.get("sha256_hash") or meta.get("quick_xor_hash")
|
||||
|
||||
content_unchanged = False
|
||||
if incoming_hash and stored_hash:
|
||||
content_unchanged = incoming_hash == stored_hash
|
||||
elif incoming_hash and not stored_hash:
|
||||
return False, None
|
||||
elif not incoming_hash and incoming_mtime and stored_mtime:
|
||||
content_unchanged = incoming_mtime == stored_mtime
|
||||
elif not incoming_hash:
|
||||
return False, None
|
||||
|
||||
if not content_unchanged:
|
||||
return False, None
|
||||
|
||||
old_name = meta.get("onedrive_file_name")
|
||||
if old_name and old_name != file_name:
|
||||
existing.title = file_name
|
||||
if not existing.document_metadata:
|
||||
existing.document_metadata = {}
|
||||
existing.document_metadata["onedrive_file_name"] = file_name
|
||||
if incoming_mtime:
|
||||
existing.document_metadata["modified_time"] = incoming_mtime
|
||||
flag_modified(existing, "document_metadata")
|
||||
await session.commit()
|
||||
logger.info(f"Rename-only update: '{old_name}' -> '{file_name}'")
|
||||
return True, f"File renamed: '{old_name}' -> '{file_name}'"
|
||||
|
||||
if not DocumentStatus.is_state(existing.status, DocumentStatus.READY):
|
||||
return True, "skipped (previously failed)"
|
||||
return True, "unchanged"
|
||||
|
||||
|
||||
def _build_connector_doc(
|
||||
file: dict,
|
||||
markdown: str,
|
||||
onedrive_metadata: dict,
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
) -> ConnectorDocument:
|
||||
file_id = file.get("id", "")
|
||||
file_name = file.get("name", "Unknown")
|
||||
|
||||
metadata = {
|
||||
**onedrive_metadata,
|
||||
"connector_id": connector_id,
|
||||
"document_type": "OneDrive File",
|
||||
"connector_type": "OneDrive",
|
||||
}
|
||||
|
||||
fallback_summary = f"File: {file_name}\n\n{markdown[:4000]}"
|
||||
|
||||
return ConnectorDocument(
|
||||
title=file_name,
|
||||
source_markdown=markdown,
|
||||
unique_id=file_id,
|
||||
document_type=DocumentType.ONEDRIVE_FILE,
|
||||
search_space_id=search_space_id,
|
||||
connector_id=connector_id,
|
||||
created_by_id=user_id,
|
||||
should_summarize=enable_summary,
|
||||
fallback_summary=fallback_summary,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def _download_files_parallel(
|
||||
onedrive_client: OneDriveClient,
|
||||
files: list[dict],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
max_concurrency: int = 3,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[list[ConnectorDocument], int]:
|
||||
"""Download and ETL files in parallel. Returns (docs, failed_count)."""
|
||||
results: list[ConnectorDocument] = []
|
||||
sem = asyncio.Semaphore(max_concurrency)
|
||||
last_heartbeat = time.time()
|
||||
completed_count = 0
|
||||
hb_lock = asyncio.Lock()
|
||||
|
||||
async def _download_one(file: dict) -> ConnectorDocument | None:
|
||||
nonlocal last_heartbeat, completed_count
|
||||
async with sem:
|
||||
markdown, od_metadata, error = await download_and_extract_content(
|
||||
onedrive_client, file
|
||||
)
|
||||
if error or not markdown:
|
||||
file_name = file.get("name", "Unknown")
|
||||
reason = error or "empty content"
|
||||
logger.warning(f"Download/ETL failed for {file_name}: {reason}")
|
||||
return None
|
||||
doc = _build_connector_doc(
|
||||
file, markdown, od_metadata,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=enable_summary,
|
||||
)
|
||||
async with hb_lock:
|
||||
completed_count += 1
|
||||
if on_heartbeat:
|
||||
now = time.time()
|
||||
if now - last_heartbeat >= HEARTBEAT_INTERVAL_SECONDS:
|
||||
await on_heartbeat(completed_count)
|
||||
last_heartbeat = now
|
||||
return doc
|
||||
|
||||
tasks = [_download_one(f) for f in files]
|
||||
outcomes = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
failed = 0
|
||||
for outcome in outcomes:
|
||||
if isinstance(outcome, Exception):
|
||||
failed += 1
|
||||
elif outcome is None:
|
||||
failed += 1
|
||||
else:
|
||||
results.append(outcome)
|
||||
|
||||
return results, failed
|
||||
|
||||
|
||||
async def _download_and_index(
|
||||
onedrive_client: OneDriveClient,
|
||||
session: AsyncSession,
|
||||
files: list[dict],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""Parallel download then parallel indexing. Returns (batch_indexed, total_failed)."""
|
||||
connector_docs, download_failed = await _download_files_parallel(
|
||||
onedrive_client, files,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
batch_indexed = 0
|
||||
batch_failed = 0
|
||||
if connector_docs:
|
||||
pipeline = IndexingPipelineService(session)
|
||||
|
||||
async def _get_llm(s):
|
||||
return await get_user_long_context_llm(s, user_id, search_space_id)
|
||||
|
||||
_, batch_indexed, batch_failed = await pipeline.index_batch_parallel(
|
||||
connector_docs, _get_llm, max_concurrency=3,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return batch_indexed, download_failed + batch_failed
|
||||
|
||||
|
||||
async def _remove_document(session: AsyncSession, file_id: str, search_space_id: int):
|
||||
"""Remove a document that was deleted in OneDrive."""
|
||||
primary_hash = compute_identifier_hash(
|
||||
DocumentType.ONEDRIVE_FILE.value, file_id, search_space_id
|
||||
)
|
||||
existing = await check_document_by_unique_identifier(session, primary_hash)
|
||||
|
||||
if not existing:
|
||||
result = await session.execute(
|
||||
select(Document).where(
|
||||
Document.search_space_id == search_space_id,
|
||||
Document.document_type == DocumentType.ONEDRIVE_FILE,
|
||||
cast(Document.document_metadata["onedrive_file_id"], String) == file_id,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
await session.delete(existing)
|
||||
logger.info(f"Removed deleted OneDrive file document: {file_id}")
|
||||
|
||||
|
||||
async def _index_selected_files(
|
||||
onedrive_client: OneDriveClient,
|
||||
session: AsyncSession,
|
||||
file_ids: list[tuple[str, str | None]],
|
||||
*,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[int, int, list[str]]:
|
||||
"""Index user-selected files using the parallel pipeline."""
|
||||
files_to_download: list[dict] = []
|
||||
errors: list[str] = []
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
|
||||
for file_id, file_name in file_ids:
|
||||
file, error = await get_file_by_id(onedrive_client, file_id)
|
||||
if error or not file:
|
||||
display = file_name or file_id
|
||||
errors.append(f"File '{display}': {error or 'File not found'}")
|
||||
continue
|
||||
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
onedrive_client, session, files_to_download,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat,
|
||||
)
|
||||
|
||||
return renamed_count + batch_indexed, skipped, errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scan strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _index_full_scan(
|
||||
onedrive_client: OneDriveClient,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
folder_id: str,
|
||||
folder_name: str,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry: object,
|
||||
max_files: int,
|
||||
include_subfolders: bool = True,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int]:
|
||||
"""Full scan indexing of a folder."""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry,
|
||||
f"Starting full scan of folder: {folder_name}",
|
||||
{"stage": "full_scan", "folder_id": folder_id, "include_subfolders": include_subfolders},
|
||||
)
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
|
||||
all_files, error = await get_files_in_folder(
|
||||
onedrive_client, folder_id, include_subfolders=include_subfolders,
|
||||
)
|
||||
if error:
|
||||
err_lower = error.lower()
|
||||
if "401" in error or "authentication expired" in err_lower:
|
||||
raise Exception(f"OneDrive authentication failed. Please re-authenticate. (Error: {error})")
|
||||
raise Exception(f"Failed to list OneDrive files: {error}")
|
||||
|
||||
for file in all_files[:max_files]:
|
||||
skip, msg = await _should_skip_file(session, file, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
files_to_download.append(file)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
onedrive_client, session, files_to_download,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(f"Full scan complete: {indexed} indexed, {skipped} skipped, {failed} failed")
|
||||
return indexed, skipped
|
||||
|
||||
|
||||
async def _index_with_delta_sync(
|
||||
onedrive_client: OneDriveClient,
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
folder_id: str | None,
|
||||
delta_link: str,
|
||||
task_logger: TaskLoggingService,
|
||||
log_entry: object,
|
||||
max_files: int,
|
||||
on_heartbeat_callback: HeartbeatCallbackType | None = None,
|
||||
enable_summary: bool = True,
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Delta sync using OneDrive change tracking. Returns (indexed, skipped, new_delta_link)."""
|
||||
await task_logger.log_task_progress(
|
||||
log_entry, "Starting delta sync",
|
||||
{"stage": "delta_sync"},
|
||||
)
|
||||
|
||||
changes, new_delta_link, error = await onedrive_client.get_delta(
|
||||
folder_id=folder_id, delta_link=delta_link
|
||||
)
|
||||
if error:
|
||||
err_lower = error.lower()
|
||||
if "401" in error or "authentication expired" in err_lower:
|
||||
raise Exception(f"OneDrive authentication failed. Please re-authenticate. (Error: {error})")
|
||||
raise Exception(f"Failed to fetch OneDrive changes: {error}")
|
||||
|
||||
if not changes:
|
||||
logger.info("No changes detected since last sync")
|
||||
return 0, 0, new_delta_link
|
||||
|
||||
logger.info(f"Processing {len(changes)} delta changes")
|
||||
|
||||
renamed_count = 0
|
||||
skipped = 0
|
||||
files_to_download: list[dict] = []
|
||||
files_processed = 0
|
||||
|
||||
for change in changes:
|
||||
if files_processed >= max_files:
|
||||
break
|
||||
files_processed += 1
|
||||
|
||||
if change.get("deleted"):
|
||||
fid = change.get("id")
|
||||
if fid:
|
||||
await _remove_document(session, fid, search_space_id)
|
||||
continue
|
||||
|
||||
if "folder" in change:
|
||||
continue
|
||||
|
||||
if not change.get("file"):
|
||||
continue
|
||||
|
||||
skip, msg = await _should_skip_file(session, change, search_space_id)
|
||||
if skip:
|
||||
if msg and "renamed" in msg.lower():
|
||||
renamed_count += 1
|
||||
else:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
files_to_download.append(change)
|
||||
|
||||
batch_indexed, failed = await _download_and_index(
|
||||
onedrive_client, session, files_to_download,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=enable_summary,
|
||||
on_heartbeat=on_heartbeat_callback,
|
||||
)
|
||||
|
||||
indexed = renamed_count + batch_indexed
|
||||
logger.info(f"Delta sync complete: {indexed} indexed, {skipped} skipped, {failed} failed")
|
||||
return indexed, skipped, new_delta_link
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def index_onedrive_files(
|
||||
session: AsyncSession,
|
||||
connector_id: int,
|
||||
search_space_id: int,
|
||||
user_id: str,
|
||||
items_dict: dict,
|
||||
) -> tuple[int, int, str | None]:
|
||||
"""Index OneDrive files for a specific connector.
|
||||
|
||||
items_dict format:
|
||||
{
|
||||
"folders": [{"id": "...", "name": "..."}, ...],
|
||||
"files": [{"id": "...", "name": "..."}, ...],
|
||||
"indexing_options": {"max_files": 500, "include_subfolders": true, "use_delta_sync": true}
|
||||
}
|
||||
"""
|
||||
task_logger = TaskLoggingService(session, search_space_id)
|
||||
log_entry = await task_logger.log_task_start(
|
||||
task_name="onedrive_files_indexing",
|
||||
source="connector_indexing_task",
|
||||
message=f"Starting OneDrive indexing for connector {connector_id}",
|
||||
metadata={"connector_id": connector_id, "user_id": str(user_id)},
|
||||
)
|
||||
|
||||
try:
|
||||
connector = await get_connector_by_id(
|
||||
session, connector_id, SearchSourceConnectorType.ONEDRIVE_CONNECTOR
|
||||
)
|
||||
if not connector:
|
||||
error_msg = f"OneDrive connector with ID {connector_id} not found"
|
||||
await task_logger.log_task_failure(log_entry, error_msg, None, {"error_type": "ConnectorNotFound"})
|
||||
return 0, 0, error_msg
|
||||
|
||||
token_encrypted = connector.config.get("_token_encrypted", False)
|
||||
if token_encrypted and not config.SECRET_KEY:
|
||||
error_msg = "SECRET_KEY not configured but credentials are encrypted"
|
||||
await task_logger.log_task_failure(log_entry, error_msg, "Missing SECRET_KEY", {"error_type": "MissingSecretKey"})
|
||||
return 0, 0, error_msg
|
||||
|
||||
connector_enable_summary = getattr(connector, "enable_summary", True)
|
||||
onedrive_client = OneDriveClient(session, connector_id)
|
||||
|
||||
indexing_options = items_dict.get("indexing_options", {})
|
||||
max_files = indexing_options.get("max_files", 500)
|
||||
include_subfolders = indexing_options.get("include_subfolders", True)
|
||||
use_delta_sync = indexing_options.get("use_delta_sync", True)
|
||||
|
||||
total_indexed = 0
|
||||
total_skipped = 0
|
||||
|
||||
# Index selected individual files
|
||||
selected_files = items_dict.get("files", [])
|
||||
if selected_files:
|
||||
file_tuples = [(f["id"], f.get("name")) for f in selected_files]
|
||||
indexed, skipped, errors = await _index_selected_files(
|
||||
onedrive_client, session, file_tuples,
|
||||
connector_id=connector_id, search_space_id=search_space_id,
|
||||
user_id=user_id, enable_summary=connector_enable_summary,
|
||||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
|
||||
# Index selected folders
|
||||
folders = items_dict.get("folders", [])
|
||||
for folder in folders:
|
||||
folder_id = folder.get("id", "root")
|
||||
folder_name = folder.get("name", "Root")
|
||||
|
||||
folder_delta_links = connector.config.get("folder_delta_links", {})
|
||||
delta_link = folder_delta_links.get(folder_id)
|
||||
can_use_delta = use_delta_sync and delta_link and connector.last_indexed_at
|
||||
|
||||
if can_use_delta:
|
||||
logger.info(f"Using delta sync for folder {folder_name}")
|
||||
indexed, skipped, new_delta_link = await _index_with_delta_sync(
|
||||
onedrive_client, session, connector_id, search_space_id, user_id,
|
||||
folder_id, delta_link, task_logger, log_entry, max_files,
|
||||
enable_summary=connector_enable_summary,
|
||||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
|
||||
if new_delta_link:
|
||||
await session.refresh(connector)
|
||||
if "folder_delta_links" not in connector.config:
|
||||
connector.config["folder_delta_links"] = {}
|
||||
connector.config["folder_delta_links"][folder_id] = new_delta_link
|
||||
flag_modified(connector, "config")
|
||||
|
||||
# Reconciliation full scan
|
||||
ri, rs = await _index_full_scan(
|
||||
onedrive_client, session, connector_id, search_space_id, user_id,
|
||||
folder_id, folder_name, task_logger, log_entry, max_files,
|
||||
include_subfolders, enable_summary=connector_enable_summary,
|
||||
)
|
||||
total_indexed += ri
|
||||
total_skipped += rs
|
||||
else:
|
||||
logger.info(f"Using full scan for folder {folder_name}")
|
||||
indexed, skipped = await _index_full_scan(
|
||||
onedrive_client, session, connector_id, search_space_id, user_id,
|
||||
folder_id, folder_name, task_logger, log_entry, max_files,
|
||||
include_subfolders, enable_summary=connector_enable_summary,
|
||||
)
|
||||
total_indexed += indexed
|
||||
total_skipped += skipped
|
||||
|
||||
# Store new delta link for this folder
|
||||
_, new_delta_link, _ = await onedrive_client.get_delta(folder_id=folder_id)
|
||||
if new_delta_link:
|
||||
await session.refresh(connector)
|
||||
if "folder_delta_links" not in connector.config:
|
||||
connector.config["folder_delta_links"] = {}
|
||||
connector.config["folder_delta_links"][folder_id] = new_delta_link
|
||||
flag_modified(connector, "config")
|
||||
|
||||
if total_indexed > 0 or folders:
|
||||
await update_connector_last_indexed(session, connector, True)
|
||||
|
||||
await session.commit()
|
||||
|
||||
await task_logger.log_task_success(
|
||||
log_entry,
|
||||
f"Successfully completed OneDrive indexing for connector {connector_id}",
|
||||
{"files_processed": total_indexed, "files_skipped": total_skipped},
|
||||
)
|
||||
logger.info(f"OneDrive indexing completed: {total_indexed} indexed, {total_skipped} skipped")
|
||||
return total_indexed, total_skipped, None
|
||||
|
||||
except SQLAlchemyError as db_error:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, f"Database error during OneDrive indexing for connector {connector_id}",
|
||||
str(db_error), {"error_type": "SQLAlchemyError"},
|
||||
)
|
||||
logger.error(f"Database error: {db_error!s}", exc_info=True)
|
||||
return 0, 0, f"Database error: {db_error!s}"
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
await task_logger.log_task_failure(
|
||||
log_entry, f"Failed to index OneDrive files for connector {connector_id}",
|
||||
str(e), {"error_type": type(e).__name__},
|
||||
)
|
||||
logger.error(f"Failed to index OneDrive files: {e!s}", exc_info=True)
|
||||
return 0, 0, f"Failed to index OneDrive files: {e!s}"
|
||||
|
|
@ -21,6 +21,7 @@ BASE_NAME_FOR_TYPE = {
|
|||
SearchSourceConnectorType.GOOGLE_CALENDAR_CONNECTOR: "Google Calendar",
|
||||
SearchSourceConnectorType.SLACK_CONNECTOR: "Slack",
|
||||
SearchSourceConnectorType.TEAMS_CONNECTOR: "Microsoft Teams",
|
||||
SearchSourceConnectorType.ONEDRIVE_CONNECTOR: "OneDrive",
|
||||
SearchSourceConnectorType.NOTION_CONNECTOR: "Notion",
|
||||
SearchSourceConnectorType.LINEAR_CONNECTOR: "Linear",
|
||||
SearchSourceConnectorType.JIRA_CONNECTOR: "Jira",
|
||||
|
|
@ -61,6 +62,9 @@ def extract_identifier_from_credentials(
|
|||
if connector_type == SearchSourceConnectorType.TEAMS_CONNECTOR:
|
||||
return credentials.get("tenant_name")
|
||||
|
||||
if connector_type == SearchSourceConnectorType.ONEDRIVE_CONNECTOR:
|
||||
return credentials.get("user_email")
|
||||
|
||||
if connector_type == SearchSourceConnectorType.NOTION_CONNECTOR:
|
||||
return credentials.get("workspace_name")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue