feat: implement Microsoft OneDrive connector with OAuth support and indexing capabilities

This commit is contained in:
Anish Sarkar 2026-03-28 14:31:25 +05:30
parent 64be61b627
commit 5bddde60cb
16 changed files with 2014 additions and 0 deletions

View file

@ -0,0 +1,13 @@
"""Microsoft OneDrive Connector Module."""
from .client import OneDriveClient
from .content_extractor import download_and_extract_content
from .folder_manager import get_file_by_id, get_files_in_folder, list_folder_contents
__all__ = [
"OneDriveClient",
"download_and_extract_content",
"get_file_by_id",
"get_files_in_folder",
"list_folder_contents",
]

View file

@ -0,0 +1,276 @@
"""Microsoft OneDrive API client using Microsoft Graph API v1.0."""
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
import httpx
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm.attributes import flag_modified
from app.config import config
from app.db import SearchSourceConnector
from app.utils.oauth_security import TokenEncryption
logger = logging.getLogger(__name__)
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
class OneDriveClient:
"""Client for Microsoft OneDrive via the Graph API."""
def __init__(self, session: AsyncSession, connector_id: int):
self._session = session
self._connector_id = connector_id
async def _get_valid_token(self) -> str:
"""Get a valid access token, refreshing if needed."""
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == self._connector_id
)
)
connector = result.scalars().first()
if not connector:
raise ValueError(f"Connector {self._connector_id} not found")
cfg = connector.config or {}
is_encrypted = cfg.get("_token_encrypted", False)
token_encryption = TokenEncryption(config.SECRET_KEY) if config.SECRET_KEY else None
access_token = cfg.get("access_token", "")
refresh_token = cfg.get("refresh_token")
if is_encrypted and token_encryption:
if access_token:
access_token = token_encryption.decrypt_token(access_token)
if refresh_token:
refresh_token = token_encryption.decrypt_token(refresh_token)
expires_at_str = cfg.get("expires_at")
is_expired = False
if expires_at_str:
expires_at = datetime.fromisoformat(expires_at_str)
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=UTC)
is_expired = expires_at <= datetime.now(UTC)
if not is_expired and access_token:
return access_token
if not refresh_token:
cfg["auth_expired"] = True
connector.config = cfg
flag_modified(connector, "config")
await self._session.commit()
raise ValueError("OneDrive token expired and no refresh token available")
token_data = await self._refresh_token(refresh_token)
new_access = token_data["access_token"]
new_refresh = token_data.get("refresh_token", refresh_token)
expires_in = token_data.get("expires_in")
new_expires_at = None
if expires_in:
new_expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in))
if token_encryption:
cfg["access_token"] = token_encryption.encrypt_token(new_access)
cfg["refresh_token"] = token_encryption.encrypt_token(new_refresh)
else:
cfg["access_token"] = new_access
cfg["refresh_token"] = new_refresh
cfg["expires_at"] = new_expires_at.isoformat() if new_expires_at else None
cfg["expires_in"] = expires_in
cfg["_token_encrypted"] = bool(token_encryption)
cfg.pop("auth_expired", None)
connector.config = cfg
flag_modified(connector, "config")
await self._session.commit()
return new_access
async def _refresh_token(self, refresh_token: str) -> dict:
data = {
"client_id": config.ONEDRIVE_CLIENT_ID,
"client_secret": config.ONEDRIVE_CLIENT_SECRET,
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"scope": "offline_access User.Read Files.Read.All Files.ReadWrite.All",
}
async with httpx.AsyncClient() as client:
resp = await client.post(
TOKEN_URL,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=30.0,
)
if resp.status_code != 200:
error_detail = resp.text
try:
error_json = resp.json()
error_detail = error_json.get("error_description", error_detail)
except Exception:
pass
raise ValueError(f"OneDrive token refresh failed: {error_detail}")
return resp.json()
async def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
"""Make an authenticated request to the Graph API."""
token = await self._get_valid_token()
headers = {"Authorization": f"Bearer {token}"}
if "headers" in kwargs:
headers.update(kwargs.pop("headers"))
async with httpx.AsyncClient() as client:
resp = await client.request(
method,
f"{GRAPH_API_BASE}{path}",
headers=headers,
timeout=60.0,
**kwargs,
)
if resp.status_code == 401:
result = await self._session.execute(
select(SearchSourceConnector).filter(
SearchSourceConnector.id == self._connector_id
)
)
connector = result.scalars().first()
if connector:
cfg = connector.config or {}
cfg["auth_expired"] = True
connector.config = cfg
flag_modified(connector, "config")
await self._session.commit()
raise ValueError("OneDrive authentication expired (401)")
return resp
async def list_children(
self, item_id: str = "root"
) -> tuple[list[dict[str, Any]], str | None]:
all_items: list[dict[str, Any]] = []
url = f"/me/drive/items/{item_id}/children"
params: dict[str, Any] = {
"$top": 200,
"$select": "id,name,size,file,folder,parentReference,lastModifiedDateTime,createdDateTime,webUrl,remoteItem,package",
}
while url:
resp = await self._request("GET", url, params=params)
if resp.status_code != 200:
return [], f"Failed to list children: {resp.status_code} - {resp.text}"
data = resp.json()
all_items.extend(data.get("value", []))
next_link = data.get("@odata.nextLink")
if next_link:
url = next_link.replace(GRAPH_API_BASE, "")
params = {}
else:
url = ""
return all_items, None
async def get_item_metadata(
self, item_id: str
) -> tuple[dict[str, Any] | None, str | None]:
resp = await self._request(
"GET",
f"/me/drive/items/{item_id}",
params={
"$select": "id,name,size,file,folder,parentReference,lastModifiedDateTime,createdDateTime,webUrl"
},
)
if resp.status_code != 200:
return None, f"Failed to get item: {resp.status_code} - {resp.text}"
return resp.json(), None
async def download_file(self, item_id: str) -> tuple[bytes | None, str | None]:
token = await self._get_valid_token()
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.get(
f"{GRAPH_API_BASE}/me/drive/items/{item_id}/content",
headers={"Authorization": f"Bearer {token}"},
timeout=120.0,
)
if resp.status_code != 200:
return None, f"Download failed: {resp.status_code}"
return resp.content, None
async def download_file_to_disk(self, item_id: str, dest_path: str) -> str | None:
"""Stream file content to disk. Returns error message on failure."""
token = await self._get_valid_token()
async with httpx.AsyncClient(follow_redirects=True) as client:
async with client.stream(
"GET",
f"{GRAPH_API_BASE}/me/drive/items/{item_id}/content",
headers={"Authorization": f"Bearer {token}"},
timeout=120.0,
) as resp:
if resp.status_code != 200:
return f"Download failed: {resp.status_code}"
with open(dest_path, "wb") as f:
async for chunk in resp.aiter_bytes(chunk_size=5 * 1024 * 1024):
f.write(chunk)
return None
async def create_file(
self,
name: str,
parent_id: str | None = None,
content: str | None = None,
mime_type: str | None = None,
) -> dict[str, Any]:
"""Create (upload) a file in OneDrive."""
folder_path = f"/me/drive/items/{parent_id or 'root'}"
body = (content or "").encode("utf-8")
resp = await self._request(
"PUT",
f"{folder_path}:/{name}:/content",
content=body,
headers={"Content-Type": mime_type or "application/octet-stream"},
)
if resp.status_code not in (200, 201):
raise ValueError(f"File creation failed: {resp.status_code} - {resp.text}")
return resp.json()
async def trash_file(self, item_id: str) -> bool:
"""Delete (move to recycle bin) a OneDrive item."""
resp = await self._request("DELETE", f"/me/drive/items/{item_id}")
if resp.status_code not in (200, 204):
raise ValueError(f"Trash failed: {resp.status_code} - {resp.text}")
return True
async def get_delta(
self, folder_id: str | None = None, delta_link: str | None = None
) -> tuple[list[dict[str, Any]], str | None, str | None]:
"""Get delta changes. Returns (changes, new_delta_link, error)."""
all_changes: list[dict[str, Any]] = []
if delta_link:
url = delta_link.replace(GRAPH_API_BASE, "")
elif folder_id:
url = f"/me/drive/items/{folder_id}/delta"
else:
url = "/me/drive/root/delta"
params: dict[str, Any] = {"$top": 200}
while url:
resp = await self._request("GET", url, params=params)
if resp.status_code != 200:
return [], None, f"Delta failed: {resp.status_code} - {resp.text}"
data = resp.json()
all_changes.extend(data.get("value", []))
next_link = data.get("@odata.nextLink")
new_delta_link = data.get("@odata.deltaLink")
if next_link:
url = next_link.replace(GRAPH_API_BASE, "")
params = {}
else:
url = ""
return all_changes, new_delta_link, None

