mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-05-04 05:12:38 +02:00
- Implemented per-thread HTTP transport for concurrent downloads to ensure thread safety. - Refactored `download_file` and `download_file_to_disk` methods to utilize blocking calls on separate threads, improving performance during file operations. - Added logging to track the start and end of download and export processes, providing better visibility into execution time. - Updated unit tests to verify parallel execution of download and export operations, ensuring efficiency in handling multiple requests.
301 lines
9.8 KiB
Python
301 lines
9.8 KiB
Python
"""Google Drive API client."""
|
|
|
|
import asyncio
|
|
import io
|
|
from typing import Any
|
|
|
|
from google.oauth2.credentials import Credentials
|
|
from googleapiclient.discovery import build
|
|
from googleapiclient.errors import HttpError
|
|
from googleapiclient.http import MediaIoBaseUpload
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from .credentials import get_valid_credentials
|
|
from .file_types import GOOGLE_DOC, GOOGLE_SHEET
|
|
|
|
|
|
class GoogleDriveClient:
|
|
"""Client for Google Drive API operations."""
|
|
|
|
def __init__(
|
|
self,
|
|
session: AsyncSession,
|
|
connector_id: int,
|
|
credentials: "Credentials | None" = None,
|
|
):
|
|
"""
|
|
Initialize Google Drive client.
|
|
|
|
Args:
|
|
session: Database session
|
|
connector_id: ID of the Drive connector
|
|
credentials: Pre-built credentials (e.g. from Composio). If None,
|
|
credentials are loaded from the DB connector config.
|
|
"""
|
|
self.session = session
|
|
self.connector_id = connector_id
|
|
self._credentials = credentials
|
|
self.service = None
|
|
self._service_lock = asyncio.Lock()
|
|
|
|
async def get_service(self):
|
|
"""
|
|
Get or create the Drive service instance.
|
|
|
|
Returns:
|
|
Google Drive service instance
|
|
|
|
Raises:
|
|
Exception: If service creation fails
|
|
"""
|
|
if self.service:
|
|
return self.service
|
|
|
|
async with self._service_lock:
|
|
if self.service:
|
|
return self.service
|
|
|
|
try:
|
|
if self._credentials:
|
|
credentials = self._credentials
|
|
else:
|
|
credentials = await get_valid_credentials(
|
|
self.session, self.connector_id
|
|
)
|
|
self.service = build("drive", "v3", credentials=credentials)
|
|
return self.service
|
|
except Exception as e:
|
|
raise Exception(f"Failed to create Google Drive service: {e!s}") from e
|
|
|
|
async def list_files(
|
|
self,
|
|
query: str = "",
|
|
fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)",
|
|
page_size: int = 100,
|
|
page_token: str | None = None,
|
|
) -> tuple[list[dict[str, Any]], str | None, str | None]:
|
|
"""
|
|
List files from Google Drive with pagination.
|
|
|
|
Args:
|
|
query: Search query (e.g., "mimeType != 'application/vnd.google-apps.folder'")
|
|
fields: Fields to retrieve
|
|
page_size: Number of files per page (max 1000)
|
|
page_token: Token for next page
|
|
|
|
Returns:
|
|
Tuple of (files list, next_page_token, error message)
|
|
"""
|
|
try:
|
|
service = await self.get_service()
|
|
|
|
params = {
|
|
"pageSize": min(page_size, 1000),
|
|
"fields": fields,
|
|
"supportsAllDrives": True,
|
|
"includeItemsFromAllDrives": True,
|
|
}
|
|
|
|
if query:
|
|
params["q"] = query
|
|
if page_token:
|
|
params["pageToken"] = page_token
|
|
|
|
result = service.files().list(**params).execute()
|
|
|
|
files = result.get("files", [])
|
|
next_token = result.get("nextPageToken")
|
|
|
|
return files, next_token, None
|
|
|
|
except HttpError as e:
|
|
error_msg = f"HTTP error listing files: {e.resp.status} - {e.error_details}"
|
|
return [], None, error_msg
|
|
except Exception as e:
|
|
return [], None, f"Error listing files: {e!s}"
|
|
|
|
async def get_file_metadata(
|
|
self, file_id: str, fields: str = "*"
|
|
) -> tuple[dict[str, Any] | None, str | None]:
|
|
"""
|
|
Get metadata for a specific file.
|
|
|
|
Args:
|
|
file_id: ID of the file
|
|
fields: Fields to retrieve
|
|
|
|
Returns:
|
|
Tuple of (file metadata, error message)
|
|
"""
|
|
try:
|
|
service = await self.get_service()
|
|
file = (
|
|
service.files()
|
|
.get(fileId=file_id, fields=fields, supportsAllDrives=True)
|
|
.execute()
|
|
)
|
|
return file, None
|
|
except HttpError as e:
|
|
return None, f"HTTP error getting file metadata: {e.resp.status}"
|
|
except Exception as e:
|
|
return None, f"Error getting file metadata: {e!s}"
|
|
|
|
@staticmethod
|
|
def _sync_download_file(service, file_id: str) -> tuple[bytes | None, str | None]:
|
|
"""Blocking download — runs on a worker thread via ``to_thread``."""
|
|
try:
|
|
from googleapiclient.http import MediaIoBaseDownload
|
|
|
|
request = service.files().get_media(fileId=file_id)
|
|
fh = io.BytesIO()
|
|
downloader = MediaIoBaseDownload(fh, request)
|
|
done = False
|
|
while not done:
|
|
_, done = downloader.next_chunk()
|
|
return fh.getvalue(), None
|
|
except HttpError as e:
|
|
return None, f"HTTP error downloading file: {e.resp.status}"
|
|
except Exception as e:
|
|
return None, f"Error downloading file: {e!s}"
|
|
|
|
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
|
|
"""
|
|
Download binary file content.
|
|
|
|
Args:
|
|
file_id: ID of the file to download
|
|
|
|
Returns:
|
|
Tuple of (file content bytes, error message)
|
|
"""
|
|
service = await self.get_service()
|
|
return await asyncio.to_thread(self._sync_download_file, service, file_id)
|
|
|
|
@staticmethod
|
|
def _sync_download_file_to_disk(
|
|
service, file_id: str, dest_path: str, chunksize: int,
|
|
) -> str | None:
|
|
"""Blocking download-to-disk — runs on a worker thread via ``to_thread``."""
|
|
try:
|
|
from googleapiclient.http import MediaIoBaseDownload
|
|
|
|
request = service.files().get_media(fileId=file_id)
|
|
with open(dest_path, "wb") as fh:
|
|
downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize)
|
|
done = False
|
|
while not done:
|
|
_, done = downloader.next_chunk()
|
|
return None
|
|
except HttpError as e:
|
|
return f"HTTP error downloading file: {e.resp.status}"
|
|
except Exception as e:
|
|
return f"Error downloading file: {e!s}"
|
|
|
|
async def download_file_to_disk(
|
|
self, file_id: str, dest_path: str, chunksize: int = 5 * 1024 * 1024,
|
|
) -> str | None:
|
|
"""Stream file directly to disk in chunks, avoiding full in-memory buffering.
|
|
|
|
Returns error message on failure, None on success.
|
|
"""
|
|
service = await self.get_service()
|
|
return await asyncio.to_thread(
|
|
self._sync_download_file_to_disk, service, file_id, dest_path, chunksize,
|
|
)
|
|
|
|
@staticmethod
|
|
def _sync_export_google_file(
|
|
service, file_id: str, mime_type: str,
|
|
) -> tuple[bytes | None, str | None]:
|
|
"""Blocking export — runs on a worker thread via ``to_thread``."""
|
|
try:
|
|
content = (
|
|
service.files().export(fileId=file_id, mimeType=mime_type).execute()
|
|
)
|
|
if not isinstance(content, bytes):
|
|
content = content.encode("utf-8")
|
|
return content, None
|
|
except HttpError as e:
|
|
return None, f"HTTP error exporting file: {e.resp.status}"
|
|
except Exception as e:
|
|
return None, f"Error exporting file: {e!s}"
|
|
|
|
async def export_google_file(
|
|
self, file_id: str, mime_type: str
|
|
) -> tuple[bytes | None, str | None]:
|
|
"""
|
|
Export Google Workspace file to specified format.
|
|
|
|
Args:
|
|
file_id: ID of the Google file
|
|
mime_type: Target MIME type (e.g., 'application/pdf', 'text/plain')
|
|
|
|
Returns:
|
|
Tuple of (exported content as bytes, error message)
|
|
"""
|
|
service = await self.get_service()
|
|
return await asyncio.to_thread(
|
|
self._sync_export_google_file, service, file_id, mime_type,
|
|
)
|
|
|
|
async def create_file(
|
|
self,
|
|
name: str,
|
|
mime_type: str,
|
|
parent_folder_id: str | None = None,
|
|
content: str | None = None,
|
|
) -> dict[str, Any]:
|
|
service = await self.get_service()
|
|
|
|
body: dict[str, Any] = {"name": name, "mimeType": mime_type}
|
|
if parent_folder_id:
|
|
body["parents"] = [parent_folder_id]
|
|
|
|
media: MediaIoBaseUpload | None = None
|
|
if content:
|
|
if mime_type == GOOGLE_DOC:
|
|
import markdown as md_lib
|
|
|
|
html = md_lib.markdown(content)
|
|
media = MediaIoBaseUpload(
|
|
io.BytesIO(html.encode("utf-8")),
|
|
mimetype="text/html",
|
|
resumable=False,
|
|
)
|
|
elif mime_type == GOOGLE_SHEET:
|
|
media = MediaIoBaseUpload(
|
|
io.BytesIO(content.encode("utf-8")),
|
|
mimetype="text/csv",
|
|
resumable=False,
|
|
)
|
|
|
|
if media:
|
|
return (
|
|
service.files()
|
|
.create(
|
|
body=body,
|
|
media_body=media,
|
|
fields="id,name,mimeType,webViewLink",
|
|
supportsAllDrives=True,
|
|
)
|
|
.execute()
|
|
)
|
|
|
|
return (
|
|
service.files()
|
|
.create(
|
|
body=body,
|
|
fields="id,name,mimeType,webViewLink",
|
|
supportsAllDrives=True,
|
|
)
|
|
.execute()
|
|
)
|
|
|
|
async def trash_file(self, file_id: str) -> bool:
|
|
service = await self.get_service()
|
|
service.files().update(
|
|
fileId=file_id,
|
|
body={"trashed": True},
|
|
supportsAllDrives=True,
|
|
).execute()
|
|
return True
|