mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-04-27 17:56:25 +02:00
372 lines
12 KiB
Python
372 lines
12 KiB
Python
"""Google Drive API client."""
|
|
|
|
import asyncio
|
|
import io
|
|
import logging
|
|
import threading
|
|
import time
|
|
from typing import Any
|
|
|
|
import httplib2
|
|
from google.oauth2.credentials import Credentials
|
|
from google_auth_httplib2 import AuthorizedHttp
|
|
from googleapiclient.discovery import build
|
|
from googleapiclient.errors import HttpError
|
|
from googleapiclient.http import MediaIoBaseUpload
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from .credentials import get_valid_credentials
|
|
from .file_types import GOOGLE_DOC, GOOGLE_SHEET
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _build_thread_http(credentials: Credentials) -> AuthorizedHttp:
|
|
"""Create a per-thread HTTP transport so concurrent downloads don't share
|
|
the same ``httplib2.Http`` (which is not thread-safe)."""
|
|
return AuthorizedHttp(credentials, http=httplib2.Http())
|
|
|
|
|
|
class GoogleDriveClient:
|
|
"""Client for Google Drive API operations."""
|
|
|
|
def __init__(
|
|
self,
|
|
session: AsyncSession,
|
|
connector_id: int,
|
|
credentials: "Credentials | None" = None,
|
|
):
|
|
"""
|
|
Initialize Google Drive client.
|
|
|
|
Args:
|
|
session: Database session
|
|
connector_id: ID of the Drive connector
|
|
credentials: Pre-built credentials (e.g. from Composio). If None,
|
|
credentials are loaded from the DB connector config.
|
|
"""
|
|
self.session = session
|
|
self.connector_id = connector_id
|
|
self._credentials = credentials
|
|
self._resolved_credentials: Credentials | None = None
|
|
self.service = None
|
|
self._service_lock = asyncio.Lock()
|
|
|
|
async def get_service(self):
|
|
"""
|
|
Get or create the Drive service instance.
|
|
|
|
Returns:
|
|
Google Drive service instance
|
|
|
|
Raises:
|
|
Exception: If service creation fails
|
|
"""
|
|
if self.service:
|
|
return self.service
|
|
|
|
async with self._service_lock:
|
|
if self.service:
|
|
return self.service
|
|
|
|
try:
|
|
if self._credentials:
|
|
credentials = self._credentials
|
|
else:
|
|
credentials = await get_valid_credentials(
|
|
self.session, self.connector_id
|
|
)
|
|
self._resolved_credentials = credentials
|
|
self.service = build("drive", "v3", credentials=credentials)
|
|
return self.service
|
|
except Exception as e:
|
|
raise Exception(f"Failed to create Google Drive service: {e!s}") from e
|
|
|
|
async def list_files(
|
|
self,
|
|
query: str = "",
|
|
fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)",
|
|
page_size: int = 100,
|
|
page_token: str | None = None,
|
|
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
|
"""
|
|
List files from Google Drive with pagination.
|
|
|
|
Args:
|
|
query: Search query (e.g., "mimeType != 'application/vnd.google-apps.folder'")
|
|
fields: Fields to retrieve
|
|
page_size: Number of files per page (max 1000)
|
|
page_token: Token for next page
|
|
|
|
Returns:
|
|
Tuple of (files list, next_page_token, error message)
|
|
"""
|
|
try:
|
|
service = await self.get_service()
|
|
|
|
params = {
|
|
"pageSize": min(page_size, 1000),
|
|
"fields": fields,
|
|
"supportsAllDrives": True,
|
|
"includeItemsFromAllDrives": True,
|
|
}
|
|
|
|
if query:
|
|
params["q"] = query
|
|
if page_token:
|
|
params["pageToken"] = page_token
|
|
|
|
result = service.files().list(**params).execute()
|
|
|
|
files = result.get("files", [])
|
|
next_token = result.get("nextPageToken")
|
|
|
|
return files, next_token, None
|
|
|
|
except HttpError as e:
|
|
error_msg = f"HTTP error listing files: {e.resp.status} - {e.error_details}"
|
|
return [], None, error_msg
|
|
except Exception as e:
|
|
return [], None, f"Error listing files: {e!s}"
|
|
|
|
async def get_file_metadata(
|
|
self, file_id: str, fields: str = "*"
|
|
) -> tuple[dict[str, Any] | None, str | None]:
|
|
"""
|
|
Get metadata for a specific file.
|
|
|
|
Args:
|
|
file_id: ID of the file
|
|
fields: Fields to retrieve
|
|
|
|
Returns:
|
|
Tuple of (file metadata, error message)
|
|
"""
|
|
try:
|
|
service = await self.get_service()
|
|
file = (
|
|
service.files()
|
|
.get(fileId=file_id, fields=fields, supportsAllDrives=True)
|
|
.execute()
|
|
)
|
|
return file, None
|
|
except HttpError as e:
|
|
return None, f"HTTP error getting file metadata: {e.resp.status}"
|
|
except Exception as e:
|
|
return None, f"Error getting file metadata: {e!s}"
|
|
|
|
@staticmethod
|
|
def _sync_download_file(
|
|
service,
|
|
file_id: str,
|
|
credentials: Credentials,
|
|
) -> tuple[bytes | None, str | None]:
|
|
"""Blocking download — runs on a worker thread via ``to_thread``."""
|
|
thread = threading.current_thread().name
|
|
t0 = time.monotonic()
|
|
logger.info(f"[download] START file={file_id} thread={thread}")
|
|
try:
|
|
from googleapiclient.http import MediaIoBaseDownload
|
|
|
|
http = _build_thread_http(credentials)
|
|
request = service.files().get_media(fileId=file_id)
|
|
request.http = http
|
|
fh = io.BytesIO()
|
|
downloader = MediaIoBaseDownload(fh, request)
|
|
done = False
|
|
while not done:
|
|
_, done = downloader.next_chunk()
|
|
return fh.getvalue(), None
|
|
except HttpError as e:
|
|
return None, f"HTTP error downloading file: {e.resp.status}"
|
|
except Exception as e:
|
|
return None, f"Error downloading file: {e!s}"
|
|
finally:
|
|
logger.info(
|
|
f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
|
|
)
|
|
|
|
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
|
|
"""
|
|
Download binary file content.
|
|
|
|
Args:
|
|
file_id: ID of the file to download
|
|
|
|
Returns:
|
|
Tuple of (file content bytes, error message)
|
|
"""
|
|
service = await self.get_service()
|
|
return await asyncio.to_thread(
|
|
self._sync_download_file,
|
|
service,
|
|
file_id,
|
|
self._resolved_credentials,
|
|
)
|
|
|
|
@staticmethod
|
|
def _sync_download_file_to_disk(
|
|
service,
|
|
file_id: str,
|
|
dest_path: str,
|
|
chunksize: int,
|
|
credentials: Credentials,
|
|
) -> str | None:
|
|
"""Blocking download-to-disk — runs on a worker thread via ``to_thread``."""
|
|
thread = threading.current_thread().name
|
|
t0 = time.monotonic()
|
|
logger.info(f"[download-to-disk] START file={file_id} thread={thread}")
|
|
try:
|
|
from googleapiclient.http import MediaIoBaseDownload
|
|
|
|
http = _build_thread_http(credentials)
|
|
request = service.files().get_media(fileId=file_id)
|
|
request.http = http
|
|
with open(dest_path, "wb") as fh:
|
|
downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize)
|
|
done = False
|
|
while not done:
|
|
_, done = downloader.next_chunk()
|
|
return None
|
|
except HttpError as e:
|
|
return f"HTTP error downloading file: {e.resp.status}"
|
|
except Exception as e:
|
|
return f"Error downloading file: {e!s}"
|
|
finally:
|
|
logger.info(
|
|
f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
|
|
)
|
|
|
|
async def download_file_to_disk(
|
|
self,
|
|
file_id: str,
|
|
dest_path: str,
|
|
chunksize: int = 5 * 1024 * 1024,
|
|
) -> str | None:
|
|
"""Stream file directly to disk in chunks, avoiding full in-memory buffering.
|
|
|
|
Returns error message on failure, None on success.
|
|
"""
|
|
service = await self.get_service()
|
|
return await asyncio.to_thread(
|
|
self._sync_download_file_to_disk,
|
|
service,
|
|
file_id,
|
|
dest_path,
|
|
chunksize,
|
|
self._resolved_credentials,
|
|
)
|
|
|
|
@staticmethod
|
|
def _sync_export_google_file(
|
|
service,
|
|
file_id: str,
|
|
mime_type: str,
|
|
credentials: Credentials,
|
|
) -> tuple[bytes | None, str | None]:
|
|
"""Blocking export — runs on a worker thread via ``to_thread``."""
|
|
thread = threading.current_thread().name
|
|
t0 = time.monotonic()
|
|
logger.info(f"[export] START file={file_id} thread={thread}")
|
|
try:
|
|
http = _build_thread_http(credentials)
|
|
content = (
|
|
service.files()
|
|
.export(fileId=file_id, mimeType=mime_type)
|
|
.execute(http=http)
|
|
)
|
|
if not isinstance(content, bytes):
|
|
content = content.encode("utf-8")
|
|
return content, None
|
|
except HttpError as e:
|
|
return None, f"HTTP error exporting file: {e.resp.status}"
|
|
except Exception as e:
|
|
return None, f"Error exporting file: {e!s}"
|
|
finally:
|
|
logger.info(
|
|
f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
|
|
)
|
|
|
|
async def export_google_file(
|
|
self, file_id: str, mime_type: str
|
|
) -> tuple[bytes | None, str | None]:
|
|
"""
|
|
Export Google Workspace file to specified format.
|
|
|
|
Args:
|
|
file_id: ID of the Google file
|
|
mime_type: Target MIME type (e.g., 'application/pdf', 'text/plain')
|
|
|
|
Returns:
|
|
Tuple of (exported content as bytes, error message)
|
|
"""
|
|
service = await self.get_service()
|
|
return await asyncio.to_thread(
|
|
self._sync_export_google_file,
|
|
service,
|
|
file_id,
|
|
mime_type,
|
|
self._resolved_credentials,
|
|
)
|
|
|
|
async def create_file(
|
|
self,
|
|
name: str,
|
|
mime_type: str,
|
|
parent_folder_id: str | None = None,
|
|
content: str | None = None,
|
|
) -> dict[str, Any]:
|
|
service = await self.get_service()
|
|
|
|
body: dict[str, Any] = {"name": name, "mimeType": mime_type}
|
|
if parent_folder_id:
|
|
body["parents"] = [parent_folder_id]
|
|
|
|
media: MediaIoBaseUpload | None = None
|
|
if content:
|
|
if mime_type == GOOGLE_DOC:
|
|
import markdown as md_lib
|
|
|
|
html = md_lib.markdown(content)
|
|
media = MediaIoBaseUpload(
|
|
io.BytesIO(html.encode("utf-8")),
|
|
mimetype="text/html",
|
|
resumable=False,
|
|
)
|
|
elif mime_type == GOOGLE_SHEET:
|
|
media = MediaIoBaseUpload(
|
|
io.BytesIO(content.encode("utf-8")),
|
|
mimetype="text/csv",
|
|
resumable=False,
|
|
)
|
|
|
|
if media:
|
|
return (
|
|
service.files()
|
|
.create(
|
|
body=body,
|
|
media_body=media,
|
|
fields="id,name,mimeType,webViewLink",
|
|
supportsAllDrives=True,
|
|
)
|
|
.execute()
|
|
)
|
|
|
|
return (
|
|
service.files()
|
|
.create(
|
|
body=body,
|
|
fields="id,name,mimeType,webViewLink",
|
|
supportsAllDrives=True,
|
|
)
|
|
.execute()
|
|
)
|
|
|
|
async def trash_file(self, file_id: str) -> bool:
|
|
service = await self.get_service()
|
|
service.files().update(
|
|
fileId=file_id,
|
|
body={"trashed": True},
|
|
supportsAllDrives=True,
|
|
).execute()
|
|
return True
|