View file

@ -0,0 +1,169 @@
"""Content extraction for OneDrive files.
Reuses the same ETL parsing logic as Google Drive since file parsing is
extension-based, not provider-specific.
"""
import asyncio
import logging
import os
import tempfile
import threading
import time
from pathlib import Path
from typing import Any
from .client import OneDriveClient
from .file_types import get_extension_from_mime, should_skip_file
logger = logging.getLogger(__name__)
async def download_and_extract_content(
client: OneDriveClient,
file: dict[str, Any],
) -> tuple[str | None, dict[str, Any], str | None]:
"""Download a OneDrive file and extract its content as markdown.
Returns (markdown_content, onedrive_metadata, error_message).
"""
item_id = file.get("id")
file_name = file.get("name", "Unknown")
if should_skip_file(file):
return None, {}, "Skipping non-indexable item"
file_info = file.get("file", {})
mime_type = file_info.get("mimeType", "")
logger.info(f"Downloading file for content extraction: {file_name} ({mime_type})")
metadata: dict[str, Any] = {
"onedrive_file_id": item_id,
"onedrive_file_name": file_name,
"onedrive_mime_type": mime_type,
"source_connector": "onedrive",
}
if "lastModifiedDateTime" in file:
metadata["modified_time"] = file["lastModifiedDateTime"]
if "createdDateTime" in file:
metadata["created_time"] = file["createdDateTime"]
if "size" in file:
metadata["file_size"] = file["size"]
if "webUrl" in file:
metadata["web_url"] = file["webUrl"]
file_hashes = file_info.get("hashes", {})
if file_hashes.get("sha256Hash"):
metadata["sha256_hash"] = file_hashes["sha256Hash"]
elif file_hashes.get("quickXorHash"):
metadata["quick_xor_hash"] = file_hashes["quickXorHash"]
temp_file_path = None
try:
extension = Path(file_name).suffix or get_extension_from_mime(mime_type) or ".bin"
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp:
temp_file_path = tmp.name
error = await client.download_file_to_disk(item_id, temp_file_path)
if error:
return None, metadata, error
markdown = await _parse_file_to_markdown(temp_file_path, file_name)
return markdown, metadata, None
except Exception as e:
logger.warning(f"Failed to extract content from {file_name}: {e!s}")
return None, metadata, str(e)
finally:
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except Exception:
pass
async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
"""Parse a local file to markdown using the configured ETL service.
Same logic as Google Drive -- file parsing is extension-based.
"""
lower = filename.lower()
if lower.endswith((".md", ".markdown", ".txt")):
with open(file_path, encoding="utf-8") as f:
return f.read()
if lower.endswith((".mp3", ".mp4", ".mpeg", ".mpga", ".m4a", ".wav", ".webm")):
from app.config import config as app_config
from litellm import atranscription
stt_service_type = (
"local"
if app_config.STT_SERVICE and app_config.STT_SERVICE.startswith("local/")
else "external"
)
if stt_service_type == "local":
from app.services.stt_service import stt_service
t0 = time.monotonic()
logger.info(f"[local-stt] START file={filename} thread={threading.current_thread().name}")
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
logger.info(f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
text = result.get("text", "")
else:
with open(file_path, "rb") as audio_file:
kwargs: dict[str, Any] = {
"model": app_config.STT_SERVICE,
"file": audio_file,
"api_key": app_config.STT_SERVICE_API_KEY,
}
if app_config.STT_SERVICE_API_BASE:
kwargs["api_base"] = app_config.STT_SERVICE_API_BASE
resp = await atranscription(**kwargs)
text = resp.get("text", "")
if not text:
raise ValueError("Transcription returned empty text")
return f"# Transcription of {filename}\n\n{text}"
from app.config import config as app_config
if app_config.ETL_SERVICE == "UNSTRUCTURED":
from langchain_unstructured import UnstructuredLoader
from app.utils.document_converters import convert_document_to_markdown
loader = UnstructuredLoader(
file_path,
mode="elements",
post_processors=[],
languages=["eng"],
include_orig_elements=False,
include_metadata=False,
strategy="auto",
)
docs = await loader.aload()
return await convert_document_to_markdown(docs)
if app_config.ETL_SERVICE == "LLAMACLOUD":
from app.tasks.document_processors.file_processors import (
parse_with_llamacloud_retry,
)
result = await parse_with_llamacloud_retry(file_path=file_path, estimated_pages=50)
markdown_documents = await result.aget_markdown_documents(split_by_page=False)
if not markdown_documents:
raise RuntimeError(f"LlamaCloud returned no documents for {filename}")
return markdown_documents[0].text
if app_config.ETL_SERVICE == "DOCLING":
from docling.document_converter import DocumentConverter
converter = DocumentConverter()
t0 = time.monotonic()
logger.info(f"[docling] START file={filename} thread={threading.current_thread().name}")
result = await asyncio.to_thread(converter.convert, file_path)
logger.info(f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
return result.document.export_to_markdown()
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")

View file

@ -0,0 +1,50 @@
"""File type handlers for Microsoft OneDrive."""
ONEDRIVE_FOLDER_FACET = "folder"
ONENOTE_MIME = "application/msonenote"
SKIP_MIME_TYPES = frozenset(
{
ONENOTE_MIME,
"application/vnd.ms-onenotesection",
"application/vnd.ms-onenotenotebook",
}
)
MIME_TO_EXTENSION: dict[str, str] = {
"application/pdf": ".pdf",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx",
"application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx",
"application/vnd.ms-excel": ".xls",
"application/msword": ".doc",
"application/vnd.ms-powerpoint": ".ppt",
"text/plain": ".txt",
"text/csv": ".csv",
"text/html": ".html",
"text/markdown": ".md",
"application/json": ".json",
"application/xml": ".xml",
"image/png": ".png",
"image/jpeg": ".jpg",
}
def get_extension_from_mime(mime_type: str) -> str | None:
return MIME_TO_EXTENSION.get(mime_type)
def is_folder(item: dict) -> bool:
return ONEDRIVE_FOLDER_FACET in item
def should_skip_file(item: dict) -> bool:
"""Skip folders, OneNote files, remote items (shared links), and packages."""
if is_folder(item):
return True
if "remoteItem" in item:
return True
if "package" in item:
return True
mime = item.get("file", {}).get("mimeType", "")
return mime in SKIP_MIME_TYPES

View file

@ -0,0 +1,90 @@
"""Folder management for Microsoft OneDrive."""
import logging
from typing import Any
from .client import OneDriveClient
from .file_types import is_folder, should_skip_file
logger = logging.getLogger(__name__)
async def list_folder_contents(
client: OneDriveClient,
parent_id: str | None = None,
) -> tuple[list[dict[str, Any]], str | None]:
"""List folders and files in a OneDrive folder.
Returns (items list with folders first, error message).
"""
try:
items, error = await client.list_children(parent_id or "root")
if error:
return [], error
for item in items:
item["isFolder"] = is_folder(item)
items.sort(key=lambda x: (not x["isFolder"], x.get("name", "").lower()))
folder_count = sum(1 for item in items if item["isFolder"])
file_count = len(items) - folder_count
logger.info(
f"Listed {len(items)} items ({folder_count} folders, {file_count} files) "
+ (f"in folder {parent_id}" if parent_id else "in root")
)
return items, None
except Exception as e:
logger.error(f"Error listing folder contents: {e!s}", exc_info=True)
return [], f"Error listing folder contents: {e!s}"
async def get_files_in_folder(
client: OneDriveClient,
folder_id: str,
include_subfolders: bool = True,
) -> tuple[list[dict[str, Any]], str | None]:
"""Get all indexable files in a folder, optionally recursing into subfolders."""
try:
items, error = await client.list_children(folder_id)
if error:
return [], error
files: list[dict[str, Any]] = []
for item in items:
if is_folder(item):
if include_subfolders:
sub_files, sub_error = await get_files_in_folder(
client, item["id"], include_subfolders=True
)
if sub_error:
logger.warning(f"Error recursing into folder {item.get('name')}: {sub_error}")
continue
files.extend(sub_files)
elif not should_skip_file(item):
files.append(item)
return files, None
except Exception as e:
logger.error(f"Error getting files in folder: {e!s}", exc_info=True)
return [], f"Error getting files in folder: {e!s}"
async def get_file_by_id(
client: OneDriveClient,
file_id: str,
) -> tuple[dict[str, Any] | None, str | None]:
"""Get file metadata by ID."""
try:
item, error = await client.get_item_metadata(file_id)
if error:
return None, error
if not item:
return None, f"File not found: {file_id}"
return item, None
except Exception as e:
logger.error(f"Error getting file by ID: {e!s}", exc_info=True)
return None, f"Error getting file by ID: {e!s}"