From adea97620307b5e249dd9c16f1fbd33000c333eb Mon Sep 17 00:00:00 2001 From: Het Patel <102606191+CuriousHet@users.noreply.github.com> Date: Sat, 18 Apr 2026 16:35:19 +0530 Subject: [PATCH] feat: implement retry logic and exponential backoff for S3 operations (#829) * feat: implement retry logic and exponential backoff for S3 operations * test: fix librarian mocks after BlobStore async conversion --- tests/unit/test_librarian/test_blob_store.py | 119 ++++++++++++++++++ .../test_librarian/test_chunked_upload.py | 4 + .../trustgraph/librarian/blob_store.py | 69 +++++++--- .../trustgraph/librarian/librarian.py | 8 +- 4 files changed, 179 insertions(+), 21 deletions(-) create mode 100644 tests/unit/test_librarian/test_blob_store.py diff --git a/tests/unit/test_librarian/test_blob_store.py b/tests/unit/test_librarian/test_blob_store.py new file mode 100644 index 00000000..b2b1541b --- /dev/null +++ b/tests/unit/test_librarian/test_blob_store.py @@ -0,0 +1,119 @@ +import asyncio +import io +import pytest +from unittest.mock import MagicMock, patch, AsyncMock +from uuid import uuid4 +from minio.error import S3Error +from trustgraph.librarian.blob_store import BlobStore + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_blob_store(): + """Create a BlobStore with mocked Minio client.""" + mock_minio = MagicMock() + with patch('trustgraph.librarian.blob_store.Minio', return_value=mock_minio): + # Prevent ensure_bucket from making network calls during init + with patch('trustgraph.librarian.blob_store.BlobStore.ensure_bucket'): + store = BlobStore( + endpoint="localhost:9000", + access_key="access", + secret_key="secret", + bucket_name="test-bucket" + ) + return store, mock_minio + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_add_success_no_retry(): + store, mock_minio = _make_blob_store() + object_id = uuid4() + + await store.add(object_id, b"data", "text/plain") + + mock_minio.put_object.assert_called_once() + +@pytest.mark.asyncio +async def test_retry_recovery_on_transient_failure(): + store, mock_minio = _make_blob_store() + store.base_delay = 0 # Disable delay for fast tests + + # Fail twice, succeed third time + mock_minio.put_object.side_effect = [ + Exception("Error 1"), + Exception("Error 2"), + MagicMock() + ] + + await store.add(uuid4(), b"data", "text/plain") + + assert mock_minio.put_object.call_count == 3 + +@pytest.mark.asyncio +async def test_retry_exhaustion_after_8_attempts(): + store, mock_minio = _make_blob_store() + store.base_delay = 0 + + # Permanent failure + mock_minio.put_object.side_effect = Exception("Permanent failure") + + with pytest.raises(Exception, match="Permanent failure"): + await store.add(uuid4(), b"data", "text/plain") + + # Author requirement: exactly 8 attempts + assert mock_minio.put_object.call_count == 8 + +@pytest.mark.asyncio +async def test_s3_error_triggers_retry(): + store, mock_minio = _make_blob_store() + store.base_delay = 0 + + # Mock S3Error + s3_err = S3Error("code", "msg", "res", "req", "host", None) + mock_minio.get_object.side_effect = [s3_err, MagicMock()] + + await store.get(uuid4()) + + assert mock_minio.get_object.call_count == 2 + +@pytest.mark.asyncio +async def test_exponential_backoff_delays(): + store, mock_minio = _make_blob_store() + # Use real base_delay to check math + store.base_delay = 0.25 + + # Correct method name is stat_object, not get_size + mock_minio.stat_object = MagicMock(side_effect=Exception("Wait")) + + with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep: + with pytest.raises(Exception): + await store.get_size(uuid4()) + + # Should have 7 sleep calls for 8 attempts + assert mock_sleep.call_count == 7 + + # Check actual sleep durations: 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0 + sleep_args = [call[0][0] for call in mock_sleep.call_args_list] + assert sleep_args == [0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0] + +@pytest.mark.asyncio +async def test_runs_in_executor(): + """Verify that synchronous Minio calls are offloaded to an executor.""" + store, mock_minio = _make_blob_store() + + # Mock response object with .read() method + mock_response = MagicMock() + mock_response.read.return_value = b"result" + + with patch('asyncio.get_event_loop') as mock_loop: + mock_loop_instance = MagicMock() + mock_loop.return_value = mock_loop_instance + mock_loop_instance.run_in_executor = AsyncMock(return_value=mock_response) + + await store.get(uuid4()) + + mock_loop_instance.run_in_executor.assert_called_once() diff --git a/tests/unit/test_librarian/test_chunked_upload.py b/tests/unit/test_librarian/test_chunked_upload.py index d0b73d48..eef83e1e 100644 --- a/tests/unit/test_librarian/test_chunked_upload.py +++ b/tests/unit/test_librarian/test_chunked_upload.py @@ -22,6 +22,10 @@ def _make_librarian(min_chunk_size=1): """Create a Librarian with mocked blob_store and table_store.""" lib = Librarian.__new__(Librarian) lib.blob_store = MagicMock() + lib.blob_store.create_multipart_upload = AsyncMock() + lib.blob_store.upload_part = AsyncMock() + lib.blob_store.complete_multipart_upload = AsyncMock() + lib.blob_store.abort_multipart_upload = AsyncMock() lib.table_store = AsyncMock() lib.load_document = AsyncMock() lib.min_chunk_size = min_chunk_size diff --git a/trustgraph-flow/trustgraph/librarian/blob_store.py b/trustgraph-flow/trustgraph/librarian/blob_store.py index d75a7af9..55ed6c61 100644 --- a/trustgraph-flow/trustgraph/librarian/blob_store.py +++ b/trustgraph-flow/trustgraph/librarian/blob_store.py @@ -4,11 +4,12 @@ from .. exceptions import RequestError from minio import Minio from minio.datatypes import Part -import time +from minio.error import S3Error import io import logging from typing import Iterator, List, Tuple from uuid import UUID +import asyncio # Module logger logger = logging.getLogger(__name__) @@ -35,8 +36,36 @@ class BlobStore: protocol = "https" if use_ssl else "http" logger.info(f"Connected to S3-compatible storage at {protocol}://{endpoint}") + # Retry and Exponential delay configuration + self.max_retries = 8 + self.base_delay = 0.25 + self.ensure_bucket() + async def _with_retry(self, operation, *args, **kwargs): + """Execute a minio operation with exponential backoff retry.""" + last_exception = None + for attempt in range(self.max_retries): + try: + # Run the synchronous minio call in the default executor to avoid blocking + return await asyncio.get_event_loop().run_in_executor( + None, lambda: operation(*args, **kwargs) + ) + except (S3Error, Exception) as e: + last_exception = e + if attempt < self.max_retries - 1: + delay = self.base_delay * (2 ** attempt) + logger.warning( + f"S3 operation failed: {e}. " + f"Retrying in {delay}s... (Attempt {attempt + 1}/{self.max_retries})" + ) + await asyncio.sleep(delay) + else: + logger.error(f"S3 operation failed after {self.max_retries} attempts: {e}") + + if last_exception: + raise last_exception + def ensure_bucket(self): # Make the bucket if it doesn't exist. @@ -49,8 +78,8 @@ class BlobStore: async def add(self, object_id, blob, kind): - # FIXME: Loop retry - self.client.put_object( + await self._with_retry( + self.client.put_object, bucket_name = self.bucket_name, object_name = "doc/" + str(object_id), length = len(blob), @@ -62,8 +91,8 @@ class BlobStore: async def remove(self, object_id): - # FIXME: Loop retry - self.client.remove_object( + await self._with_retry( + self.client.remove_object, bucket_name = self.bucket_name, object_name = "doc/" + str(object_id), ) @@ -73,8 +102,8 @@ class BlobStore: async def get(self, object_id): - # FIXME: Loop retry - resp = self.client.get_object( + resp = await self._with_retry( + self.client.get_object, bucket_name = self.bucket_name, object_name = "doc/" + str(object_id), ) @@ -83,7 +112,8 @@ class BlobStore: async def get_range(self, object_id, offset: int, length: int) -> bytes: """Fetch a specific byte range from an object.""" - resp = self.client.get_object( + resp = await self._with_retry( + self.client.get_object, bucket_name=self.bucket_name, object_name="doc/" + str(object_id), offset=offset, @@ -97,7 +127,8 @@ class BlobStore: async def get_size(self, object_id) -> int: """Get the size of an object without downloading it.""" - stat = self.client.stat_object( + stat = await self._with_retry( + self.client.stat_object, bucket_name=self.bucket_name, object_name="doc/" + str(object_id), ) @@ -134,7 +165,7 @@ class BlobStore: logger.debug("Stream complete") - def create_multipart_upload(self, object_id: UUID, kind: str) -> str: + async def create_multipart_upload(self, object_id: UUID, kind: str) -> str: """ Initialize a multipart upload. @@ -148,7 +179,8 @@ class BlobStore: object_name = "doc/" + str(object_id) # Use minio's internal method to create multipart upload - upload_id = self.client._create_multipart_upload( + upload_id = await self._with_retry( + self.client._create_multipart_upload, bucket_name=self.bucket_name, object_name=object_name, headers={"Content-Type": kind}, @@ -157,7 +189,7 @@ class BlobStore: logger.info(f"Created multipart upload {upload_id} for {object_id}") return upload_id - def upload_part( + async def upload_part( self, object_id: UUID, upload_id: str, @@ -178,7 +210,8 @@ class BlobStore: """ object_name = "doc/" + str(object_id) - etag = self.client._upload_part( + etag = await self._with_retry( + self.client._upload_part, bucket_name=self.bucket_name, object_name=object_name, data=data, @@ -190,7 +223,7 @@ class BlobStore: logger.debug(f"Uploaded part {part_number} for {object_id}, etag={etag}") return etag - def complete_multipart_upload( + async def complete_multipart_upload( self, object_id: UUID, upload_id: str, @@ -214,7 +247,8 @@ class BlobStore: for part_number, etag in parts ] - self.client._complete_multipart_upload( + await self._with_retry( + self.client._complete_multipart_upload, bucket_name=self.bucket_name, object_name=object_name, upload_id=upload_id, @@ -223,7 +257,7 @@ class BlobStore: logger.info(f"Completed multipart upload for {object_id}") - def abort_multipart_upload(self, object_id: UUID, upload_id: str) -> None: + async def abort_multipart_upload(self, object_id: UUID, upload_id: str) -> None: """ Abort a multipart upload, cleaning up any uploaded parts. @@ -233,7 +267,8 @@ class BlobStore: """ object_name = "doc/" + str(object_id) - self.client._abort_multipart_upload( + await self._with_retry( + self.client._abort_multipart_upload, bucket_name=self.bucket_name, object_name=object_name, upload_id=upload_id, diff --git a/trustgraph-flow/trustgraph/librarian/librarian.py b/trustgraph-flow/trustgraph/librarian/librarian.py index 2a8ad3a6..77232650 100644 --- a/trustgraph-flow/trustgraph/librarian/librarian.py +++ b/trustgraph-flow/trustgraph/librarian/librarian.py @@ -301,7 +301,7 @@ class Librarian: object_id = uuid.uuid4() # Create S3 multipart upload - s3_upload_id = self.blob_store.create_multipart_upload( + s3_upload_id = await self.blob_store.create_multipart_upload( object_id, request.document_metadata.kind ) @@ -367,7 +367,7 @@ class Librarian: # Upload to S3 (part numbers are 1-indexed in S3) part_number = request.chunk_index + 1 - etag = self.blob_store.upload_part( + etag = await self.blob_store.upload_part( object_id=session["object_id"], upload_id=session["s3_upload_id"], part_number=part_number, @@ -440,7 +440,7 @@ class Librarian: ] # Complete S3 multipart upload - self.blob_store.complete_multipart_upload( + await self.blob_store.complete_multipart_upload( object_id=session["object_id"], upload_id=session["s3_upload_id"], parts=parts, @@ -492,7 +492,7 @@ class Librarian: raise RequestError("Not authorized to abort this upload") # Abort S3 multipart upload - self.blob_store.abort_multipart_upload( + await self.blob_store.abort_multipart_upload( object_id=session["object_id"], upload_id=session["s3_upload_id"], )