mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-29 19:06:24 +02:00
283 lines
11 KiB
Python
283 lines
11 KiB
Python
"""Microsoft OneDrive API client using Microsoft Graph API v1.0."""
|
|
|
|
import logging
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm.attributes import flag_modified
|
|
|
|
from app.config import config
|
|
from app.db import SearchSourceConnector
|
|
from app.utils.oauth_security import TokenEncryption
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
GRAPH_API_BASE = "https://graph.microsoft.com/v1.0"
|
|
TOKEN_URL = "https://login.microsoftonline.com/common/oauth2/v2.0/token"
|
|
|
|
|
|
class OneDriveClient:
|
|
"""Client for Microsoft OneDrive via the Graph API."""
|
|
|
|
def __init__(self, session: AsyncSession, connector_id: int):
|
|
self._session = session
|
|
self._connector_id = connector_id
|
|
|
|
async def _get_valid_token(self) -> str:
|
|
"""Get a valid access token, refreshing if needed."""
|
|
result = await self._session.execute(
|
|
select(SearchSourceConnector).filter(
|
|
SearchSourceConnector.id == self._connector_id
|
|
)
|
|
)
|
|
connector = result.scalars().first()
|
|
if not connector:
|
|
raise ValueError(f"Connector {self._connector_id} not found")
|
|
|
|
cfg = connector.config or {}
|
|
is_encrypted = cfg.get("_token_encrypted", False)
|
|
token_encryption = (
|
|
TokenEncryption(config.SECRET_KEY) if config.SECRET_KEY else None
|
|
)
|
|
|
|
access_token = cfg.get("access_token", "")
|
|
refresh_token = cfg.get("refresh_token")
|
|
|
|
if is_encrypted and token_encryption:
|
|
if access_token:
|
|
access_token = token_encryption.decrypt_token(access_token)
|
|
if refresh_token:
|
|
refresh_token = token_encryption.decrypt_token(refresh_token)
|
|
|
|
expires_at_str = cfg.get("expires_at")
|
|
is_expired = False
|
|
if expires_at_str:
|
|
expires_at = datetime.fromisoformat(expires_at_str)
|
|
if expires_at.tzinfo is None:
|
|
expires_at = expires_at.replace(tzinfo=UTC)
|
|
is_expired = expires_at <= datetime.now(UTC)
|
|
|
|
if not is_expired and access_token:
|
|
return access_token
|
|
|
|
if not refresh_token:
|
|
cfg["auth_expired"] = True
|
|
connector.config = cfg
|
|
flag_modified(connector, "config")
|
|
await self._session.commit()
|
|
raise ValueError("OneDrive token expired and no refresh token available")
|
|
|
|
token_data = await self._refresh_token(refresh_token)
|
|
|
|
new_access = token_data["access_token"]
|
|
new_refresh = token_data.get("refresh_token", refresh_token)
|
|
expires_in = token_data.get("expires_in")
|
|
|
|
new_expires_at = None
|
|
if expires_in:
|
|
new_expires_at = datetime.now(UTC) + timedelta(seconds=int(expires_in))
|
|
|
|
if token_encryption:
|
|
cfg["access_token"] = token_encryption.encrypt_token(new_access)
|
|
cfg["refresh_token"] = token_encryption.encrypt_token(new_refresh)
|
|
else:
|
|
cfg["access_token"] = new_access
|
|
cfg["refresh_token"] = new_refresh
|
|
|
|
cfg["expires_at"] = new_expires_at.isoformat() if new_expires_at else None
|
|
cfg["expires_in"] = expires_in
|
|
cfg["_token_encrypted"] = bool(token_encryption)
|
|
cfg.pop("auth_expired", None)
|
|
|
|
connector.config = cfg
|
|
flag_modified(connector, "config")
|
|
await self._session.commit()
|
|
|
|
return new_access
|
|
|
|
async def _refresh_token(self, refresh_token: str) -> dict:
|
|
data = {
|
|
"client_id": config.MICROSOFT_CLIENT_ID,
|
|
"client_secret": config.MICROSOFT_CLIENT_SECRET,
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
"scope": "offline_access User.Read Files.Read.All Files.ReadWrite.All",
|
|
}
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.post(
|
|
TOKEN_URL,
|
|
data=data,
|
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
timeout=30.0,
|
|
)
|
|
if resp.status_code != 200:
|
|
error_detail = resp.text
|
|
try:
|
|
error_json = resp.json()
|
|
error_detail = error_json.get("error_description", error_detail)
|
|
except Exception:
|
|
pass
|
|
raise ValueError(f"OneDrive token refresh failed: {error_detail}")
|
|
return resp.json()
|
|
|
|
async def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
|
|
"""Make an authenticated request to the Graph API."""
|
|
token = await self._get_valid_token()
|
|
headers = {"Authorization": f"Bearer {token}"}
|
|
if "headers" in kwargs:
|
|
headers.update(kwargs.pop("headers"))
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.request(
|
|
method,
|
|
f"{GRAPH_API_BASE}{path}",
|
|
headers=headers,
|
|
timeout=60.0,
|
|
**kwargs,
|
|
)
|
|
|
|
if resp.status_code == 401:
|
|
result = await self._session.execute(
|
|
select(SearchSourceConnector).filter(
|
|
SearchSourceConnector.id == self._connector_id
|
|
)
|
|
)
|
|
connector = result.scalars().first()
|
|
if connector:
|
|
cfg = connector.config or {}
|
|
cfg["auth_expired"] = True
|
|
connector.config = cfg
|
|
flag_modified(connector, "config")
|
|
await self._session.commit()
|
|
raise ValueError("OneDrive authentication expired (401)")
|
|
|
|
return resp
|
|
|
|
async def list_children(
|
|
self, item_id: str = "root"
|
|
) -> tuple[list[dict[str, Any]], str | None]:
|
|
all_items: list[dict[str, Any]] = []
|
|
url = f"/me/drive/items/{item_id}/children"
|
|
params: dict[str, Any] = {
|
|
"$top": 200,
|
|
"$select": "id,name,size,file,folder,parentReference,lastModifiedDateTime,createdDateTime,webUrl,remoteItem,package",
|
|
}
|
|
while url:
|
|
resp = await self._request("GET", url, params=params)
|
|
if resp.status_code != 200:
|
|
return [], f"Failed to list children: {resp.status_code} - {resp.text}"
|
|
data = resp.json()
|
|
all_items.extend(data.get("value", []))
|
|
next_link = data.get("@odata.nextLink")
|
|
if next_link:
|
|
url = next_link.replace(GRAPH_API_BASE, "")
|
|
params = {}
|
|
else:
|
|
url = ""
|
|
return all_items, None
|
|
|
|
async def get_item_metadata(
|
|
self, item_id: str
|
|
) -> tuple[dict[str, Any] | None, str | None]:
|
|
resp = await self._request(
|
|
"GET",
|
|
f"/me/drive/items/{item_id}",
|
|
params={
|
|
"$select": "id,name,size,file,folder,parentReference,lastModifiedDateTime,createdDateTime,webUrl"
|
|
},
|
|
)
|
|
if resp.status_code != 200:
|
|
return None, f"Failed to get item: {resp.status_code} - {resp.text}"
|
|
return resp.json(), None
|
|
|
|
async def download_file(self, item_id: str) -> tuple[bytes | None, str | None]:
|
|
token = await self._get_valid_token()
|
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
resp = await client.get(
|
|
f"{GRAPH_API_BASE}/me/drive/items/{item_id}/content",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
timeout=120.0,
|
|
)
|
|
if resp.status_code != 200:
|
|
return None, f"Download failed: {resp.status_code}"
|
|
return resp.content, None
|
|
|
|
async def download_file_to_disk(self, item_id: str, dest_path: str) -> str | None:
|
|
"""Stream file content to disk. Returns error message on failure."""
|
|
token = await self._get_valid_token()
|
|
async with (
|
|
httpx.AsyncClient(follow_redirects=True) as client,
|
|
client.stream(
|
|
"GET",
|
|
f"{GRAPH_API_BASE}/me/drive/items/{item_id}/content",
|
|
headers={"Authorization": f"Bearer {token}"},
|
|
timeout=120.0,
|
|
) as resp,
|
|
):
|
|
if resp.status_code != 200:
|
|
return f"Download failed: {resp.status_code}"
|
|
with open(dest_path, "wb") as f:
|
|
async for chunk in resp.aiter_bytes(chunk_size=5 * 1024 * 1024):
|
|
f.write(chunk)
|
|
return None
|
|
|
|
async def create_file(
|
|
self,
|
|
name: str,
|
|
parent_id: str | None = None,
|
|
content: str | bytes | None = None,
|
|
mime_type: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Create (upload) a file in OneDrive."""
|
|
folder_path = f"/me/drive/items/{parent_id or 'root'}"
|
|
if isinstance(content, bytes):
|
|
body = content
|
|
else:
|
|
body = (content or "").encode("utf-8")
|
|
resp = await self._request(
|
|
"PUT",
|
|
f"{folder_path}:/{name}:/content",
|
|
content=body,
|
|
headers={"Content-Type": mime_type or "application/octet-stream"},
|
|
)
|
|
if resp.status_code not in (200, 201):
|
|
raise ValueError(f"File creation failed: {resp.status_code} - {resp.text}")
|
|
return resp.json()
|
|
|
|
async def trash_file(self, item_id: str) -> bool:
|
|
"""Delete (move to recycle bin) a OneDrive item."""
|
|
resp = await self._request("DELETE", f"/me/drive/items/{item_id}")
|
|
if resp.status_code not in (200, 204):
|
|
raise ValueError(f"Trash failed: {resp.status_code} - {resp.text}")
|
|
return True
|
|
|
|
async def get_delta(
|
|
self, folder_id: str | None = None, delta_link: str | None = None
|
|
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
|
"""Get delta changes. Returns (changes, new_delta_link, error)."""
|
|
all_changes: list[dict[str, Any]] = []
|
|
if delta_link:
|
|
url = delta_link.replace(GRAPH_API_BASE, "")
|
|
elif folder_id:
|
|
url = f"/me/drive/items/{folder_id}/delta"
|
|
else:
|
|
url = "/me/drive/root/delta"
|
|
|
|
params: dict[str, Any] = {"$top": 200}
|
|
while url:
|
|
resp = await self._request("GET", url, params=params)
|
|
if resp.status_code != 200:
|
|
return [], None, f"Delta failed: {resp.status_code} - {resp.text}"
|
|
data = resp.json()
|
|
all_changes.extend(data.get("value", []))
|
|
next_link = data.get("@odata.nextLink")
|
|
new_delta_link = data.get("@odata.deltaLink")
|
|
if next_link:
|
|
url = next_link.replace(GRAPH_API_BASE, "")
|
|
params = {}
|
|
else:
|
|
url = ""
|
|
return all_changes, new_delta_link, None
|