Incremental / large document loading (#659)

Tech spec

BlobStore (trustgraph-flow/trustgraph/librarian/blob_store.py):
- get_stream() - yields document content in chunks for streaming retrieval
- create_multipart_upload() - initializes S3 multipart upload, returns
  upload_id
- upload_part() - uploads a single part, returns etag
- complete_multipart_upload() - finalizes upload with part etags
- abort_multipart_upload() - cancels and cleans up

Cassandra schema (trustgraph-flow/trustgraph/tables/library.py):
- New upload_session table with 24-hour TTL
- Index on user for listing sessions
- Prepared statements for all operations
- Methods: create_upload_session(), get_upload_session(),
  update_upload_session_chunk(), delete_upload_session(),
  list_upload_sessions()

- Schema extended with UploadSession, UploadProgress, and new
  request/response fields
- Librarian methods: begin_upload, upload_chunk, complete_upload,
  abort_upload, get_upload_status, list_uploads
- Service routing for all new operations
- Python SDK with transparent chunked upload:
  - add_document() auto-switches to chunked for files > 10MB
  - Progress callback support (on_progress)
  - get_pending_uploads(), get_upload_status(), abort_upload(),
    resume_upload()

- Document table: Added parent_id and document_type columns with index
- Document schema (knowledge/document.py): Added document_id field for
  streaming retrieval
- Librarian operations:
  - add-child-document for extracted PDF pages
  - list-children to get child documents
  - stream-document for chunked content retrieval
  - Cascade delete removes children when parent is deleted
  - list-documents filters children by default
- PDF decoder (decoding/pdf/pdf_decoder.py): Updated to stream large
  documents from librarian API to temp file
- Librarian service (librarian/service.py): Sends document_id instead of
  content for large PDFs (>2MB)
- Deprecated tools (load_pdf.py, load_text.py): Added deprecation
  warnings directing users to tg-add-library-document +
  tg-start-library-processing

Remove load_pdf and load_text utils

Move chunker/librarian comms to base class

Updating tests
This commit is contained in:
cybermaggedon 2026-03-04 16:57:58 +00:00 committed by GitHub
parent a38ca9474f
commit a630e143ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 3164 additions and 650 deletions

View file

@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
default_ident = "chunker"
class Processor(ChunkingService):
def __init__(self, **params):
@ -23,7 +24,7 @@ class Processor(ChunkingService):
id = params.get("id", default_ident)
chunk_size = params.get("chunk_size", 2000)
chunk_overlap = params.get("chunk_overlap", 100)
super(Processor, self).__init__(
**params | { "id": id }
)
@ -69,6 +70,9 @@ class Processor(ChunkingService):
v = msg.value()
logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed)
text = await self.get_document_text(v)
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
msg, consumer, flow,
@ -90,9 +94,7 @@ class Processor(ChunkingService):
is_separator_regex=False,
)
texts = text_splitter.create_documents(
[v.text.decode("utf-8")]
)
texts = text_splitter.create_documents([text])
for ix, chunk in enumerate(texts):
@ -133,4 +135,3 @@ class Processor(ChunkingService):
def run():
Processor.launch(default_ident, __doc__)

View file

@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
default_ident = "chunker"
class Processor(ChunkingService):
def __init__(self, **params):
@ -68,6 +69,9 @@ class Processor(ChunkingService):
v = msg.value()
logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed)
text = await self.get_document_text(v)
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
msg, consumer, flow,
@ -88,9 +92,7 @@ class Processor(ChunkingService):
chunk_overlap=chunk_overlap,
)
texts = text_splitter.create_documents(
[v.text.decode("utf-8")]
)
texts = text_splitter.create_documents([text])
for ix, chunk in enumerate(texts):

View file

@ -2,21 +2,34 @@
"""
Simple decoder, accepts PDF documents on input, outputs pages from the
PDF document as text as separate output objects.
Supports both inline document data and fetching from librarian via Pulsar
for large documents.
"""
import asyncio
import os
import tempfile
import base64
import logging
import uuid
from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Metadata
from ... schema import LibrarianRequest, LibrarianResponse
from ... schema import librarian_request_queue, librarian_response_queue
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "pdf-decoder"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor):
def __init__(self, **params):
@ -44,8 +57,97 @@ class Processor(FlowProcessor):
)
)
# Librarian client for fetching document content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor = id, flow = None, name = "librarian-request"
)
self.librarian_request_producer = Producer(
backend = self.pubsub,
topic = librarian_request_q,
schema = LibrarianRequest,
metrics = librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor = id, flow = None, name = "librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub,
flow = None,
topic = librarian_response_q,
subscriber = f"{id}-librarian",
schema = LibrarianResponse,
handler = self.on_librarian_response,
metrics = librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
logger.info("PDF decoder initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def fetch_document_content(self, document_id, user, timeout=120):
"""
Fetch document content from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.content
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching document {document_id}")
async def on_message(self, msg, consumer, flow):
logger.debug("PDF message received")
@ -54,26 +156,53 @@ class Processor(FlowProcessor):
logger.info(f"Decoding PDF {v.metadata.id}...")
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
temp_path = fp.name
fp.write(base64.b64decode(v.data))
fp.close()
# Check if we should fetch from librarian or use inline data
if v.document_id:
# Fetch from librarian via Pulsar
logger.info(f"Fetching document {v.document_id} from librarian...")
fp.close()
with open(fp.name, mode='rb') as f:
content = await self.fetch_document_content(
document_id=v.document_id,
user=v.metadata.user,
)
loader = PyPDFLoader(fp.name)
pages = loader.load()
# Content is base64 encoded
if isinstance(content, str):
content = content.encode('utf-8')
decoded_content = base64.b64decode(content)
for ix, page in enumerate(pages):
with open(temp_path, 'wb') as f:
f.write(decoded_content)
logger.debug(f"Processing page {ix}")
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
else:
# Use inline data (backward compatibility)
fp.write(base64.b64decode(v.data))
fp.close()
r = TextDocument(
metadata=v.metadata,
text=page.page_content.encode("utf-8"),
)
loader = PyPDFLoader(temp_path)
pages = loader.load()
await flow("output").send(r)
for ix, page in enumerate(pages):
logger.debug(f"Processing page {ix}")
r = TextDocument(
metadata=v.metadata,
text=page.page_content.encode("utf-8"),
)
await flow("output").send(r)
# Clean up temp file
try:
os.unlink(temp_path)
except OSError:
pass
logger.debug("PDF decoding complete")
@ -81,7 +210,18 @@ class Processor(FlowProcessor):
def add_args(parser):
FlowProcessor.add_args(parser)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -3,9 +3,12 @@ from .. knowledge import hash
from .. exceptions import RequestError
from minio import Minio
from minio.datatypes import Part
import time
import io
import logging
from typing import Iterator, List, Tuple
from uuid import UUID
# Module logger
logger = logging.getLogger(__name__)
@ -78,3 +81,141 @@ class BlobStore:
return resp.read()
def get_stream(self, object_id, chunk_size: int = 1024 * 1024) -> Iterator[bytes]:
"""
Stream document content in chunks.
Yields chunks of the document, allowing processing without loading
the entire document into memory.
Args:
object_id: The UUID of the document object
chunk_size: Size of each chunk in bytes (default 1MB)
Yields:
Chunks of document content as bytes
"""
resp = self.client.get_object(
bucket_name=self.bucket_name,
object_name="doc/" + str(object_id),
)
try:
while True:
chunk = resp.read(chunk_size)
if not chunk:
break
yield chunk
finally:
resp.close()
resp.release_conn()
logger.debug("Stream complete")
def create_multipart_upload(self, object_id: UUID, kind: str) -> str:
"""
Initialize a multipart upload.
Args:
object_id: The UUID for the new object
kind: MIME type of the document
Returns:
The S3 upload_id for this multipart upload session
"""
object_name = "doc/" + str(object_id)
# Use minio's internal method to create multipart upload
upload_id = self.client._create_multipart_upload(
bucket_name=self.bucket_name,
object_name=object_name,
headers={"Content-Type": kind},
)
logger.info(f"Created multipart upload {upload_id} for {object_id}")
return upload_id
def upload_part(
self,
object_id: UUID,
upload_id: str,
part_number: int,
data: bytes
) -> str:
"""
Upload a single part of a multipart upload.
Args:
object_id: The UUID of the object being uploaded
upload_id: The S3 upload_id from create_multipart_upload
part_number: Part number (1-indexed, as per S3 spec)
data: The chunk data to upload
Returns:
The ETag for this part (needed for complete_multipart_upload)
"""
object_name = "doc/" + str(object_id)
etag = self.client._upload_part(
bucket_name=self.bucket_name,
object_name=object_name,
data=data,
headers={"Content-Length": str(len(data))},
upload_id=upload_id,
part_number=part_number,
)
logger.debug(f"Uploaded part {part_number} for {object_id}, etag={etag}")
return etag
def complete_multipart_upload(
self,
object_id: UUID,
upload_id: str,
parts: List[Tuple[int, str]]
) -> None:
"""
Complete a multipart upload, assembling all parts into the final object.
S3 coalesces the parts server-side - no data transfer through this client.
Args:
object_id: The UUID of the object
upload_id: The S3 upload_id from create_multipart_upload
parts: List of (part_number, etag) tuples in order
"""
object_name = "doc/" + str(object_id)
# Convert to Part objects as expected by minio
part_objects = [
Part(part_number, etag)
for part_number, etag in parts
]
self.client._complete_multipart_upload(
bucket_name=self.bucket_name,
object_name=object_name,
upload_id=upload_id,
parts=part_objects,
)
logger.info(f"Completed multipart upload for {object_id}")
def abort_multipart_upload(self, object_id: UUID, upload_id: str) -> None:
"""
Abort a multipart upload, cleaning up any uploaded parts.
Args:
object_id: The UUID of the object
upload_id: The S3 upload_id from create_multipart_upload
"""
object_name = "doc/" + str(object_id)
self.client._abort_multipart_upload(
bucket_name=self.bucket_name,
object_name=object_name,
upload_id=upload_id,
)
logger.info(f"Aborted multipart upload {upload_id} for {object_id}")

View file

@ -1,17 +1,24 @@
from .. schema import LibrarianRequest, LibrarianResponse, Error, Triple
from .. schema import UploadSession
from .. knowledge import hash
from .. exceptions import RequestError
from .. tables.library import LibraryTableStore
from . blob_store import BlobStore
import base64
import json
import logging
import math
import time
import uuid
# Module logger
logger = logging.getLogger(__name__)
# Default chunk size for multipart uploads (5MB - S3 minimum)
DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024
class Librarian:
def __init__(
@ -66,13 +73,7 @@ class Librarian:
logger.debug("Add complete")
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
return LibrarianResponse()
async def remove_document(self, request):
@ -84,6 +85,21 @@ class Librarian:
):
raise RuntimeError("Document does not exist")
# First, cascade delete all child documents
children = await self.table_store.list_children(request.document_id)
for child in children:
logger.debug(f"Cascade deleting child document {child.id}")
try:
child_object_id = await self.table_store.get_document_object_id(
child.user,
child.id
)
await self.blob_store.remove(child_object_id)
await self.table_store.remove_document(child.user, child.id)
except Exception as e:
logger.warning(f"Failed to delete child document {child.id}: {e}")
# Now remove the parent document
object_id = await self.table_store.get_document_object_id(
request.user,
request.document_id
@ -100,13 +116,7 @@ class Librarian:
logger.debug("Remove complete")
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
return LibrarianResponse()
async def update_document(self, request):
@ -124,13 +134,7 @@ class Librarian:
logger.debug("Update complete")
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
return LibrarianResponse()
async def get_document_metadata(self, request):
@ -147,8 +151,6 @@ class Librarian:
error = None,
document_metadata = doc,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
async def get_document_content(self, request):
@ -170,8 +172,6 @@ class Librarian:
error = None,
document_metadata = None,
content = base64.b64encode(content),
document_metadatas = None,
processing_metadatas = None,
)
async def add_processing(self, request):
@ -217,13 +217,7 @@ class Librarian:
logger.debug("Add complete")
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
return LibrarianResponse()
async def remove_processing(self, request):
@ -243,24 +237,22 @@ class Librarian:
logger.debug("Remove complete")
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
return LibrarianResponse()
async def list_documents(self, request):
docs = await self.table_store.list_documents(request.user)
# Filter out child documents by default unless include_children is True
include_children = getattr(request, 'include_children', False)
if not include_children:
docs = [
doc for doc in docs
if not doc.parent_id # Only include top-level documents
]
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = docs,
processing_metadatas = None,
)
async def list_processing(self, request):
@ -268,10 +260,438 @@ class Librarian:
procs = await self.table_store.list_processing(request.user)
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = procs,
)
# Chunked upload operations
async def begin_upload(self, request):
"""
Initialize a chunked upload session.
Creates an S3 multipart upload and stores session state in Cassandra.
"""
logger.info(f"Beginning chunked upload for document {request.document_metadata.id}")
if request.document_metadata.kind not in ("text/plain", "application/pdf"):
raise RequestError(
"Invalid document kind: " + request.document_metadata.kind
)
if await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.id
):
raise RequestError("Document already exists")
# Validate sizes
total_size = request.total_size
if total_size <= 0:
raise RequestError("total_size must be positive")
# Use provided chunk size or default (minimum 5MB for S3)
chunk_size = request.chunk_size if request.chunk_size > 0 else DEFAULT_CHUNK_SIZE
if chunk_size < DEFAULT_CHUNK_SIZE:
chunk_size = DEFAULT_CHUNK_SIZE
# Calculate total chunks
total_chunks = math.ceil(total_size / chunk_size)
# Generate IDs
upload_id = str(uuid.uuid4())
object_id = uuid.uuid4()
# Create S3 multipart upload
s3_upload_id = self.blob_store.create_multipart_upload(
object_id, request.document_metadata.kind
)
# Serialize document metadata for storage
doc_meta_json = json.dumps({
"id": request.document_metadata.id,
"time": request.document_metadata.time,
"kind": request.document_metadata.kind,
"title": request.document_metadata.title,
"comments": request.document_metadata.comments,
"user": request.document_metadata.user,
"tags": request.document_metadata.tags,
})
# Store session in Cassandra
await self.table_store.create_upload_session(
upload_id=upload_id,
user=request.document_metadata.user,
document_id=request.document_metadata.id,
document_metadata=doc_meta_json,
s3_upload_id=s3_upload_id,
object_id=object_id,
total_size=total_size,
chunk_size=chunk_size,
total_chunks=total_chunks,
)
logger.info(f"Created upload session {upload_id} with {total_chunks} chunks")
return LibrarianResponse(
error=None,
upload_id=upload_id,
chunk_size=chunk_size,
total_chunks=total_chunks,
)
async def upload_chunk(self, request):
"""
Upload a single chunk of a document.
Forwards the chunk to S3 and updates session state.
"""
logger.debug(f"Uploading chunk {request.chunk_index} for upload {request.upload_id}")
# Get session
session = await self.table_store.get_upload_session(request.upload_id)
if session is None:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
raise RequestError("Not authorized to upload to this session")
# Validate chunk index
if request.chunk_index < 0 or request.chunk_index >= session["total_chunks"]:
raise RequestError(
f"Invalid chunk index {request.chunk_index}, "
f"must be 0-{session['total_chunks']-1}"
)
# Decode content
content = base64.b64decode(request.content)
# Upload to S3 (part numbers are 1-indexed in S3)
part_number = request.chunk_index + 1
etag = self.blob_store.upload_part(
object_id=session["object_id"],
upload_id=session["s3_upload_id"],
part_number=part_number,
data=content,
)
# Update session with chunk info
await self.table_store.update_upload_session_chunk(
upload_id=request.upload_id,
chunk_index=request.chunk_index,
etag=etag,
)
# Calculate progress
chunks_received = session["chunks_received"]
# Add this chunk if not already present
if request.chunk_index not in chunks_received:
chunks_received[request.chunk_index] = etag
num_chunks_received = len(chunks_received) + 1 # +1 for this chunk
bytes_received = num_chunks_received * session["chunk_size"]
# Adjust for last chunk potentially being smaller
if bytes_received > session["total_size"]:
bytes_received = session["total_size"]
logger.debug(f"Chunk {request.chunk_index} uploaded, {num_chunks_received}/{session['total_chunks']} complete")
return LibrarianResponse(
error=None,
upload_id=request.upload_id,
chunk_index=request.chunk_index,
chunks_received=num_chunks_received,
total_chunks=session["total_chunks"],
bytes_received=bytes_received,
total_bytes=session["total_size"],
)
async def complete_upload(self, request):
"""
Finalize a chunked upload and create the document.
Completes the S3 multipart upload and creates the document metadata.
"""
logger.info(f"Completing upload {request.upload_id}")
# Get session
session = await self.table_store.get_upload_session(request.upload_id)
if session is None:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
raise RequestError("Not authorized to complete this upload")
# Verify all chunks received
chunks_received = session["chunks_received"]
if len(chunks_received) != session["total_chunks"]:
missing = [
i for i in range(session["total_chunks"])
if i not in chunks_received
]
raise RequestError(
f"Missing chunks: {missing[:10]}{'...' if len(missing) > 10 else ''}"
)
# Build parts list for S3 (sorted by part number)
parts = [
(chunk_index + 1, etag) # S3 part numbers are 1-indexed
for chunk_index, etag in sorted(chunks_received.items())
]
# Complete S3 multipart upload
self.blob_store.complete_multipart_upload(
object_id=session["object_id"],
upload_id=session["s3_upload_id"],
parts=parts,
)
# Parse document metadata from session
doc_meta_dict = json.loads(session["document_metadata"])
# Create DocumentMetadata object
from .. schema import DocumentMetadata
doc_metadata = DocumentMetadata(
id=doc_meta_dict["id"],
time=doc_meta_dict.get("time", int(time.time())),
kind=doc_meta_dict["kind"],
title=doc_meta_dict.get("title", ""),
comments=doc_meta_dict.get("comments", ""),
user=doc_meta_dict["user"],
tags=doc_meta_dict.get("tags", []),
metadata=[], # Triples not supported in chunked upload yet
)
# Add document to table
await self.table_store.add_document(doc_metadata, session["object_id"])
# Delete upload session
await self.table_store.delete_upload_session(request.upload_id)
logger.info(f"Upload {request.upload_id} completed, document {doc_metadata.id} created")
return LibrarianResponse(
error=None,
document_id=doc_metadata.id,
object_id=str(session["object_id"]),
)
async def abort_upload(self, request):
"""
Cancel a chunked upload and clean up resources.
"""
logger.info(f"Aborting upload {request.upload_id}")
# Get session
session = await self.table_store.get_upload_session(request.upload_id)
if session is None:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
raise RequestError("Not authorized to abort this upload")
# Abort S3 multipart upload
self.blob_store.abort_multipart_upload(
object_id=session["object_id"],
upload_id=session["s3_upload_id"],
)
# Delete session from Cassandra
await self.table_store.delete_upload_session(request.upload_id)
logger.info(f"Upload {request.upload_id} aborted")
return LibrarianResponse(error=None)
async def get_upload_status(self, request):
"""
Get the status of an in-progress upload.
"""
logger.debug(f"Getting status for upload {request.upload_id}")
# Get session
session = await self.table_store.get_upload_session(request.upload_id)
if session is None:
return LibrarianResponse(
error=None,
upload_id=request.upload_id,
upload_state="expired",
)
# Validate ownership
if session["user"] != request.user:
raise RequestError("Not authorized to view this upload")
chunks_received = session["chunks_received"]
received_list = sorted(chunks_received.keys())
missing_list = [
i for i in range(session["total_chunks"])
if i not in chunks_received
]
bytes_received = len(chunks_received) * session["chunk_size"]
if bytes_received > session["total_size"]:
bytes_received = session["total_size"]
return LibrarianResponse(
error=None,
upload_id=request.upload_id,
upload_state="in-progress",
received_chunks=received_list,
missing_chunks=missing_list,
chunks_received=len(chunks_received),
total_chunks=session["total_chunks"],
bytes_received=bytes_received,
total_bytes=session["total_size"],
)
async def list_uploads(self, request):
"""
List all in-progress uploads for a user.
"""
logger.debug(f"Listing uploads for user {request.user}")
sessions = await self.table_store.list_upload_sessions(request.user)
upload_sessions = [
UploadSession(
upload_id=s["upload_id"],
document_id=s["document_id"],
document_metadata_json=s.get("document_metadata", ""),
total_size=s["total_size"],
chunk_size=s["chunk_size"],
total_chunks=s["total_chunks"],
chunks_received=s["chunks_received"],
created_at=str(s.get("created_at", "")),
)
for s in sessions
]
return LibrarianResponse(
error=None,
upload_sessions=upload_sessions,
)
# Child document operations
async def add_child_document(self, request):
"""
Add a child document linked to a parent document.
Child documents are typically extracted content (e.g., pages from a PDF).
They have a parent_id pointing to the source document and document_type
set to "extracted".
"""
logger.info(f"Adding child document {request.document_metadata.id} "
f"for parent {request.document_metadata.parent_id}")
if not request.document_metadata.parent_id:
raise RequestError("parent_id is required for child documents")
# Verify parent exists
if not await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.parent_id
):
raise RequestError(
f"Parent document {request.document_metadata.parent_id} does not exist"
)
if await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.id
):
raise RequestError("Document already exists")
# Ensure document_type is set to "extracted"
request.document_metadata.document_type = "extracted"
# Create object ID for blob
object_id = uuid.uuid4()
logger.debug("Adding blob...")
await self.blob_store.add(
object_id, base64.b64decode(request.content),
request.document_metadata.kind
)
logger.debug("Adding to table...")
await self.table_store.add_document(
request.document_metadata, object_id
)
logger.debug("Add child document complete")
return LibrarianResponse(
error=None,
document_id=request.document_metadata.id,
)
async def list_children(self, request):
"""
List all child documents for a given parent document.
"""
logger.debug(f"Listing children for parent {request.document_id}")
children = await self.table_store.list_children(request.document_id)
return LibrarianResponse(
error=None,
document_metadatas=children,
)
async def stream_document(self, request):
"""
Stream document content in chunks.
This operation returns document content in smaller chunks, allowing
memory-efficient processing of large documents. The response includes
chunk information for reassembly.
Note: This operation returns a single chunk at a time. Clients should
call repeatedly with increasing chunk_index until all chunks are received.
"""
logger.debug(f"Streaming document {request.document_id}, chunk {request.chunk_index}")
object_id = await self.table_store.get_document_object_id(
request.user,
request.document_id
)
# Default chunk size of 1MB
chunk_size = request.chunk_size if request.chunk_size > 0 else 1024 * 1024
# Get the full content and slice out the requested chunk
# Note: This is a simple implementation. For true streaming, we'd need
# range requests on the object storage.
content = await self.blob_store.get(object_id)
total_size = len(content)
total_chunks = math.ceil(total_size / chunk_size)
if request.chunk_index >= total_chunks:
raise RequestError(
f"Invalid chunk index {request.chunk_index}, "
f"document has {total_chunks} chunks"
)
start = request.chunk_index * chunk_size
end = min(start + chunk_size, total_size)
chunk_content = content[start:end]
logger.debug(f"Returning chunk {request.chunk_index}/{total_chunks}, "
f"bytes {start}-{end} of {total_size}")
return LibrarianResponse(
error=None,
content=base64.b64encode(chunk_content),
chunk_index=request.chunk_index,
chunks_received=1, # Using as "current chunk" indicator
total_chunks=total_chunks,
bytes_received=end,
total_bytes=total_size,
)

View file

@ -271,6 +271,9 @@ class Processor(AsyncProcessor):
pass
# Threshold for sending document_id instead of inline content (2MB)
STREAMING_THRESHOLD = 2 * 1024 * 1024
async def load_document(self, document, processing, content):
logger.debug("Ready for document processing...")
@ -292,26 +295,57 @@ class Processor(AsyncProcessor):
q = flow["interfaces"][kind]
if kind == "text-load":
doc = TextDocument(
metadata = Metadata(
id = document.id,
metadata = document.metadata,
user = processing.user,
collection = processing.collection
),
text = content,
)
# For large text documents, send document_id for streaming retrieval
if len(content) >= self.STREAMING_THRESHOLD:
logger.info(f"Text document {document.id} is large ({len(content)} bytes), "
f"sending document_id for streaming retrieval")
doc = TextDocument(
metadata = Metadata(
id = document.id,
metadata = document.metadata,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
text = b"", # Empty, receiver will fetch via librarian
)
else:
doc = TextDocument(
metadata = Metadata(
id = document.id,
metadata = document.metadata,
user = processing.user,
collection = processing.collection
),
text = content,
)
schema = TextDocument
else:
doc = Document(
metadata = Metadata(
id = document.id,
metadata = document.metadata,
user = processing.user,
collection = processing.collection
),
data = base64.b64encode(content).decode("utf-8")
)
# For large PDF documents, send document_id for streaming retrieval
# instead of embedding the entire content in the message
if len(content) >= self.STREAMING_THRESHOLD:
logger.info(f"Document {document.id} is large ({len(content)} bytes), "
f"sending document_id for streaming retrieval")
doc = Document(
metadata = Metadata(
id = document.id,
metadata = document.metadata,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
data = b"", # Empty data, receiver will fetch via API
)
else:
doc = Document(
metadata = Metadata(
id = document.id,
metadata = document.metadata,
user = processing.user,
collection = processing.collection
),
data = base64.b64encode(content).decode("utf-8")
)
schema = Document
logger.debug(f"Submitting to queue {q}...")
@ -361,6 +395,17 @@ class Processor(AsyncProcessor):
"remove-processing": self.librarian.remove_processing,
"list-documents": self.librarian.list_documents,
"list-processing": self.librarian.list_processing,
# Chunked upload operations
"begin-upload": self.librarian.begin_upload,
"upload-chunk": self.librarian.upload_chunk,
"complete-upload": self.librarian.complete_upload,
"abort-upload": self.librarian.abort_upload,
"get-upload-status": self.librarian.get_upload_status,
"list-uploads": self.librarian.list_uploads,
# Child document and streaming operations
"add-child-document": self.librarian.add_child_document,
"list-children": self.librarian.list_children,
"stream-document": self.librarian.stream_document,
}
if v.operation not in impls:

View file

@ -112,6 +112,34 @@ class LibraryTableStore:
ON document (object_id)
""");
# Add parent_id and document_type columns for child document support
logger.debug("document table parent_id column...")
try:
self.cassandra.execute("""
ALTER TABLE document ADD parent_id text
""");
except Exception as e:
# Column may already exist
if "already exists" not in str(e).lower() and "Invalid column name" not in str(e):
logger.debug(f"parent_id column may already exist: {e}")
try:
self.cassandra.execute("""
ALTER TABLE document ADD document_type text
""");
except Exception as e:
# Column may already exist
if "already exists" not in str(e).lower() and "Invalid column name" not in str(e):
logger.debug(f"document_type column may already exist: {e}")
logger.debug("document parent index...")
self.cassandra.execute("""
CREATE INDEX IF NOT EXISTS document_parent
ON document (parent_id)
""");
logger.debug("processing table...")
self.cassandra.execute("""
@ -127,6 +155,32 @@ class LibraryTableStore:
);
""");
logger.debug("upload_session table...")
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS upload_session (
upload_id text PRIMARY KEY,
user text,
document_id text,
document_metadata text,
s3_upload_id text,
object_id uuid,
total_size bigint,
chunk_size int,
total_chunks int,
chunks_received map<int, text>,
created_at timestamp,
updated_at timestamp
) WITH default_time_to_live = 86400;
""");
logger.debug("upload_session user index...")
self.cassandra.execute("""
CREATE INDEX IF NOT EXISTS upload_session_user
ON upload_session (user)
""");
logger.info("Cassandra schema OK.")
def prepare_statements(self):
@ -136,9 +190,10 @@ class LibraryTableStore:
(
id, user, time,
kind, title, comments,
metadata, tags, object_id
metadata, tags, object_id,
parent_id, document_type
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""")
self.update_document_stmt = self.cassandra.prepare("""
@ -149,7 +204,8 @@ class LibraryTableStore:
""")
self.get_document_stmt = self.cassandra.prepare("""
SELECT time, kind, title, comments, metadata, tags, object_id
SELECT time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ? AND id = ?
""")
@ -168,14 +224,16 @@ class LibraryTableStore:
self.list_document_stmt = self.cassandra.prepare("""
SELECT
id, time, kind, title, comments, metadata, tags, object_id
id, time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ?
""")
self.list_document_by_tag_stmt = self.cassandra.prepare("""
SELECT
id, time, kind, title, comments, metadata, tags, object_id
id, time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ? AND tags CONTAINS ?
ALLOW FILTERING
@ -210,6 +268,57 @@ class LibraryTableStore:
WHERE user = ?
""")
# Upload session prepared statements
self.insert_upload_session_stmt = self.cassandra.prepare("""
INSERT INTO upload_session
(
upload_id, user, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, chunks_received, created_at, updated_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""")
self.get_upload_session_stmt = self.cassandra.prepare("""
SELECT
upload_id, user, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, chunks_received, created_at, updated_at
FROM upload_session
WHERE upload_id = ?
""")
self.update_upload_session_chunk_stmt = self.cassandra.prepare("""
UPDATE upload_session
SET chunks_received = chunks_received + ?,
updated_at = ?
WHERE upload_id = ?
""")
self.delete_upload_session_stmt = self.cassandra.prepare("""
DELETE FROM upload_session
WHERE upload_id = ?
""")
self.list_upload_sessions_stmt = self.cassandra.prepare("""
SELECT
upload_id, document_id, document_metadata,
total_size, chunk_size, total_chunks,
chunks_received, created_at, updated_at
FROM upload_session
WHERE user = ?
""")
# Child document queries
self.list_children_stmt = self.cassandra.prepare("""
SELECT
id, user, time, kind, title, comments, metadata, tags,
object_id, parent_id, document_type
FROM document
WHERE parent_id = ?
ALLOW FILTERING
""")
async def document_exists(self, user, id):
resp = self.cassandra.execute(
@ -236,6 +345,10 @@ class LibraryTableStore:
for v in document.metadata
]
# Get parent_id and document_type from document, defaulting if not set
parent_id = getattr(document, 'parent_id', '') or ''
document_type = getattr(document, 'document_type', 'source') or 'source'
while True:
try:
@ -245,7 +358,8 @@ class LibraryTableStore:
(
document.id, document.user, int(document.time * 1000),
document.kind, document.title, document.comments,
metadata, document.tags, object_id
metadata, document.tags, object_id,
parent_id, document_type
)
)
@ -349,9 +463,58 @@ class LibraryTableStore:
p=tuple_to_term(m[2], m[3]),
o=tuple_to_term(m[4], m[5])
)
for m in row[5]
for m in (row[5] or [])
],
tags = row[6] if row[6] else [],
parent_id = row[8] if row[8] else "",
document_type = row[9] if row[9] else "source",
)
for row in resp
]
logger.debug("Done")
return lst
async def list_children(self, parent_id):
"""List all child documents for a given parent document ID."""
logger.debug(f"List children for parent {parent_id}")
while True:
try:
resp = self.cassandra.execute(
self.list_children_stmt,
(parent_id,)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
lst = [
DocumentMetadata(
id = row[0],
user = row[1],
time = int(time.mktime(row[2].timetuple())),
kind = row[3],
title = row[4],
comments = row[5],
metadata = [
Triple(
s=tuple_to_term(m[0], m[1]),
p=tuple_to_term(m[2], m[3]),
o=tuple_to_term(m[4], m[5])
)
for m in (row[6] or [])
],
tags = row[7] if row[7] else [],
parent_id = row[9] if row[9] else "",
document_type = row[10] if row[10] else "source",
)
for row in resp
]
@ -394,9 +557,11 @@ class LibraryTableStore:
p=tuple_to_term(m[2], m[3]),
o=tuple_to_term(m[4], m[5])
)
for m in row[4]
for m in (row[4] or [])
],
tags = row[5] if row[5] else [],
parent_id = row[7] if row[7] else "",
document_type = row[8] if row[8] else "source",
)
logger.debug("Done")
@ -532,3 +697,152 @@ class LibraryTableStore:
logger.debug("Done")
return lst
# Upload session methods
async def create_upload_session(
self,
upload_id,
user,
document_id,
document_metadata,
s3_upload_id,
object_id,
total_size,
chunk_size,
total_chunks,
):
"""Create a new upload session for chunked upload."""
logger.info(f"Creating upload session {upload_id}")
now = int(time.time() * 1000)
while True:
try:
self.cassandra.execute(
self.insert_upload_session_stmt,
(
upload_id, user, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, {}, now, now
)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
logger.debug("Upload session created")
async def get_upload_session(self, upload_id):
"""Get an upload session by ID."""
logger.debug(f"Get upload session {upload_id}")
while True:
try:
resp = self.cassandra.execute(
self.get_upload_session_stmt,
(upload_id,)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
for row in resp:
session = {
"upload_id": row[0],
"user": row[1],
"document_id": row[2],
"document_metadata": row[3],
"s3_upload_id": row[4],
"object_id": row[5],
"total_size": row[6],
"chunk_size": row[7],
"total_chunks": row[8],
"chunks_received": row[9] if row[9] else {},
"created_at": row[10],
"updated_at": row[11],
}
logger.debug("Done")
return session
return None
async def update_upload_session_chunk(self, upload_id, chunk_index, etag):
"""Record a successfully uploaded chunk."""
logger.debug(f"Update upload session {upload_id} chunk {chunk_index}")
now = int(time.time() * 1000)
while True:
try:
self.cassandra.execute(
self.update_upload_session_chunk_stmt,
(
{chunk_index: etag},
now,
upload_id
)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
logger.debug("Chunk recorded")
async def delete_upload_session(self, upload_id):
"""Delete an upload session."""
logger.info(f"Deleting upload session {upload_id}")
while True:
try:
self.cassandra.execute(
self.delete_upload_session_stmt,
(upload_id,)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
logger.debug("Upload session deleted")
async def list_upload_sessions(self, user):
"""List all upload sessions for a user."""
logger.debug(f"List upload sessions for {user}")
while True:
try:
resp = self.cassandra.execute(
self.list_upload_sessions_stmt,
(user,)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
sessions = []
for row in resp:
chunks_received = row[6] if row[6] else {}
sessions.append({
"upload_id": row[0],
"document_id": row[1],
"document_metadata": row[2],
"total_size": row[3],
"chunk_size": row[4],
"total_chunks": row[5],
"chunks_received": len(chunks_received),
"created_at": row[7],
"updated_at": row[8],
})
logger.debug("Done")
return sessions