feat: implement retry logic and exponential backoff for S3 operations (#829)

* feat: implement retry logic and exponential backoff for S3 operations

* test: fix librarian mocks after BlobStore async conversion
This commit is contained in:
Het Patel 2026-04-18 16:35:19 +05:30 committed by GitHub
parent b341bf5ea1
commit 9a1b2463b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 179 additions and 21 deletions

View file

@ -0,0 +1,119 @@
import asyncio
import io
import pytest
from unittest.mock import MagicMock, patch, AsyncMock
from uuid import uuid4
from minio.error import S3Error
from trustgraph.librarian.blob_store import BlobStore
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_blob_store():
"""Create a BlobStore with mocked Minio client."""
mock_minio = MagicMock()
with patch('trustgraph.librarian.blob_store.Minio', return_value=mock_minio):
# Prevent ensure_bucket from making network calls during init
with patch('trustgraph.librarian.blob_store.BlobStore.ensure_bucket'):
store = BlobStore(
endpoint="localhost:9000",
access_key="access",
secret_key="secret",
bucket_name="test-bucket"
)
return store, mock_minio
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_add_success_no_retry():
store, mock_minio = _make_blob_store()
object_id = uuid4()
await store.add(object_id, b"data", "text/plain")
mock_minio.put_object.assert_called_once()
@pytest.mark.asyncio
async def test_retry_recovery_on_transient_failure():
store, mock_minio = _make_blob_store()
store.base_delay = 0 # Disable delay for fast tests
# Fail twice, succeed third time
mock_minio.put_object.side_effect = [
Exception("Error 1"),
Exception("Error 2"),
MagicMock()
]
await store.add(uuid4(), b"data", "text/plain")
assert mock_minio.put_object.call_count == 3
@pytest.mark.asyncio
async def test_retry_exhaustion_after_8_attempts():
store, mock_minio = _make_blob_store()
store.base_delay = 0
# Permanent failure
mock_minio.put_object.side_effect = Exception("Permanent failure")
with pytest.raises(Exception, match="Permanent failure"):
await store.add(uuid4(), b"data", "text/plain")
# Author requirement: exactly 8 attempts
assert mock_minio.put_object.call_count == 8
@pytest.mark.asyncio
async def test_s3_error_triggers_retry():
store, mock_minio = _make_blob_store()
store.base_delay = 0
# Mock S3Error
s3_err = S3Error("code", "msg", "res", "req", "host", None)
mock_minio.get_object.side_effect = [s3_err, MagicMock()]
await store.get(uuid4())
assert mock_minio.get_object.call_count == 2
@pytest.mark.asyncio
async def test_exponential_backoff_delays():
store, mock_minio = _make_blob_store()
# Use real base_delay to check math
store.base_delay = 0.25
# Correct method name is stat_object, not get_size
mock_minio.stat_object = MagicMock(side_effect=Exception("Wait"))
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
with pytest.raises(Exception):
await store.get_size(uuid4())
# Should have 7 sleep calls for 8 attempts
assert mock_sleep.call_count == 7
# Check actual sleep durations: 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0
sleep_args = [call[0][0] for call in mock_sleep.call_args_list]
assert sleep_args == [0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]
@pytest.mark.asyncio
async def test_runs_in_executor():
"""Verify that synchronous Minio calls are offloaded to an executor."""
store, mock_minio = _make_blob_store()
# Mock response object with .read() method
mock_response = MagicMock()
mock_response.read.return_value = b"result"
with patch('asyncio.get_event_loop') as mock_loop:
mock_loop_instance = MagicMock()
mock_loop.return_value = mock_loop_instance
mock_loop_instance.run_in_executor = AsyncMock(return_value=mock_response)
await store.get(uuid4())
mock_loop_instance.run_in_executor.assert_called_once()

View file

@ -22,6 +22,10 @@ def _make_librarian(min_chunk_size=1):
"""Create a Librarian with mocked blob_store and table_store."""
lib = Librarian.__new__(Librarian)
lib.blob_store = MagicMock()
lib.blob_store.create_multipart_upload = AsyncMock()
lib.blob_store.upload_part = AsyncMock()
lib.blob_store.complete_multipart_upload = AsyncMock()
lib.blob_store.abort_multipart_upload = AsyncMock()
lib.table_store = AsyncMock()
lib.load_document = AsyncMock()
lib.min_chunk_size = min_chunk_size

View file

@ -4,11 +4,12 @@ from .. exceptions import RequestError
from minio import Minio
from minio.datatypes import Part
import time
from minio.error import S3Error
import io
import logging
from typing import Iterator, List, Tuple
from uuid import UUID
import asyncio
# Module logger
logger = logging.getLogger(__name__)
@ -35,8 +36,36 @@ class BlobStore:
protocol = "https" if use_ssl else "http"
logger.info(f"Connected to S3-compatible storage at {protocol}://{endpoint}")
# Retry and Exponential delay configuration
self.max_retries = 8
self.base_delay = 0.25
self.ensure_bucket()
async def _with_retry(self, operation, *args, **kwargs):
"""Execute a minio operation with exponential backoff retry."""
last_exception = None
for attempt in range(self.max_retries):
try:
# Run the synchronous minio call in the default executor to avoid blocking
return await asyncio.get_event_loop().run_in_executor(
None, lambda: operation(*args, **kwargs)
)
except (S3Error, Exception) as e:
last_exception = e
if attempt < self.max_retries - 1:
delay = self.base_delay * (2 ** attempt)
logger.warning(
f"S3 operation failed: {e}. "
f"Retrying in {delay}s... (Attempt {attempt + 1}/{self.max_retries})"
)
await asyncio.sleep(delay)
else:
logger.error(f"S3 operation failed after {self.max_retries} attempts: {e}")
if last_exception:
raise last_exception
def ensure_bucket(self):
# Make the bucket if it doesn't exist.
@ -49,8 +78,8 @@ class BlobStore:
async def add(self, object_id, blob, kind):
# FIXME: Loop retry
self.client.put_object(
await self._with_retry(
self.client.put_object,
bucket_name = self.bucket_name,
object_name = "doc/" + str(object_id),
length = len(blob),
@ -62,8 +91,8 @@ class BlobStore:
async def remove(self, object_id):
# FIXME: Loop retry
self.client.remove_object(
await self._with_retry(
self.client.remove_object,
bucket_name = self.bucket_name,
object_name = "doc/" + str(object_id),
)
@ -73,8 +102,8 @@ class BlobStore:
async def get(self, object_id):
# FIXME: Loop retry
resp = self.client.get_object(
resp = await self._with_retry(
self.client.get_object,
bucket_name = self.bucket_name,
object_name = "doc/" + str(object_id),
)
@ -83,7 +112,8 @@ class BlobStore:
async def get_range(self, object_id, offset: int, length: int) -> bytes:
"""Fetch a specific byte range from an object."""
resp = self.client.get_object(
resp = await self._with_retry(
self.client.get_object,
bucket_name=self.bucket_name,
object_name="doc/" + str(object_id),
offset=offset,
@ -97,7 +127,8 @@ class BlobStore:
async def get_size(self, object_id) -> int:
"""Get the size of an object without downloading it."""
stat = self.client.stat_object(
stat = await self._with_retry(
self.client.stat_object,
bucket_name=self.bucket_name,
object_name="doc/" + str(object_id),
)
@ -134,7 +165,7 @@ class BlobStore:
logger.debug("Stream complete")
def create_multipart_upload(self, object_id: UUID, kind: str) -> str:
async def create_multipart_upload(self, object_id: UUID, kind: str) -> str:
"""
Initialize a multipart upload.
@ -148,7 +179,8 @@ class BlobStore:
object_name = "doc/" + str(object_id)
# Use minio's internal method to create multipart upload
upload_id = self.client._create_multipart_upload(
upload_id = await self._with_retry(
self.client._create_multipart_upload,
bucket_name=self.bucket_name,
object_name=object_name,
headers={"Content-Type": kind},
@ -157,7 +189,7 @@ class BlobStore:
logger.info(f"Created multipart upload {upload_id} for {object_id}")
return upload_id
def upload_part(
async def upload_part(
self,
object_id: UUID,
upload_id: str,
@ -178,7 +210,8 @@ class BlobStore:
"""
object_name = "doc/" + str(object_id)
etag = self.client._upload_part(
etag = await self._with_retry(
self.client._upload_part,
bucket_name=self.bucket_name,
object_name=object_name,
data=data,
@ -190,7 +223,7 @@ class BlobStore:
logger.debug(f"Uploaded part {part_number} for {object_id}, etag={etag}")
return etag
def complete_multipart_upload(
async def complete_multipart_upload(
self,
object_id: UUID,
upload_id: str,
@ -214,7 +247,8 @@ class BlobStore:
for part_number, etag in parts
]
self.client._complete_multipart_upload(
await self._with_retry(
self.client._complete_multipart_upload,
bucket_name=self.bucket_name,
object_name=object_name,
upload_id=upload_id,
@ -223,7 +257,7 @@ class BlobStore:
logger.info(f"Completed multipart upload for {object_id}")
def abort_multipart_upload(self, object_id: UUID, upload_id: str) -> None:
async def abort_multipart_upload(self, object_id: UUID, upload_id: str) -> None:
"""
Abort a multipart upload, cleaning up any uploaded parts.
@ -233,7 +267,8 @@ class BlobStore:
"""
object_name = "doc/" + str(object_id)
self.client._abort_multipart_upload(
await self._with_retry(
self.client._abort_multipart_upload,
bucket_name=self.bucket_name,
object_name=object_name,
upload_id=upload_id,

View file

@ -301,7 +301,7 @@ class Librarian:
object_id = uuid.uuid4()
# Create S3 multipart upload
s3_upload_id = self.blob_store.create_multipart_upload(
s3_upload_id = await self.blob_store.create_multipart_upload(
object_id, request.document_metadata.kind
)
@ -367,7 +367,7 @@ class Librarian:
# Upload to S3 (part numbers are 1-indexed in S3)
part_number = request.chunk_index + 1
etag = self.blob_store.upload_part(
etag = await self.blob_store.upload_part(
object_id=session["object_id"],
upload_id=session["s3_upload_id"],
part_number=part_number,
@ -440,7 +440,7 @@ class Librarian:
]
# Complete S3 multipart upload
self.blob_store.complete_multipart_upload(
await self.blob_store.complete_multipart_upload(
object_id=session["object_id"],
upload_id=session["s3_upload_id"],
parts=parts,
@ -492,7 +492,7 @@ class Librarian:
raise RequestError("Not authorized to abort this upload")
# Abort S3 multipart upload
self.blob_store.abort_multipart_upload(
await self.blob_store.abort_multipart_upload(
object_id=session["object_id"],
upload_id=session["s3_upload_id"],
)