mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
feat: implement retry logic and exponential backoff for S3 operations (#829)
* feat: implement retry logic and exponential backoff for S3 operations * test: fix librarian mocks after BlobStore async conversion
This commit is contained in:
parent
b341bf5ea1
commit
9a1b2463b6
4 changed files with 179 additions and 21 deletions
119
tests/unit/test_librarian/test_blob_store.py
Normal file
119
tests/unit/test_librarian/test_blob_store.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import asyncio
|
||||
import io
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from uuid import uuid4
|
||||
from minio.error import S3Error
|
||||
from trustgraph.librarian.blob_store import BlobStore
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_blob_store():
|
||||
"""Create a BlobStore with mocked Minio client."""
|
||||
mock_minio = MagicMock()
|
||||
with patch('trustgraph.librarian.blob_store.Minio', return_value=mock_minio):
|
||||
# Prevent ensure_bucket from making network calls during init
|
||||
with patch('trustgraph.librarian.blob_store.BlobStore.ensure_bucket'):
|
||||
store = BlobStore(
|
||||
endpoint="localhost:9000",
|
||||
access_key="access",
|
||||
secret_key="secret",
|
||||
bucket_name="test-bucket"
|
||||
)
|
||||
return store, mock_minio
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_success_no_retry():
|
||||
store, mock_minio = _make_blob_store()
|
||||
object_id = uuid4()
|
||||
|
||||
await store.add(object_id, b"data", "text/plain")
|
||||
|
||||
mock_minio.put_object.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_recovery_on_transient_failure():
|
||||
store, mock_minio = _make_blob_store()
|
||||
store.base_delay = 0 # Disable delay for fast tests
|
||||
|
||||
# Fail twice, succeed third time
|
||||
mock_minio.put_object.side_effect = [
|
||||
Exception("Error 1"),
|
||||
Exception("Error 2"),
|
||||
MagicMock()
|
||||
]
|
||||
|
||||
await store.add(uuid4(), b"data", "text/plain")
|
||||
|
||||
assert mock_minio.put_object.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_exhaustion_after_8_attempts():
|
||||
store, mock_minio = _make_blob_store()
|
||||
store.base_delay = 0
|
||||
|
||||
# Permanent failure
|
||||
mock_minio.put_object.side_effect = Exception("Permanent failure")
|
||||
|
||||
with pytest.raises(Exception, match="Permanent failure"):
|
||||
await store.add(uuid4(), b"data", "text/plain")
|
||||
|
||||
# Author requirement: exactly 8 attempts
|
||||
assert mock_minio.put_object.call_count == 8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3_error_triggers_retry():
|
||||
store, mock_minio = _make_blob_store()
|
||||
store.base_delay = 0
|
||||
|
||||
# Mock S3Error
|
||||
s3_err = S3Error("code", "msg", "res", "req", "host", None)
|
||||
mock_minio.get_object.side_effect = [s3_err, MagicMock()]
|
||||
|
||||
await store.get(uuid4())
|
||||
|
||||
assert mock_minio.get_object.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exponential_backoff_delays():
|
||||
store, mock_minio = _make_blob_store()
|
||||
# Use real base_delay to check math
|
||||
store.base_delay = 0.25
|
||||
|
||||
# Correct method name is stat_object, not get_size
|
||||
mock_minio.stat_object = MagicMock(side_effect=Exception("Wait"))
|
||||
|
||||
with patch('asyncio.sleep', new_callable=AsyncMock) as mock_sleep:
|
||||
with pytest.raises(Exception):
|
||||
await store.get_size(uuid4())
|
||||
|
||||
# Should have 7 sleep calls for 8 attempts
|
||||
assert mock_sleep.call_count == 7
|
||||
|
||||
# Check actual sleep durations: 0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0
|
||||
sleep_args = [call[0][0] for call in mock_sleep.call_args_list]
|
||||
assert sleep_args == [0.25, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runs_in_executor():
|
||||
"""Verify that synchronous Minio calls are offloaded to an executor."""
|
||||
store, mock_minio = _make_blob_store()
|
||||
|
||||
# Mock response object with .read() method
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = b"result"
|
||||
|
||||
with patch('asyncio.get_event_loop') as mock_loop:
|
||||
mock_loop_instance = MagicMock()
|
||||
mock_loop.return_value = mock_loop_instance
|
||||
mock_loop_instance.run_in_executor = AsyncMock(return_value=mock_response)
|
||||
|
||||
await store.get(uuid4())
|
||||
|
||||
mock_loop_instance.run_in_executor.assert_called_once()
|
||||
|
|
@ -22,6 +22,10 @@ def _make_librarian(min_chunk_size=1):
|
|||
"""Create a Librarian with mocked blob_store and table_store."""
|
||||
lib = Librarian.__new__(Librarian)
|
||||
lib.blob_store = MagicMock()
|
||||
lib.blob_store.create_multipart_upload = AsyncMock()
|
||||
lib.blob_store.upload_part = AsyncMock()
|
||||
lib.blob_store.complete_multipart_upload = AsyncMock()
|
||||
lib.blob_store.abort_multipart_upload = AsyncMock()
|
||||
lib.table_store = AsyncMock()
|
||||
lib.load_document = AsyncMock()
|
||||
lib.min_chunk_size = min_chunk_size
|
||||
|
|
|
|||
|
|
@ -4,11 +4,12 @@ from .. exceptions import RequestError
|
|||
|
||||
from minio import Minio
|
||||
from minio.datatypes import Part
|
||||
import time
|
||||
from minio.error import S3Error
|
||||
import io
|
||||
import logging
|
||||
from typing import Iterator, List, Tuple
|
||||
from uuid import UUID
|
||||
import asyncio
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -35,8 +36,36 @@ class BlobStore:
|
|||
protocol = "https" if use_ssl else "http"
|
||||
logger.info(f"Connected to S3-compatible storage at {protocol}://{endpoint}")
|
||||
|
||||
# Retry and Exponential delay configuration
|
||||
self.max_retries = 8
|
||||
self.base_delay = 0.25
|
||||
|
||||
self.ensure_bucket()
|
||||
|
||||
async def _with_retry(self, operation, *args, **kwargs):
|
||||
"""Execute a minio operation with exponential backoff retry."""
|
||||
last_exception = None
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Run the synchronous minio call in the default executor to avoid blocking
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: operation(*args, **kwargs)
|
||||
)
|
||||
except (S3Error, Exception) as e:
|
||||
last_exception = e
|
||||
if attempt < self.max_retries - 1:
|
||||
delay = self.base_delay * (2 ** attempt)
|
||||
logger.warning(
|
||||
f"S3 operation failed: {e}. "
|
||||
f"Retrying in {delay}s... (Attempt {attempt + 1}/{self.max_retries})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
logger.error(f"S3 operation failed after {self.max_retries} attempts: {e}")
|
||||
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
|
||||
def ensure_bucket(self):
|
||||
|
||||
# Make the bucket if it doesn't exist.
|
||||
|
|
@ -49,8 +78,8 @@ class BlobStore:
|
|||
|
||||
async def add(self, object_id, blob, kind):
|
||||
|
||||
# FIXME: Loop retry
|
||||
self.client.put_object(
|
||||
await self._with_retry(
|
||||
self.client.put_object,
|
||||
bucket_name = self.bucket_name,
|
||||
object_name = "doc/" + str(object_id),
|
||||
length = len(blob),
|
||||
|
|
@ -62,8 +91,8 @@ class BlobStore:
|
|||
|
||||
async def remove(self, object_id):
|
||||
|
||||
# FIXME: Loop retry
|
||||
self.client.remove_object(
|
||||
await self._with_retry(
|
||||
self.client.remove_object,
|
||||
bucket_name = self.bucket_name,
|
||||
object_name = "doc/" + str(object_id),
|
||||
)
|
||||
|
|
@ -73,8 +102,8 @@ class BlobStore:
|
|||
|
||||
async def get(self, object_id):
|
||||
|
||||
# FIXME: Loop retry
|
||||
resp = self.client.get_object(
|
||||
resp = await self._with_retry(
|
||||
self.client.get_object,
|
||||
bucket_name = self.bucket_name,
|
||||
object_name = "doc/" + str(object_id),
|
||||
)
|
||||
|
|
@ -83,7 +112,8 @@ class BlobStore:
|
|||
|
||||
async def get_range(self, object_id, offset: int, length: int) -> bytes:
|
||||
"""Fetch a specific byte range from an object."""
|
||||
resp = self.client.get_object(
|
||||
resp = await self._with_retry(
|
||||
self.client.get_object,
|
||||
bucket_name=self.bucket_name,
|
||||
object_name="doc/" + str(object_id),
|
||||
offset=offset,
|
||||
|
|
@ -97,7 +127,8 @@ class BlobStore:
|
|||
|
||||
async def get_size(self, object_id) -> int:
|
||||
"""Get the size of an object without downloading it."""
|
||||
stat = self.client.stat_object(
|
||||
stat = await self._with_retry(
|
||||
self.client.stat_object,
|
||||
bucket_name=self.bucket_name,
|
||||
object_name="doc/" + str(object_id),
|
||||
)
|
||||
|
|
@ -134,7 +165,7 @@ class BlobStore:
|
|||
|
||||
logger.debug("Stream complete")
|
||||
|
||||
def create_multipart_upload(self, object_id: UUID, kind: str) -> str:
|
||||
async def create_multipart_upload(self, object_id: UUID, kind: str) -> str:
|
||||
"""
|
||||
Initialize a multipart upload.
|
||||
|
||||
|
|
@ -148,7 +179,8 @@ class BlobStore:
|
|||
object_name = "doc/" + str(object_id)
|
||||
|
||||
# Use minio's internal method to create multipart upload
|
||||
upload_id = self.client._create_multipart_upload(
|
||||
upload_id = await self._with_retry(
|
||||
self.client._create_multipart_upload,
|
||||
bucket_name=self.bucket_name,
|
||||
object_name=object_name,
|
||||
headers={"Content-Type": kind},
|
||||
|
|
@ -157,7 +189,7 @@ class BlobStore:
|
|||
logger.info(f"Created multipart upload {upload_id} for {object_id}")
|
||||
return upload_id
|
||||
|
||||
def upload_part(
|
||||
async def upload_part(
|
||||
self,
|
||||
object_id: UUID,
|
||||
upload_id: str,
|
||||
|
|
@ -178,7 +210,8 @@ class BlobStore:
|
|||
"""
|
||||
object_name = "doc/" + str(object_id)
|
||||
|
||||
etag = self.client._upload_part(
|
||||
etag = await self._with_retry(
|
||||
self.client._upload_part,
|
||||
bucket_name=self.bucket_name,
|
||||
object_name=object_name,
|
||||
data=data,
|
||||
|
|
@ -190,7 +223,7 @@ class BlobStore:
|
|||
logger.debug(f"Uploaded part {part_number} for {object_id}, etag={etag}")
|
||||
return etag
|
||||
|
||||
def complete_multipart_upload(
|
||||
async def complete_multipart_upload(
|
||||
self,
|
||||
object_id: UUID,
|
||||
upload_id: str,
|
||||
|
|
@ -214,7 +247,8 @@ class BlobStore:
|
|||
for part_number, etag in parts
|
||||
]
|
||||
|
||||
self.client._complete_multipart_upload(
|
||||
await self._with_retry(
|
||||
self.client._complete_multipart_upload,
|
||||
bucket_name=self.bucket_name,
|
||||
object_name=object_name,
|
||||
upload_id=upload_id,
|
||||
|
|
@ -223,7 +257,7 @@ class BlobStore:
|
|||
|
||||
logger.info(f"Completed multipart upload for {object_id}")
|
||||
|
||||
def abort_multipart_upload(self, object_id: UUID, upload_id: str) -> None:
|
||||
async def abort_multipart_upload(self, object_id: UUID, upload_id: str) -> None:
|
||||
"""
|
||||
Abort a multipart upload, cleaning up any uploaded parts.
|
||||
|
||||
|
|
@ -233,7 +267,8 @@ class BlobStore:
|
|||
"""
|
||||
object_name = "doc/" + str(object_id)
|
||||
|
||||
self.client._abort_multipart_upload(
|
||||
await self._with_retry(
|
||||
self.client._abort_multipart_upload,
|
||||
bucket_name=self.bucket_name,
|
||||
object_name=object_name,
|
||||
upload_id=upload_id,
|
||||
|
|
|
|||
|
|
@ -301,7 +301,7 @@ class Librarian:
|
|||
object_id = uuid.uuid4()
|
||||
|
||||
# Create S3 multipart upload
|
||||
s3_upload_id = self.blob_store.create_multipart_upload(
|
||||
s3_upload_id = await self.blob_store.create_multipart_upload(
|
||||
object_id, request.document_metadata.kind
|
||||
)
|
||||
|
||||
|
|
@ -367,7 +367,7 @@ class Librarian:
|
|||
|
||||
# Upload to S3 (part numbers are 1-indexed in S3)
|
||||
part_number = request.chunk_index + 1
|
||||
etag = self.blob_store.upload_part(
|
||||
etag = await self.blob_store.upload_part(
|
||||
object_id=session["object_id"],
|
||||
upload_id=session["s3_upload_id"],
|
||||
part_number=part_number,
|
||||
|
|
@ -440,7 +440,7 @@ class Librarian:
|
|||
]
|
||||
|
||||
# Complete S3 multipart upload
|
||||
self.blob_store.complete_multipart_upload(
|
||||
await self.blob_store.complete_multipart_upload(
|
||||
object_id=session["object_id"],
|
||||
upload_id=session["s3_upload_id"],
|
||||
parts=parts,
|
||||
|
|
@ -492,7 +492,7 @@ class Librarian:
|
|||
raise RequestError("Not authorized to abort this upload")
|
||||
|
||||
# Abort S3 multipart upload
|
||||
self.blob_store.abort_multipart_upload(
|
||||
await self.blob_store.abort_multipart_upload(
|
||||
object_id=session["object_id"],
|
||||
upload_id=session["s3_upload_id"],
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue