mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Document API updates (#660)
* Doc streaming from librarian * Fix chunk minimum confusion * Add CLI args
This commit is contained in:
parent
a630e143ef
commit
d8f0a576af
3 changed files with 64 additions and 18 deletions
|
|
@ -81,6 +81,28 @@ class BlobStore:
|
|||
|
||||
return resp.read()
|
||||
|
||||
async def get_range(self, object_id, offset: int, length: int) -> bytes:
|
||||
"""Fetch a specific byte range from an object."""
|
||||
resp = self.client.get_object(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name="doc/" + str(object_id),
|
||||
offset=offset,
|
||||
length=length,
|
||||
)
|
||||
try:
|
||||
return resp.read()
|
||||
finally:
|
||||
resp.close()
|
||||
resp.release_conn()
|
||||
|
||||
async def get_size(self, object_id) -> int:
|
||||
"""Get the size of an object without downloading it."""
|
||||
stat = self.client.stat_object(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name="doc/" + str(object_id),
|
||||
)
|
||||
return stat.size
|
||||
|
||||
def get_stream(self, object_id, chunk_size: int = 1024 * 1024) -> Iterator[bytes]:
|
||||
"""
|
||||
Stream document content in chunks.
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ import uuid
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default chunk size for multipart uploads (5MB - S3 minimum)
|
||||
DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024
|
||||
# Default chunk size for multipart uploads
|
||||
DEFAULT_CHUNK_SIZE = 2 * 1024 * 1024 # 2MB default
|
||||
|
||||
class Librarian:
|
||||
|
||||
|
|
@ -27,6 +27,7 @@ class Librarian:
|
|||
object_store_endpoint, object_store_access_key, object_store_secret_key,
|
||||
bucket_name, keyspace, load_document,
|
||||
object_store_use_ssl=False, object_store_region=None,
|
||||
min_chunk_size=1, # Default: no minimum (for Garage)
|
||||
):
|
||||
|
||||
self.blob_store = BlobStore(
|
||||
|
|
@ -39,6 +40,7 @@ class Librarian:
|
|||
)
|
||||
|
||||
self.load_document = load_document
|
||||
self.min_chunk_size = min_chunk_size
|
||||
|
||||
async def add_document(self, request):
|
||||
|
||||
|
|
@ -289,10 +291,12 @@ class Librarian:
|
|||
if total_size <= 0:
|
||||
raise RequestError("total_size must be positive")
|
||||
|
||||
# Use provided chunk size or default (minimum 5MB for S3)
|
||||
# Use provided chunk size or default
|
||||
chunk_size = request.chunk_size if request.chunk_size > 0 else DEFAULT_CHUNK_SIZE
|
||||
if chunk_size < DEFAULT_CHUNK_SIZE:
|
||||
chunk_size = DEFAULT_CHUNK_SIZE
|
||||
if chunk_size < self.min_chunk_size:
|
||||
raise RequestError(
|
||||
f"Chunk size {chunk_size} is below minimum {self.min_chunk_size}"
|
||||
)
|
||||
|
||||
# Calculate total chunks
|
||||
total_chunks = math.ceil(total_size / chunk_size)
|
||||
|
|
@ -657,19 +661,21 @@ class Librarian:
|
|||
"""
|
||||
logger.debug(f"Streaming document {request.document_id}, chunk {request.chunk_index}")
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1024 * 1024 # 1MB default
|
||||
|
||||
chunk_size = request.chunk_size if request.chunk_size > 0 else DEFAULT_CHUNK_SIZE
|
||||
if chunk_size < self.min_chunk_size:
|
||||
raise RequestError(
|
||||
f"Chunk size {chunk_size} is below minimum {self.min_chunk_size}"
|
||||
)
|
||||
|
||||
object_id = await self.table_store.get_document_object_id(
|
||||
request.user,
|
||||
request.document_id
|
||||
)
|
||||
|
||||
# Default chunk size of 1MB
|
||||
chunk_size = request.chunk_size if request.chunk_size > 0 else 1024 * 1024
|
||||
|
||||
# Get the full content and slice out the requested chunk
|
||||
# Note: This is a simple implementation. For true streaming, we'd need
|
||||
# range requests on the object storage.
|
||||
content = await self.blob_store.get(object_id)
|
||||
total_size = len(content)
|
||||
# Get size via stat (no content download)
|
||||
total_size = await self.blob_store.get_size(object_id)
|
||||
total_chunks = math.ceil(total_size / chunk_size)
|
||||
|
||||
if request.chunk_index >= total_chunks:
|
||||
|
|
@ -678,12 +684,15 @@ class Librarian:
|
|||
f"document has {total_chunks} chunks"
|
||||
)
|
||||
|
||||
start = request.chunk_index * chunk_size
|
||||
end = min(start + chunk_size, total_size)
|
||||
chunk_content = content[start:end]
|
||||
# Calculate byte range
|
||||
offset = request.chunk_index * chunk_size
|
||||
length = min(chunk_size, total_size - offset)
|
||||
|
||||
# Fetch only the requested range
|
||||
chunk_content = await self.blob_store.get_range(object_id, offset, length)
|
||||
|
||||
logger.debug(f"Returning chunk {request.chunk_index}/{total_chunks}, "
|
||||
f"bytes {start}-{end} of {total_size}")
|
||||
f"bytes {offset}-{offset + length} of {total_size}")
|
||||
|
||||
return LibrarianResponse(
|
||||
error=None,
|
||||
|
|
@ -691,7 +700,7 @@ class Librarian:
|
|||
chunk_index=request.chunk_index,
|
||||
chunks_received=1, # Using as "current chunk" indicator
|
||||
total_chunks=total_chunks,
|
||||
bytes_received=end,
|
||||
bytes_received=offset + length,
|
||||
total_bytes=total_size,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ default_object_store_secret_key = "object-password"
|
|||
default_object_store_use_ssl = False
|
||||
default_object_store_region = None
|
||||
default_cassandra_host = "cassandra"
|
||||
default_min_chunk_size = 1 # No minimum by default (for Garage)
|
||||
|
||||
bucket_name = "library"
|
||||
|
||||
|
|
@ -100,6 +101,11 @@ class Processor(AsyncProcessor):
|
|||
default_object_store_region
|
||||
)
|
||||
|
||||
min_chunk_size = params.get(
|
||||
"min_chunk_size",
|
||||
default_min_chunk_size
|
||||
)
|
||||
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
|
@ -226,6 +232,7 @@ class Processor(AsyncProcessor):
|
|||
load_document = self.load_document,
|
||||
object_store_use_ssl = object_store_use_ssl,
|
||||
object_store_region = object_store_region,
|
||||
min_chunk_size = min_chunk_size,
|
||||
)
|
||||
|
||||
self.collection_manager = CollectionManager(
|
||||
|
|
@ -583,6 +590,14 @@ class Processor(AsyncProcessor):
|
|||
help='Object storage region (optional)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--min-chunk-size',
|
||||
type=int,
|
||||
default=default_min_chunk_size,
|
||||
help=f'Minimum chunk size in bytes for uploads/downloads '
|
||||
f'(default: {default_min_chunk_size})',
|
||||
)
|
||||
|
||||
add_cassandra_args(parser)
|
||||
|
||||
def run():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue