mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
RabbitMQ pub/sub backend with topic exchange architecture (#752)
Adds a RabbitMQ backend as an alternative to Pulsar, selectable via PUBSUB_BACKEND=rabbitmq. Both backends implement the same PubSubBackend protocol — no application code changes needed to switch. RabbitMQ topology: - Single topic exchange per topicspace (e.g. 'tg') - Routing key derived from queue class and topic name - Shared consumers: named queue bound to exchange (competing, round-robin) - Exclusive consumers: anonymous auto-delete queue (broadcast, each gets every message). Used by Subscriber and config push consumer. - Thread-local producer connections (pika is not thread-safe) - Push-based consumption via basic_consume with process_data_events for heartbeat processing Consumer model changes: - Consumer class creates one backend consumer per concurrent task (required for pika thread safety, harmless for Pulsar) - Consumer class accepts consumer_type parameter - Subscriber passes consumer_type='exclusive' for broadcast semantics - Config push consumer uses consumer_type='exclusive' so every processor instance receives config updates - handle_one_from_queue receives consumer as parameter for correct per-connection ack/nack LibrarianClient: - New shared client class replacing duplicated librarian request-response code across 6+ services (chunking, decoders, RAG, etc.) - Uses stream-document instead of get-document-content for fetching document content in 1MB chunks (avoids broker message size limits) - Standalone object (self.librarian = LibrarianClient(...)) not a mixin - get-document-content marked deprecated in schema and OpenAPI spec Serialisation: - Extracted dataclass_to_dict/dict_to_dataclass to shared serialization.py (used by both Pulsar and RabbitMQ backends) Librarian queues: - Changed from flow class (persistent) back to request/response class now that stream-document eliminates large single messages - API upload chunk size reduced from 5MB to 3MB to stay under broker limits after base64 encoding Factory and CLI: - get_pubsub() handles 'rabbitmq' backend with RabbitMQ connection params - add_pubsub_args() includes RabbitMQ options (host, port, credentials) - add_pubsub_args(standalone=True) defaults to localhost for CLI tools - init_trustgraph skips Pulsar admin setup for non-Pulsar backends - tg-dump-queues and tg-monitor-prompts use backend abstraction - BaseClient and ConfigClient accept generic pubsub config
This commit is contained in:
parent
4fb0b4d8e8
commit
24f0190ce7
36 changed files with 1277 additions and 1313 deletions
4
Makefile
4
Makefile
|
|
@ -77,8 +77,8 @@ some-containers:
|
||||||
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
|
||||||
${DOCKER} build -f containers/Containerfile.flow \
|
${DOCKER} build -f containers/Containerfile.flow \
|
||||||
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
|
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
|
||||||
# ${DOCKER} build -f containers/Containerfile.unstructured \
|
${DOCKER} build -f containers/Containerfile.unstructured \
|
||||||
# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
|
-t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
|
||||||
# ${DOCKER} build -f containers/Containerfile.vertexai \
|
# ${DOCKER} build -f containers/Containerfile.vertexai \
|
||||||
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
|
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
|
||||||
# ${DOCKER} build -f containers/Containerfile.mcp \
|
# ${DOCKER} build -f containers/Containerfile.mcp \
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,9 @@ description: |
|
||||||
Librarian service request for document library management.
|
Librarian service request for document library management.
|
||||||
|
|
||||||
Operations: add-document, remove-document, list-documents,
|
Operations: add-document, remove-document, list-documents,
|
||||||
|
get-document-metadata, stream-document, add-child-document,
|
||||||
|
list-children, begin-upload, upload-chunk, complete-upload,
|
||||||
|
abort-upload, get-upload-status, list-uploads,
|
||||||
start-processing, stop-processing, list-processing
|
start-processing, stop-processing, list-processing
|
||||||
required:
|
required:
|
||||||
- operation
|
- operation
|
||||||
|
|
@ -13,6 +16,17 @@ properties:
|
||||||
- add-document
|
- add-document
|
||||||
- remove-document
|
- remove-document
|
||||||
- list-documents
|
- list-documents
|
||||||
|
- get-document-metadata
|
||||||
|
- get-document-content
|
||||||
|
- stream-document
|
||||||
|
- add-child-document
|
||||||
|
- list-children
|
||||||
|
- begin-upload
|
||||||
|
- upload-chunk
|
||||||
|
- complete-upload
|
||||||
|
- abort-upload
|
||||||
|
- get-upload-status
|
||||||
|
- list-uploads
|
||||||
- start-processing
|
- start-processing
|
||||||
- stop-processing
|
- stop-processing
|
||||||
- list-processing
|
- list-processing
|
||||||
|
|
@ -21,6 +35,21 @@ properties:
|
||||||
- `add-document`: Add document to library
|
- `add-document`: Add document to library
|
||||||
- `remove-document`: Remove document from library
|
- `remove-document`: Remove document from library
|
||||||
- `list-documents`: List documents in library
|
- `list-documents`: List documents in library
|
||||||
|
- `get-document-metadata`: Get document metadata
|
||||||
|
- `get-document-content`: Get full document content in a single response.
|
||||||
|
**Deprecated** — use `stream-document` instead. Fails for documents
|
||||||
|
exceeding the broker's max message size.
|
||||||
|
- `stream-document`: Stream document content in chunks. Each response
|
||||||
|
includes `chunk_index` and `is_final`. Preferred over `get-document-content`
|
||||||
|
for all document sizes.
|
||||||
|
- `add-child-document`: Add a child document (e.g. page, chunk)
|
||||||
|
- `list-children`: List child documents of a parent
|
||||||
|
- `begin-upload`: Start a chunked upload session
|
||||||
|
- `upload-chunk`: Upload a chunk of data
|
||||||
|
- `complete-upload`: Finalize a chunked upload
|
||||||
|
- `abort-upload`: Cancel a chunked upload
|
||||||
|
- `get-upload-status`: Check upload progress
|
||||||
|
- `list-uploads`: List active upload sessions
|
||||||
- `start-processing`: Start processing library documents
|
- `start-processing`: Start processing library documents
|
||||||
- `stop-processing`: Stop library processing
|
- `stop-processing`: Stop library processing
|
||||||
- `list-processing`: List processing status
|
- `list-processing`: List processing status
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,8 @@ class MockAsyncProcessor:
|
||||||
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
"""Test Recursive chunker functionality"""
|
"""Test Recursive chunker functionality"""
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
|
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
|
||||||
"""Test basic processor initialization"""
|
"""Test basic processor initialization"""
|
||||||
|
|
@ -51,8 +51,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||||
assert len(param_specs) == 2
|
assert len(param_specs) == 2
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
|
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
|
||||||
"""Test chunk_document with chunk-size parameter override"""
|
"""Test chunk_document with chunk-size parameter override"""
|
||||||
|
|
@ -71,7 +71,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_flow.side_effect = lambda param: {
|
mock_flow.parameters.get.side_effect = lambda param: {
|
||||||
"chunk-size": 2000, # Override chunk size
|
"chunk-size": 2000, # Override chunk size
|
||||||
"chunk-overlap": None # Use default chunk overlap
|
"chunk-overlap": None # Use default chunk overlap
|
||||||
}.get(param)
|
}.get(param)
|
||||||
|
|
@ -85,8 +85,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
assert chunk_size == 2000 # Should use overridden value
|
assert chunk_size == 2000 # Should use overridden value
|
||||||
assert chunk_overlap == 100 # Should use default value
|
assert chunk_overlap == 100 # Should use default value
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
|
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
|
||||||
"""Test chunk_document with chunk-overlap parameter override"""
|
"""Test chunk_document with chunk-overlap parameter override"""
|
||||||
|
|
@ -105,7 +105,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_flow.side_effect = lambda param: {
|
mock_flow.parameters.get.side_effect = lambda param: {
|
||||||
"chunk-size": None, # Use default chunk size
|
"chunk-size": None, # Use default chunk size
|
||||||
"chunk-overlap": 200 # Override chunk overlap
|
"chunk-overlap": 200 # Override chunk overlap
|
||||||
}.get(param)
|
}.get(param)
|
||||||
|
|
@ -119,8 +119,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
assert chunk_size == 1000 # Should use default value
|
assert chunk_size == 1000 # Should use default value
|
||||||
assert chunk_overlap == 200 # Should use overridden value
|
assert chunk_overlap == 200 # Should use overridden value
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
|
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
|
||||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||||
|
|
@ -139,7 +139,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_flow.side_effect = lambda param: {
|
mock_flow.parameters.get.side_effect = lambda param: {
|
||||||
"chunk-size": 1500, # Override chunk size
|
"chunk-size": 1500, # Override chunk size
|
||||||
"chunk-overlap": 150 # Override chunk overlap
|
"chunk-overlap": 150 # Override chunk overlap
|
||||||
}.get(param)
|
}.get(param)
|
||||||
|
|
@ -153,8 +153,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
assert chunk_size == 1500 # Should use overridden value
|
assert chunk_size == 1500 # Should use overridden value
|
||||||
assert chunk_overlap == 150 # Should use overridden value
|
assert chunk_overlap == 150 # Should use overridden value
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
|
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
|
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
|
||||||
|
|
@ -177,7 +177,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Mock save_child_document to avoid waiting for librarian response
|
# Mock save_child_document to avoid waiting for librarian response
|
||||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||||
|
|
||||||
# Mock message with TextDocument
|
# Mock message with TextDocument
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
|
|
@ -196,12 +196,14 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
mock_producer = AsyncMock()
|
mock_producer = AsyncMock()
|
||||||
mock_triples_producer = AsyncMock()
|
mock_triples_producer = AsyncMock()
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_flow.side_effect = lambda param: {
|
mock_flow.parameters.get.side_effect = lambda param: {
|
||||||
"chunk-size": 1500,
|
"chunk-size": 1500,
|
||||||
"chunk-overlap": 150,
|
"chunk-overlap": 150,
|
||||||
|
}.get(param)
|
||||||
|
mock_flow.side_effect = lambda name: {
|
||||||
"output": mock_producer,
|
"output": mock_producer,
|
||||||
"triples": mock_triples_producer,
|
"triples": mock_triples_producer,
|
||||||
}.get(param)
|
}.get(name)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||||
|
|
@ -219,8 +221,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
sent_chunk = mock_producer.send.call_args[0][0]
|
sent_chunk = mock_producer.send.call_args[0][0]
|
||||||
assert isinstance(sent_chunk, Chunk)
|
assert isinstance(sent_chunk, Chunk)
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
|
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
|
||||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||||
|
|
@ -239,7 +241,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_flow.return_value = None # No overrides
|
mock_flow.parameters.get.return_value = None # No overrides
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,8 @@ class MockAsyncProcessor:
|
||||||
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
"""Test Token chunker functionality"""
|
"""Test Token chunker functionality"""
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
|
def test_processor_initialization_basic(self, mock_producer, mock_consumer):
|
||||||
"""Test basic processor initialization"""
|
"""Test basic processor initialization"""
|
||||||
|
|
@ -51,8 +51,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
|
||||||
assert len(param_specs) == 2
|
assert len(param_specs) == 2
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
|
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
|
||||||
"""Test chunk_document with chunk-size parameter override"""
|
"""Test chunk_document with chunk-size parameter override"""
|
||||||
|
|
@ -71,7 +71,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_flow.side_effect = lambda param: {
|
mock_flow.parameters.get.side_effect = lambda param: {
|
||||||
"chunk-size": 400, # Override chunk size
|
"chunk-size": 400, # Override chunk size
|
||||||
"chunk-overlap": None # Use default chunk overlap
|
"chunk-overlap": None # Use default chunk overlap
|
||||||
}.get(param)
|
}.get(param)
|
||||||
|
|
@ -85,8 +85,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
assert chunk_size == 400 # Should use overridden value
|
assert chunk_size == 400 # Should use overridden value
|
||||||
assert chunk_overlap == 15 # Should use default value
|
assert chunk_overlap == 15 # Should use default value
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
|
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
|
||||||
"""Test chunk_document with chunk-overlap parameter override"""
|
"""Test chunk_document with chunk-overlap parameter override"""
|
||||||
|
|
@ -105,7 +105,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_flow.side_effect = lambda param: {
|
mock_flow.parameters.get.side_effect = lambda param: {
|
||||||
"chunk-size": None, # Use default chunk size
|
"chunk-size": None, # Use default chunk size
|
||||||
"chunk-overlap": 25 # Override chunk overlap
|
"chunk-overlap": 25 # Override chunk overlap
|
||||||
}.get(param)
|
}.get(param)
|
||||||
|
|
@ -119,8 +119,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
assert chunk_size == 250 # Should use default value
|
assert chunk_size == 250 # Should use default value
|
||||||
assert chunk_overlap == 25 # Should use overridden value
|
assert chunk_overlap == 25 # Should use overridden value
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
|
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
|
||||||
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
"""Test chunk_document with both chunk-size and chunk-overlap overrides"""
|
||||||
|
|
@ -139,7 +139,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_flow.side_effect = lambda param: {
|
mock_flow.parameters.get.side_effect = lambda param: {
|
||||||
"chunk-size": 350, # Override chunk size
|
"chunk-size": 350, # Override chunk size
|
||||||
"chunk-overlap": 30 # Override chunk overlap
|
"chunk-overlap": 30 # Override chunk overlap
|
||||||
}.get(param)
|
}.get(param)
|
||||||
|
|
@ -153,8 +153,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
assert chunk_size == 350 # Should use overridden value
|
assert chunk_size == 350 # Should use overridden value
|
||||||
assert chunk_overlap == 30 # Should use overridden value
|
assert chunk_overlap == 30 # Should use overridden value
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
|
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
|
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
|
||||||
|
|
@ -177,7 +177,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Mock save_child_document to avoid librarian producer interactions
|
# Mock save_child_document to avoid librarian producer interactions
|
||||||
processor.save_child_document = AsyncMock(return_value="chunk-id")
|
processor.librarian.save_child_document = AsyncMock(return_value="chunk-id")
|
||||||
|
|
||||||
# Mock message with TextDocument
|
# Mock message with TextDocument
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
|
|
@ -196,12 +196,14 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
mock_producer = AsyncMock()
|
mock_producer = AsyncMock()
|
||||||
mock_triples_producer = AsyncMock()
|
mock_triples_producer = AsyncMock()
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_flow.side_effect = lambda param: {
|
mock_flow.parameters.get.side_effect = lambda param: {
|
||||||
"chunk-size": 400,
|
"chunk-size": 400,
|
||||||
"chunk-overlap": 40,
|
"chunk-overlap": 40,
|
||||||
|
}.get(param)
|
||||||
|
mock_flow.side_effect = lambda name: {
|
||||||
"output": mock_producer,
|
"output": mock_producer,
|
||||||
"triples": mock_triples_producer,
|
"triples": mock_triples_producer,
|
||||||
}.get(param)
|
}.get(name)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
await processor.on_message(mock_message, mock_consumer, mock_flow)
|
||||||
|
|
@ -223,8 +225,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
sent_chunk = mock_producer.send.call_args[0][0]
|
sent_chunk = mock_producer.send.call_args[0][0]
|
||||||
assert isinstance(sent_chunk, Chunk)
|
assert isinstance(sent_chunk, Chunk)
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
|
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
|
||||||
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
"""Test chunk_document when no parameters are overridden (flow returns None)"""
|
||||||
|
|
@ -243,7 +245,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
mock_message = MagicMock()
|
mock_message = MagicMock()
|
||||||
mock_consumer = MagicMock()
|
mock_consumer = MagicMock()
|
||||||
mock_flow = MagicMock()
|
mock_flow = MagicMock()
|
||||||
mock_flow.return_value = None # No overrides
|
mock_flow.parameters.get.return_value = None # No overrides
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
chunk_size, chunk_overlap = await processor.chunk_document(
|
chunk_size, chunk_overlap = await processor.chunk_document(
|
||||||
|
|
@ -254,8 +256,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
||||||
assert chunk_size == 250 # Should use default value
|
assert chunk_size == 250 # Should use default value
|
||||||
assert chunk_overlap == 15 # Should use default value
|
assert chunk_overlap == 15 # Should use default value
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer):
|
def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer):
|
||||||
"""Test that token chunker has different defaults than recursive chunker"""
|
"""Test that token chunker has different defaults than recursive chunker"""
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ class TestTaskGroupConcurrency:
|
||||||
call_count = 0
|
call_count = 0
|
||||||
original_running = True
|
original_running = True
|
||||||
|
|
||||||
async def mock_consume():
|
async def mock_consume(backend_consumer):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
# Wait a bit to let all tasks start, then signal stop
|
# Wait a bit to let all tasks start, then signal stop
|
||||||
|
|
@ -107,7 +107,7 @@ class TestTaskGroupConcurrency:
|
||||||
consumer = _make_consumer(concurrency=1)
|
consumer = _make_consumer(concurrency=1)
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def mock_consume():
|
async def mock_consume(backend_consumer):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
|
|
@ -147,7 +147,7 @@ class TestRateLimitRetry:
|
||||||
mock_msg = _make_msg()
|
mock_msg = _make_msg()
|
||||||
consumer.consumer = MagicMock()
|
consumer.consumer = MagicMock()
|
||||||
|
|
||||||
await consumer.handle_one_from_queue(mock_msg)
|
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||||
|
|
||||||
assert call_count == 2
|
assert call_count == 2
|
||||||
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
|
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
|
||||||
|
|
@ -166,7 +166,7 @@ class TestRateLimitRetry:
|
||||||
mock_msg = _make_msg()
|
mock_msg = _make_msg()
|
||||||
consumer.consumer = MagicMock()
|
consumer.consumer = MagicMock()
|
||||||
|
|
||||||
await consumer.handle_one_from_queue(mock_msg)
|
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||||
|
|
||||||
consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
|
consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
|
||||||
consumer.consumer.acknowledge.assert_not_called()
|
consumer.consumer.acknowledge.assert_not_called()
|
||||||
|
|
@ -185,7 +185,7 @@ class TestRateLimitRetry:
|
||||||
mock_msg = _make_msg()
|
mock_msg = _make_msg()
|
||||||
consumer.consumer = MagicMock()
|
consumer.consumer = MagicMock()
|
||||||
|
|
||||||
await consumer.handle_one_from_queue(mock_msg)
|
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||||
|
|
||||||
assert call_count == 1
|
assert call_count == 1
|
||||||
consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
|
consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
|
||||||
|
|
@ -197,7 +197,7 @@ class TestRateLimitRetry:
|
||||||
mock_msg = _make_msg()
|
mock_msg = _make_msg()
|
||||||
consumer.consumer = MagicMock()
|
consumer.consumer = MagicMock()
|
||||||
|
|
||||||
await consumer.handle_one_from_queue(mock_msg)
|
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||||
|
|
||||||
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
|
consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
|
||||||
|
|
||||||
|
|
@ -219,7 +219,7 @@ class TestMetricsIntegration:
|
||||||
mock_metrics.record_time.return_value.__exit__ = MagicMock()
|
mock_metrics.record_time.return_value.__exit__ = MagicMock()
|
||||||
consumer.metrics = mock_metrics
|
consumer.metrics = mock_metrics
|
||||||
|
|
||||||
await consumer.handle_one_from_queue(mock_msg)
|
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||||
|
|
||||||
mock_metrics.process.assert_called_once_with("success")
|
mock_metrics.process.assert_called_once_with("success")
|
||||||
|
|
||||||
|
|
@ -235,7 +235,7 @@ class TestMetricsIntegration:
|
||||||
mock_metrics = MagicMock()
|
mock_metrics = MagicMock()
|
||||||
consumer.metrics = mock_metrics
|
consumer.metrics = mock_metrics
|
||||||
|
|
||||||
await consumer.handle_one_from_queue(mock_msg)
|
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||||
|
|
||||||
mock_metrics.process.assert_called_once_with("error")
|
mock_metrics.process.assert_called_once_with("error")
|
||||||
|
|
||||||
|
|
@ -261,7 +261,7 @@ class TestMetricsIntegration:
|
||||||
mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
|
mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
consumer.metrics = mock_metrics
|
consumer.metrics = mock_metrics
|
||||||
|
|
||||||
await consumer.handle_one_from_queue(mock_msg)
|
await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
|
||||||
|
|
||||||
mock_metrics.rate_limit.assert_called_once()
|
mock_metrics.rate_limit.assert_called_once()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,8 @@ class MockAsyncProcessor:
|
||||||
class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
||||||
"""Test Mistral OCR processor functionality"""
|
"""Test Mistral OCR processor functionality"""
|
||||||
|
|
||||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_processor_initialization_with_api_key(
|
async def test_processor_initialization_with_api_key(
|
||||||
|
|
@ -51,8 +51,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
||||||
assert consumer_specs[0].name == "input"
|
assert consumer_specs[0].name == "input"
|
||||||
assert consumer_specs[0].schema == Document
|
assert consumer_specs[0].schema == Document
|
||||||
|
|
||||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_processor_initialization_without_api_key(
|
async def test_processor_initialization_without_api_key(
|
||||||
self, mock_producer, mock_consumer
|
self, mock_producer, mock_consumer
|
||||||
|
|
@ -66,8 +66,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
||||||
with pytest.raises(RuntimeError, match="Mistral API key not specified"):
|
with pytest.raises(RuntimeError, match="Mistral API key not specified"):
|
||||||
Processor(**config)
|
Processor(**config)
|
||||||
|
|
||||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_ocr_single_chunk(
|
async def test_ocr_single_chunk(
|
||||||
|
|
@ -131,8 +131,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
||||||
)
|
)
|
||||||
mock_mistral.ocr.process.assert_called_once()
|
mock_mistral.ocr.process.assert_called_once()
|
||||||
|
|
||||||
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.decoding.mistral_ocr.processor.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_success(
|
async def test_on_message_success(
|
||||||
|
|
@ -172,7 +172,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
# Mock save_child_document
|
# Mock save_child_document
|
||||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||||
|
|
||||||
with patch.object(processor, 'ocr', return_value=ocr_result):
|
with patch.object(processor, 'ocr', return_value=ocr_result):
|
||||||
await processor.on_message(mock_msg, None, mock_flow)
|
await processor.on_message(mock_msg, None, mock_flow)
|
||||||
|
|
|
||||||
|
|
@ -24,12 +24,10 @@ class MockAsyncProcessor:
|
||||||
class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
"""Test PDF decoder processor functionality"""
|
"""Test PDF decoder processor functionality"""
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_processor_initialization(self, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
async def test_processor_initialization(self, mock_producer, mock_consumer):
|
||||||
"""Test PDF decoder processor initialization"""
|
"""Test PDF decoder processor initialization"""
|
||||||
config = {
|
config = {
|
||||||
'id': 'test-pdf-decoder',
|
'id': 'test-pdf-decoder',
|
||||||
|
|
@ -44,13 +42,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
assert consumer_specs[0].name == "input"
|
assert consumer_specs[0].name == "input"
|
||||||
assert consumer_specs[0].schema == Document
|
assert consumer_specs[0].schema == Document
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||||
"""Test successful PDF processing"""
|
"""Test successful PDF processing"""
|
||||||
# Mock PDF content
|
# Mock PDF content
|
||||||
pdf_content = b"fake pdf content"
|
pdf_content = b"fake pdf content"
|
||||||
|
|
@ -85,7 +81,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Mock save_child_document to avoid waiting for librarian response
|
# Mock save_child_document to avoid waiting for librarian response
|
||||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||||
|
|
||||||
await processor.on_message(mock_msg, None, mock_flow)
|
await processor.on_message(mock_msg, None, mock_flow)
|
||||||
|
|
||||||
|
|
@ -94,13 +90,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
# Verify triples were sent for each page (provenance)
|
# Verify triples were sent for each page (provenance)
|
||||||
assert mock_triples_flow.send.call_count == 2
|
assert mock_triples_flow.send.call_count == 2
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||||
"""Test handling of empty PDF"""
|
"""Test handling of empty PDF"""
|
||||||
pdf_content = b"fake pdf content"
|
pdf_content = b"fake pdf content"
|
||||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||||
|
|
@ -128,13 +122,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
mock_output_flow.send.assert_not_called()
|
mock_output_flow.send.assert_not_called()
|
||||||
|
|
||||||
@patch('trustgraph.base.chunking_service.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.base.chunking_service.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
|
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
|
|
||||||
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer):
|
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
|
||||||
"""Test handling of unicode content in PDF"""
|
"""Test handling of unicode content in PDF"""
|
||||||
pdf_content = b"fake pdf content"
|
pdf_content = b"fake pdf content"
|
||||||
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
|
||||||
|
|
@ -165,7 +157,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
|
||||||
processor = Processor(**config)
|
processor = Processor(**config)
|
||||||
|
|
||||||
# Mock save_child_document to avoid waiting for librarian response
|
# Mock save_child_document to avoid waiting for librarian response
|
||||||
processor.save_child_document = AsyncMock(return_value="mock-doc-id")
|
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
|
||||||
|
|
||||||
await processor.on_message(mock_msg, None, mock_flow)
|
await processor.on_message(mock_msg, None, mock_flow)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -142,8 +142,8 @@ class TestPageBasedFormats:
|
||||||
class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
"""Test universal decoder processor."""
|
"""Test universal decoder processor."""
|
||||||
|
|
||||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_processor_initialization(
|
async def test_processor_initialization(
|
||||||
self, mock_producer, mock_consumer
|
self, mock_producer, mock_consumer
|
||||||
|
|
@ -169,8 +169,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
assert consumer_specs[0].name == "input"
|
assert consumer_specs[0].name == "input"
|
||||||
assert consumer_specs[0].schema == Document
|
assert consumer_specs[0].schema == Document
|
||||||
|
|
||||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_processor_custom_strategy(
|
async def test_processor_custom_strategy(
|
||||||
self, mock_producer, mock_consumer
|
self, mock_producer, mock_consumer
|
||||||
|
|
@ -188,8 +188,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
assert processor.partition_strategy == "hi_res"
|
assert processor.partition_strategy == "hi_res"
|
||||||
assert processor.section_strategy_name == "heading"
|
assert processor.section_strategy_name == "heading"
|
||||||
|
|
||||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_group_by_page(self, mock_producer, mock_consumer):
|
async def test_group_by_page(self, mock_producer, mock_consumer):
|
||||||
"""Test page grouping of elements."""
|
"""Test page grouping of elements."""
|
||||||
|
|
@ -214,8 +214,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
assert result[1][0] == 2
|
assert result[1][0] == 2
|
||||||
assert len(result[1][1]) == 1
|
assert len(result[1][1]) == 1
|
||||||
|
|
||||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.universal.processor.partition')
|
@patch('trustgraph.decoding.universal.processor.partition')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_inline_non_page(
|
async def test_on_message_inline_non_page(
|
||||||
|
|
@ -255,7 +255,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
}.get(name))
|
}.get(name))
|
||||||
|
|
||||||
# Mock save_child_document and magic
|
# Mock save_child_document and magic
|
||||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||||
|
|
||||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||||
mock_magic.from_buffer.return_value = "text/markdown"
|
mock_magic.from_buffer.return_value = "text/markdown"
|
||||||
|
|
@ -271,8 +271,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
assert call_args.document_id.startswith("urn:section:")
|
assert call_args.document_id.startswith("urn:section:")
|
||||||
assert call_args.text == b""
|
assert call_args.text == b""
|
||||||
|
|
||||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.universal.processor.partition')
|
@patch('trustgraph.decoding.universal.processor.partition')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_on_message_page_based(
|
async def test_on_message_page_based(
|
||||||
|
|
@ -310,7 +310,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
"triples": mock_triples_flow,
|
"triples": mock_triples_flow,
|
||||||
}.get(name))
|
}.get(name))
|
||||||
|
|
||||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||||
|
|
||||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||||
mock_magic.from_buffer.return_value = "application/pdf"
|
mock_magic.from_buffer.return_value = "application/pdf"
|
||||||
|
|
@ -323,8 +323,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
call_args = mock_output_flow.send.call_args_list[0][0][0]
|
call_args = mock_output_flow.send.call_args_list[0][0][0]
|
||||||
assert call_args.document_id.startswith("urn:page:")
|
assert call_args.document_id.startswith("urn:page:")
|
||||||
|
|
||||||
@patch('trustgraph.decoding.universal.processor.Consumer')
|
@patch('trustgraph.base.librarian_client.Consumer')
|
||||||
@patch('trustgraph.decoding.universal.processor.Producer')
|
@patch('trustgraph.base.librarian_client.Producer')
|
||||||
@patch('trustgraph.decoding.universal.processor.partition')
|
@patch('trustgraph.decoding.universal.processor.partition')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
|
||||||
async def test_images_stored_not_emitted(
|
async def test_images_stored_not_emitted(
|
||||||
|
|
@ -361,7 +361,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
"triples": mock_triples_flow,
|
"triples": mock_triples_flow,
|
||||||
}.get(name))
|
}.get(name))
|
||||||
|
|
||||||
processor.save_child_document = AsyncMock(return_value="mock-id")
|
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
|
||||||
|
|
||||||
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
|
||||||
mock_magic.from_buffer.return_value = "application/pdf"
|
mock_magic.from_buffer.return_value = "application/pdf"
|
||||||
|
|
@ -374,7 +374,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
|
||||||
assert mock_triples_flow.send.call_count == 2
|
assert mock_triples_flow.send.call_count == 2
|
||||||
|
|
||||||
# save_child_document called twice (page + image)
|
# save_child_document called twice (page + image)
|
||||||
assert processor.save_child_document.call_count == 2
|
assert processor.librarian.save_child_document.call_count == 2
|
||||||
|
|
||||||
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
|
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
|
||||||
def test_add_args(self, mock_parent_add_args):
|
def test_add_args(self, mock_parent_add_args):
|
||||||
|
|
|
||||||
|
|
@ -109,6 +109,37 @@ class TestAddPubsubArgs:
|
||||||
assert args.pubsub_backend == 'pulsar'
|
assert args.pubsub_backend == 'pulsar'
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddPubsubArgsRabbitMQ:
|
||||||
|
|
||||||
|
def test_rabbitmq_args_present(self):
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
add_pubsub_args(parser)
|
||||||
|
args = parser.parse_args([
|
||||||
|
'--pubsub-backend', 'rabbitmq',
|
||||||
|
'--rabbitmq-host', 'myhost',
|
||||||
|
'--rabbitmq-port', '5673',
|
||||||
|
])
|
||||||
|
assert args.pubsub_backend == 'rabbitmq'
|
||||||
|
assert args.rabbitmq_host == 'myhost'
|
||||||
|
assert args.rabbitmq_port == 5673
|
||||||
|
|
||||||
|
def test_rabbitmq_defaults_container(self):
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
add_pubsub_args(parser)
|
||||||
|
args = parser.parse_args([])
|
||||||
|
assert args.rabbitmq_host == 'rabbitmq'
|
||||||
|
assert args.rabbitmq_port == 5672
|
||||||
|
assert args.rabbitmq_username == 'guest'
|
||||||
|
assert args.rabbitmq_password == 'guest'
|
||||||
|
assert args.rabbitmq_vhost == '/'
|
||||||
|
|
||||||
|
def test_rabbitmq_standalone_defaults_to_localhost(self):
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
add_pubsub_args(parser, standalone=True)
|
||||||
|
args = parser.parse_args([])
|
||||||
|
assert args.rabbitmq_host == 'localhost'
|
||||||
|
|
||||||
|
|
||||||
class TestQueueDefinitions:
|
class TestQueueDefinitions:
|
||||||
"""Verify the actual queue constants produce correct names."""
|
"""Verify the actual queue constants produce correct names."""
|
||||||
|
|
||||||
|
|
@ -124,9 +155,9 @@ class TestQueueDefinitions:
|
||||||
from trustgraph.schema.services.config import config_push_queue
|
from trustgraph.schema.services.config import config_push_queue
|
||||||
assert config_push_queue == 'state:tg:config'
|
assert config_push_queue == 'state:tg:config'
|
||||||
|
|
||||||
def test_librarian_request_is_persistent(self):
|
def test_librarian_request(self):
|
||||||
from trustgraph.schema.services.library import librarian_request_queue
|
from trustgraph.schema.services.library import librarian_request_queue
|
||||||
assert librarian_request_queue.startswith('flow:')
|
assert librarian_request_queue == 'request:tg:librarian'
|
||||||
|
|
||||||
def test_knowledge_request(self):
|
def test_knowledge_request(self):
|
||||||
from trustgraph.schema.knowledge.knowledge import knowledge_request_queue
|
from trustgraph.schema.knowledge.knowledge import knowledge_request_queue
|
||||||
|
|
|
||||||
107
tests/unit/test_pubsub/test_rabbitmq_backend.py
Normal file
107
tests/unit/test_pubsub/test_rabbitmq_backend.py
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
"""
|
||||||
|
Unit tests for RabbitMQ backend — queue name mapping and factory dispatch.
|
||||||
|
Does not require a running RabbitMQ instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
pika = pytest.importorskip("pika", reason="pika not installed")
|
||||||
|
|
||||||
|
from trustgraph.base.rabbitmq_backend import RabbitMQBackend
|
||||||
|
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
|
||||||
|
|
||||||
|
|
||||||
|
class TestRabbitMQMapQueueName:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def backend(self):
|
||||||
|
b = object.__new__(RabbitMQBackend)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def test_flow_is_durable(self, backend):
|
||||||
|
name, durable = backend.map_queue_name('flow:tg:text-completion-request')
|
||||||
|
assert durable is True
|
||||||
|
assert name == 'tg.flow.text-completion-request'
|
||||||
|
|
||||||
|
def test_state_is_durable(self, backend):
|
||||||
|
name, durable = backend.map_queue_name('state:tg:config')
|
||||||
|
assert durable is True
|
||||||
|
assert name == 'tg.state.config'
|
||||||
|
|
||||||
|
def test_request_is_not_durable(self, backend):
|
||||||
|
name, durable = backend.map_queue_name('request:tg:config')
|
||||||
|
assert durable is False
|
||||||
|
assert name == 'tg.request.config'
|
||||||
|
|
||||||
|
def test_response_is_not_durable(self, backend):
|
||||||
|
name, durable = backend.map_queue_name('response:tg:librarian')
|
||||||
|
assert durable is False
|
||||||
|
assert name == 'tg.response.librarian'
|
||||||
|
|
||||||
|
def test_custom_topicspace(self, backend):
|
||||||
|
name, durable = backend.map_queue_name('flow:prod:my-queue')
|
||||||
|
assert name == 'prod.flow.my-queue'
|
||||||
|
assert durable is True
|
||||||
|
|
||||||
|
def test_no_colon_defaults_to_flow(self, backend):
|
||||||
|
name, durable = backend.map_queue_name('simple-queue')
|
||||||
|
assert name == 'tg.simple-queue'
|
||||||
|
assert durable is False
|
||||||
|
|
||||||
|
def test_invalid_class_raises(self, backend):
|
||||||
|
with pytest.raises(ValueError, match="Invalid queue class"):
|
||||||
|
backend.map_queue_name('unknown:tg:topic')
|
||||||
|
|
||||||
|
def test_flow_with_flow_suffix(self, backend):
|
||||||
|
"""Queue names with flow suffix (e.g. :default) are preserved."""
|
||||||
|
name, durable = backend.map_queue_name('request:tg:prompt:default')
|
||||||
|
assert name == 'tg.request.prompt:default'
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetPubsubRabbitMQ:
|
||||||
|
|
||||||
|
def test_factory_creates_rabbitmq_backend(self):
|
||||||
|
backend = get_pubsub(pubsub_backend='rabbitmq')
|
||||||
|
assert isinstance(backend, RabbitMQBackend)
|
||||||
|
|
||||||
|
def test_factory_passes_config(self):
|
||||||
|
backend = get_pubsub(
|
||||||
|
pubsub_backend='rabbitmq',
|
||||||
|
rabbitmq_host='myhost',
|
||||||
|
rabbitmq_port=5673,
|
||||||
|
rabbitmq_username='user',
|
||||||
|
rabbitmq_password='pass',
|
||||||
|
rabbitmq_vhost='/test',
|
||||||
|
)
|
||||||
|
assert isinstance(backend, RabbitMQBackend)
|
||||||
|
# Verify connection params were set
|
||||||
|
params = backend._connection_params
|
||||||
|
assert params.host == 'myhost'
|
||||||
|
assert params.port == 5673
|
||||||
|
assert params.virtual_host == '/test'
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddPubsubArgsRabbitMQ:
|
||||||
|
|
||||||
|
def test_rabbitmq_args_present(self):
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
add_pubsub_args(parser)
|
||||||
|
args = parser.parse_args([
|
||||||
|
'--pubsub-backend', 'rabbitmq',
|
||||||
|
'--rabbitmq-host', 'myhost',
|
||||||
|
'--rabbitmq-port', '5673',
|
||||||
|
])
|
||||||
|
assert args.pubsub_backend == 'rabbitmq'
|
||||||
|
assert args.rabbitmq_host == 'myhost'
|
||||||
|
assert args.rabbitmq_port == 5673
|
||||||
|
|
||||||
|
def test_rabbitmq_defaults_container(self):
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
add_pubsub_args(parser)
|
||||||
|
args = parser.parse_args([])
|
||||||
|
assert args.rabbitmq_host == 'rabbitmq'
|
||||||
|
assert args.rabbitmq_port == 5672
|
||||||
|
assert args.rabbitmq_username == 'guest'
|
||||||
|
assert args.rabbitmq_password == 'guest'
|
||||||
|
assert args.rabbitmq_vhost == '/'
|
||||||
|
|
@ -14,6 +14,7 @@ dependencies = [
|
||||||
"prometheus-client",
|
"prometheus-client",
|
||||||
"requests",
|
"requests",
|
||||||
"python-logging-loki",
|
"python-logging-loki",
|
||||||
|
"pika",
|
||||||
]
|
]
|
||||||
classifiers = [
|
classifiers = [
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,9 @@ logger = logging.getLogger(__name__)
|
||||||
# Lower threshold provides progress feedback and resumability on slower connections
|
# Lower threshold provides progress feedback and resumability on slower connections
|
||||||
CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024
|
CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024
|
||||||
|
|
||||||
# Default chunk size (5MB - S3 multipart minimum)
|
# Default chunk size (3MB - stays under broker message size limits
|
||||||
DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024
|
# after base64 encoding ~4MB)
|
||||||
|
DEFAULT_CHUNK_SIZE = 3 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
def to_value(x):
|
def to_value(x):
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from . producer_spec import ProducerSpec
|
||||||
from . subscriber_spec import SubscriberSpec
|
from . subscriber_spec import SubscriberSpec
|
||||||
from . request_response_spec import RequestResponseSpec
|
from . request_response_spec import RequestResponseSpec
|
||||||
from . llm_service import LlmService, LlmResult, LlmChunk
|
from . llm_service import LlmService, LlmResult, LlmChunk
|
||||||
|
from . librarian_client import LibrarianClient
|
||||||
from . chunking_service import ChunkingService
|
from . chunking_service import ChunkingService
|
||||||
from . embeddings_service import EmbeddingsService
|
from . embeddings_service import EmbeddingsService
|
||||||
from . embeddings_client import EmbeddingsClientSpec
|
from . embeddings_client import EmbeddingsClientSpec
|
||||||
|
|
|
||||||
|
|
@ -68,11 +68,12 @@ class AsyncProcessor:
|
||||||
processor = self.id, flow = None, name = "config",
|
processor = self.id, flow = None, name = "config",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subscribe to config queue
|
# Subscribe to config queue — exclusive so every processor
|
||||||
|
# gets its own copy of config pushes (broadcast pattern)
|
||||||
self.config_sub_task = Consumer(
|
self.config_sub_task = Consumer(
|
||||||
|
|
||||||
taskgroup = self.taskgroup,
|
taskgroup = self.taskgroup,
|
||||||
backend = self.pubsub_backend, # Changed from client to backend
|
backend = self.pubsub_backend,
|
||||||
subscriber = config_subscriber_id,
|
subscriber = config_subscriber_id,
|
||||||
flow = None,
|
flow = None,
|
||||||
|
|
||||||
|
|
@ -83,9 +84,8 @@ class AsyncProcessor:
|
||||||
|
|
||||||
metrics = config_consumer_metrics,
|
metrics = config_consumer_metrics,
|
||||||
|
|
||||||
# This causes new subscriptions to view the entire history of
|
start_of_messages = True,
|
||||||
# configuration
|
consumer_type = 'exclusive',
|
||||||
start_of_messages = True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|
|
||||||
|
|
@ -7,23 +7,14 @@ fetching large document content.
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
|
|
||||||
from .flow_processor import FlowProcessor
|
from .flow_processor import FlowProcessor
|
||||||
from .parameter_spec import ParameterSpec
|
from .parameter_spec import ParameterSpec
|
||||||
from .consumer import Consumer
|
from .librarian_client import LibrarianClient
|
||||||
from .producer import Producer
|
|
||||||
from .metrics import ConsumerMetrics, ProducerMetrics
|
|
||||||
|
|
||||||
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
|
||||||
from ..schema import librarian_request_queue, librarian_response_queue
|
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_librarian_request_queue = librarian_request_queue
|
|
||||||
default_librarian_response_queue = librarian_response_queue
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkingService(FlowProcessor):
|
class ChunkingService(FlowProcessor):
|
||||||
"""Base service for chunking processors with parameter specification support"""
|
"""Base service for chunking processors with parameter specification support"""
|
||||||
|
|
@ -44,155 +35,18 @@ class ChunkingService(FlowProcessor):
|
||||||
ParameterSpec(name="chunk-overlap")
|
ParameterSpec(name="chunk-overlap")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client for fetching document content
|
# Librarian client
|
||||||
librarian_request_q = params.get(
|
self.librarian = LibrarianClient(
|
||||||
"librarian_request_queue", default_librarian_request_queue
|
id=id,
|
||||||
)
|
|
||||||
librarian_response_q = params.get(
|
|
||||||
"librarian_response_queue", default_librarian_response_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_request_metrics = ProducerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_request_producer = Producer(
|
|
||||||
backend=self.pubsub,
|
backend=self.pubsub,
|
||||||
topic=librarian_request_q,
|
|
||||||
schema=LibrarianRequest,
|
|
||||||
metrics=librarian_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_response_metrics = ConsumerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_response_consumer = Consumer(
|
|
||||||
taskgroup=self.taskgroup,
|
taskgroup=self.taskgroup,
|
||||||
backend=self.pubsub,
|
|
||||||
flow=None,
|
|
||||||
topic=librarian_response_q,
|
|
||||||
subscriber=f"{id}-librarian",
|
|
||||||
schema=LibrarianResponse,
|
|
||||||
handler=self.on_librarian_response,
|
|
||||||
metrics=librarian_response_metrics,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pending librarian requests: request_id -> asyncio.Future
|
|
||||||
self.pending_requests = {}
|
|
||||||
|
|
||||||
logger.debug("ChunkingService initialized with parameter specifications")
|
logger.debug("ChunkingService initialized with parameter specifications")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await super(ChunkingService, self).start()
|
await super(ChunkingService, self).start()
|
||||||
await self.librarian_request_producer.start()
|
await self.librarian.start()
|
||||||
await self.librarian_response_consumer.start()
|
|
||||||
|
|
||||||
async def on_librarian_response(self, msg, consumer, flow):
|
|
||||||
"""Handle responses from the librarian service."""
|
|
||||||
response = msg.value()
|
|
||||||
request_id = msg.properties().get("id")
|
|
||||||
|
|
||||||
if request_id and request_id in self.pending_requests:
|
|
||||||
future = self.pending_requests.pop(request_id)
|
|
||||||
future.set_result(response)
|
|
||||||
|
|
||||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
|
||||||
"""
|
|
||||||
Fetch document content from librarian via Pulsar.
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="get-document-content",
|
|
||||||
document_id=document_id,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create future for response
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.content
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout fetching document {document_id}")
|
|
||||||
|
|
||||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
|
||||||
document_type="chunk", title=None, timeout=120):
|
|
||||||
"""
|
|
||||||
Save a child document (chunk) to the librarian.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id: ID for the new child document
|
|
||||||
parent_id: ID of the parent document
|
|
||||||
user: User ID
|
|
||||||
content: Document content (bytes or str)
|
|
||||||
document_type: Type of document ("chunk", etc.)
|
|
||||||
title: Optional title
|
|
||||||
timeout: Request timeout in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The document ID on success
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
if isinstance(content, str):
|
|
||||||
content = content.encode("utf-8")
|
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
|
||||||
id=doc_id,
|
|
||||||
user=user,
|
|
||||||
kind="text/plain",
|
|
||||||
title=title or doc_id,
|
|
||||||
parent_id=parent_id,
|
|
||||||
document_type=document_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="add-child-document",
|
|
||||||
document_metadata=doc_metadata,
|
|
||||||
content=base64.b64encode(content).decode("utf-8"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create future for response
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error saving chunk: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout saving chunk {doc_id}")
|
|
||||||
|
|
||||||
async def get_document_text(self, doc):
|
async def get_document_text(self, doc):
|
||||||
"""
|
"""
|
||||||
|
|
@ -206,14 +60,10 @@ class ChunkingService(FlowProcessor):
|
||||||
"""
|
"""
|
||||||
if doc.document_id and not doc.text:
|
if doc.document_id and not doc.text:
|
||||||
logger.info(f"Fetching document {doc.document_id} from librarian...")
|
logger.info(f"Fetching document {doc.document_id} from librarian...")
|
||||||
content = await self.fetch_document_content(
|
text = await self.librarian.fetch_document_text(
|
||||||
document_id=doc.document_id,
|
document_id=doc.document_id,
|
||||||
user=doc.metadata.user,
|
user=doc.metadata.user,
|
||||||
)
|
)
|
||||||
# Content is base64 encoded
|
|
||||||
if isinstance(content, str):
|
|
||||||
content = content.encode('utf-8')
|
|
||||||
text = base64.b64decode(content).decode("utf-8")
|
|
||||||
logger.info(f"Fetched {len(text)} characters from librarian")
|
logger.info(f"Fetched {len(text)} characters from librarian")
|
||||||
return text
|
return text
|
||||||
else:
|
else:
|
||||||
|
|
@ -224,41 +74,31 @@ class ChunkingService(FlowProcessor):
|
||||||
Extract chunk parameters from flow and return effective values
|
Extract chunk parameters from flow and return effective values
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
msg: The message containing the document to chunk
|
msg: The message being processed
|
||||||
consumer: The consumer spec
|
consumer: The consumer instance
|
||||||
flow: The flow context
|
flow: The flow object containing parameters
|
||||||
default_chunk_size: Default chunk size from processor config
|
default_chunk_size: Default chunk size if not configured
|
||||||
default_chunk_overlap: Default chunk overlap from processor config
|
default_chunk_overlap: Default chunk overlap if not configured
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (chunk_size, chunk_overlap) - effective values to use
|
tuple: (chunk_size, chunk_overlap) effective values
|
||||||
"""
|
"""
|
||||||
# Extract parameters from flow (flow-configurable parameters)
|
|
||||||
chunk_size = flow("chunk-size")
|
|
||||||
chunk_overlap = flow("chunk-overlap")
|
|
||||||
|
|
||||||
# Use provided values or fall back to defaults
|
chunk_size = default_chunk_size
|
||||||
effective_chunk_size = chunk_size if chunk_size is not None else default_chunk_size
|
chunk_overlap = default_chunk_overlap
|
||||||
effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else default_chunk_overlap
|
|
||||||
|
|
||||||
logger.debug(f"Using chunk-size: {effective_chunk_size}")
|
try:
|
||||||
logger.debug(f"Using chunk-overlap: {effective_chunk_overlap}")
|
cs = flow.parameters.get("chunk-size")
|
||||||
|
if cs is not None:
|
||||||
|
chunk_size = int(cs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not parse chunk-size parameter: {e}")
|
||||||
|
|
||||||
return effective_chunk_size, effective_chunk_overlap
|
try:
|
||||||
|
co = flow.parameters.get("chunk-overlap")
|
||||||
|
if co is not None:
|
||||||
|
chunk_overlap = int(co)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not parse chunk-overlap parameter: {e}")
|
||||||
|
|
||||||
@staticmethod
|
return chunk_size, chunk_overlap
|
||||||
def add_args(parser):
|
|
||||||
"""Add chunking service arguments to parser"""
|
|
||||||
FlowProcessor.add_args(parser)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--librarian-request-queue',
|
|
||||||
default=default_librarian_request_queue,
|
|
||||||
help=f'Librarian request queue (default: {default_librarian_request_queue})',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--librarian-response-queue',
|
|
||||||
default=default_librarian_response_queue,
|
|
||||||
help=f'Librarian response queue (default: {default_librarian_response_queue})',
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,7 @@ class Consumer:
|
||||||
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
|
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
|
||||||
reconnect_time = 5,
|
reconnect_time = 5,
|
||||||
concurrency = 1, # Number of concurrent requests to handle
|
concurrency = 1, # Number of concurrent requests to handle
|
||||||
|
consumer_type = 'shared',
|
||||||
):
|
):
|
||||||
|
|
||||||
self.taskgroup = taskgroup
|
self.taskgroup = taskgroup
|
||||||
|
|
@ -42,6 +43,8 @@ class Consumer:
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
|
|
||||||
|
self.consumer_type = consumer_type
|
||||||
|
|
||||||
self.rate_limit_retry_time = rate_limit_retry_time
|
self.rate_limit_retry_time = rate_limit_retry_time
|
||||||
self.rate_limit_timeout = rate_limit_timeout
|
self.rate_limit_timeout = rate_limit_timeout
|
||||||
|
|
||||||
|
|
@ -93,33 +96,11 @@ class Consumer:
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.state("stopped")
|
self.metrics.state("stopped")
|
||||||
|
|
||||||
try:
|
# Determine initial position
|
||||||
|
if self.start_of_messages:
|
||||||
logger.info(f"Subscribing to topic: {self.topic}")
|
initial_pos = 'earliest'
|
||||||
|
else:
|
||||||
# Determine initial position
|
initial_pos = 'latest'
|
||||||
if self.start_of_messages:
|
|
||||||
initial_pos = 'earliest'
|
|
||||||
else:
|
|
||||||
initial_pos = 'latest'
|
|
||||||
|
|
||||||
# Create consumer via backend
|
|
||||||
self.consumer = await asyncio.to_thread(
|
|
||||||
self.backend.create_consumer,
|
|
||||||
topic = self.topic,
|
|
||||||
subscription = self.subscriber,
|
|
||||||
schema = self.schema,
|
|
||||||
initial_position = initial_pos,
|
|
||||||
consumer_type = 'shared',
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
|
|
||||||
logger.error(f"Consumer subscription exception: {e}", exc_info=True)
|
|
||||||
await asyncio.sleep(self.reconnect_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info(f"Successfully subscribed to topic: {self.topic}")
|
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.state("running")
|
self.metrics.state("running")
|
||||||
|
|
@ -128,14 +109,30 @@ class Consumer:
|
||||||
|
|
||||||
logger.info(f"Starting {self.concurrency} receiver threads")
|
logger.info(f"Starting {self.concurrency} receiver threads")
|
||||||
|
|
||||||
async with asyncio.TaskGroup() as tg:
|
# Create one backend consumer per concurrent task.
|
||||||
|
# Each gets its own connection — required for backends
|
||||||
tasks = []
|
# like RabbitMQ where connections are not thread-safe.
|
||||||
|
consumers = []
|
||||||
for i in range(0, self.concurrency):
|
for i in range(self.concurrency):
|
||||||
tasks.append(
|
try:
|
||||||
tg.create_task(self.consume_from_queue())
|
logger.info(f"Subscribing to topic: {self.topic} (worker {i})")
|
||||||
|
c = await asyncio.to_thread(
|
||||||
|
self.backend.create_consumer,
|
||||||
|
topic = self.topic,
|
||||||
|
subscription = self.subscriber,
|
||||||
|
schema = self.schema,
|
||||||
|
initial_position = initial_pos,
|
||||||
|
consumer_type = self.consumer_type,
|
||||||
)
|
)
|
||||||
|
consumers.append(c)
|
||||||
|
logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async with asyncio.TaskGroup() as tg:
|
||||||
|
for c in consumers:
|
||||||
|
tg.create_task(self.consume_from_queue(c))
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.state("stopped")
|
self.metrics.state("stopped")
|
||||||
|
|
@ -143,23 +140,31 @@ class Consumer:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
logger.error(f"Consumer loop exception: {e}", exc_info=True)
|
logger.error(f"Consumer loop exception: {e}", exc_info=True)
|
||||||
self.consumer.unsubscribe()
|
for c in consumers:
|
||||||
self.consumer.close()
|
try:
|
||||||
self.consumer = None
|
c.unsubscribe()
|
||||||
|
c.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
consumers = []
|
||||||
await asyncio.sleep(self.reconnect_time)
|
await asyncio.sleep(self.reconnect_time)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.consumer:
|
finally:
|
||||||
self.consumer.unsubscribe()
|
for c in consumers:
|
||||||
self.consumer.close()
|
try:
|
||||||
|
c.unsubscribe()
|
||||||
|
c.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
async def consume_from_queue(self):
|
async def consume_from_queue(self, consumer):
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.to_thread(
|
msg = await asyncio.to_thread(
|
||||||
self.consumer.receive,
|
consumer.receive,
|
||||||
timeout_millis=2000
|
timeout_millis=2000
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -168,9 +173,9 @@ class Consumer:
|
||||||
continue
|
continue
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
await self.handle_one_from_queue(msg)
|
await self.handle_one_from_queue(msg, consumer)
|
||||||
|
|
||||||
async def handle_one_from_queue(self, msg):
|
async def handle_one_from_queue(self, msg, consumer):
|
||||||
|
|
||||||
expiry = time.time() + self.rate_limit_timeout
|
expiry = time.time() + self.rate_limit_timeout
|
||||||
|
|
||||||
|
|
@ -183,7 +188,7 @@ class Consumer:
|
||||||
|
|
||||||
# Message failed to be processed, this causes it to
|
# Message failed to be processed, this causes it to
|
||||||
# be retried
|
# be retried
|
||||||
self.consumer.negative_acknowledge(msg)
|
consumer.negative_acknowledge(msg)
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.process("error")
|
self.metrics.process("error")
|
||||||
|
|
@ -206,7 +211,7 @@ class Consumer:
|
||||||
logger.debug("Message processed successfully")
|
logger.debug("Message processed successfully")
|
||||||
|
|
||||||
# Acknowledge successful processing of the message
|
# Acknowledge successful processing of the message
|
||||||
self.consumer.acknowledge(msg)
|
consumer.acknowledge(msg)
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.process("success")
|
self.metrics.process("success")
|
||||||
|
|
@ -233,7 +238,7 @@ class Consumer:
|
||||||
|
|
||||||
# Message failed to be processed, this causes it to
|
# Message failed to be processed, this causes it to
|
||||||
# be retried
|
# be retried
|
||||||
self.consumer.negative_acknowledge(msg)
|
consumer.negative_acknowledge(msg)
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.process("error")
|
self.metrics.process("error")
|
||||||
|
|
|
||||||
246
trustgraph-base/trustgraph/base/librarian_client.py
Normal file
246
trustgraph-base/trustgraph/base/librarian_client.py
Normal file
|
|
@ -0,0 +1,246 @@
|
||||||
|
"""
|
||||||
|
Shared librarian client for services that need to communicate
|
||||||
|
with the librarian via pub/sub.
|
||||||
|
|
||||||
|
Provides request-response and streaming operations over the message
|
||||||
|
broker, with proper support for large documents via stream-document.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
self.librarian = LibrarianClient(
|
||||||
|
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
|
||||||
|
)
|
||||||
|
await self.librarian.start()
|
||||||
|
content = await self.librarian.fetch_document_content(doc_id, user)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from .consumer import Consumer
|
||||||
|
from .producer import Producer
|
||||||
|
from .metrics import ConsumerMetrics, ProducerMetrics
|
||||||
|
|
||||||
|
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||||
|
from ..schema import librarian_request_queue, librarian_response_queue
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LibrarianClient:
|
||||||
|
"""Client for librarian request-response over the message broker."""
|
||||||
|
|
||||||
|
def __init__(self, id, backend, taskgroup, **params):
|
||||||
|
|
||||||
|
librarian_request_q = params.get(
|
||||||
|
"librarian_request_queue", librarian_request_queue,
|
||||||
|
)
|
||||||
|
librarian_response_q = params.get(
|
||||||
|
"librarian_response_queue", librarian_response_queue,
|
||||||
|
)
|
||||||
|
|
||||||
|
librarian_request_metrics = ProducerMetrics(
|
||||||
|
processor=id, flow=None, name="librarian-request",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._producer = Producer(
|
||||||
|
backend=backend,
|
||||||
|
topic=librarian_request_q,
|
||||||
|
schema=LibrarianRequest,
|
||||||
|
metrics=librarian_request_metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
librarian_response_metrics = ConsumerMetrics(
|
||||||
|
processor=id, flow=None, name="librarian-response",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._consumer = Consumer(
|
||||||
|
taskgroup=taskgroup,
|
||||||
|
backend=backend,
|
||||||
|
flow=None,
|
||||||
|
topic=librarian_response_q,
|
||||||
|
subscriber=f"{id}-librarian",
|
||||||
|
schema=LibrarianResponse,
|
||||||
|
handler=self._on_response,
|
||||||
|
metrics=librarian_response_metrics,
|
||||||
|
consumer_type='exclusive',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Single-response requests: request_id -> asyncio.Future
|
||||||
|
self._pending = {}
|
||||||
|
# Streaming requests: request_id -> asyncio.Queue
|
||||||
|
self._streams = {}
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the librarian producer and consumer."""
|
||||||
|
await self._producer.start()
|
||||||
|
await self._consumer.start()
|
||||||
|
|
||||||
|
async def _on_response(self, msg, consumer, flow):
|
||||||
|
"""Route librarian responses to the right waiter."""
|
||||||
|
response = msg.value()
|
||||||
|
request_id = msg.properties().get("id")
|
||||||
|
|
||||||
|
if not request_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
if request_id in self._pending:
|
||||||
|
future = self._pending.pop(request_id)
|
||||||
|
future.set_result(response)
|
||||||
|
elif request_id in self._streams:
|
||||||
|
await self._streams[request_id].put(response)
|
||||||
|
|
||||||
|
async def request(self, request, timeout=120):
|
||||||
|
"""Send a request to the librarian and wait for a single response."""
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
future = asyncio.get_event_loop().create_future()
|
||||||
|
self._pending[request_id] = future
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._producer.send(
|
||||||
|
request, properties={"id": request_id},
|
||||||
|
)
|
||||||
|
response = await asyncio.wait_for(future, timeout=timeout)
|
||||||
|
|
||||||
|
if response.error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Librarian error: {response.error.type}: "
|
||||||
|
f"{response.error.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._pending.pop(request_id, None)
|
||||||
|
raise RuntimeError("Timeout waiting for librarian response")
|
||||||
|
|
||||||
|
async def stream(self, request, timeout=120):
|
||||||
|
"""Send a request and collect streamed response chunks."""
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
q = asyncio.Queue()
|
||||||
|
self._streams[request_id] = q
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._producer.send(
|
||||||
|
request, properties={"id": request_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
while True:
|
||||||
|
response = await asyncio.wait_for(q.get(), timeout=timeout)
|
||||||
|
|
||||||
|
if response.error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Librarian error: {response.error.type}: "
|
||||||
|
f"{response.error.message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks.append(response)
|
||||||
|
|
||||||
|
if response.is_final:
|
||||||
|
break
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._streams.pop(request_id, None)
|
||||||
|
raise RuntimeError("Timeout waiting for librarian stream")
|
||||||
|
finally:
|
||||||
|
self._streams.pop(request_id, None)
|
||||||
|
|
||||||
|
async def fetch_document_content(self, document_id, user, timeout=120):
|
||||||
|
"""Fetch document content using streaming.
|
||||||
|
|
||||||
|
Returns base64-encoded content. Caller is responsible for decoding.
|
||||||
|
"""
|
||||||
|
req = LibrarianRequest(
|
||||||
|
operation="stream-document",
|
||||||
|
document_id=document_id,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
chunks = await self.stream(req, timeout=timeout)
|
||||||
|
|
||||||
|
# Decode each chunk's base64 to raw bytes, concatenate,
|
||||||
|
# re-encode for the caller.
|
||||||
|
raw = b""
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.content:
|
||||||
|
if isinstance(chunk.content, bytes):
|
||||||
|
raw += base64.b64decode(chunk.content)
|
||||||
|
else:
|
||||||
|
raw += base64.b64decode(
|
||||||
|
chunk.content.encode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
return base64.b64encode(raw)
|
||||||
|
|
||||||
|
async def fetch_document_text(self, document_id, user, timeout=120):
|
||||||
|
"""Fetch document content and decode as UTF-8 text."""
|
||||||
|
content = await self.fetch_document_content(
|
||||||
|
document_id, user, timeout=timeout,
|
||||||
|
)
|
||||||
|
return base64.b64decode(content).decode("utf-8")
|
||||||
|
|
||||||
|
async def fetch_document_metadata(self, document_id, user, timeout=120):
|
||||||
|
"""Fetch document metadata from the librarian."""
|
||||||
|
req = LibrarianRequest(
|
||||||
|
operation="get-document-metadata",
|
||||||
|
document_id=document_id,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
response = await self.request(req, timeout=timeout)
|
||||||
|
return response.document_metadata
|
||||||
|
|
||||||
|
async def save_child_document(self, doc_id, parent_id, user, content,
|
||||||
|
document_type="chunk", title=None,
|
||||||
|
kind="text/plain", timeout=120):
|
||||||
|
"""Save a child document to the librarian."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
content = content.encode("utf-8")
|
||||||
|
|
||||||
|
doc_metadata = DocumentMetadata(
|
||||||
|
id=doc_id,
|
||||||
|
user=user,
|
||||||
|
kind=kind,
|
||||||
|
title=title or doc_id,
|
||||||
|
parent_id=parent_id,
|
||||||
|
document_type=document_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
req = LibrarianRequest(
|
||||||
|
operation="add-child-document",
|
||||||
|
document_metadata=doc_metadata,
|
||||||
|
content=base64.b64encode(content).decode("utf-8"),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.request(req, timeout=timeout)
|
||||||
|
return doc_id
|
||||||
|
|
||||||
|
async def save_document(self, doc_id, user, content, title=None,
|
||||||
|
document_type="answer", kind="text/plain",
|
||||||
|
timeout=120):
|
||||||
|
"""Save a document to the librarian."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
content = content.encode("utf-8")
|
||||||
|
|
||||||
|
doc_metadata = DocumentMetadata(
|
||||||
|
id=doc_id,
|
||||||
|
user=user,
|
||||||
|
kind=kind,
|
||||||
|
title=title or doc_id,
|
||||||
|
document_type=document_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
req = LibrarianRequest(
|
||||||
|
operation="add-document",
|
||||||
|
document_id=doc_id,
|
||||||
|
document_metadata=doc_metadata,
|
||||||
|
content=base64.b64encode(content).decode("utf-8"),
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.request(req, timeout=timeout)
|
||||||
|
return doc_id
|
||||||
|
|
@ -8,6 +8,12 @@ logger = logging.getLogger(__name__)
|
||||||
DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
|
DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
|
||||||
DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None)
|
DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None)
|
||||||
|
|
||||||
|
DEFAULT_RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", 'rabbitmq')
|
||||||
|
DEFAULT_RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", '5672'))
|
||||||
|
DEFAULT_RABBITMQ_USERNAME = os.getenv("RABBITMQ_USERNAME", 'guest')
|
||||||
|
DEFAULT_RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", 'guest')
|
||||||
|
DEFAULT_RABBITMQ_VHOST = os.getenv("RABBITMQ_VHOST", '/')
|
||||||
|
|
||||||
|
|
||||||
def get_pubsub(**config):
|
def get_pubsub(**config):
|
||||||
"""
|
"""
|
||||||
|
|
@ -29,6 +35,15 @@ def get_pubsub(**config):
|
||||||
api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY),
|
api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY),
|
||||||
listener=config.get('pulsar_listener'),
|
listener=config.get('pulsar_listener'),
|
||||||
)
|
)
|
||||||
|
elif backend_type == 'rabbitmq':
|
||||||
|
from .rabbitmq_backend import RabbitMQBackend
|
||||||
|
return RabbitMQBackend(
|
||||||
|
host=config.get('rabbitmq_host', DEFAULT_RABBITMQ_HOST),
|
||||||
|
port=config.get('rabbitmq_port', DEFAULT_RABBITMQ_PORT),
|
||||||
|
username=config.get('rabbitmq_username', DEFAULT_RABBITMQ_USERNAME),
|
||||||
|
password=config.get('rabbitmq_password', DEFAULT_RABBITMQ_PASSWORD),
|
||||||
|
vhost=config.get('rabbitmq_vhost', DEFAULT_RABBITMQ_VHOST),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
|
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
|
||||||
|
|
||||||
|
|
@ -44,8 +59,9 @@ def add_pubsub_args(parser, standalone=False):
|
||||||
standalone: If True, default host is localhost (for CLI tools
|
standalone: If True, default host is localhost (for CLI tools
|
||||||
that run outside containers)
|
that run outside containers)
|
||||||
"""
|
"""
|
||||||
host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST
|
pulsar_host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST
|
||||||
listener_default = 'localhost' if standalone else None
|
pulsar_listener = 'localhost' if standalone else None
|
||||||
|
rabbitmq_host = 'localhost' if standalone else DEFAULT_RABBITMQ_HOST
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--pubsub-backend',
|
'--pubsub-backend',
|
||||||
|
|
@ -53,10 +69,11 @@ def add_pubsub_args(parser, standalone=False):
|
||||||
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
|
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pulsar options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-p', '--pulsar-host',
|
'-p', '--pulsar-host',
|
||||||
default=host,
|
default=pulsar_host,
|
||||||
help=f'Pulsar host (default: {host})',
|
help=f'Pulsar host (default: {pulsar_host})',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -67,6 +84,38 @@ def add_pubsub_args(parser, standalone=False):
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--pulsar-listener',
|
'--pulsar-listener',
|
||||||
default=listener_default,
|
default=pulsar_listener,
|
||||||
help=f'Pulsar listener (default: {listener_default or "none"})',
|
help=f'Pulsar listener (default: {pulsar_listener or "none"})',
|
||||||
|
)
|
||||||
|
|
||||||
|
# RabbitMQ options
|
||||||
|
parser.add_argument(
|
||||||
|
'--rabbitmq-host',
|
||||||
|
default=rabbitmq_host,
|
||||||
|
help=f'RabbitMQ host (default: {rabbitmq_host})',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--rabbitmq-port',
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_RABBITMQ_PORT,
|
||||||
|
help=f'RabbitMQ port (default: {DEFAULT_RABBITMQ_PORT})',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--rabbitmq-username',
|
||||||
|
default=DEFAULT_RABBITMQ_USERNAME,
|
||||||
|
help='RabbitMQ username',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--rabbitmq-password',
|
||||||
|
default=DEFAULT_RABBITMQ_PASSWORD,
|
||||||
|
help='RabbitMQ password',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--rabbitmq-vhost',
|
||||||
|
default=DEFAULT_RABBITMQ_VHOST,
|
||||||
|
help=f'RabbitMQ vhost (default: {DEFAULT_RABBITMQ_VHOST})',
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -9,122 +9,14 @@ import pulsar
|
||||||
import _pulsar
|
import _pulsar
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import base64
|
from typing import Any
|
||||||
import types
|
|
||||||
from dataclasses import asdict, is_dataclass
|
|
||||||
from typing import Any, get_type_hints
|
|
||||||
|
|
||||||
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
|
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
|
||||||
|
from .serialization import dataclass_to_dict, dict_to_dataclass
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def dataclass_to_dict(obj: Any) -> dict:
|
|
||||||
"""
|
|
||||||
Recursively convert a dataclass to a dictionary, handling None values and bytes.
|
|
||||||
|
|
||||||
None values are excluded from the dictionary (not serialized).
|
|
||||||
Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior).
|
|
||||||
Handles nested dataclasses, lists, and dictionaries recursively.
|
|
||||||
"""
|
|
||||||
if obj is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle bytes - decode to UTF-8 for JSON serialization
|
|
||||||
if isinstance(obj, bytes):
|
|
||||||
return obj.decode('utf-8')
|
|
||||||
|
|
||||||
# Handle dataclass - convert to dict then recursively process all values
|
|
||||||
if is_dataclass(obj):
|
|
||||||
result = {}
|
|
||||||
for key, value in asdict(obj).items():
|
|
||||||
result[key] = dataclass_to_dict(value) if value is not None else None
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Handle list - recursively process all items
|
|
||||||
if isinstance(obj, list):
|
|
||||||
return [dataclass_to_dict(item) for item in obj]
|
|
||||||
|
|
||||||
# Handle dict - recursively process all values
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return {k: dataclass_to_dict(v) for k, v in obj.items()}
|
|
||||||
|
|
||||||
# Return primitive types as-is
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def dict_to_dataclass(data: dict, cls: type) -> Any:
|
|
||||||
"""
|
|
||||||
Convert a dictionary back to a dataclass instance.
|
|
||||||
|
|
||||||
Handles nested dataclasses and missing fields.
|
|
||||||
Uses get_type_hints() to resolve forward references (string annotations).
|
|
||||||
"""
|
|
||||||
if data is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not is_dataclass(cls):
|
|
||||||
return data
|
|
||||||
|
|
||||||
# Get field types from the dataclass, resolving forward references
|
|
||||||
# get_type_hints() evaluates string annotations like "Triple | None"
|
|
||||||
try:
|
|
||||||
field_types = get_type_hints(cls)
|
|
||||||
except Exception:
|
|
||||||
# Fallback if get_type_hints fails (shouldn't happen normally)
|
|
||||||
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
|
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
for key, value in data.items():
|
|
||||||
if key in field_types:
|
|
||||||
field_type = field_types[key]
|
|
||||||
|
|
||||||
# Handle modern union types (X | Y)
|
|
||||||
if isinstance(field_type, types.UnionType):
|
|
||||||
# Check if it's Optional (X | None)
|
|
||||||
if type(None) in field_type.__args__:
|
|
||||||
# Get the non-None type
|
|
||||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
|
||||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
|
||||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
|
||||||
else:
|
|
||||||
kwargs[key] = value
|
|
||||||
else:
|
|
||||||
kwargs[key] = value
|
|
||||||
# Check if this is a generic type (list, dict, etc.)
|
|
||||||
elif hasattr(field_type, '__origin__'):
|
|
||||||
# Handle list[T]
|
|
||||||
if field_type.__origin__ == list:
|
|
||||||
item_type = field_type.__args__[0] if field_type.__args__ else None
|
|
||||||
if item_type and is_dataclass(item_type) and isinstance(value, list):
|
|
||||||
kwargs[key] = [
|
|
||||||
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
|
|
||||||
for item in value
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
kwargs[key] = value
|
|
||||||
# Handle old-style Optional[T] (which is Union[T, None])
|
|
||||||
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
|
|
||||||
# Get the non-None type from Union
|
|
||||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
|
||||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
|
||||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
|
||||||
else:
|
|
||||||
kwargs[key] = value
|
|
||||||
else:
|
|
||||||
kwargs[key] = value
|
|
||||||
# Handle direct dataclass fields
|
|
||||||
elif is_dataclass(field_type) and isinstance(value, dict):
|
|
||||||
kwargs[key] = dict_to_dataclass(value, field_type)
|
|
||||||
# Handle bytes fields (UTF-8 encoded strings from JSON)
|
|
||||||
elif field_type == bytes and isinstance(value, str):
|
|
||||||
kwargs[key] = value.encode('utf-8')
|
|
||||||
else:
|
|
||||||
kwargs[key] = value
|
|
||||||
|
|
||||||
return cls(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class PulsarMessage:
|
class PulsarMessage:
|
||||||
"""Wrapper for Pulsar messages to match Message protocol."""
|
"""Wrapper for Pulsar messages to match Message protocol."""
|
||||||
|
|
||||||
|
|
|
||||||
390
trustgraph-base/trustgraph/base/rabbitmq_backend.py
Normal file
390
trustgraph-base/trustgraph/base/rabbitmq_backend.py
Normal file
|
|
@ -0,0 +1,390 @@
|
||||||
|
"""
|
||||||
|
RabbitMQ backend implementation for pub/sub abstraction.
|
||||||
|
|
||||||
|
Uses a single topic exchange per topicspace. The logical queue name
|
||||||
|
becomes the routing key. Consumer behavior is determined by the
|
||||||
|
subscription name:
|
||||||
|
|
||||||
|
- Same subscription + same topic = shared queue (competing consumers)
|
||||||
|
- Different subscriptions = separate queues (broadcast / fan-out)
|
||||||
|
|
||||||
|
This mirrors Pulsar's subscription model using idiomatic RabbitMQ.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
Producer --> [tg exchange] --routing key--> [named queue] --> Consumer
|
||||||
|
--routing key--> [named queue] --> Consumer
|
||||||
|
--routing key--> [exclusive q] --> Subscriber
|
||||||
|
|
||||||
|
Uses basic_consume (push) instead of basic_get (polling) for
|
||||||
|
efficient message delivery.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import pika
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
|
||||||
|
from .serialization import dataclass_to_dict, dict_to_dataclass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RabbitMQMessage:
|
||||||
|
"""Wrapper for RabbitMQ messages to match Message protocol."""
|
||||||
|
|
||||||
|
def __init__(self, method, properties, body, schema_cls):
|
||||||
|
self._method = method
|
||||||
|
self._properties = properties
|
||||||
|
self._body = body
|
||||||
|
self._schema_cls = schema_cls
|
||||||
|
self._value = None
|
||||||
|
|
||||||
|
def value(self) -> Any:
|
||||||
|
"""Deserialize and return the message value as a dataclass."""
|
||||||
|
if self._value is None:
|
||||||
|
data_dict = json.loads(self._body.decode('utf-8'))
|
||||||
|
self._value = dict_to_dataclass(data_dict, self._schema_cls)
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def properties(self) -> dict:
|
||||||
|
"""Return message properties from AMQP headers."""
|
||||||
|
headers = self._properties.headers or {}
|
||||||
|
return dict(headers)
|
||||||
|
|
||||||
|
|
||||||
|
class RabbitMQBackendProducer:
|
||||||
|
"""Publishes messages to a topic exchange with a routing key.
|
||||||
|
|
||||||
|
Uses thread-local connections so each thread gets its own
|
||||||
|
connection/channel. This avoids wire corruption from concurrent
|
||||||
|
threads writing to the same socket (pika is not thread-safe).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, connection_params, exchange_name, routing_key,
|
||||||
|
durable):
|
||||||
|
self._connection_params = connection_params
|
||||||
|
self._exchange_name = exchange_name
|
||||||
|
self._routing_key = routing_key
|
||||||
|
self._durable = durable
|
||||||
|
self._local = threading.local()
|
||||||
|
|
||||||
|
def _get_channel(self):
|
||||||
|
"""Get or create a thread-local connection and channel."""
|
||||||
|
conn = getattr(self._local, 'connection', None)
|
||||||
|
chan = getattr(self._local, 'channel', None)
|
||||||
|
|
||||||
|
if conn is None or not conn.is_open or chan is None or not chan.is_open:
|
||||||
|
# Close stale connection if any
|
||||||
|
if conn is not None:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
conn = pika.BlockingConnection(self._connection_params)
|
||||||
|
chan = conn.channel()
|
||||||
|
chan.exchange_declare(
|
||||||
|
exchange=self._exchange_name,
|
||||||
|
exchange_type='topic',
|
||||||
|
durable=True,
|
||||||
|
)
|
||||||
|
self._local.connection = conn
|
||||||
|
self._local.channel = chan
|
||||||
|
|
||||||
|
return chan
|
||||||
|
|
||||||
|
def send(self, message: Any, properties: dict = {}) -> None:
|
||||||
|
data_dict = dataclass_to_dict(message)
|
||||||
|
json_data = json.dumps(data_dict)
|
||||||
|
|
||||||
|
amqp_properties = pika.BasicProperties(
|
||||||
|
delivery_mode=2 if self._durable else 1,
|
||||||
|
content_type='application/json',
|
||||||
|
headers=properties if properties else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
for attempt in range(2):
|
||||||
|
try:
|
||||||
|
channel = self._get_channel()
|
||||||
|
channel.basic_publish(
|
||||||
|
exchange=self._exchange_name,
|
||||||
|
routing_key=self._routing_key,
|
||||||
|
body=json_data.encode('utf-8'),
|
||||||
|
properties=amqp_properties,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"RabbitMQ send failed (attempt {attempt + 1}): {e}"
|
||||||
|
)
|
||||||
|
# Force reconnect on next attempt
|
||||||
|
self._local.connection = None
|
||||||
|
self._local.channel = None
|
||||||
|
if attempt == 1:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the thread-local connection if any."""
|
||||||
|
conn = getattr(self._local, 'connection', None)
|
||||||
|
if conn is not None:
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._local.connection = None
|
||||||
|
self._local.channel = None
|
||||||
|
|
||||||
|
|
||||||
|
class RabbitMQBackendConsumer:
|
||||||
|
"""Consumes from a queue bound to a topic exchange.
|
||||||
|
|
||||||
|
Uses basic_consume (push model) with messages delivered to an
|
||||||
|
internal thread-safe queue. process_data_events() drives both
|
||||||
|
message delivery and heartbeat processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, connection_params, exchange_name, routing_key,
|
||||||
|
queue_name, schema_cls, durable, exclusive=False,
|
||||||
|
auto_delete=False):
|
||||||
|
self._connection_params = connection_params
|
||||||
|
self._exchange_name = exchange_name
|
||||||
|
self._routing_key = routing_key
|
||||||
|
self._queue_name = queue_name
|
||||||
|
self._schema_cls = schema_cls
|
||||||
|
self._durable = durable
|
||||||
|
self._exclusive = exclusive
|
||||||
|
self._auto_delete = auto_delete
|
||||||
|
self._connection = None
|
||||||
|
self._channel = None
|
||||||
|
self._consumer_tag = None
|
||||||
|
self._incoming = queue.Queue()
|
||||||
|
|
||||||
|
def _connect(self):
|
||||||
|
self._connection = pika.BlockingConnection(self._connection_params)
|
||||||
|
self._channel = self._connection.channel()
|
||||||
|
|
||||||
|
# Declare the topic exchange
|
||||||
|
self._channel.exchange_declare(
|
||||||
|
exchange=self._exchange_name,
|
||||||
|
exchange_type='topic',
|
||||||
|
durable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Declare the queue — anonymous if exclusive
|
||||||
|
result = self._channel.queue_declare(
|
||||||
|
queue=self._queue_name,
|
||||||
|
durable=self._durable,
|
||||||
|
exclusive=self._exclusive,
|
||||||
|
auto_delete=self._auto_delete,
|
||||||
|
)
|
||||||
|
# Capture actual name (important for anonymous queues where name='')
|
||||||
|
self._queue_name = result.method.queue
|
||||||
|
|
||||||
|
self._channel.queue_bind(
|
||||||
|
queue=self._queue_name,
|
||||||
|
exchange=self._exchange_name,
|
||||||
|
routing_key=self._routing_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._channel.basic_qos(prefetch_count=1)
|
||||||
|
|
||||||
|
# Register push-based consumer
|
||||||
|
self._consumer_tag = self._channel.basic_consume(
|
||||||
|
queue=self._queue_name,
|
||||||
|
on_message_callback=self._on_message,
|
||||||
|
auto_ack=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_message(self, channel, method, properties, body):
|
||||||
|
"""Callback invoked by pika when a message arrives."""
|
||||||
|
self._incoming.put((method, properties, body))
|
||||||
|
|
||||||
|
def _is_alive(self):
|
||||||
|
return (
|
||||||
|
self._connection is not None
|
||||||
|
and self._connection.is_open
|
||||||
|
and self._channel is not None
|
||||||
|
and self._channel.is_open
|
||||||
|
)
|
||||||
|
|
||||||
|
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||||
|
"""Receive a message. Raises TimeoutError if none available."""
|
||||||
|
if not self._is_alive():
|
||||||
|
self._connect()
|
||||||
|
|
||||||
|
timeout_seconds = timeout_millis / 1000.0
|
||||||
|
deadline = time.monotonic() + timeout_seconds
|
||||||
|
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
# Check if a message was already delivered
|
||||||
|
try:
|
||||||
|
method, properties, body = self._incoming.get_nowait()
|
||||||
|
return RabbitMQMessage(
|
||||||
|
method, properties, body, self._schema_cls,
|
||||||
|
)
|
||||||
|
except queue.Empty:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Drive pika's I/O — delivers messages and processes heartbeats
|
||||||
|
remaining = deadline - time.monotonic()
|
||||||
|
if remaining > 0:
|
||||||
|
self._connection.process_data_events(
|
||||||
|
time_limit=min(0.1, remaining),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise TimeoutError("No message received within timeout")
|
||||||
|
|
||||||
|
def acknowledge(self, message: Message) -> None:
|
||||||
|
if isinstance(message, RabbitMQMessage) and message._method:
|
||||||
|
self._channel.basic_ack(
|
||||||
|
delivery_tag=message._method.delivery_tag,
|
||||||
|
)
|
||||||
|
|
||||||
|
def negative_acknowledge(self, message: Message) -> None:
|
||||||
|
if isinstance(message, RabbitMQMessage) and message._method:
|
||||||
|
self._channel.basic_nack(
|
||||||
|
delivery_tag=message._method.delivery_tag,
|
||||||
|
requeue=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def unsubscribe(self) -> None:
|
||||||
|
if self._consumer_tag and self._channel and self._channel.is_open:
|
||||||
|
try:
|
||||||
|
self._channel.basic_cancel(self._consumer_tag)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._consumer_tag = None
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self.unsubscribe()
|
||||||
|
try:
|
||||||
|
if self._channel and self._channel.is_open:
|
||||||
|
self._channel.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if self._connection and self._connection.is_open:
|
||||||
|
self._connection.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._channel = None
|
||||||
|
self._connection = None
|
||||||
|
|
||||||
|
|
||||||
|
class RabbitMQBackend:
|
||||||
|
"""RabbitMQ pub/sub backend using a topic exchange per topicspace."""
|
||||||
|
|
||||||
|
def __init__(self, host='localhost', port=5672, username='guest',
|
||||||
|
password='guest', vhost='/'):
|
||||||
|
self._connection_params = pika.ConnectionParameters(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
virtual_host=vhost,
|
||||||
|
credentials=pika.PlainCredentials(username, password),
|
||||||
|
)
|
||||||
|
logger.info(f"RabbitMQ backend: {host}:{port} vhost={vhost}")
|
||||||
|
|
||||||
|
def _parse_queue_id(self, queue_id: str) -> tuple[str, str, str, bool]:
|
||||||
|
"""
|
||||||
|
Parse queue identifier into exchange, routing key, and durability.
|
||||||
|
|
||||||
|
Format: class:topicspace:topic
|
||||||
|
Returns: (exchange_name, routing_key, class, durable)
|
||||||
|
"""
|
||||||
|
if ':' not in queue_id:
|
||||||
|
return 'tg', queue_id, 'flow', False
|
||||||
|
|
||||||
|
parts = queue_id.split(':', 2)
|
||||||
|
if len(parts) != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid queue format: {queue_id}, "
|
||||||
|
f"expected class:topicspace:topic"
|
||||||
|
)
|
||||||
|
|
||||||
|
cls, topicspace, topic = parts
|
||||||
|
|
||||||
|
if cls in ('flow', 'state'):
|
||||||
|
durable = True
|
||||||
|
elif cls in ('request', 'response'):
|
||||||
|
durable = False
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid queue class: {cls}, "
|
||||||
|
f"expected flow, request, response, or state"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exchange per topicspace, routing key includes class
|
||||||
|
exchange_name = topicspace
|
||||||
|
routing_key = f"{cls}.{topic}"
|
||||||
|
|
||||||
|
return exchange_name, routing_key, cls, durable
|
||||||
|
|
||||||
|
# Keep map_queue_name for backward compatibility with tests
|
||||||
|
def map_queue_name(self, queue_id: str) -> tuple[str, bool]:
|
||||||
|
exchange, routing_key, cls, durable = self._parse_queue_id(queue_id)
|
||||||
|
return f"{exchange}.{routing_key}", durable
|
||||||
|
|
||||||
|
def create_producer(self, topic: str, schema: type,
|
||||||
|
**options) -> BackendProducer:
|
||||||
|
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
|
||||||
|
logger.debug(
|
||||||
|
f"Creating producer: exchange={exchange}, "
|
||||||
|
f"routing_key={routing_key}"
|
||||||
|
)
|
||||||
|
return RabbitMQBackendProducer(
|
||||||
|
self._connection_params, exchange, routing_key, durable,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_consumer(self, topic: str, subscription: str, schema: type,
|
||||||
|
initial_position: str = 'latest',
|
||||||
|
consumer_type: str = 'shared',
|
||||||
|
**options) -> BackendConsumer:
|
||||||
|
"""Create a consumer with a queue bound to the topic exchange.
|
||||||
|
|
||||||
|
consumer_type='shared': Named durable queue. Multiple consumers
|
||||||
|
with the same subscription compete (round-robin).
|
||||||
|
consumer_type='exclusive': Anonymous ephemeral queue. Each
|
||||||
|
consumer gets its own copy of every message (broadcast).
|
||||||
|
"""
|
||||||
|
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
|
||||||
|
|
||||||
|
if consumer_type == 'exclusive' and cls == 'state':
|
||||||
|
# State broadcast: named durable queue per subscriber.
|
||||||
|
# Retains messages so late-starting processors see current state.
|
||||||
|
queue_name = f"{exchange}.{routing_key}.{subscription}"
|
||||||
|
queue_durable = True
|
||||||
|
exclusive = False
|
||||||
|
auto_delete = False
|
||||||
|
elif consumer_type == 'exclusive':
|
||||||
|
# Broadcast: anonymous queue, auto-deleted on disconnect
|
||||||
|
queue_name = ''
|
||||||
|
queue_durable = False
|
||||||
|
exclusive = True
|
||||||
|
auto_delete = True
|
||||||
|
else:
|
||||||
|
# Shared: named queue, competing consumers
|
||||||
|
queue_name = f"{exchange}.{routing_key}.{subscription}"
|
||||||
|
queue_durable = durable
|
||||||
|
exclusive = False
|
||||||
|
auto_delete = False
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Creating consumer: exchange={exchange}, "
|
||||||
|
f"routing_key={routing_key}, queue={queue_name or '(anonymous)'}, "
|
||||||
|
f"type={consumer_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return RabbitMQBackendConsumer(
|
||||||
|
self._connection_params, exchange, routing_key,
|
||||||
|
queue_name, schema, queue_durable, exclusive, auto_delete,
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
pass
|
||||||
115
trustgraph-base/trustgraph/base/serialization.py
Normal file
115
trustgraph-base/trustgraph/base/serialization.py
Normal file
|
|
@ -0,0 +1,115 @@
|
||||||
|
"""
|
||||||
|
JSON serialization helpers for dataclass ↔ dict conversion.
|
||||||
|
|
||||||
|
Used by pub/sub backends that use JSON as their wire format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import types
|
||||||
|
from dataclasses import asdict, is_dataclass
|
||||||
|
from typing import Any, get_type_hints
|
||||||
|
|
||||||
|
|
||||||
|
def dataclass_to_dict(obj: Any) -> dict:
|
||||||
|
"""
|
||||||
|
Recursively convert a dataclass to a dictionary, handling None values and bytes.
|
||||||
|
|
||||||
|
None values are excluded from the dictionary (not serialized).
|
||||||
|
Bytes values are decoded as UTF-8 strings for JSON serialization.
|
||||||
|
Handles nested dataclasses, lists, and dictionaries recursively.
|
||||||
|
"""
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle bytes - decode to UTF-8 for JSON serialization
|
||||||
|
if isinstance(obj, bytes):
|
||||||
|
return obj.decode('utf-8')
|
||||||
|
|
||||||
|
# Handle dataclass - convert to dict then recursively process all values
|
||||||
|
if is_dataclass(obj):
|
||||||
|
result = {}
|
||||||
|
for key, value in asdict(obj).items():
|
||||||
|
result[key] = dataclass_to_dict(value) if value is not None else None
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Handle list - recursively process all items
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [dataclass_to_dict(item) for item in obj]
|
||||||
|
|
||||||
|
# Handle dict - recursively process all values
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: dataclass_to_dict(v) for k, v in obj.items()}
|
||||||
|
|
||||||
|
# Return primitive types as-is
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def dict_to_dataclass(data: dict, cls: type) -> Any:
|
||||||
|
"""
|
||||||
|
Convert a dictionary back to a dataclass instance.
|
||||||
|
|
||||||
|
Handles nested dataclasses and missing fields.
|
||||||
|
Uses get_type_hints() to resolve forward references (string annotations).
|
||||||
|
"""
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not is_dataclass(cls):
|
||||||
|
return data
|
||||||
|
|
||||||
|
# Get field types from the dataclass, resolving forward references
|
||||||
|
# get_type_hints() evaluates string annotations like "Triple | None"
|
||||||
|
try:
|
||||||
|
field_types = get_type_hints(cls)
|
||||||
|
except Exception:
|
||||||
|
# Fallback if get_type_hints fails (shouldn't happen normally)
|
||||||
|
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
for key, value in data.items():
|
||||||
|
if key in field_types:
|
||||||
|
field_type = field_types[key]
|
||||||
|
|
||||||
|
# Handle modern union types (X | Y)
|
||||||
|
if isinstance(field_type, types.UnionType):
|
||||||
|
# Check if it's Optional (X | None)
|
||||||
|
if type(None) in field_type.__args__:
|
||||||
|
# Get the non-None type
|
||||||
|
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||||
|
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||||
|
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
# Check if this is a generic type (list, dict, etc.)
|
||||||
|
elif hasattr(field_type, '__origin__'):
|
||||||
|
# Handle list[T]
|
||||||
|
if field_type.__origin__ == list:
|
||||||
|
item_type = field_type.__args__[0] if field_type.__args__ else None
|
||||||
|
if item_type and is_dataclass(item_type) and isinstance(value, list):
|
||||||
|
kwargs[key] = [
|
||||||
|
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
# Handle old-style Optional[T] (which is Union[T, None])
|
||||||
|
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
|
||||||
|
# Get the non-None type from Union
|
||||||
|
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||||
|
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||||
|
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
# Handle direct dataclass fields
|
||||||
|
elif is_dataclass(field_type) and isinstance(value, dict):
|
||||||
|
kwargs[key] = dict_to_dataclass(value, field_type)
|
||||||
|
# Handle bytes fields (UTF-8 encoded strings from JSON)
|
||||||
|
elif field_type == bytes and isinstance(value, str):
|
||||||
|
kwargs[key] = value.encode('utf-8')
|
||||||
|
else:
|
||||||
|
kwargs[key] = value
|
||||||
|
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
@ -51,7 +51,7 @@ class Subscriber:
|
||||||
topic=self.topic,
|
topic=self.topic,
|
||||||
subscription=self.subscription,
|
subscription=self.subscription,
|
||||||
schema=self.schema,
|
schema=self.schema,
|
||||||
consumer_type='shared',
|
consumer_type='exclusive',
|
||||||
)
|
)
|
||||||
|
|
||||||
self.task = asyncio.create_task(self.run())
|
self.task = asyncio.create_task(self.run())
|
||||||
|
|
|
||||||
|
|
@ -18,9 +18,7 @@ class BaseClient:
|
||||||
output_queue=None,
|
output_queue=None,
|
||||||
input_schema=None,
|
input_schema=None,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
pulsar_host="pulsar://pulsar:6650",
|
**pubsub_config,
|
||||||
pulsar_api_key=None,
|
|
||||||
listener=None,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
if input_queue == None: raise RuntimeError("Need input_queue")
|
if input_queue == None: raise RuntimeError("Need input_queue")
|
||||||
|
|
@ -32,12 +30,7 @@ class BaseClient:
|
||||||
subscriber = str(uuid.uuid4())
|
subscriber = str(uuid.uuid4())
|
||||||
|
|
||||||
# Create backend using factory
|
# Create backend using factory
|
||||||
self.backend = get_pubsub(
|
self.backend = get_pubsub(**pubsub_config)
|
||||||
pulsar_host=pulsar_host,
|
|
||||||
pulsar_api_key=pulsar_api_key,
|
|
||||||
pulsar_listener=listener,
|
|
||||||
pubsub_backend='pulsar'
|
|
||||||
)
|
|
||||||
|
|
||||||
self.producer = self.backend.create_producer(
|
self.producer = self.backend.create_producer(
|
||||||
topic=input_queue,
|
topic=input_queue,
|
||||||
|
|
|
||||||
|
|
@ -33,9 +33,7 @@ class ConfigClient(BaseClient):
|
||||||
subscriber=None,
|
subscriber=None,
|
||||||
input_queue=None,
|
input_queue=None,
|
||||||
output_queue=None,
|
output_queue=None,
|
||||||
pulsar_host="pulsar://pulsar:6650",
|
**pubsub_config,
|
||||||
listener=None,
|
|
||||||
pulsar_api_key=None,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
if input_queue == None:
|
if input_queue == None:
|
||||||
|
|
@ -48,11 +46,9 @@ class ConfigClient(BaseClient):
|
||||||
subscriber=subscriber,
|
subscriber=subscriber,
|
||||||
input_queue=input_queue,
|
input_queue=input_queue,
|
||||||
output_queue=output_queue,
|
output_queue=output_queue,
|
||||||
pulsar_host=pulsar_host,
|
|
||||||
pulsar_api_key=pulsar_api_key,
|
|
||||||
input_schema=ConfigRequest,
|
input_schema=ConfigRequest,
|
||||||
output_schema=ConfigResponse,
|
output_schema=ConfigResponse,
|
||||||
listener=listener,
|
**pubsub_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, keys, timeout=300):
|
def get(self, keys, timeout=300):
|
||||||
|
|
|
||||||
|
|
@ -24,10 +24,13 @@ from ..core.metadata import Metadata
|
||||||
# <- (document_metadata)
|
# <- (document_metadata)
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
|
||||||
# get-document-content
|
# get-document-content [DEPRECATED — use stream-document instead]
|
||||||
# -> (document_id)
|
# -> (document_id)
|
||||||
# <- (content)
|
# <- (content)
|
||||||
# <- (error)
|
# <- (error)
|
||||||
|
# NOTE: Returns entire document in a single message. Fails for documents
|
||||||
|
# exceeding the broker's max message size. Use stream-document which
|
||||||
|
# returns content in chunks.
|
||||||
|
|
||||||
# add-processing
|
# add-processing
|
||||||
# -> (processing_id, processing_metadata)
|
# -> (processing_id, processing_metadata)
|
||||||
|
|
@ -220,5 +223,5 @@ class LibrarianResponse:
|
||||||
# FIXME: Is this right? Using persistence on librarian so that
|
# FIXME: Is this right? Using persistence on librarian so that
|
||||||
# message chunking works
|
# message chunking works
|
||||||
|
|
||||||
librarian_request_queue = queue('librarian-request', cls='flow')
|
librarian_request_queue = queue('librarian', cls='request')
|
||||||
librarian_response_queue = queue('librarian-response', cls='flow')
|
librarian_response_queue = queue('librarian', cls='response')
|
||||||
|
|
|
||||||
|
|
@ -354,10 +354,8 @@ IMPORTANT:
|
||||||
output_file=args.output,
|
output_file=args.output,
|
||||||
subscriber_name=args.subscriber,
|
subscriber_name=args.subscriber,
|
||||||
append_mode=args.append,
|
append_mode=args.append,
|
||||||
pubsub_backend=args.pubsub_backend,
|
**{k: v for k, v in vars(args).items()
|
||||||
pulsar_host=args.pulsar_host,
|
if k not in ('queues', 'output', 'subscriber', 'append')},
|
||||||
pulsar_api_key=args.pulsar_api_key,
|
|
||||||
pulsar_listener=args.pulsar_listener,
|
|
||||||
))
|
))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
# Already handled in async_main
|
# Already handled in async_main
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
"""
|
"""
|
||||||
Initialises Pulsar with Trustgraph tenant / namespaces & policy.
|
Initialises TrustGraph pub/sub infrastructure and pushes initial config.
|
||||||
|
|
||||||
|
For Pulsar: creates tenant, namespaces, and retention policies.
|
||||||
|
For RabbitMQ: queues are auto-declared, so only config push is needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
@ -8,10 +11,11 @@ import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from trustgraph.clients.config_client import ConfigClient
|
from trustgraph.clients.config_client import ConfigClient
|
||||||
|
from trustgraph.base.pubsub import add_pubsub_args
|
||||||
|
|
||||||
default_pulsar_admin_url = "http://pulsar:8080"
|
default_pulsar_admin_url = "http://pulsar:8080"
|
||||||
default_pulsar_host = "pulsar://pulsar:6650"
|
subscriber = "tg-init-pubsub"
|
||||||
subscriber = "tg-init-pulsar"
|
|
||||||
|
|
||||||
def get_clusters(url):
|
def get_clusters(url):
|
||||||
|
|
||||||
|
|
@ -65,12 +69,11 @@ def ensure_namespace(url, tenant, namespace, config):
|
||||||
|
|
||||||
print(f"Namespace {tenant}/{namespace} created.", flush=True)
|
print(f"Namespace {tenant}/{namespace} created.", flush=True)
|
||||||
|
|
||||||
def ensure_config(config, pulsar_host, pulsar_api_key):
|
def ensure_config(config, **pubsub_config):
|
||||||
|
|
||||||
cli = ConfigClient(
|
cli = ConfigClient(
|
||||||
subscriber=subscriber,
|
subscriber=subscriber,
|
||||||
pulsar_host=pulsar_host,
|
**pubsub_config,
|
||||||
pulsar_api_key=pulsar_api_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -115,11 +118,9 @@ def ensure_config(config, pulsar_host, pulsar_api_key):
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
print("Retrying...", flush=True)
|
print("Retrying...", flush=True)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def init(
|
def init_pulsar(pulsar_admin_url, tenant):
|
||||||
pulsar_admin_url, pulsar_host, pulsar_api_key, tenant,
|
"""Pulsar-specific setup: create tenant, namespaces, retention policies."""
|
||||||
config, config_file,
|
|
||||||
):
|
|
||||||
|
|
||||||
clusters = get_clusters(pulsar_admin_url)
|
clusters = get_clusters(pulsar_admin_url)
|
||||||
|
|
||||||
|
|
@ -145,17 +146,21 @@ def init(
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
if config is not None:
|
|
||||||
|
def push_config(config_json, config_file, **pubsub_config):
|
||||||
|
"""Push initial config if provided."""
|
||||||
|
|
||||||
|
if config_json is not None:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print("Decoding config...", flush=True)
|
print("Decoding config...", flush=True)
|
||||||
dec = json.loads(config)
|
dec = json.loads(config_json)
|
||||||
print("Decoded.", flush=True)
|
print("Decoded.", flush=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Exception:", e, flush=True)
|
print("Exception:", e, flush=True)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
ensure_config(dec, pulsar_host, pulsar_api_key)
|
ensure_config(dec, **pubsub_config)
|
||||||
|
|
||||||
elif config_file is not None:
|
elif config_file is not None:
|
||||||
|
|
||||||
|
|
@ -167,11 +172,12 @@ def init(
|
||||||
print("Exception:", e, flush=True)
|
print("Exception:", e, flush=True)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
ensure_config(dec, pulsar_host, pulsar_api_key)
|
ensure_config(dec, **pubsub_config)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("No config to update.", flush=True)
|
print("No config to update.", flush=True)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|
@ -180,22 +186,11 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-p', '--pulsar-admin-url',
|
'--pulsar-admin-url',
|
||||||
default=default_pulsar_admin_url,
|
default=default_pulsar_admin_url,
|
||||||
help=f'Pulsar admin URL (default: {default_pulsar_admin_url})',
|
help=f'Pulsar admin URL (default: {default_pulsar_admin_url})',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--pulsar-host',
|
|
||||||
default=default_pulsar_host,
|
|
||||||
help=f'Pulsar host (default: {default_pulsar_host})',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--pulsar-api-key',
|
|
||||||
help=f'Pulsar API key',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-c', '--config',
|
'-c', '--config',
|
||||||
help=f'Initial configuration to load',
|
help=f'Initial configuration to load',
|
||||||
|
|
@ -212,18 +207,43 @@ def main():
|
||||||
help=f'Tenant (default: tg)',
|
help=f'Tenant (default: tg)',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_pubsub_args(parser)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
backend_type = args.pubsub_backend
|
||||||
|
|
||||||
|
# Extract pubsub config from args
|
||||||
|
pubsub_config = {
|
||||||
|
k: v for k, v in vars(args).items()
|
||||||
|
if k not in ('pulsar_admin_url', 'config', 'config_file', 'tenant')
|
||||||
|
}
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
print(flush=True)
|
# Pulsar-specific setup (tenants, namespaces)
|
||||||
print(
|
if backend_type == 'pulsar':
|
||||||
f"Initialising with Pulsar {args.pulsar_admin_url}...",
|
print(flush=True)
|
||||||
flush=True
|
print(
|
||||||
|
f"Initialising Pulsar at {args.pulsar_admin_url}...",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
init_pulsar(args.pulsar_admin_url, args.tenant)
|
||||||
|
else:
|
||||||
|
print(flush=True)
|
||||||
|
print(
|
||||||
|
f"Using {backend_type} backend (no admin setup needed).",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Push config (works with any backend)
|
||||||
|
push_config(
|
||||||
|
args.config, args.config_file,
|
||||||
|
**pubsub_config,
|
||||||
)
|
)
|
||||||
init(**vars(args))
|
|
||||||
print("Initialisation complete.", flush=True)
|
print("Initialisation complete.", flush=True)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -236,4 +256,4 @@ def main():
|
||||||
print("Will retry...", flush=True)
|
print("Will retry...", flush=True)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -316,10 +316,8 @@ def main():
|
||||||
queue_type=args.queue_type,
|
queue_type=args.queue_type,
|
||||||
max_lines=args.max_lines,
|
max_lines=args.max_lines,
|
||||||
max_width=args.max_width,
|
max_width=args.max_width,
|
||||||
pulsar_host=args.pulsar_host,
|
**{k: v for k, v in vars(args).items()
|
||||||
pulsar_api_key=args.pulsar_api_key,
|
if k not in ('flow', 'queue_type', 'max_lines', 'max_width')},
|
||||||
pulsar_listener=args.pulsar_listener,
|
|
||||||
pubsub_backend=args.pubsub_backend,
|
|
||||||
))
|
))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -133,7 +133,7 @@ class Processor(ChunkingService):
|
||||||
chunk_length = len(chunk.page_content)
|
chunk_length = len(chunk.page_content)
|
||||||
|
|
||||||
# Save chunk to librarian as child document
|
# Save chunk to librarian as child document
|
||||||
await self.save_child_document(
|
await self.librarian.save_child_document(
|
||||||
doc_id=chunk_doc_id,
|
doc_id=chunk_doc_id,
|
||||||
parent_id=parent_doc_id,
|
parent_id=parent_doc_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
|
|
|
||||||
|
|
@ -131,7 +131,7 @@ class Processor(ChunkingService):
|
||||||
chunk_length = len(chunk.page_content)
|
chunk_length = len(chunk.page_content)
|
||||||
|
|
||||||
# Save chunk to librarian as child document
|
# Save chunk to librarian as child document
|
||||||
await self.save_child_document(
|
await self.librarian.save_child_document(
|
||||||
doc_id=chunk_doc_id,
|
doc_id=chunk_doc_id,
|
||||||
parent_id=parent_doc_id,
|
parent_id=parent_doc_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
|
|
|
||||||
|
|
@ -9,20 +9,16 @@ for large documents.
|
||||||
|
|
||||||
from pypdf import PdfWriter, PdfReader
|
from pypdf import PdfWriter, PdfReader
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from mistralai import Mistral
|
from mistralai import Mistral
|
||||||
from mistralai.models import OCRResponse
|
|
||||||
|
|
||||||
from ... schema import Document, TextDocument, Metadata
|
from ... schema import Document, TextDocument, Metadata
|
||||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
from ... schema import librarian_request_queue, librarian_response_queue
|
||||||
from ... schema import Triples
|
from ... schema import Triples
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
|
||||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
|
||||||
|
|
||||||
from ... provenance import (
|
from ... provenance import (
|
||||||
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
||||||
|
|
@ -102,42 +98,10 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client for fetching document content
|
# Librarian client
|
||||||
librarian_request_q = params.get(
|
self.librarian = LibrarianClient(
|
||||||
"librarian_request_queue", default_librarian_request_queue
|
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
|
||||||
)
|
)
|
||||||
librarian_response_q = params.get(
|
|
||||||
"librarian_response_queue", default_librarian_response_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_request_metrics = ProducerMetrics(
|
|
||||||
processor = id, flow = None, name = "librarian-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_request_producer = Producer(
|
|
||||||
backend = self.pubsub,
|
|
||||||
topic = librarian_request_q,
|
|
||||||
schema = LibrarianRequest,
|
|
||||||
metrics = librarian_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_response_metrics = ConsumerMetrics(
|
|
||||||
processor = id, flow = None, name = "librarian-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_response_consumer = Consumer(
|
|
||||||
taskgroup = self.taskgroup,
|
|
||||||
backend = self.pubsub,
|
|
||||||
flow = None,
|
|
||||||
topic = librarian_response_q,
|
|
||||||
subscriber = f"{id}-librarian",
|
|
||||||
schema = LibrarianResponse,
|
|
||||||
handler = self.on_librarian_response,
|
|
||||||
metrics = librarian_response_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pending librarian requests: request_id -> asyncio.Future
|
|
||||||
self.pending_requests = {}
|
|
||||||
|
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise RuntimeError("Mistral API key not specified")
|
raise RuntimeError("Mistral API key not specified")
|
||||||
|
|
@ -151,132 +115,7 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await super(Processor, self).start()
|
await super(Processor, self).start()
|
||||||
await self.librarian_request_producer.start()
|
await self.librarian.start()
|
||||||
await self.librarian_response_consumer.start()
|
|
||||||
|
|
||||||
async def on_librarian_response(self, msg, consumer, flow):
|
|
||||||
"""Handle responses from the librarian service."""
|
|
||||||
response = msg.value()
|
|
||||||
request_id = msg.properties().get("id")
|
|
||||||
|
|
||||||
if request_id and request_id in self.pending_requests:
|
|
||||||
future = self.pending_requests.pop(request_id)
|
|
||||||
future.set_result(response)
|
|
||||||
|
|
||||||
async def fetch_document_metadata(self, document_id, user, timeout=120):
|
|
||||||
"""
|
|
||||||
Fetch document metadata from librarian via Pulsar.
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="get-document-metadata",
|
|
||||||
document_id=document_id,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.document_metadata
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
|
|
||||||
|
|
||||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
|
||||||
"""
|
|
||||||
Fetch document content from librarian via Pulsar.
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="get-document-content",
|
|
||||||
document_id=document_id,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create future for response
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.content
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout fetching document {document_id}")
|
|
||||||
|
|
||||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
|
||||||
document_type="page", title=None, timeout=120):
|
|
||||||
"""
|
|
||||||
Save a child document to the librarian.
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
|
||||||
id=doc_id,
|
|
||||||
user=user,
|
|
||||||
kind="text/plain",
|
|
||||||
title=title or doc_id,
|
|
||||||
parent_id=parent_id,
|
|
||||||
document_type=document_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="add-child-document",
|
|
||||||
document_metadata=doc_metadata,
|
|
||||||
content=base64.b64encode(content).decode("utf-8"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create future for response
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout saving child document {doc_id}")
|
|
||||||
|
|
||||||
def ocr(self, blob):
|
def ocr(self, blob):
|
||||||
"""
|
"""
|
||||||
|
|
@ -359,7 +198,7 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
# Check MIME type if fetching from librarian
|
# Check MIME type if fetching from librarian
|
||||||
if v.document_id:
|
if v.document_id:
|
||||||
doc_meta = await self.fetch_document_metadata(
|
doc_meta = await self.librarian.fetch_document_metadata(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
)
|
)
|
||||||
|
|
@ -374,7 +213,7 @@ class Processor(FlowProcessor):
|
||||||
# Get PDF content - fetch from librarian or use inline data
|
# Get PDF content - fetch from librarian or use inline data
|
||||||
if v.document_id:
|
if v.document_id:
|
||||||
logger.info(f"Fetching document {v.document_id} from librarian...")
|
logger.info(f"Fetching document {v.document_id} from librarian...")
|
||||||
content = await self.fetch_document_content(
|
content = await self.librarian.fetch_document_content(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
)
|
)
|
||||||
|
|
@ -401,7 +240,7 @@ class Processor(FlowProcessor):
|
||||||
page_content = markdown.encode("utf-8")
|
page_content = markdown.encode("utf-8")
|
||||||
|
|
||||||
# Save page as child document in librarian
|
# Save page as child document in librarian
|
||||||
await self.save_child_document(
|
await self.librarian.save_child_document(
|
||||||
doc_id=page_doc_id,
|
doc_id=page_doc_id,
|
||||||
parent_id=source_doc_id,
|
parent_id=source_doc_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
|
|
|
||||||
|
|
@ -7,20 +7,16 @@ Supports both inline document data and fetching from librarian via Pulsar
|
||||||
for large documents.
|
for large documents.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
from langchain_community.document_loaders import PyPDFLoader
|
from langchain_community.document_loaders import PyPDFLoader
|
||||||
|
|
||||||
from ... schema import Document, TextDocument, Metadata
|
from ... schema import Document, TextDocument, Metadata
|
||||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
from ... schema import librarian_request_queue, librarian_response_queue
|
||||||
from ... schema import Triples
|
from ... schema import Triples
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
|
||||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
|
||||||
|
|
||||||
from ... provenance import (
|
from ... provenance import (
|
||||||
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
||||||
|
|
@ -74,187 +70,16 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client for fetching document content
|
# Librarian client
|
||||||
librarian_request_q = params.get(
|
self.librarian = LibrarianClient(
|
||||||
"librarian_request_queue", default_librarian_request_queue
|
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
|
||||||
)
|
)
|
||||||
librarian_response_q = params.get(
|
|
||||||
"librarian_response_queue", default_librarian_response_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_request_metrics = ProducerMetrics(
|
|
||||||
processor = id, flow = None, name = "librarian-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_request_producer = Producer(
|
|
||||||
backend = self.pubsub,
|
|
||||||
topic = librarian_request_q,
|
|
||||||
schema = LibrarianRequest,
|
|
||||||
metrics = librarian_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_response_metrics = ConsumerMetrics(
|
|
||||||
processor = id, flow = None, name = "librarian-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_response_consumer = Consumer(
|
|
||||||
taskgroup = self.taskgroup,
|
|
||||||
backend = self.pubsub,
|
|
||||||
flow = None,
|
|
||||||
topic = librarian_response_q,
|
|
||||||
subscriber = f"{id}-librarian",
|
|
||||||
schema = LibrarianResponse,
|
|
||||||
handler = self.on_librarian_response,
|
|
||||||
metrics = librarian_response_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pending librarian requests: request_id -> asyncio.Future
|
|
||||||
self.pending_requests = {}
|
|
||||||
|
|
||||||
logger.info("PDF decoder initialized")
|
logger.info("PDF decoder initialized")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await super(Processor, self).start()
|
await super(Processor, self).start()
|
||||||
await self.librarian_request_producer.start()
|
await self.librarian.start()
|
||||||
await self.librarian_response_consumer.start()
|
|
||||||
|
|
||||||
async def on_librarian_response(self, msg, consumer, flow):
|
|
||||||
"""Handle responses from the librarian service."""
|
|
||||||
response = msg.value()
|
|
||||||
request_id = msg.properties().get("id")
|
|
||||||
|
|
||||||
if request_id and request_id in self.pending_requests:
|
|
||||||
future = self.pending_requests.pop(request_id)
|
|
||||||
future.set_result(response)
|
|
||||||
|
|
||||||
async def fetch_document_metadata(self, document_id, user, timeout=120):
|
|
||||||
"""
|
|
||||||
Fetch document metadata from librarian via Pulsar.
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="get-document-metadata",
|
|
||||||
document_id=document_id,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.document_metadata
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
|
|
||||||
|
|
||||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
|
||||||
"""
|
|
||||||
Fetch document content from librarian via Pulsar.
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="get-document-content",
|
|
||||||
document_id=document_id,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create future for response
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.content
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout fetching document {document_id}")
|
|
||||||
|
|
||||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
|
||||||
document_type="page", title=None, timeout=120):
|
|
||||||
"""
|
|
||||||
Save a child document to the librarian.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id: ID for the new child document
|
|
||||||
parent_id: ID of the parent document
|
|
||||||
user: User ID
|
|
||||||
content: Document content (bytes)
|
|
||||||
document_type: Type of document ("page", "chunk", etc.)
|
|
||||||
title: Optional title
|
|
||||||
timeout: Request timeout in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The document ID on success
|
|
||||||
"""
|
|
||||||
import base64
|
|
||||||
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
|
||||||
id=doc_id,
|
|
||||||
user=user,
|
|
||||||
kind="text/plain",
|
|
||||||
title=title or doc_id,
|
|
||||||
parent_id=parent_id,
|
|
||||||
document_type=document_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="add-child-document",
|
|
||||||
document_metadata=doc_metadata,
|
|
||||||
content=base64.b64encode(content).decode("utf-8"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create future for response
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout saving child document {doc_id}")
|
|
||||||
|
|
||||||
async def on_message(self, msg, consumer, flow):
|
async def on_message(self, msg, consumer, flow):
|
||||||
|
|
||||||
|
|
@ -266,7 +91,7 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
# Check MIME type if fetching from librarian
|
# Check MIME type if fetching from librarian
|
||||||
if v.document_id:
|
if v.document_id:
|
||||||
doc_meta = await self.fetch_document_metadata(
|
doc_meta = await self.librarian.fetch_document_metadata(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
)
|
)
|
||||||
|
|
@ -287,7 +112,7 @@ class Processor(FlowProcessor):
|
||||||
logger.info(f"Fetching document {v.document_id} from librarian...")
|
logger.info(f"Fetching document {v.document_id} from librarian...")
|
||||||
fp.close()
|
fp.close()
|
||||||
|
|
||||||
content = await self.fetch_document_content(
|
content = await self.librarian.fetch_document_content(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
)
|
)
|
||||||
|
|
@ -323,7 +148,7 @@ class Processor(FlowProcessor):
|
||||||
page_content = page.page_content.encode("utf-8")
|
page_content = page.page_content.encode("utf-8")
|
||||||
|
|
||||||
# Save page as child document in librarian
|
# Save page as child document in librarian
|
||||||
await self.save_child_document(
|
await self.librarian.save_child_document(
|
||||||
doc_id=page_doc_id,
|
doc_id=page_doc_id,
|
||||||
parent_id=source_doc_id,
|
parent_id=source_doc_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from trustgraph.base.logging import setup_logging
|
from trustgraph.base.logging import setup_logging
|
||||||
from trustgraph.base.pubsub import get_pubsub
|
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
|
||||||
|
|
||||||
from . auth import Authenticator
|
from . auth import Authenticator
|
||||||
from . config.receiver import ConfigReceiver
|
from . config.receiver import ConfigReceiver
|
||||||
|
|
@ -167,30 +167,7 @@ def run():
|
||||||
help='Service identifier for logging and metrics (default: api-gateway)',
|
help='Service identifier for logging and metrics (default: api-gateway)',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pub/sub backend selection
|
add_pubsub_args(parser)
|
||||||
parser.add_argument(
|
|
||||||
'--pubsub-backend',
|
|
||||||
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
|
|
||||||
choices=['pulsar', 'mqtt'],
|
|
||||||
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-p', '--pulsar-host',
|
|
||||||
default=default_pulsar_host,
|
|
||||||
help=f'Pulsar host (default: {default_pulsar_host})',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--pulsar-api-key',
|
|
||||||
default=default_pulsar_api_key,
|
|
||||||
help=f'Pulsar API key',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--pulsar-listener',
|
|
||||||
help=f'Pulsar listener (default: none)',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-m', '--prometheus-url',
|
'-m', '--prometheus-url',
|
||||||
|
|
|
||||||
|
|
@ -12,22 +12,18 @@ import uuid
|
||||||
|
|
||||||
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
|
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
|
||||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
|
||||||
from ... schema import Triples, Metadata
|
from ... schema import Triples, Metadata
|
||||||
from ... provenance import GRAPH_RETRIEVAL
|
from ... provenance import GRAPH_RETRIEVAL
|
||||||
from . document_rag import DocumentRag
|
from . document_rag import DocumentRag
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||||
from ... base import DocumentEmbeddingsClientSpec
|
from ... base import DocumentEmbeddingsClientSpec
|
||||||
from ... base import Consumer, Producer
|
from ... base import LibrarianClient
|
||||||
from ... base import ConsumerMetrics, ProducerMetrics
|
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
default_ident = "document-rag"
|
default_ident = "document-rag"
|
||||||
default_librarian_request_queue = librarian_request_queue
|
|
||||||
default_librarian_response_queue = librarian_response_queue
|
|
||||||
|
|
||||||
class Processor(FlowProcessor):
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
|
|
@ -89,111 +85,26 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client for fetching chunk content from Garage
|
# Librarian client
|
||||||
librarian_request_q = params.get(
|
self.librarian = LibrarianClient(
|
||||||
"librarian_request_queue", default_librarian_request_queue
|
id=id,
|
||||||
)
|
|
||||||
librarian_response_q = params.get(
|
|
||||||
"librarian_response_queue", default_librarian_response_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_request_metrics = ProducerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_request_producer = Producer(
|
|
||||||
backend=self.pubsub,
|
backend=self.pubsub,
|
||||||
topic=librarian_request_q,
|
|
||||||
schema=LibrarianRequest,
|
|
||||||
metrics=librarian_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_response_metrics = ConsumerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_response_consumer = Consumer(
|
|
||||||
taskgroup=self.taskgroup,
|
taskgroup=self.taskgroup,
|
||||||
backend=self.pubsub,
|
|
||||||
flow=None,
|
|
||||||
topic=librarian_response_q,
|
|
||||||
subscriber=f"{id}-librarian",
|
|
||||||
schema=LibrarianResponse,
|
|
||||||
handler=self.on_librarian_response,
|
|
||||||
metrics=librarian_response_metrics,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pending librarian requests: request_id -> asyncio.Future
|
|
||||||
self.pending_requests = {}
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await super(Processor, self).start()
|
await super(Processor, self).start()
|
||||||
await self.librarian_request_producer.start()
|
await self.librarian.start()
|
||||||
await self.librarian_response_consumer.start()
|
|
||||||
|
|
||||||
async def on_librarian_response(self, msg, consumer, flow):
|
|
||||||
"""Handle responses from the librarian service."""
|
|
||||||
response = msg.value()
|
|
||||||
request_id = msg.properties().get("id")
|
|
||||||
|
|
||||||
if request_id in self.pending_requests:
|
|
||||||
future = self.pending_requests.pop(request_id)
|
|
||||||
future.set_result(response)
|
|
||||||
|
|
||||||
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
|
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
|
||||||
"""Fetch chunk content from librarian/Garage."""
|
"""Fetch chunk content from librarian. Chunks are small so
|
||||||
import uuid
|
single request-response is fine."""
|
||||||
request_id = str(uuid.uuid4())
|
return await self.librarian.fetch_document_text(
|
||||||
|
document_id=chunk_id, user=user, timeout=timeout,
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="get-document-content",
|
|
||||||
document_id=chunk_id,
|
|
||||||
user=user,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create future for response
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Content is base64 encoded
|
|
||||||
content = response.content
|
|
||||||
if isinstance(content, str):
|
|
||||||
content = content.encode('utf-8')
|
|
||||||
return base64.b64decode(content).decode("utf-8")
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout fetching chunk {chunk_id}")
|
|
||||||
|
|
||||||
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
|
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
|
||||||
"""
|
"""Save answer content to the librarian."""
|
||||||
Save answer content to the librarian.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
doc_id: ID for the answer document
|
|
||||||
user: User ID
|
|
||||||
content: Answer text content
|
|
||||||
title: Optional title
|
|
||||||
timeout: Request timeout in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The document ID on success
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
doc_metadata = DocumentMetadata(
|
||||||
id=doc_id,
|
id=doc_id,
|
||||||
|
|
@ -211,29 +122,8 @@ class Processor(FlowProcessor):
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create future for response
|
await self.librarian.request(request, timeout=timeout)
|
||||||
future = asyncio.get_event_loop().create_future()
|
return doc_id
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
|
||||||
|
|
||||||
async def on_request(self, msg, consumer, flow):
|
async def on_request(self, msg, consumer, flow):
|
||||||
|
|
||||||
|
|
@ -390,4 +280,3 @@ class Processor(FlowProcessor):
|
||||||
def run():
|
def run():
|
||||||
|
|
||||||
Processor.launch(default_ident, __doc__)
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,19 +7,15 @@ Supports both inline document data and fetching from librarian via Pulsar
|
||||||
for large documents.
|
for large documents.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
import pytesseract
|
import pytesseract
|
||||||
from pdf2image import convert_from_bytes
|
from pdf2image import convert_from_bytes
|
||||||
|
|
||||||
from ... schema import Document, TextDocument, Metadata
|
from ... schema import Document, TextDocument, Metadata
|
||||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
from ... schema import librarian_request_queue, librarian_response_queue
|
||||||
from ... schema import Triples
|
from ... schema import Triples
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
|
||||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
|
||||||
|
|
||||||
from ... provenance import (
|
from ... provenance import (
|
||||||
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
||||||
|
|
@ -72,173 +68,16 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client for fetching document content
|
# Librarian client
|
||||||
librarian_request_q = params.get(
|
self.librarian = LibrarianClient(
|
||||||
"librarian_request_queue", default_librarian_request_queue
|
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
|
||||||
)
|
)
|
||||||
librarian_response_q = params.get(
|
|
||||||
"librarian_response_queue", default_librarian_response_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_request_metrics = ProducerMetrics(
|
|
||||||
processor = id, flow = None, name = "librarian-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_request_producer = Producer(
|
|
||||||
backend = self.pubsub,
|
|
||||||
topic = librarian_request_q,
|
|
||||||
schema = LibrarianRequest,
|
|
||||||
metrics = librarian_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_response_metrics = ConsumerMetrics(
|
|
||||||
processor = id, flow = None, name = "librarian-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_response_consumer = Consumer(
|
|
||||||
taskgroup = self.taskgroup,
|
|
||||||
backend = self.pubsub,
|
|
||||||
flow = None,
|
|
||||||
topic = librarian_response_q,
|
|
||||||
subscriber = f"{id}-librarian",
|
|
||||||
schema = LibrarianResponse,
|
|
||||||
handler = self.on_librarian_response,
|
|
||||||
metrics = librarian_response_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pending librarian requests: request_id -> asyncio.Future
|
|
||||||
self.pending_requests = {}
|
|
||||||
|
|
||||||
logger.info("PDF OCR processor initialized")
|
logger.info("PDF OCR processor initialized")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await super(Processor, self).start()
|
await super(Processor, self).start()
|
||||||
await self.librarian_request_producer.start()
|
await self.librarian.start()
|
||||||
await self.librarian_response_consumer.start()
|
|
||||||
|
|
||||||
async def on_librarian_response(self, msg, consumer, flow):
|
|
||||||
"""Handle responses from the librarian service."""
|
|
||||||
response = msg.value()
|
|
||||||
request_id = msg.properties().get("id")
|
|
||||||
|
|
||||||
if request_id and request_id in self.pending_requests:
|
|
||||||
future = self.pending_requests.pop(request_id)
|
|
||||||
future.set_result(response)
|
|
||||||
|
|
||||||
async def fetch_document_metadata(self, document_id, user, timeout=120):
|
|
||||||
"""
|
|
||||||
Fetch document metadata from librarian via Pulsar.
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="get-document-metadata",
|
|
||||||
document_id=document_id,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.document_metadata
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
|
|
||||||
|
|
||||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
|
||||||
"""
|
|
||||||
Fetch document content from librarian via Pulsar.
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="get-document-content",
|
|
||||||
document_id=document_id,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create future for response
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response.content
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout fetching document {document_id}")
|
|
||||||
|
|
||||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
|
||||||
document_type="page", title=None, timeout=120):
|
|
||||||
"""
|
|
||||||
Save a child document to the librarian.
|
|
||||||
"""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
|
||||||
id=doc_id,
|
|
||||||
user=user,
|
|
||||||
kind="text/plain",
|
|
||||||
title=title or doc_id,
|
|
||||||
parent_id=parent_id,
|
|
||||||
document_type=document_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="add-child-document",
|
|
||||||
document_metadata=doc_metadata,
|
|
||||||
content=base64.b64encode(content).decode("utf-8"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create future for response
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Send request
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for response
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError(f"Timeout saving child document {doc_id}")
|
|
||||||
|
|
||||||
async def on_message(self, msg, consumer, flow):
|
async def on_message(self, msg, consumer, flow):
|
||||||
|
|
||||||
|
|
@ -250,7 +89,7 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
# Check MIME type if fetching from librarian
|
# Check MIME type if fetching from librarian
|
||||||
if v.document_id:
|
if v.document_id:
|
||||||
doc_meta = await self.fetch_document_metadata(
|
doc_meta = await self.librarian.fetch_document_metadata(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
)
|
)
|
||||||
|
|
@ -265,7 +104,7 @@ class Processor(FlowProcessor):
|
||||||
# Get PDF content - fetch from librarian or use inline data
|
# Get PDF content - fetch from librarian or use inline data
|
||||||
if v.document_id:
|
if v.document_id:
|
||||||
logger.info(f"Fetching document {v.document_id} from librarian...")
|
logger.info(f"Fetching document {v.document_id} from librarian...")
|
||||||
content = await self.fetch_document_content(
|
content = await self.librarian.fetch_document_content(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
)
|
)
|
||||||
|
|
@ -299,7 +138,7 @@ class Processor(FlowProcessor):
|
||||||
page_content = text.encode("utf-8")
|
page_content = text.encode("utf-8")
|
||||||
|
|
||||||
# Save page as child document in librarian
|
# Save page as child document in librarian
|
||||||
await self.save_child_document(
|
await self.librarian.save_child_document(
|
||||||
doc_id=page_doc_id,
|
doc_id=page_doc_id,
|
||||||
parent_id=source_doc_id,
|
parent_id=source_doc_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
|
|
|
||||||
|
|
@ -14,22 +14,18 @@ Tables are preserved as HTML markup for better downstream extraction.
|
||||||
Images are stored in the librarian but not sent to the text pipeline.
|
Images are stored in the librarian but not sent to the text pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import magic
|
import magic
|
||||||
import tempfile
|
import tempfile
|
||||||
import os
|
import os
|
||||||
import uuid
|
|
||||||
|
|
||||||
from unstructured.partition.auto import partition
|
from unstructured.partition.auto import partition
|
||||||
|
|
||||||
from ... schema import Document, TextDocument, Metadata
|
from ... schema import Document, TextDocument, Metadata
|
||||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
from ... schema import librarian_request_queue, librarian_response_queue
|
||||||
from ... schema import Triples
|
from ... schema import Triples
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
|
||||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
|
||||||
|
|
||||||
from ... provenance import (
|
from ... provenance import (
|
||||||
document_uri, page_uri as make_page_uri,
|
document_uri, page_uri as make_page_uri,
|
||||||
|
|
@ -166,128 +162,16 @@ class Processor(FlowProcessor):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Librarian client for fetching/storing document content
|
# Librarian client
|
||||||
librarian_request_q = params.get(
|
self.librarian = LibrarianClient(
|
||||||
"librarian_request_queue", default_librarian_request_queue
|
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
|
||||||
)
|
)
|
||||||
librarian_response_q = params.get(
|
|
||||||
"librarian_response_queue", default_librarian_response_queue
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_request_metrics = ProducerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-request"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_request_producer = Producer(
|
|
||||||
backend=self.pubsub,
|
|
||||||
topic=librarian_request_q,
|
|
||||||
schema=LibrarianRequest,
|
|
||||||
metrics=librarian_request_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
librarian_response_metrics = ConsumerMetrics(
|
|
||||||
processor=id, flow=None, name="librarian-response"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.librarian_response_consumer = Consumer(
|
|
||||||
taskgroup=self.taskgroup,
|
|
||||||
backend=self.pubsub,
|
|
||||||
flow=None,
|
|
||||||
topic=librarian_response_q,
|
|
||||||
subscriber=f"{id}-librarian",
|
|
||||||
schema=LibrarianResponse,
|
|
||||||
handler=self.on_librarian_response,
|
|
||||||
metrics=librarian_response_metrics,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pending librarian requests: request_id -> asyncio.Future
|
|
||||||
self.pending_requests = {}
|
|
||||||
|
|
||||||
logger.info("Universal decoder initialized")
|
logger.info("Universal decoder initialized")
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await super(Processor, self).start()
|
await super(Processor, self).start()
|
||||||
await self.librarian_request_producer.start()
|
await self.librarian.start()
|
||||||
await self.librarian_response_consumer.start()
|
|
||||||
|
|
||||||
async def on_librarian_response(self, msg, consumer, flow):
|
|
||||||
"""Handle responses from the librarian service."""
|
|
||||||
response = msg.value()
|
|
||||||
request_id = msg.properties().get("id")
|
|
||||||
|
|
||||||
if request_id and request_id in self.pending_requests:
|
|
||||||
future = self.pending_requests.pop(request_id)
|
|
||||||
future.set_result(response)
|
|
||||||
|
|
||||||
async def _librarian_request(self, request, timeout=120):
|
|
||||||
"""Send a request to the librarian and wait for response."""
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
self.pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.librarian_request_producer.send(
|
|
||||||
request, properties={"id": request_id}
|
|
||||||
)
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
|
|
||||||
if response.error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Librarian error: {response.error.type}: "
|
|
||||||
f"{response.error.message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.pending_requests.pop(request_id, None)
|
|
||||||
raise RuntimeError("Timeout waiting for librarian response")
|
|
||||||
|
|
||||||
async def fetch_document_metadata(self, document_id, user):
|
|
||||||
"""Fetch document metadata from the librarian."""
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="get-document-metadata",
|
|
||||||
document_id=document_id,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
response = await self._librarian_request(request)
|
|
||||||
return response.document_metadata
|
|
||||||
|
|
||||||
async def fetch_document_content(self, document_id, user):
|
|
||||||
"""Fetch document content from the librarian."""
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="get-document-content",
|
|
||||||
document_id=document_id,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
response = await self._librarian_request(request)
|
|
||||||
return response.content
|
|
||||||
|
|
||||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
|
||||||
document_type="page", title=None,
|
|
||||||
kind="text/plain"):
|
|
||||||
"""Save a child document to the librarian."""
|
|
||||||
if isinstance(content, str):
|
|
||||||
content = content.encode("utf-8")
|
|
||||||
|
|
||||||
doc_metadata = DocumentMetadata(
|
|
||||||
id=doc_id,
|
|
||||||
user=user,
|
|
||||||
kind=kind,
|
|
||||||
title=title or doc_id,
|
|
||||||
parent_id=parent_id,
|
|
||||||
document_type=document_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
request = LibrarianRequest(
|
|
||||||
operation="add-child-document",
|
|
||||||
document_metadata=doc_metadata,
|
|
||||||
content=base64.b64encode(content).decode("utf-8"),
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._librarian_request(request)
|
|
||||||
return doc_id
|
|
||||||
|
|
||||||
def extract_elements(self, blob, mime_type=None):
|
def extract_elements(self, blob, mime_type=None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -388,7 +272,7 @@ class Processor(FlowProcessor):
|
||||||
page_content = text.encode("utf-8")
|
page_content = text.encode("utf-8")
|
||||||
|
|
||||||
# Save to librarian
|
# Save to librarian
|
||||||
await self.save_child_document(
|
await self.librarian.save_child_document(
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
parent_id=parent_doc_id,
|
parent_id=parent_doc_id,
|
||||||
user=metadata.user,
|
user=metadata.user,
|
||||||
|
|
@ -469,7 +353,7 @@ class Processor(FlowProcessor):
|
||||||
|
|
||||||
# Save to librarian
|
# Save to librarian
|
||||||
if img_content:
|
if img_content:
|
||||||
await self.save_child_document(
|
await self.librarian.save_child_document(
|
||||||
doc_id=img_uri,
|
doc_id=img_uri,
|
||||||
parent_id=parent_doc_id,
|
parent_id=parent_doc_id,
|
||||||
user=metadata.user,
|
user=metadata.user,
|
||||||
|
|
@ -518,13 +402,13 @@ class Processor(FlowProcessor):
|
||||||
f"Fetching document {v.document_id} from librarian..."
|
f"Fetching document {v.document_id} from librarian..."
|
||||||
)
|
)
|
||||||
|
|
||||||
doc_meta = await self.fetch_document_metadata(
|
doc_meta = await self.librarian.fetch_document_metadata(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
)
|
)
|
||||||
mime_type = doc_meta.kind if doc_meta else None
|
mime_type = doc_meta.kind if doc_meta else None
|
||||||
|
|
||||||
content = await self.fetch_document_content(
|
content = await self.librarian.fetch_document_content(
|
||||||
document_id=v.document_id,
|
document_id=v.document_id,
|
||||||
user=v.metadata.user,
|
user=v.metadata.user,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue