mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-06 20:15:17 +02:00
feat: enhance Google Drive client with improved logging and thread-safe operations
- Added logging to track the start and end of file download and export processes, improving visibility into execution time. - Implemented per-thread HTTP transport for concurrent downloads and exports, ensuring thread safety. - Refactored download and export methods to utilize resolved credentials, enhancing functionality. - Updated unit tests to validate the new threading and logging features, ensuring robust parallel execution.
This commit is contained in:
parent
d2a4b238d7
commit
00934ff462
4 changed files with 65 additions and 8 deletions
|
|
@ -2,9 +2,14 @@
|
|||
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httplib2
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_httplib2 import AuthorizedHttp
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import MediaIoBaseUpload
|
||||
|
|
@ -13,6 +18,14 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from .credentials import get_valid_credentials
|
||||
from .file_types import GOOGLE_DOC, GOOGLE_SHEET
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_thread_http(credentials: Credentials) -> AuthorizedHttp:
|
||||
"""Create a per-thread HTTP transport so concurrent downloads don't share
|
||||
the same ``httplib2.Http`` (which is not thread-safe)."""
|
||||
return AuthorizedHttp(credentials, http=httplib2.Http())
|
||||
|
||||
|
||||
class GoogleDriveClient:
|
||||
"""Client for Google Drive API operations."""
|
||||
|
|
@ -35,6 +48,7 @@ class GoogleDriveClient:
|
|||
self.session = session
|
||||
self.connector_id = connector_id
|
||||
self._credentials = credentials
|
||||
self._resolved_credentials: Credentials | None = None
|
||||
self.service = None
|
||||
self._service_lock = asyncio.Lock()
|
||||
|
||||
|
|
@ -62,6 +76,7 @@ class GoogleDriveClient:
|
|||
credentials = await get_valid_credentials(
|
||||
self.session, self.connector_id
|
||||
)
|
||||
self._resolved_credentials = credentials
|
||||
self.service = build("drive", "v3", credentials=credentials)
|
||||
return self.service
|
||||
except Exception as e:
|
||||
|
|
@ -141,12 +156,19 @@ class GoogleDriveClient:
|
|||
return None, f"Error getting file metadata: {e!s}"
|
||||
|
||||
@staticmethod
|
||||
def _sync_download_file(service, file_id: str) -> tuple[bytes | None, str | None]:
|
||||
def _sync_download_file(
|
||||
service, file_id: str, credentials: Credentials,
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
"""Blocking download — runs on a worker thread via ``to_thread``."""
|
||||
thread = threading.current_thread().name
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[download] START file={file_id} thread={thread}")
|
||||
try:
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
|
||||
http = _build_thread_http(credentials)
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
request.http = http
|
||||
fh = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(fh, request)
|
||||
done = False
|
||||
|
|
@ -157,6 +179,8 @@ class GoogleDriveClient:
|
|||
return None, f"HTTP error downloading file: {e.resp.status}"
|
||||
except Exception as e:
|
||||
return None, f"Error downloading file: {e!s}"
|
||||
finally:
|
||||
logger.info(f"[download] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
|
||||
|
||||
async def download_file(self, file_id: str) -> tuple[bytes | None, str | None]:
|
||||
"""
|
||||
|
|
@ -169,17 +193,25 @@ class GoogleDriveClient:
|
|||
Tuple of (file content bytes, error message)
|
||||
"""
|
||||
service = await self.get_service()
|
||||
return await asyncio.to_thread(self._sync_download_file, service, file_id)
|
||||
return await asyncio.to_thread(
|
||||
self._sync_download_file, service, file_id, self._resolved_credentials,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _sync_download_file_to_disk(
|
||||
service, file_id: str, dest_path: str, chunksize: int,
|
||||
credentials: Credentials,
|
||||
) -> str | None:
|
||||
"""Blocking download-to-disk — runs on a worker thread via ``to_thread``."""
|
||||
thread = threading.current_thread().name
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[download-to-disk] START file={file_id} thread={thread}")
|
||||
try:
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
|
||||
http = _build_thread_http(credentials)
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
request.http = http
|
||||
with open(dest_path, "wb") as fh:
|
||||
downloader = MediaIoBaseDownload(fh, request, chunksize=chunksize)
|
||||
done = False
|
||||
|
|
@ -190,6 +222,8 @@ class GoogleDriveClient:
|
|||
return f"HTTP error downloading file: {e.resp.status}"
|
||||
except Exception as e:
|
||||
return f"Error downloading file: {e!s}"
|
||||
finally:
|
||||
logger.info(f"[download-to-disk] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
|
||||
|
||||
async def download_file_to_disk(
|
||||
self, file_id: str, dest_path: str, chunksize: int = 5 * 1024 * 1024,
|
||||
|
|
@ -200,17 +234,24 @@ class GoogleDriveClient:
|
|||
"""
|
||||
service = await self.get_service()
|
||||
return await asyncio.to_thread(
|
||||
self._sync_download_file_to_disk, service, file_id, dest_path, chunksize,
|
||||
self._sync_download_file_to_disk,
|
||||
service, file_id, dest_path, chunksize, self._resolved_credentials,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _sync_export_google_file(
|
||||
service, file_id: str, mime_type: str,
|
||||
service, file_id: str, mime_type: str, credentials: Credentials,
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
"""Blocking export — runs on a worker thread via ``to_thread``."""
|
||||
thread = threading.current_thread().name
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[export] START file={file_id} thread={thread}")
|
||||
try:
|
||||
http = _build_thread_http(credentials)
|
||||
content = (
|
||||
service.files().export(fileId=file_id, mimeType=mime_type).execute()
|
||||
service.files()
|
||||
.export(fileId=file_id, mimeType=mime_type)
|
||||
.execute(http=http)
|
||||
)
|
||||
if not isinstance(content, bytes):
|
||||
content = content.encode("utf-8")
|
||||
|
|
@ -219,6 +260,8 @@ class GoogleDriveClient:
|
|||
return None, f"HTTP error exporting file: {e.resp.status}"
|
||||
except Exception as e:
|
||||
return None, f"Error exporting file: {e!s}"
|
||||
finally:
|
||||
logger.info(f"[export] END file={file_id} thread={thread} elapsed={time.monotonic() - t0:.2f}s")
|
||||
|
||||
async def export_google_file(
|
||||
self, file_id: str, mime_type: str
|
||||
|
|
@ -236,6 +279,7 @@ class GoogleDriveClient:
|
|||
service = await self.get_service()
|
||||
return await asyncio.to_thread(
|
||||
self._sync_export_google_file, service, file_id, mime_type,
|
||||
self._resolved_credentials,
|
||||
)
|
||||
|
||||
async def create_file(
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import asyncio
|
|||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -119,7 +121,10 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
|||
)
|
||||
if stt_service_type == "local":
|
||||
from app.services.stt_service import stt_service
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[local-stt] START file={filename} thread={threading.current_thread().name}")
|
||||
result = await asyncio.to_thread(stt_service.transcribe_file, file_path)
|
||||
logger.info(f"[local-stt] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
|
||||
text = result.get("text", "")
|
||||
else:
|
||||
with open(file_path, "rb") as audio_file:
|
||||
|
|
@ -171,7 +176,10 @@ async def _parse_file_to_markdown(file_path: str, filename: str) -> str:
|
|||
from docling.document_converter import DocumentConverter
|
||||
|
||||
converter = DocumentConverter()
|
||||
t0 = time.monotonic()
|
||||
logger.info(f"[docling] START file={filename} thread={threading.current_thread().name}")
|
||||
result = await asyncio.to_thread(converter.convert, file_path)
|
||||
logger.info(f"[docling] END file={filename} elapsed={time.monotonic() - t0:.2f}s")
|
||||
return result.document.export_to_markdown()
|
||||
|
||||
raise RuntimeError(f"Unknown ETL_SERVICE: {app_config.ETL_SERVICE}")
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ async def _download_files_parallel(
|
|||
search_space_id: int,
|
||||
user_id: str,
|
||||
enable_summary: bool,
|
||||
max_concurrency: int = 5,
|
||||
max_concurrency: int = 3,
|
||||
on_heartbeat: HeartbeatCallbackType | None = None,
|
||||
) -> tuple[list[ConnectorDocument], int]:
|
||||
"""Download and ETL files in parallel, returning ConnectorDocuments.
|
||||
|
|
@ -219,6 +219,9 @@ async def _download_files_parallel(
|
|||
drive_client, file
|
||||
)
|
||||
if error or not markdown:
|
||||
file_name = file.get("name", "Unknown")
|
||||
reason = error or "empty content"
|
||||
logger.warning(f"Download/ETL failed for {file_name}: {reason}")
|
||||
return None
|
||||
doc = _build_connector_doc(
|
||||
file,
|
||||
|
|
|
|||
|
|
@ -605,12 +605,13 @@ async def test_client_download_file_runs_in_thread_parallel():
|
|||
BLOCK_SECONDS = 0.2
|
||||
NUM_CALLS = 3
|
||||
|
||||
def _blocking_download(service, file_id):
|
||||
def _blocking_download(service, file_id, credentials):
|
||||
time.sleep(BLOCK_SECONDS)
|
||||
return b"fake-content", None
|
||||
|
||||
client = GoogleDriveClient.__new__(GoogleDriveClient)
|
||||
client.service = MagicMock()
|
||||
client._resolved_credentials = MagicMock()
|
||||
client._service_lock = asyncio.Lock()
|
||||
|
||||
with patch.object(
|
||||
|
|
@ -640,12 +641,13 @@ async def test_client_export_google_file_runs_in_thread_parallel():
|
|||
BLOCK_SECONDS = 0.2
|
||||
NUM_CALLS = 3
|
||||
|
||||
def _blocking_export(service, file_id, mime_type):
|
||||
def _blocking_export(service, file_id, mime_type, credentials):
|
||||
time.sleep(BLOCK_SECONDS)
|
||||
return b"exported", None
|
||||
|
||||
client = GoogleDriveClient.__new__(GoogleDriveClient)
|
||||
client.service = MagicMock()
|
||||
client._resolved_credentials = MagicMock()
|
||||
client._service_lock = asyncio.Lock()
|
||||
|
||||
with patch.object(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue