SurfSense/surfsense_backend/app/connectors/google_drive/client.py
2026-03-28 16:39:46 -07:00

372 lines
12 KiB
Python

"""Google Drive API client."""
import asyncio
import io
import logging
import threading
import time
from typing import Any
import httplib2
from google.oauth2.credentials import Credentials
from google_auth_httplib2 import AuthorizedHttp
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseUpload
from sqlalchemy.ext.asyncio import AsyncSession
from .credentials import get_valid_credentials
from .file_types import GOOGLE_DOC, GOOGLE_SHEET
logger = logging.getLogger(__name__)
def _build_thread_http(credentials: Credentials) -> AuthorizedHttp:
"""Create a per-thread HTTP transport so concurrent downloads don't share
the same ``httplib2.Http`` (which is not thread-safe)."""
return AuthorizedHttp(credentials, http=httplib2.Http())
class GoogleDriveClient:
"""Client for Google Drive API operations."""
def __init__(
self,
session: AsyncSession,
connector_id: int,
credentials: "Credentials | None" = None,
):
"""
Initialize Google Drive client.
Args:
session: Database session
connector_id: ID of the Drive connector
credentials: Pre-built credentials (e.g. from Composio). If None,
credentials are loaded from the DB connector config.
"""
self.session = session
self.connector_id = connector_id
self._credentials = credentials
self._resolved_credentials: Credentials | None = None
self.service = None
self._service_lock = asyncio.Lock()
async def get_service(self):
"""
Get or create the Drive service instance.
Returns:
Google Drive service instance
Raises:
Exception: If service creation fails
"""
if self.service:
return self.service
async with self._service_lock:
if self.service:
return self.service
try:
if self._credentials:
credentials = self._credentials
else:
credentials = await get_valid_credentials(
self.session, self.connector_id
)
self._resolved_credentials = credentials
self.service = build("drive", "v3", credentials=credentials)
return self.service
except Exception as e:
raise Exception(f"Failed to create Google Drive service: {e!s}") from e
async def list_files(
self,
query: str = "",
fields: str = "nextPageToken, files(id, name, mimeType, modifiedTime, md5Checksum, size, webViewLink, parents, owners, createdTime, description)",
page_size: int = 100,
page_token: str | None = None,
) -> tuple[list[dict[str, Any]], str | None, str | None]:
"""
List files from Google Drive with pagination.
Args:
query: Search query (e.g., "mimeType != 'application/vnd.google-apps.folder'")
fields: Fields to retrieve
page_size: Number of files per page (max 1000)
page_token: Token for next page
Returns:
Tuple of (files list, next_page_token, error message)
"""
try:
service = await self.get_service()
params = {
"pageSize": min(page_size, 1000),
"fields": fields,
"supportsAllDrives": True,
"includeItemsFromAllDrives": True,
}
if query:
params["q"] = query
if page_token:
params["pageToken"] = page_token
result = service.files().list(**params).execute()
files = result.get("files", [])
next_token = result.get("nextPageToken")
return files, next_token, None
except HttpError as e:
error_msg = f"HTTP error listing files: {e.resp.status} - {e.error_details}"
return [], None, error_msg
except Exception as e:
return [], None, f"Error listing files: {e!s}"
async def get_file_metadata(
self, file_id: str, fields: str = "*"
) -> tuple[dict[str, Any] | None, str | None]:
"""
Get metadata for a specific file.
Args:
file_id: ID of the file
fields: Fields to retrieve
Returns:
Tuple of (file metadata, error message)
"""
try:
service = await self.get_service()
file = (
service.files()
.get(fileId=file_id, fields=fields, supportsAllDrives=True)
.execute()
)
return file, None
except HttpError as e:
return None, f"HTTP error getting file metadata: {e.resp.status}"
except Exception as e:
return None, f"Error getting file metadata: {e!s}"
@staticmethod
def _sync_download_file(
service,
file_id: str,
credentials: Credentials,
) -> tuple[bytes | None, str | None]:
"""Blocking download — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[download] START file={file_id} thread={thread}")
try:
from googleapiclient.http import MediaIoBaseDownload
http = _build_thread_http(credentials)
request = service.files().get_media(fileId=file_id)
request.http = http
fh = io.BytesIO()
downloader = MediaIoBaseDownload(fh, request)
done = False
while not done:
_, done = downloader.next_chunk()
return fh.getvalue(), None
except HttpError as e:
return None, f"HTTP error downloading file: {e.resp.status}"
except Exception as e:
return None, f"Error downloading file: {e!s}"
finally:
logger.info(
f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
)
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
"""
Download binary file content.
Args:
file_id: ID of the file to download
Returns:
Tuple of (file content bytes, error message)
"""
service = await self.get_service()
return await asyncio.to_thread(
self._sync_download_file,
service,
file_id,
self._resolved_credentials,
)
@staticmethod
def _sync_download_file_to_disk(
service,
file_id: str,
dest_path: str,
chunksize: int,
credentials: Credentials,
) -> str | None:
"""Blocking download-to-disk — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[download-to-disk] START file={file_id} thread={thread}")
try:
from googleapiclient.http import MediaIoBaseDownload
http = _build_thread_http(credentials)
request = service.files().get_media(fileId=file_id)
request.http = http
with open(dest_path, "wb") as fh:
downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize)
done = False
while not done:
_, done = downloader.next_chunk()
return None
except HttpError as e:
return f"HTTP error downloading file: {e.resp.status}"
except Exception as e:
return f"Error downloading file: {e!s}"
finally:
logger.info(
f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
)
async def download_file_to_disk(
self,
file_id: str,
dest_path: str,
chunksize: int = 5 * 1024 * 1024,
) -> str | None:
"""Stream file directly to disk in chunks, avoiding full in-memory buffering.
Returns error message on failure, None on success.
"""
service = await self.get_service()
return await asyncio.to_thread(
self._sync_download_file_to_disk,
service,
file_id,
dest_path,
chunksize,
self._resolved_credentials,
)
@staticmethod
def _sync_export_google_file(
service,
file_id: str,
mime_type: str,
credentials: Credentials,
) -> tuple[bytes | None, str | None]:
"""Blocking export — runs on a worker thread via ``to_thread``."""
thread = threading.current_thread().name
t0 = time.monotonic()
logger.info(f"[export] START file={file_id} thread={thread}")
try:
http = _build_thread_http(credentials)
content = (
service.files()
.export(fileId=file_id, mimeType=mime_type)
.execute(http=http)
)
if not isinstance(content, bytes):
content = content.encode("utf-8")
return content, None
except HttpError as e:
return None, f"HTTP error exporting file: {e.resp.status}"
except Exception as e:
return None, f"Error exporting file: {e!s}"
finally:
logger.info(
f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s"
)
async def export_google_file(
self, file_id: str, mime_type: str
) -> tuple[bytes | None, str | None]:
"""
Export Google Workspace file to specified format.
Args:
file_id: ID of the Google file
mime_type: Target MIME type (e.g., 'application/pdf', 'text/plain')
Returns:
Tuple of (exported content as bytes, error message)
"""
service = await self.get_service()
return await asyncio.to_thread(
self._sync_export_google_file,
service,
file_id,
mime_type,
self._resolved_credentials,
)
async def create_file(
self,
name: str,
mime_type: str,
parent_folder_id: str | None = None,
content: str | None = None,
) -> dict[str, Any]:
service = await self.get_service()
body: dict[str, Any] = {"name": name, "mimeType": mime_type}
if parent_folder_id:
body["parents"] = [parent_folder_id]
media: MediaIoBaseUpload | None = None
if content:
if mime_type == GOOGLE_DOC:
import markdown as md_lib
html = md_lib.markdown(content)
media = MediaIoBaseUpload(
io.BytesIO(html.encode("utf-8")),
mimetype="text/html",
resumable=False,
)
elif mime_type == GOOGLE_SHEET:
media = MediaIoBaseUpload(
io.BytesIO(content.encode("utf-8")),
mimetype="text/csv",
resumable=False,
)
if media:
return (
service.files()
.create(
body=body,
media_body=media,
fields="id,name,mimeType,webViewLink",
supportsAllDrives=True,
)
.execute()
)
return (
service.files()
.create(
body=body,
fields="id,name,mimeType,webViewLink",
supportsAllDrives=True,
)
.execute()
)
async def trash_file(self, file_id: str) -> bool:
service = await self.get_service()
service.files().update(
fileId=file_id,
body={"trashed": True},
supportsAllDrives=True,
).execute()
return True