RabbitMQ pub/sub backend with topic exchange architecture (#752)

Adds a RabbitMQ backend as an alternative to Pulsar, selectable via
PUBSUB_BACKEND=rabbitmq. Both backends implement the same PubSubBackend
protocol — no application code changes needed to switch.

RabbitMQ topology:
- Single topic exchange per topicspace (e.g. 'tg')
- Routing key derived from queue class and topic name
- Shared consumers: named queue bound to exchange (competing, round-robin)
- Exclusive consumers: anonymous auto-delete queue (broadcast, each gets
  every message). Used by Subscriber and config push consumer.
- Thread-local producer connections (pika is not thread-safe)
- Push-based consumption via basic_consume with process_data_events
  for heartbeat processing

Consumer model changes:
- Consumer class creates one backend consumer per concurrent task
  (required for pika thread safety, harmless for Pulsar)
- Consumer class accepts consumer_type parameter
- Subscriber passes consumer_type='exclusive' for broadcast semantics
- Config push consumer uses consumer_type='exclusive' so every
  processor instance receives config updates
- handle_one_from_queue receives consumer as parameter for correct
  per-connection ack/nack

LibrarianClient:
- New shared client class replacing duplicated librarian request-response
  code across 6+ services (chunking, decoders, RAG, etc.)
- Uses stream-document instead of get-document-content for fetching
  document content in 1MB chunks (avoids broker message size limits)
- Standalone object (self.librarian = LibrarianClient(...)) not a mixin
- get-document-content marked deprecated in schema and OpenAPI spec

Serialisation:
- Extracted dataclass_to_dict/dict_to_dataclass to shared
  serialization.py (used by both Pulsar and RabbitMQ backends)

Librarian queues:
- Changed from flow class (persistent) back to request/response class
  now that stream-document eliminates large single messages
- API upload chunk size reduced from 5MB to 3MB to stay under broker
  limits after base64 encoding

Factory and CLI:
- get_pubsub() handles 'rabbitmq' backend with RabbitMQ connection params
- add_pubsub_args() includes RabbitMQ options (host, port, credentials)
- add_pubsub_args(standalone=True) defaults to localhost for CLI tools
- init_trustgraph skips Pulsar admin setup for non-Pulsar backends
- tg-dump-queues and tg-monitor-prompts use backend abstraction
- BaseClient and ConfigClient accept generic pubsub config
This commit is contained in:
cybermaggedon 2026-04-02 12:47:16 +01:00 committed by GitHub
parent 4fb0b4d8e8
commit 24f0190ce7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 1277 additions and 1313 deletions

View file

@ -77,8 +77,8 @@ some-containers:
-t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} .
${DOCKER} build -f containers/Containerfile.flow \ ${DOCKER} build -f containers/Containerfile.flow \
-t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.unstructured \ ${DOCKER} build -f containers/Containerfile.unstructured \
# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.vertexai \ # ${DOCKER} build -f containers/Containerfile.vertexai \
# -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . # -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} .
# ${DOCKER} build -f containers/Containerfile.mcp \ # ${DOCKER} build -f containers/Containerfile.mcp \

View file

@ -3,6 +3,9 @@ description: |
Librarian service request for document library management. Librarian service request for document library management.
Operations: add-document, remove-document, list-documents, Operations: add-document, remove-document, list-documents,
get-document-metadata, stream-document, add-child-document,
list-children, begin-upload, upload-chunk, complete-upload,
abort-upload, get-upload-status, list-uploads,
start-processing, stop-processing, list-processing start-processing, stop-processing, list-processing
required: required:
- operation - operation
@ -13,6 +16,17 @@ properties:
- add-document - add-document
- remove-document - remove-document
- list-documents - list-documents
- get-document-metadata
- get-document-content
- stream-document
- add-child-document
- list-children
- begin-upload
- upload-chunk
- complete-upload
- abort-upload
- get-upload-status
- list-uploads
- start-processing - start-processing
- stop-processing - stop-processing
- list-processing - list-processing
@ -21,6 +35,21 @@ properties:
- `add-document`: Add document to library - `add-document`: Add document to library
- `remove-document`: Remove document from library - `remove-document`: Remove document from library
- `list-documents`: List documents in library - `list-documents`: List documents in library
- `get-document-metadata`: Get document metadata
- `get-document-content`: Get full document content in a single response.
**Deprecated** — use `stream-document` instead. Fails for documents
exceeding the broker's max message size.
- `stream-document`: Stream document content in chunks. Each response
includes `chunk_index` and `is_final`. Preferred over `get-document-content`
for all document sizes.
- `add-child-document`: Add a child document (e.g. page, chunk)
- `list-children`: List child documents of a parent
- `begin-upload`: Start a chunked upload session
- `upload-chunk`: Upload a chunk of data
- `complete-upload`: Finalize a chunked upload
- `abort-upload`: Cancel a chunked upload
- `get-upload-status`: Check upload progress
- `list-uploads`: List active upload sessions
- `start-processing`: Start processing library documents - `start-processing`: Start processing library documents
- `stop-processing`: Stop library processing - `stop-processing`: Stop library processing
- `list-processing`: List processing status - `list-processing`: List processing status

View file

@ -24,8 +24,8 @@ class MockAsyncProcessor:
class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
"""Test Recursive chunker functionality""" """Test Recursive chunker functionality"""
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self, mock_producer, mock_consumer): def test_processor_initialization_basic(self, mock_producer, mock_consumer):
"""Test basic processor initialization""" """Test basic processor initialization"""
@ -51,8 +51,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']] if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2 assert len(param_specs) == 2
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer): async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-size parameter override""" """Test chunk_document with chunk-size parameter override"""
@ -71,7 +71,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock() mock_message = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_flow = MagicMock() mock_flow = MagicMock()
mock_flow.side_effect = lambda param: { mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 2000, # Override chunk size "chunk-size": 2000, # Override chunk size
"chunk-overlap": None # Use default chunk overlap "chunk-overlap": None # Use default chunk overlap
}.get(param) }.get(param)
@ -85,8 +85,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 2000 # Should use overridden value assert chunk_size == 2000 # Should use overridden value
assert chunk_overlap == 100 # Should use default value assert chunk_overlap == 100 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer): async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-overlap parameter override""" """Test chunk_document with chunk-overlap parameter override"""
@ -105,7 +105,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock() mock_message = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_flow = MagicMock() mock_flow = MagicMock()
mock_flow.side_effect = lambda param: { mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": None, # Use default chunk size "chunk-size": None, # Use default chunk size
"chunk-overlap": 200 # Override chunk overlap "chunk-overlap": 200 # Override chunk overlap
}.get(param) }.get(param)
@ -119,8 +119,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 1000 # Should use default value assert chunk_size == 1000 # Should use default value
assert chunk_overlap == 200 # Should use overridden value assert chunk_overlap == 200 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer): async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
"""Test chunk_document with both chunk-size and chunk-overlap overrides""" """Test chunk_document with both chunk-size and chunk-overlap overrides"""
@ -139,7 +139,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock() mock_message = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_flow = MagicMock() mock_flow = MagicMock()
mock_flow.side_effect = lambda param: { mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 1500, # Override chunk size "chunk-size": 1500, # Override chunk size
"chunk-overlap": 150 # Override chunk overlap "chunk-overlap": 150 # Override chunk overlap
}.get(param) }.get(param)
@ -153,8 +153,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 1500 # Should use overridden value assert chunk_size == 1500 # Should use overridden value
assert chunk_overlap == 150 # Should use overridden value assert chunk_overlap == 150 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter') @patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer): async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
@ -177,7 +177,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response # Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id") processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
# Mock message with TextDocument # Mock message with TextDocument
mock_message = MagicMock() mock_message = MagicMock()
@ -196,12 +196,14 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_producer = AsyncMock() mock_producer = AsyncMock()
mock_triples_producer = AsyncMock() mock_triples_producer = AsyncMock()
mock_flow = MagicMock() mock_flow = MagicMock()
mock_flow.side_effect = lambda param: { mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 1500, "chunk-size": 1500,
"chunk-overlap": 150, "chunk-overlap": 150,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer, "output": mock_producer,
"triples": mock_triples_producer, "triples": mock_triples_producer,
}.get(param) }.get(name)
# Act # Act
await processor.on_message(mock_message, mock_consumer, mock_flow) await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -219,8 +221,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
sent_chunk = mock_producer.send.call_args[0][0] sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk) assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer): async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
"""Test chunk_document when no parameters are overridden (flow returns None)""" """Test chunk_document when no parameters are overridden (flow returns None)"""
@ -239,7 +241,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock() mock_message = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_flow = MagicMock() mock_flow = MagicMock()
mock_flow.return_value = None # No overrides mock_flow.parameters.get.return_value = None # No overrides
# Act # Act
chunk_size, chunk_overlap = await processor.chunk_document( chunk_size, chunk_overlap = await processor.chunk_document(

View file

@ -24,8 +24,8 @@ class MockAsyncProcessor:
class TestTokenChunkerSimple(IsolatedAsyncioTestCase): class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
"""Test Token chunker functionality""" """Test Token chunker functionality"""
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_processor_initialization_basic(self, mock_producer, mock_consumer): def test_processor_initialization_basic(self, mock_producer, mock_consumer):
"""Test basic processor initialization""" """Test basic processor initialization"""
@ -51,8 +51,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']] if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']]
assert len(param_specs) == 2 assert len(param_specs) == 2
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer): async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-size parameter override""" """Test chunk_document with chunk-size parameter override"""
@ -71,7 +71,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock() mock_message = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_flow = MagicMock() mock_flow = MagicMock()
mock_flow.side_effect = lambda param: { mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 400, # Override chunk size "chunk-size": 400, # Override chunk size
"chunk-overlap": None # Use default chunk overlap "chunk-overlap": None # Use default chunk overlap
}.get(param) }.get(param)
@ -85,8 +85,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 400 # Should use overridden value assert chunk_size == 400 # Should use overridden value
assert chunk_overlap == 15 # Should use default value assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer): async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer):
"""Test chunk_document with chunk-overlap parameter override""" """Test chunk_document with chunk-overlap parameter override"""
@ -105,7 +105,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock() mock_message = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_flow = MagicMock() mock_flow = MagicMock()
mock_flow.side_effect = lambda param: { mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": None, # Use default chunk size "chunk-size": None, # Use default chunk size
"chunk-overlap": 25 # Override chunk overlap "chunk-overlap": 25 # Override chunk overlap
}.get(param) }.get(param)
@ -119,8 +119,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 250 # Should use default value assert chunk_size == 250 # Should use default value
assert chunk_overlap == 25 # Should use overridden value assert chunk_overlap == 25 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer): async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer):
"""Test chunk_document with both chunk-size and chunk-overlap overrides""" """Test chunk_document with both chunk-size and chunk-overlap overrides"""
@ -139,7 +139,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock() mock_message = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_flow = MagicMock() mock_flow = MagicMock()
mock_flow.side_effect = lambda param: { mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 350, # Override chunk size "chunk-size": 350, # Override chunk size
"chunk-overlap": 30 # Override chunk overlap "chunk-overlap": 30 # Override chunk overlap
}.get(param) }.get(param)
@ -153,8 +153,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 350 # Should use overridden value assert chunk_size == 350 # Should use overridden value
assert chunk_overlap == 30 # Should use overridden value assert chunk_overlap == 30 # Should use overridden value
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.chunking.token.chunker.TokenTextSplitter') @patch('trustgraph.chunking.token.chunker.TokenTextSplitter')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer): async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer):
@ -177,7 +177,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid librarian producer interactions # Mock save_child_document to avoid librarian producer interactions
processor.save_child_document = AsyncMock(return_value="chunk-id") processor.librarian.save_child_document = AsyncMock(return_value="chunk-id")
# Mock message with TextDocument # Mock message with TextDocument
mock_message = MagicMock() mock_message = MagicMock()
@ -196,12 +196,14 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_producer = AsyncMock() mock_producer = AsyncMock()
mock_triples_producer = AsyncMock() mock_triples_producer = AsyncMock()
mock_flow = MagicMock() mock_flow = MagicMock()
mock_flow.side_effect = lambda param: { mock_flow.parameters.get.side_effect = lambda param: {
"chunk-size": 400, "chunk-size": 400,
"chunk-overlap": 40, "chunk-overlap": 40,
}.get(param)
mock_flow.side_effect = lambda name: {
"output": mock_producer, "output": mock_producer,
"triples": mock_triples_producer, "triples": mock_triples_producer,
}.get(param) }.get(name)
# Act # Act
await processor.on_message(mock_message, mock_consumer, mock_flow) await processor.on_message(mock_message, mock_consumer, mock_flow)
@ -223,8 +225,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
sent_chunk = mock_producer.send.call_args[0][0] sent_chunk = mock_producer.send.call_args[0][0]
assert isinstance(sent_chunk, Chunk) assert isinstance(sent_chunk, Chunk)
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer): async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer):
"""Test chunk_document when no parameters are overridden (flow returns None)""" """Test chunk_document when no parameters are overridden (flow returns None)"""
@ -243,7 +245,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
mock_message = MagicMock() mock_message = MagicMock()
mock_consumer = MagicMock() mock_consumer = MagicMock()
mock_flow = MagicMock() mock_flow = MagicMock()
mock_flow.return_value = None # No overrides mock_flow.parameters.get.return_value = None # No overrides
# Act # Act
chunk_size, chunk_overlap = await processor.chunk_document( chunk_size, chunk_overlap = await processor.chunk_document(
@ -254,8 +256,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
assert chunk_size == 250 # Should use default value assert chunk_size == 250 # Should use default value
assert chunk_overlap == 15 # Should use default value assert chunk_overlap == 15 # Should use default value
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer): def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer):
"""Test that token chunker has different defaults than recursive chunker""" """Test that token chunker has different defaults than recursive chunker"""

View file

@ -83,7 +83,7 @@ class TestTaskGroupConcurrency:
call_count = 0 call_count = 0
original_running = True original_running = True
async def mock_consume(): async def mock_consume(backend_consumer):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
# Wait a bit to let all tasks start, then signal stop # Wait a bit to let all tasks start, then signal stop
@ -107,7 +107,7 @@ class TestTaskGroupConcurrency:
consumer = _make_consumer(concurrency=1) consumer = _make_consumer(concurrency=1)
call_count = 0 call_count = 0
async def mock_consume(): async def mock_consume(backend_consumer):
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
@ -147,7 +147,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg() mock_msg = _make_msg()
consumer.consumer = MagicMock() consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg) await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
assert call_count == 2 assert call_count == 2
consumer.consumer.acknowledge.assert_called_once_with(mock_msg) consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
@ -166,7 +166,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg() mock_msg = _make_msg()
consumer.consumer = MagicMock() consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg) await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
consumer.consumer.negative_acknowledge.assert_called_with(mock_msg) consumer.consumer.negative_acknowledge.assert_called_with(mock_msg)
consumer.consumer.acknowledge.assert_not_called() consumer.consumer.acknowledge.assert_not_called()
@ -185,7 +185,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg() mock_msg = _make_msg()
consumer.consumer = MagicMock() consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg) await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
assert call_count == 1 assert call_count == 1
consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg) consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg)
@ -197,7 +197,7 @@ class TestRateLimitRetry:
mock_msg = _make_msg() mock_msg = _make_msg()
consumer.consumer = MagicMock() consumer.consumer = MagicMock()
await consumer.handle_one_from_queue(mock_msg) await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
consumer.consumer.acknowledge.assert_called_once_with(mock_msg) consumer.consumer.acknowledge.assert_called_once_with(mock_msg)
@ -219,7 +219,7 @@ class TestMetricsIntegration:
mock_metrics.record_time.return_value.__exit__ = MagicMock() mock_metrics.record_time.return_value.__exit__ = MagicMock()
consumer.metrics = mock_metrics consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg) await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.process.assert_called_once_with("success") mock_metrics.process.assert_called_once_with("success")
@ -235,7 +235,7 @@ class TestMetricsIntegration:
mock_metrics = MagicMock() mock_metrics = MagicMock()
consumer.metrics = mock_metrics consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg) await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.process.assert_called_once_with("error") mock_metrics.process.assert_called_once_with("error")
@ -261,7 +261,7 @@ class TestMetricsIntegration:
mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False) mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False)
consumer.metrics = mock_metrics consumer.metrics = mock_metrics
await consumer.handle_one_from_queue(mock_msg) await consumer.handle_one_from_queue(mock_msg, consumer.consumer)
mock_metrics.rate_limit.assert_called_once() mock_metrics.rate_limit.assert_called_once()

View file

@ -25,8 +25,8 @@ class MockAsyncProcessor:
class TestMistralOcrProcessor(IsolatedAsyncioTestCase): class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
"""Test Mistral OCR processor functionality""" """Test Mistral OCR processor functionality"""
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization_with_api_key( async def test_processor_initialization_with_api_key(
@ -51,8 +51,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
assert consumer_specs[0].name == "input" assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document assert consumer_specs[0].schema == Document
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization_without_api_key( async def test_processor_initialization_without_api_key(
self, mock_producer, mock_consumer self, mock_producer, mock_consumer
@ -66,8 +66,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
with pytest.raises(RuntimeError, match="Mistral API key not specified"): with pytest.raises(RuntimeError, match="Mistral API key not specified"):
Processor(**config) Processor(**config)
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_ocr_single_chunk( async def test_ocr_single_chunk(
@ -131,8 +131,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
) )
mock_mistral.ocr.process.assert_called_once() mock_mistral.ocr.process.assert_called_once()
@patch('trustgraph.decoding.mistral_ocr.processor.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.decoding.mistral_ocr.processor.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.mistral_ocr.processor.Mistral') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_success( async def test_on_message_success(
@ -172,7 +172,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
] ]
# Mock save_child_document # Mock save_child_document
processor.save_child_document = AsyncMock(return_value="mock-doc-id") processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
with patch.object(processor, 'ocr', return_value=ocr_result): with patch.object(processor, 'ocr', return_value=ocr_result):
await processor.on_message(mock_msg, None, mock_flow) await processor.on_message(mock_msg, None, mock_flow)

View file

@ -24,12 +24,10 @@ class MockAsyncProcessor:
class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
"""Test PDF decoder processor functionality""" """Test PDF decoder processor functionality"""
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization(self, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): async def test_processor_initialization(self, mock_producer, mock_consumer):
"""Test PDF decoder processor initialization""" """Test PDF decoder processor initialization"""
config = { config = {
'id': 'test-pdf-decoder', 'id': 'test-pdf-decoder',
@ -44,13 +42,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
assert consumer_specs[0].name == "input" assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document assert consumer_specs[0].schema == Document
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test successful PDF processing""" """Test successful PDF processing"""
# Mock PDF content # Mock PDF content
pdf_content = b"fake pdf content" pdf_content = b"fake pdf content"
@ -85,7 +81,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response # Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id") processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
await processor.on_message(mock_msg, None, mock_flow) await processor.on_message(mock_msg, None, mock_flow)
@ -94,13 +90,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
# Verify triples were sent for each page (provenance) # Verify triples were sent for each page (provenance)
assert mock_triples_flow.send.call_count == 2 assert mock_triples_flow.send.call_count == 2
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of empty PDF""" """Test handling of empty PDF"""
pdf_content = b"fake pdf content" pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
@ -128,13 +122,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
mock_output_flow.send.assert_not_called() mock_output_flow.send.assert_not_called()
@patch('trustgraph.base.chunking_service.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.base.chunking_service.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Consumer')
@patch('trustgraph.decoding.pdf.pdf_decoder.Producer')
@patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer):
"""Test handling of unicode content in PDF""" """Test handling of unicode content in PDF"""
pdf_content = b"fake pdf content" pdf_content = b"fake pdf content"
pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') pdf_base64 = base64.b64encode(pdf_content).decode('utf-8')
@ -165,7 +157,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response # Mock save_child_document to avoid waiting for librarian response
processor.save_child_document = AsyncMock(return_value="mock-doc-id") processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
await processor.on_message(mock_msg, None, mock_flow) await processor.on_message(mock_msg, None, mock_flow)

View file

@ -142,8 +142,8 @@ class TestPageBasedFormats:
class TestUniversalProcessor(IsolatedAsyncioTestCase): class TestUniversalProcessor(IsolatedAsyncioTestCase):
"""Test universal decoder processor.""" """Test universal decoder processor."""
@patch('trustgraph.decoding.universal.processor.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_initialization( async def test_processor_initialization(
self, mock_producer, mock_consumer self, mock_producer, mock_consumer
@ -169,8 +169,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert consumer_specs[0].name == "input" assert consumer_specs[0].name == "input"
assert consumer_specs[0].schema == Document assert consumer_specs[0].schema == Document
@patch('trustgraph.decoding.universal.processor.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_processor_custom_strategy( async def test_processor_custom_strategy(
self, mock_producer, mock_consumer self, mock_producer, mock_consumer
@ -188,8 +188,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert processor.partition_strategy == "hi_res" assert processor.partition_strategy == "hi_res"
assert processor.section_strategy_name == "heading" assert processor.section_strategy_name == "heading"
@patch('trustgraph.decoding.universal.processor.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_group_by_page(self, mock_producer, mock_consumer): async def test_group_by_page(self, mock_producer, mock_consumer):
"""Test page grouping of elements.""" """Test page grouping of elements."""
@ -214,8 +214,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert result[1][0] == 2 assert result[1][0] == 2
assert len(result[1][1]) == 1 assert len(result[1][1]) == 1
@patch('trustgraph.decoding.universal.processor.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.universal.processor.partition') @patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_inline_non_page( async def test_on_message_inline_non_page(
@ -255,7 +255,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
}.get(name)) }.get(name))
# Mock save_child_document and magic # Mock save_child_document and magic
processor.save_child_document = AsyncMock(return_value="mock-id") processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "text/markdown" mock_magic.from_buffer.return_value = "text/markdown"
@ -271,8 +271,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert call_args.document_id.startswith("urn:section:") assert call_args.document_id.startswith("urn:section:")
assert call_args.text == b"" assert call_args.text == b""
@patch('trustgraph.decoding.universal.processor.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.universal.processor.partition') @patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_on_message_page_based( async def test_on_message_page_based(
@ -310,7 +310,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
processor.save_child_document = AsyncMock(return_value="mock-id") processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf" mock_magic.from_buffer.return_value = "application/pdf"
@ -323,8 +323,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
call_args = mock_output_flow.send.call_args_list[0][0][0] call_args = mock_output_flow.send.call_args_list[0][0][0]
assert call_args.document_id.startswith("urn:page:") assert call_args.document_id.startswith("urn:page:")
@patch('trustgraph.decoding.universal.processor.Consumer') @patch('trustgraph.base.librarian_client.Consumer')
@patch('trustgraph.decoding.universal.processor.Producer') @patch('trustgraph.base.librarian_client.Producer')
@patch('trustgraph.decoding.universal.processor.partition') @patch('trustgraph.decoding.universal.processor.partition')
@patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor)
async def test_images_stored_not_emitted( async def test_images_stored_not_emitted(
@ -361,7 +361,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
processor.save_child_document = AsyncMock(return_value="mock-id") processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf" mock_magic.from_buffer.return_value = "application/pdf"
@ -374,7 +374,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert mock_triples_flow.send.call_count == 2 assert mock_triples_flow.send.call_count == 2
# save_child_document called twice (page + image) # save_child_document called twice (page + image)
assert processor.save_child_document.call_count == 2 assert processor.librarian.save_child_document.call_count == 2
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args') @patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
def test_add_args(self, mock_parent_add_args): def test_add_args(self, mock_parent_add_args):

View file

@ -109,6 +109,37 @@ class TestAddPubsubArgs:
assert args.pubsub_backend == 'pulsar' assert args.pubsub_backend == 'pulsar'
class TestAddPubsubArgsRabbitMQ:
def test_rabbitmq_args_present(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([
'--pubsub-backend', 'rabbitmq',
'--rabbitmq-host', 'myhost',
'--rabbitmq-port', '5673',
])
assert args.pubsub_backend == 'rabbitmq'
assert args.rabbitmq_host == 'myhost'
assert args.rabbitmq_port == 5673
def test_rabbitmq_defaults_container(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([])
assert args.rabbitmq_host == 'rabbitmq'
assert args.rabbitmq_port == 5672
assert args.rabbitmq_username == 'guest'
assert args.rabbitmq_password == 'guest'
assert args.rabbitmq_vhost == '/'
def test_rabbitmq_standalone_defaults_to_localhost(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser, standalone=True)
args = parser.parse_args([])
assert args.rabbitmq_host == 'localhost'
class TestQueueDefinitions: class TestQueueDefinitions:
"""Verify the actual queue constants produce correct names.""" """Verify the actual queue constants produce correct names."""
@ -124,9 +155,9 @@ class TestQueueDefinitions:
from trustgraph.schema.services.config import config_push_queue from trustgraph.schema.services.config import config_push_queue
assert config_push_queue == 'state:tg:config' assert config_push_queue == 'state:tg:config'
def test_librarian_request_is_persistent(self): def test_librarian_request(self):
from trustgraph.schema.services.library import librarian_request_queue from trustgraph.schema.services.library import librarian_request_queue
assert librarian_request_queue.startswith('flow:') assert librarian_request_queue == 'request:tg:librarian'
def test_knowledge_request(self): def test_knowledge_request(self):
from trustgraph.schema.knowledge.knowledge import knowledge_request_queue from trustgraph.schema.knowledge.knowledge import knowledge_request_queue

View file

@ -0,0 +1,107 @@
"""
Unit tests for RabbitMQ backend queue name mapping and factory dispatch.
Does not require a running RabbitMQ instance.
"""
import pytest
import argparse
pika = pytest.importorskip("pika", reason="pika not installed")
from trustgraph.base.rabbitmq_backend import RabbitMQBackend
from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
class TestRabbitMQMapQueueName:
@pytest.fixture
def backend(self):
b = object.__new__(RabbitMQBackend)
return b
def test_flow_is_durable(self, backend):
name, durable = backend.map_queue_name('flow:tg:text-completion-request')
assert durable is True
assert name == 'tg.flow.text-completion-request'
def test_state_is_durable(self, backend):
name, durable = backend.map_queue_name('state:tg:config')
assert durable is True
assert name == 'tg.state.config'
def test_request_is_not_durable(self, backend):
name, durable = backend.map_queue_name('request:tg:config')
assert durable is False
assert name == 'tg.request.config'
def test_response_is_not_durable(self, backend):
name, durable = backend.map_queue_name('response:tg:librarian')
assert durable is False
assert name == 'tg.response.librarian'
def test_custom_topicspace(self, backend):
name, durable = backend.map_queue_name('flow:prod:my-queue')
assert name == 'prod.flow.my-queue'
assert durable is True
def test_no_colon_defaults_to_flow(self, backend):
name, durable = backend.map_queue_name('simple-queue')
assert name == 'tg.simple-queue'
assert durable is False
def test_invalid_class_raises(self, backend):
with pytest.raises(ValueError, match="Invalid queue class"):
backend.map_queue_name('unknown:tg:topic')
def test_flow_with_flow_suffix(self, backend):
"""Queue names with flow suffix (e.g. :default) are preserved."""
name, durable = backend.map_queue_name('request:tg:prompt:default')
assert name == 'tg.request.prompt:default'
class TestGetPubsubRabbitMQ:
def test_factory_creates_rabbitmq_backend(self):
backend = get_pubsub(pubsub_backend='rabbitmq')
assert isinstance(backend, RabbitMQBackend)
def test_factory_passes_config(self):
backend = get_pubsub(
pubsub_backend='rabbitmq',
rabbitmq_host='myhost',
rabbitmq_port=5673,
rabbitmq_username='user',
rabbitmq_password='pass',
rabbitmq_vhost='/test',
)
assert isinstance(backend, RabbitMQBackend)
# Verify connection params were set
params = backend._connection_params
assert params.host == 'myhost'
assert params.port == 5673
assert params.virtual_host == '/test'
class TestAddPubsubArgsRabbitMQ:
def test_rabbitmq_args_present(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([
'--pubsub-backend', 'rabbitmq',
'--rabbitmq-host', 'myhost',
'--rabbitmq-port', '5673',
])
assert args.pubsub_backend == 'rabbitmq'
assert args.rabbitmq_host == 'myhost'
assert args.rabbitmq_port == 5673
def test_rabbitmq_defaults_container(self):
parser = argparse.ArgumentParser()
add_pubsub_args(parser)
args = parser.parse_args([])
assert args.rabbitmq_host == 'rabbitmq'
assert args.rabbitmq_port == 5672
assert args.rabbitmq_username == 'guest'
assert args.rabbitmq_password == 'guest'
assert args.rabbitmq_vhost == '/'

View file

@ -14,6 +14,7 @@ dependencies = [
"prometheus-client", "prometheus-client",
"requests", "requests",
"python-logging-loki", "python-logging-loki",
"pika",
] ]
classifiers = [ classifiers = [
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",

View file

@ -22,8 +22,9 @@ logger = logging.getLogger(__name__)
# Lower threshold provides progress feedback and resumability on slower connections # Lower threshold provides progress feedback and resumability on slower connections
CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024 CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024
# Default chunk size (5MB - S3 multipart minimum) # Default chunk size (3MB - stays under broker message size limits
DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024 # after base64 encoding ~4MB)
DEFAULT_CHUNK_SIZE = 3 * 1024 * 1024
def to_value(x): def to_value(x):

View file

@ -14,6 +14,7 @@ from . producer_spec import ProducerSpec
from . subscriber_spec import SubscriberSpec from . subscriber_spec import SubscriberSpec
from . request_response_spec import RequestResponseSpec from . request_response_spec import RequestResponseSpec
from . llm_service import LlmService, LlmResult, LlmChunk from . llm_service import LlmService, LlmResult, LlmChunk
from . librarian_client import LibrarianClient
from . chunking_service import ChunkingService from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec from . embeddings_client import EmbeddingsClientSpec

View file

@ -68,11 +68,12 @@ class AsyncProcessor:
processor = self.id, flow = None, name = "config", processor = self.id, flow = None, name = "config",
) )
# Subscribe to config queue # Subscribe to config queue — exclusive so every processor
# gets its own copy of config pushes (broadcast pattern)
self.config_sub_task = Consumer( self.config_sub_task = Consumer(
taskgroup = self.taskgroup, taskgroup = self.taskgroup,
backend = self.pubsub_backend, # Changed from client to backend backend = self.pubsub_backend,
subscriber = config_subscriber_id, subscriber = config_subscriber_id,
flow = None, flow = None,
@ -83,9 +84,8 @@ class AsyncProcessor:
metrics = config_consumer_metrics, metrics = config_consumer_metrics,
# This causes new subscriptions to view the entire history of start_of_messages = True,
# configuration consumer_type = 'exclusive',
start_of_messages = True
) )
self.running = True self.running = True

View file

@ -7,23 +7,14 @@ fetching large document content.
import asyncio import asyncio
import base64 import base64
import logging import logging
import uuid
from .flow_processor import FlowProcessor from .flow_processor import FlowProcessor
from .parameter_spec import ParameterSpec from .parameter_spec import ParameterSpec
from .consumer import Consumer from .librarian_client import LibrarianClient
from .producer import Producer
from .metrics import ConsumerMetrics, ProducerMetrics
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ..schema import librarian_request_queue, librarian_response_queue
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class ChunkingService(FlowProcessor): class ChunkingService(FlowProcessor):
"""Base service for chunking processors with parameter specification support""" """Base service for chunking processors with parameter specification support"""
@ -44,155 +35,18 @@ class ChunkingService(FlowProcessor):
ParameterSpec(name="chunk-overlap") ParameterSpec(name="chunk-overlap")
) )
# Librarian client for fetching document content # Librarian client
librarian_request_q = params.get( self.librarian = LibrarianClient(
"librarian_request_queue", default_librarian_request_queue id=id,
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub, backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup, taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
) )
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
logger.debug("ChunkingService initialized with parameter specifications") logger.debug("ChunkingService initialized with parameter specifications")
async def start(self): async def start(self):
await super(ChunkingService, self).start() await super(ChunkingService, self).start()
await self.librarian_request_producer.start() await self.librarian.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
async def fetch_document_content(self, document_id, user, timeout=120):
"""
Fetch document content from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.content
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching document {document_id}")
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="chunk", title=None, timeout=120):
"""
Save a child document (chunk) to the librarian.
Args:
doc_id: ID for the new child document
parent_id: ID of the parent document
user: User ID
content: Document content (bytes or str)
document_type: Type of document ("chunk", etc.)
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
if isinstance(content, str):
content = content.encode("utf-8")
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
request = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving chunk: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving chunk {doc_id}")
async def get_document_text(self, doc): async def get_document_text(self, doc):
""" """
@ -206,14 +60,10 @@ class ChunkingService(FlowProcessor):
""" """
if doc.document_id and not doc.text: if doc.document_id and not doc.text:
logger.info(f"Fetching document {doc.document_id} from librarian...") logger.info(f"Fetching document {doc.document_id} from librarian...")
content = await self.fetch_document_content( text = await self.librarian.fetch_document_text(
document_id=doc.document_id, document_id=doc.document_id,
user=doc.metadata.user, user=doc.metadata.user,
) )
# Content is base64 encoded
if isinstance(content, str):
content = content.encode('utf-8')
text = base64.b64decode(content).decode("utf-8")
logger.info(f"Fetched {len(text)} characters from librarian") logger.info(f"Fetched {len(text)} characters from librarian")
return text return text
else: else:
@ -224,41 +74,31 @@ class ChunkingService(FlowProcessor):
Extract chunk parameters from flow and return effective values Extract chunk parameters from flow and return effective values
Args: Args:
msg: The message containing the document to chunk msg: The message being processed
consumer: The consumer spec consumer: The consumer instance
flow: The flow context flow: The flow object containing parameters
default_chunk_size: Default chunk size from processor config default_chunk_size: Default chunk size if not configured
default_chunk_overlap: Default chunk overlap from processor config default_chunk_overlap: Default chunk overlap if not configured
Returns: Returns:
tuple: (chunk_size, chunk_overlap) - effective values to use tuple: (chunk_size, chunk_overlap) effective values
""" """
# Extract parameters from flow (flow-configurable parameters)
chunk_size = flow("chunk-size")
chunk_overlap = flow("chunk-overlap")
# Use provided values or fall back to defaults chunk_size = default_chunk_size
effective_chunk_size = chunk_size if chunk_size is not None else default_chunk_size chunk_overlap = default_chunk_overlap
effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else default_chunk_overlap
logger.debug(f"Using chunk-size: {effective_chunk_size}") try:
logger.debug(f"Using chunk-overlap: {effective_chunk_overlap}") cs = flow.parameters.get("chunk-size")
if cs is not None:
chunk_size = int(cs)
except Exception as e:
logger.warning(f"Could not parse chunk-size parameter: {e}")
return effective_chunk_size, effective_chunk_overlap try:
co = flow.parameters.get("chunk-overlap")
if co is not None:
chunk_overlap = int(co)
except Exception as e:
logger.warning(f"Could not parse chunk-overlap parameter: {e}")
@staticmethod return chunk_size, chunk_overlap
def add_args(parser):
"""Add chunking service arguments to parser"""
FlowProcessor.add_args(parser)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)

View file

@ -32,6 +32,7 @@ class Consumer:
rate_limit_retry_time = 10, rate_limit_timeout = 7200, rate_limit_retry_time = 10, rate_limit_timeout = 7200,
reconnect_time = 5, reconnect_time = 5,
concurrency = 1, # Number of concurrent requests to handle concurrency = 1, # Number of concurrent requests to handle
consumer_type = 'shared',
): ):
self.taskgroup = taskgroup self.taskgroup = taskgroup
@ -42,6 +43,8 @@ class Consumer:
self.schema = schema self.schema = schema
self.handler = handler self.handler = handler
self.consumer_type = consumer_type
self.rate_limit_retry_time = rate_limit_retry_time self.rate_limit_retry_time = rate_limit_retry_time
self.rate_limit_timeout = rate_limit_timeout self.rate_limit_timeout = rate_limit_timeout
@ -93,33 +96,11 @@ class Consumer:
if self.metrics: if self.metrics:
self.metrics.state("stopped") self.metrics.state("stopped")
try: # Determine initial position
if self.start_of_messages:
logger.info(f"Subscribing to topic: {self.topic}") initial_pos = 'earliest'
else:
# Determine initial position initial_pos = 'latest'
if self.start_of_messages:
initial_pos = 'earliest'
else:
initial_pos = 'latest'
# Create consumer via backend
self.consumer = await asyncio.to_thread(
self.backend.create_consumer,
topic = self.topic,
subscription = self.subscriber,
schema = self.schema,
initial_position = initial_pos,
consumer_type = 'shared',
)
except Exception as e:
logger.error(f"Consumer subscription exception: {e}", exc_info=True)
await asyncio.sleep(self.reconnect_time)
continue
logger.info(f"Successfully subscribed to topic: {self.topic}")
if self.metrics: if self.metrics:
self.metrics.state("running") self.metrics.state("running")
@ -128,14 +109,30 @@ class Consumer:
logger.info(f"Starting {self.concurrency} receiver threads") logger.info(f"Starting {self.concurrency} receiver threads")
async with asyncio.TaskGroup() as tg: # Create one backend consumer per concurrent task.
# Each gets its own connection — required for backends
tasks = [] # like RabbitMQ where connections are not thread-safe.
consumers = []
for i in range(0, self.concurrency): for i in range(self.concurrency):
tasks.append( try:
tg.create_task(self.consume_from_queue()) logger.info(f"Subscribing to topic: {self.topic} (worker {i})")
c = await asyncio.to_thread(
self.backend.create_consumer,
topic = self.topic,
subscription = self.subscriber,
schema = self.schema,
initial_position = initial_pos,
consumer_type = self.consumer_type,
) )
consumers.append(c)
logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})")
except Exception as e:
logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True)
raise
async with asyncio.TaskGroup() as tg:
for c in consumers:
tg.create_task(self.consume_from_queue(c))
if self.metrics: if self.metrics:
self.metrics.state("stopped") self.metrics.state("stopped")
@ -143,23 +140,31 @@ class Consumer:
except Exception as e: except Exception as e:
logger.error(f"Consumer loop exception: {e}", exc_info=True) logger.error(f"Consumer loop exception: {e}", exc_info=True)
self.consumer.unsubscribe() for c in consumers:
self.consumer.close() try:
self.consumer = None c.unsubscribe()
c.close()
except Exception:
pass
consumers = []
await asyncio.sleep(self.reconnect_time) await asyncio.sleep(self.reconnect_time)
continue continue
if self.consumer: finally:
self.consumer.unsubscribe() for c in consumers:
self.consumer.close() try:
c.unsubscribe()
c.close()
except Exception:
pass
async def consume_from_queue(self): async def consume_from_queue(self, consumer):
while self.running: while self.running:
try: try:
msg = await asyncio.to_thread( msg = await asyncio.to_thread(
self.consumer.receive, consumer.receive,
timeout_millis=2000 timeout_millis=2000
) )
except Exception as e: except Exception as e:
@ -168,9 +173,9 @@ class Consumer:
continue continue
raise e raise e
await self.handle_one_from_queue(msg) await self.handle_one_from_queue(msg, consumer)
async def handle_one_from_queue(self, msg): async def handle_one_from_queue(self, msg, consumer):
expiry = time.time() + self.rate_limit_timeout expiry = time.time() + self.rate_limit_timeout
@ -183,7 +188,7 @@ class Consumer:
# Message failed to be processed, this causes it to # Message failed to be processed, this causes it to
# be retried # be retried
self.consumer.negative_acknowledge(msg) consumer.negative_acknowledge(msg)
if self.metrics: if self.metrics:
self.metrics.process("error") self.metrics.process("error")
@ -206,7 +211,7 @@ class Consumer:
logger.debug("Message processed successfully") logger.debug("Message processed successfully")
# Acknowledge successful processing of the message # Acknowledge successful processing of the message
self.consumer.acknowledge(msg) consumer.acknowledge(msg)
if self.metrics: if self.metrics:
self.metrics.process("success") self.metrics.process("success")
@ -233,7 +238,7 @@ class Consumer:
# Message failed to be processed, this causes it to # Message failed to be processed, this causes it to
# be retried # be retried
self.consumer.negative_acknowledge(msg) consumer.negative_acknowledge(msg)
if self.metrics: if self.metrics:
self.metrics.process("error") self.metrics.process("error")

View file

@ -0,0 +1,246 @@
"""
Shared librarian client for services that need to communicate
with the librarian via pub/sub.
Provides request-response and streaming operations over the message
broker, with proper support for large documents via stream-document.
Usage:
self.librarian = LibrarianClient(
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
)
await self.librarian.start()
content = await self.librarian.fetch_document_content(doc_id, user)
"""
import asyncio
import base64
import logging
import uuid
from .consumer import Consumer
from .producer import Producer
from .metrics import ConsumerMetrics, ProducerMetrics
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ..schema import librarian_request_queue, librarian_response_queue
logger = logging.getLogger(__name__)
class LibrarianClient:
"""Client for librarian request-response over the message broker."""
def __init__(self, id, backend, taskgroup, **params):
librarian_request_q = params.get(
"librarian_request_queue", librarian_request_queue,
)
librarian_response_q = params.get(
"librarian_response_queue", librarian_response_queue,
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request",
)
self._producer = Producer(
backend=backend,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response",
)
self._consumer = Consumer(
taskgroup=taskgroup,
backend=backend,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self._on_response,
metrics=librarian_response_metrics,
consumer_type='exclusive',
)
# Single-response requests: request_id -> asyncio.Future
self._pending = {}
# Streaming requests: request_id -> asyncio.Queue
self._streams = {}
async def start(self):
"""Start the librarian producer and consumer."""
await self._producer.start()
await self._consumer.start()
async def _on_response(self, msg, consumer, flow):
"""Route librarian responses to the right waiter."""
response = msg.value()
request_id = msg.properties().get("id")
if not request_id:
return
if request_id in self._pending:
future = self._pending.pop(request_id)
future.set_result(response)
elif request_id in self._streams:
await self._streams[request_id].put(response)
async def request(self, request, timeout=120):
"""Send a request to the librarian and wait for a single response."""
request_id = str(uuid.uuid4())
future = asyncio.get_event_loop().create_future()
self._pending[request_id] = future
try:
await self._producer.send(
request, properties={"id": request_id},
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: "
f"{response.error.message}"
)
return response
except asyncio.TimeoutError:
self._pending.pop(request_id, None)
raise RuntimeError("Timeout waiting for librarian response")
async def stream(self, request, timeout=120):
"""Send a request and collect streamed response chunks."""
request_id = str(uuid.uuid4())
q = asyncio.Queue()
self._streams[request_id] = q
try:
await self._producer.send(
request, properties={"id": request_id},
)
chunks = []
while True:
response = await asyncio.wait_for(q.get(), timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: "
f"{response.error.message}"
)
chunks.append(response)
if response.is_final:
break
return chunks
except asyncio.TimeoutError:
self._streams.pop(request_id, None)
raise RuntimeError("Timeout waiting for librarian stream")
finally:
self._streams.pop(request_id, None)
async def fetch_document_content(self, document_id, user, timeout=120):
"""Fetch document content using streaming.
Returns base64-encoded content. Caller is responsible for decoding.
"""
req = LibrarianRequest(
operation="stream-document",
document_id=document_id,
user=user,
)
chunks = await self.stream(req, timeout=timeout)
# Decode each chunk's base64 to raw bytes, concatenate,
# re-encode for the caller.
raw = b""
for chunk in chunks:
if chunk.content:
if isinstance(chunk.content, bytes):
raw += base64.b64decode(chunk.content)
else:
raw += base64.b64decode(
chunk.content.encode("utf-8")
)
return base64.b64encode(raw)
async def fetch_document_text(self, document_id, user, timeout=120):
"""Fetch document content and decode as UTF-8 text."""
content = await self.fetch_document_content(
document_id, user, timeout=timeout,
)
return base64.b64decode(content).decode("utf-8")
async def fetch_document_metadata(self, document_id, user, timeout=120):
"""Fetch document metadata from the librarian."""
req = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
)
response = await self.request(req, timeout=timeout)
return response.document_metadata
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="chunk", title=None,
kind="text/plain", timeout=120):
"""Save a child document to the librarian."""
if isinstance(content, str):
content = content.encode("utf-8")
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind=kind,
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
req = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
await self.request(req, timeout=timeout)
return doc_id
async def save_document(self, doc_id, user, content, title=None,
document_type="answer", kind="text/plain",
timeout=120):
"""Save a document to the librarian."""
if isinstance(content, str):
content = content.encode("utf-8")
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind=kind,
title=title or doc_id,
document_type=document_type,
)
req = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
user=user,
)
await self.request(req, timeout=timeout)
return doc_id

View file

@ -8,6 +8,12 @@ logger = logging.getLogger(__name__)
DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None) DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None)
DEFAULT_RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", 'rabbitmq')
DEFAULT_RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", '5672'))
DEFAULT_RABBITMQ_USERNAME = os.getenv("RABBITMQ_USERNAME", 'guest')
DEFAULT_RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", 'guest')
DEFAULT_RABBITMQ_VHOST = os.getenv("RABBITMQ_VHOST", '/')
def get_pubsub(**config): def get_pubsub(**config):
""" """
@ -29,6 +35,15 @@ def get_pubsub(**config):
api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY), api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY),
listener=config.get('pulsar_listener'), listener=config.get('pulsar_listener'),
) )
elif backend_type == 'rabbitmq':
from .rabbitmq_backend import RabbitMQBackend
return RabbitMQBackend(
host=config.get('rabbitmq_host', DEFAULT_RABBITMQ_HOST),
port=config.get('rabbitmq_port', DEFAULT_RABBITMQ_PORT),
username=config.get('rabbitmq_username', DEFAULT_RABBITMQ_USERNAME),
password=config.get('rabbitmq_password', DEFAULT_RABBITMQ_PASSWORD),
vhost=config.get('rabbitmq_vhost', DEFAULT_RABBITMQ_VHOST),
)
else: else:
raise ValueError(f"Unknown pub/sub backend: {backend_type}") raise ValueError(f"Unknown pub/sub backend: {backend_type}")
@ -44,8 +59,9 @@ def add_pubsub_args(parser, standalone=False):
standalone: If True, default host is localhost (for CLI tools standalone: If True, default host is localhost (for CLI tools
that run outside containers) that run outside containers)
""" """
host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST pulsar_host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST
listener_default = 'localhost' if standalone else None pulsar_listener = 'localhost' if standalone else None
rabbitmq_host = 'localhost' if standalone else DEFAULT_RABBITMQ_HOST
parser.add_argument( parser.add_argument(
'--pubsub-backend', '--pubsub-backend',
@ -53,10 +69,11 @@ def add_pubsub_args(parser, standalone=False):
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)', help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
) )
# Pulsar options
parser.add_argument( parser.add_argument(
'-p', '--pulsar-host', '-p', '--pulsar-host',
default=host, default=pulsar_host,
help=f'Pulsar host (default: {host})', help=f'Pulsar host (default: {pulsar_host})',
) )
parser.add_argument( parser.add_argument(
@ -67,6 +84,38 @@ def add_pubsub_args(parser, standalone=False):
parser.add_argument( parser.add_argument(
'--pulsar-listener', '--pulsar-listener',
default=listener_default, default=pulsar_listener,
help=f'Pulsar listener (default: {listener_default or "none"})', help=f'Pulsar listener (default: {pulsar_listener or "none"})',
)
# RabbitMQ options
parser.add_argument(
'--rabbitmq-host',
default=rabbitmq_host,
help=f'RabbitMQ host (default: {rabbitmq_host})',
)
parser.add_argument(
'--rabbitmq-port',
type=int,
default=DEFAULT_RABBITMQ_PORT,
help=f'RabbitMQ port (default: {DEFAULT_RABBITMQ_PORT})',
)
parser.add_argument(
'--rabbitmq-username',
default=DEFAULT_RABBITMQ_USERNAME,
help='RabbitMQ username',
)
parser.add_argument(
'--rabbitmq-password',
default=DEFAULT_RABBITMQ_PASSWORD,
help='RabbitMQ password',
)
parser.add_argument(
'--rabbitmq-vhost',
default=DEFAULT_RABBITMQ_VHOST,
help=f'RabbitMQ vhost (default: {DEFAULT_RABBITMQ_VHOST})',
) )

View file

@ -9,122 +9,14 @@ import pulsar
import _pulsar import _pulsar
import json import json
import logging import logging
import base64 from typing import Any
import types
from dataclasses import asdict, is_dataclass
from typing import Any, get_type_hints
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
from .serialization import dataclass_to_dict, dict_to_dataclass
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def dataclass_to_dict(obj: Any) -> dict:
"""
Recursively convert a dataclass to a dictionary, handling None values and bytes.
None values are excluded from the dictionary (not serialized).
Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior).
Handles nested dataclasses, lists, and dictionaries recursively.
"""
if obj is None:
return None
# Handle bytes - decode to UTF-8 for JSON serialization
if isinstance(obj, bytes):
return obj.decode('utf-8')
# Handle dataclass - convert to dict then recursively process all values
if is_dataclass(obj):
result = {}
for key, value in asdict(obj).items():
result[key] = dataclass_to_dict(value) if value is not None else None
return result
# Handle list - recursively process all items
if isinstance(obj, list):
return [dataclass_to_dict(item) for item in obj]
# Handle dict - recursively process all values
if isinstance(obj, dict):
return {k: dataclass_to_dict(v) for k, v in obj.items()}
# Return primitive types as-is
return obj
def dict_to_dataclass(data: dict, cls: type) -> Any:
"""
Convert a dictionary back to a dataclass instance.
Handles nested dataclasses and missing fields.
Uses get_type_hints() to resolve forward references (string annotations).
"""
if data is None:
return None
if not is_dataclass(cls):
return data
# Get field types from the dataclass, resolving forward references
# get_type_hints() evaluates string annotations like "Triple | None"
try:
field_types = get_type_hints(cls)
except Exception:
# Fallback if get_type_hints fails (shouldn't happen normally)
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
kwargs = {}
for key, value in data.items():
if key in field_types:
field_type = field_types[key]
# Handle modern union types (X | Y)
if isinstance(field_type, types.UnionType):
# Check if it's Optional (X | None)
if type(None) in field_type.__args__:
# Get the non-None type
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Check if this is a generic type (list, dict, etc.)
elif hasattr(field_type, '__origin__'):
# Handle list[T]
if field_type.__origin__ == list:
item_type = field_type.__args__[0] if field_type.__args__ else None
if item_type and is_dataclass(item_type) and isinstance(value, list):
kwargs[key] = [
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
for item in value
]
else:
kwargs[key] = value
# Handle old-style Optional[T] (which is Union[T, None])
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
# Get the non-None type from Union
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Handle direct dataclass fields
elif is_dataclass(field_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, field_type)
# Handle bytes fields (UTF-8 encoded strings from JSON)
elif field_type == bytes and isinstance(value, str):
kwargs[key] = value.encode('utf-8')
else:
kwargs[key] = value
return cls(**kwargs)
class PulsarMessage: class PulsarMessage:
"""Wrapper for Pulsar messages to match Message protocol.""" """Wrapper for Pulsar messages to match Message protocol."""

View file

@ -0,0 +1,390 @@
"""
RabbitMQ backend implementation for pub/sub abstraction.
Uses a single topic exchange per topicspace. The logical queue name
becomes the routing key. Consumer behavior is determined by the
subscription name:
- Same subscription + same topic = shared queue (competing consumers)
- Different subscriptions = separate queues (broadcast / fan-out)
This mirrors Pulsar's subscription model using idiomatic RabbitMQ.
Architecture:
Producer --> [tg exchange] --routing key--> [named queue] --> Consumer
--routing key--> [named queue] --> Consumer
--routing key--> [exclusive q] --> Subscriber
Uses basic_consume (push) instead of basic_get (polling) for
efficient message delivery.
"""
import json
import time
import logging
import queue
import threading
import pika
from typing import Any
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
from .serialization import dataclass_to_dict, dict_to_dataclass
logger = logging.getLogger(__name__)
class RabbitMQMessage:
"""Wrapper for RabbitMQ messages to match Message protocol."""
def __init__(self, method, properties, body, schema_cls):
self._method = method
self._properties = properties
self._body = body
self._schema_cls = schema_cls
self._value = None
def value(self) -> Any:
"""Deserialize and return the message value as a dataclass."""
if self._value is None:
data_dict = json.loads(self._body.decode('utf-8'))
self._value = dict_to_dataclass(data_dict, self._schema_cls)
return self._value
def properties(self) -> dict:
"""Return message properties from AMQP headers."""
headers = self._properties.headers or {}
return dict(headers)
class RabbitMQBackendProducer:
"""Publishes messages to a topic exchange with a routing key.
Uses thread-local connections so each thread gets its own
connection/channel. This avoids wire corruption from concurrent
threads writing to the same socket (pika is not thread-safe).
"""
def __init__(self, connection_params, exchange_name, routing_key,
durable):
self._connection_params = connection_params
self._exchange_name = exchange_name
self._routing_key = routing_key
self._durable = durable
self._local = threading.local()
def _get_channel(self):
"""Get or create a thread-local connection and channel."""
conn = getattr(self._local, 'connection', None)
chan = getattr(self._local, 'channel', None)
if conn is None or not conn.is_open or chan is None or not chan.is_open:
# Close stale connection if any
if conn is not None:
try:
conn.close()
except Exception:
pass
conn = pika.BlockingConnection(self._connection_params)
chan = conn.channel()
chan.exchange_declare(
exchange=self._exchange_name,
exchange_type='topic',
durable=True,
)
self._local.connection = conn
self._local.channel = chan
return chan
def send(self, message: Any, properties: dict = {}) -> None:
data_dict = dataclass_to_dict(message)
json_data = json.dumps(data_dict)
amqp_properties = pika.BasicProperties(
delivery_mode=2 if self._durable else 1,
content_type='application/json',
headers=properties if properties else None,
)
for attempt in range(2):
try:
channel = self._get_channel()
channel.basic_publish(
exchange=self._exchange_name,
routing_key=self._routing_key,
body=json_data.encode('utf-8'),
properties=amqp_properties,
)
return
except Exception as e:
logger.warning(
f"RabbitMQ send failed (attempt {attempt + 1}): {e}"
)
# Force reconnect on next attempt
self._local.connection = None
self._local.channel = None
if attempt == 1:
raise
def flush(self) -> None:
pass
def close(self) -> None:
"""Close the thread-local connection if any."""
conn = getattr(self._local, 'connection', None)
if conn is not None:
try:
conn.close()
except Exception:
pass
self._local.connection = None
self._local.channel = None
class RabbitMQBackendConsumer:
"""Consumes from a queue bound to a topic exchange.
Uses basic_consume (push model) with messages delivered to an
internal thread-safe queue. process_data_events() drives both
message delivery and heartbeat processing.
"""
def __init__(self, connection_params, exchange_name, routing_key,
queue_name, schema_cls, durable, exclusive=False,
auto_delete=False):
self._connection_params = connection_params
self._exchange_name = exchange_name
self._routing_key = routing_key
self._queue_name = queue_name
self._schema_cls = schema_cls
self._durable = durable
self._exclusive = exclusive
self._auto_delete = auto_delete
self._connection = None
self._channel = None
self._consumer_tag = None
self._incoming = queue.Queue()
def _connect(self):
self._connection = pika.BlockingConnection(self._connection_params)
self._channel = self._connection.channel()
# Declare the topic exchange
self._channel.exchange_declare(
exchange=self._exchange_name,
exchange_type='topic',
durable=True,
)
# Declare the queue — anonymous if exclusive
result = self._channel.queue_declare(
queue=self._queue_name,
durable=self._durable,
exclusive=self._exclusive,
auto_delete=self._auto_delete,
)
# Capture actual name (important for anonymous queues where name='')
self._queue_name = result.method.queue
self._channel.queue_bind(
queue=self._queue_name,
exchange=self._exchange_name,
routing_key=self._routing_key,
)
self._channel.basic_qos(prefetch_count=1)
# Register push-based consumer
self._consumer_tag = self._channel.basic_consume(
queue=self._queue_name,
on_message_callback=self._on_message,
auto_ack=False,
)
def _on_message(self, channel, method, properties, body):
"""Callback invoked by pika when a message arrives."""
self._incoming.put((method, properties, body))
def _is_alive(self):
return (
self._connection is not None
and self._connection.is_open
and self._channel is not None
and self._channel.is_open
)
def receive(self, timeout_millis: int = 2000) -> Message:
"""Receive a message. Raises TimeoutError if none available."""
if not self._is_alive():
self._connect()
timeout_seconds = timeout_millis / 1000.0
deadline = time.monotonic() + timeout_seconds
while time.monotonic() < deadline:
# Check if a message was already delivered
try:
method, properties, body = self._incoming.get_nowait()
return RabbitMQMessage(
method, properties, body, self._schema_cls,
)
except queue.Empty:
pass
# Drive pika's I/O — delivers messages and processes heartbeats
remaining = deadline - time.monotonic()
if remaining > 0:
self._connection.process_data_events(
time_limit=min(0.1, remaining),
)
raise TimeoutError("No message received within timeout")
def acknowledge(self, message: Message) -> None:
if isinstance(message, RabbitMQMessage) and message._method:
self._channel.basic_ack(
delivery_tag=message._method.delivery_tag,
)
def negative_acknowledge(self, message: Message) -> None:
if isinstance(message, RabbitMQMessage) and message._method:
self._channel.basic_nack(
delivery_tag=message._method.delivery_tag,
requeue=True,
)
def unsubscribe(self) -> None:
if self._consumer_tag and self._channel and self._channel.is_open:
try:
self._channel.basic_cancel(self._consumer_tag)
except Exception:
pass
self._consumer_tag = None
def close(self) -> None:
self.unsubscribe()
try:
if self._channel and self._channel.is_open:
self._channel.close()
except Exception:
pass
try:
if self._connection and self._connection.is_open:
self._connection.close()
except Exception:
pass
self._channel = None
self._connection = None
class RabbitMQBackend:
"""RabbitMQ pub/sub backend using a topic exchange per topicspace."""
def __init__(self, host='localhost', port=5672, username='guest',
password='guest', vhost='/'):
self._connection_params = pika.ConnectionParameters(
host=host,
port=port,
virtual_host=vhost,
credentials=pika.PlainCredentials(username, password),
)
logger.info(f"RabbitMQ backend: {host}:{port} vhost={vhost}")
def _parse_queue_id(self, queue_id: str) -> tuple[str, str, str, bool]:
"""
Parse queue identifier into exchange, routing key, and durability.
Format: class:topicspace:topic
Returns: (exchange_name, routing_key, class, durable)
"""
if ':' not in queue_id:
return 'tg', queue_id, 'flow', False
parts = queue_id.split(':', 2)
if len(parts) != 3:
raise ValueError(
f"Invalid queue format: {queue_id}, "
f"expected class:topicspace:topic"
)
cls, topicspace, topic = parts
if cls in ('flow', 'state'):
durable = True
elif cls in ('request', 'response'):
durable = False
else:
raise ValueError(
f"Invalid queue class: {cls}, "
f"expected flow, request, response, or state"
)
# Exchange per topicspace, routing key includes class
exchange_name = topicspace
routing_key = f"{cls}.{topic}"
return exchange_name, routing_key, cls, durable
# Keep map_queue_name for backward compatibility with tests
def map_queue_name(self, queue_id: str) -> tuple[str, bool]:
exchange, routing_key, cls, durable = self._parse_queue_id(queue_id)
return f"{exchange}.{routing_key}", durable
def create_producer(self, topic: str, schema: type,
**options) -> BackendProducer:
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
logger.debug(
f"Creating producer: exchange={exchange}, "
f"routing_key={routing_key}"
)
return RabbitMQBackendProducer(
self._connection_params, exchange, routing_key, durable,
)
def create_consumer(self, topic: str, subscription: str, schema: type,
initial_position: str = 'latest',
consumer_type: str = 'shared',
**options) -> BackendConsumer:
"""Create a consumer with a queue bound to the topic exchange.
consumer_type='shared': Named durable queue. Multiple consumers
with the same subscription compete (round-robin).
consumer_type='exclusive': Anonymous ephemeral queue. Each
consumer gets its own copy of every message (broadcast).
"""
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
if consumer_type == 'exclusive' and cls == 'state':
# State broadcast: named durable queue per subscriber.
# Retains messages so late-starting processors see current state.
queue_name = f"{exchange}.{routing_key}.{subscription}"
queue_durable = True
exclusive = False
auto_delete = False
elif consumer_type == 'exclusive':
# Broadcast: anonymous queue, auto-deleted on disconnect
queue_name = ''
queue_durable = False
exclusive = True
auto_delete = True
else:
# Shared: named queue, competing consumers
queue_name = f"{exchange}.{routing_key}.{subscription}"
queue_durable = durable
exclusive = False
auto_delete = False
logger.debug(
f"Creating consumer: exchange={exchange}, "
f"routing_key={routing_key}, queue={queue_name or '(anonymous)'}, "
f"type={consumer_type}"
)
return RabbitMQBackendConsumer(
self._connection_params, exchange, routing_key,
queue_name, schema, queue_durable, exclusive, auto_delete,
)
def close(self) -> None:
pass

View file

@ -0,0 +1,115 @@
"""
JSON serialization helpers for dataclass dict conversion.
Used by pub/sub backends that use JSON as their wire format.
"""
import types
from dataclasses import asdict, is_dataclass
from typing import Any, get_type_hints
def dataclass_to_dict(obj: Any) -> dict:
"""
Recursively convert a dataclass to a dictionary, handling None values and bytes.
None values are excluded from the dictionary (not serialized).
Bytes values are decoded as UTF-8 strings for JSON serialization.
Handles nested dataclasses, lists, and dictionaries recursively.
"""
if obj is None:
return None
# Handle bytes - decode to UTF-8 for JSON serialization
if isinstance(obj, bytes):
return obj.decode('utf-8')
# Handle dataclass - convert to dict then recursively process all values
if is_dataclass(obj):
result = {}
for key, value in asdict(obj).items():
result[key] = dataclass_to_dict(value) if value is not None else None
return result
# Handle list - recursively process all items
if isinstance(obj, list):
return [dataclass_to_dict(item) for item in obj]
# Handle dict - recursively process all values
if isinstance(obj, dict):
return {k: dataclass_to_dict(v) for k, v in obj.items()}
# Return primitive types as-is
return obj
def dict_to_dataclass(data: dict, cls: type) -> Any:
"""
Convert a dictionary back to a dataclass instance.
Handles nested dataclasses and missing fields.
Uses get_type_hints() to resolve forward references (string annotations).
"""
if data is None:
return None
if not is_dataclass(cls):
return data
# Get field types from the dataclass, resolving forward references
# get_type_hints() evaluates string annotations like "Triple | None"
try:
field_types = get_type_hints(cls)
except Exception:
# Fallback if get_type_hints fails (shouldn't happen normally)
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
kwargs = {}
for key, value in data.items():
if key in field_types:
field_type = field_types[key]
# Handle modern union types (X | Y)
if isinstance(field_type, types.UnionType):
# Check if it's Optional (X | None)
if type(None) in field_type.__args__:
# Get the non-None type
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Check if this is a generic type (list, dict, etc.)
elif hasattr(field_type, '__origin__'):
# Handle list[T]
if field_type.__origin__ == list:
item_type = field_type.__args__[0] if field_type.__args__ else None
if item_type and is_dataclass(item_type) and isinstance(value, list):
kwargs[key] = [
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
for item in value
]
else:
kwargs[key] = value
# Handle old-style Optional[T] (which is Union[T, None])
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
# Get the non-None type from Union
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Handle direct dataclass fields
elif is_dataclass(field_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, field_type)
# Handle bytes fields (UTF-8 encoded strings from JSON)
elif field_type == bytes and isinstance(value, str):
kwargs[key] = value.encode('utf-8')
else:
kwargs[key] = value
return cls(**kwargs)

View file

@ -51,7 +51,7 @@ class Subscriber:
topic=self.topic, topic=self.topic,
subscription=self.subscription, subscription=self.subscription,
schema=self.schema, schema=self.schema,
consumer_type='shared', consumer_type='exclusive',
) )
self.task = asyncio.create_task(self.run()) self.task = asyncio.create_task(self.run())

View file

@ -18,9 +18,7 @@ class BaseClient:
output_queue=None, output_queue=None,
input_schema=None, input_schema=None,
output_schema=None, output_schema=None,
pulsar_host="pulsar://pulsar:6650", **pubsub_config,
pulsar_api_key=None,
listener=None,
): ):
if input_queue == None: raise RuntimeError("Need input_queue") if input_queue == None: raise RuntimeError("Need input_queue")
@ -32,12 +30,7 @@ class BaseClient:
subscriber = str(uuid.uuid4()) subscriber = str(uuid.uuid4())
# Create backend using factory # Create backend using factory
self.backend = get_pubsub( self.backend = get_pubsub(**pubsub_config)
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
pulsar_listener=listener,
pubsub_backend='pulsar'
)
self.producer = self.backend.create_producer( self.producer = self.backend.create_producer(
topic=input_queue, topic=input_queue,

View file

@ -33,9 +33,7 @@ class ConfigClient(BaseClient):
subscriber=None, subscriber=None,
input_queue=None, input_queue=None,
output_queue=None, output_queue=None,
pulsar_host="pulsar://pulsar:6650", **pubsub_config,
listener=None,
pulsar_api_key=None,
): ):
if input_queue == None: if input_queue == None:
@ -48,11 +46,9 @@ class ConfigClient(BaseClient):
subscriber=subscriber, subscriber=subscriber,
input_queue=input_queue, input_queue=input_queue,
output_queue=output_queue, output_queue=output_queue,
pulsar_host=pulsar_host,
pulsar_api_key=pulsar_api_key,
input_schema=ConfigRequest, input_schema=ConfigRequest,
output_schema=ConfigResponse, output_schema=ConfigResponse,
listener=listener, **pubsub_config,
) )
def get(self, keys, timeout=300): def get(self, keys, timeout=300):

View file

@ -24,10 +24,13 @@ from ..core.metadata import Metadata
# <- (document_metadata) # <- (document_metadata)
# <- (error) # <- (error)
# get-document-content # get-document-content [DEPRECATED — use stream-document instead]
# -> (document_id) # -> (document_id)
# <- (content) # <- (content)
# <- (error) # <- (error)
# NOTE: Returns entire document in a single message. Fails for documents
# exceeding the broker's max message size. Use stream-document which
# returns content in chunks.
# add-processing # add-processing
# -> (processing_id, processing_metadata) # -> (processing_id, processing_metadata)
@ -220,5 +223,5 @@ class LibrarianResponse:
# FIXME: Is this right? Using persistence on librarian so that # FIXME: Is this right? Using persistence on librarian so that
# message chunking works # message chunking works
librarian_request_queue = queue('librarian-request', cls='flow') librarian_request_queue = queue('librarian', cls='request')
librarian_response_queue = queue('librarian-response', cls='flow') librarian_response_queue = queue('librarian', cls='response')

View file

@ -354,10 +354,8 @@ IMPORTANT:
output_file=args.output, output_file=args.output,
subscriber_name=args.subscriber, subscriber_name=args.subscriber,
append_mode=args.append, append_mode=args.append,
pubsub_backend=args.pubsub_backend, **{k: v for k, v in vars(args).items()
pulsar_host=args.pulsar_host, if k not in ('queues', 'output', 'subscriber', 'append')},
pulsar_api_key=args.pulsar_api_key,
pulsar_listener=args.pulsar_listener,
)) ))
except KeyboardInterrupt: except KeyboardInterrupt:
# Already handled in async_main # Already handled in async_main

View file

@ -1,5 +1,8 @@
""" """
Initialises Pulsar with Trustgraph tenant / namespaces & policy. Initialises TrustGraph pub/sub infrastructure and pushes initial config.
For Pulsar: creates tenant, namespaces, and retention policies.
For RabbitMQ: queues are auto-declared, so only config push is needed.
""" """
import requests import requests
@ -8,10 +11,11 @@ import argparse
import json import json
from trustgraph.clients.config_client import ConfigClient from trustgraph.clients.config_client import ConfigClient
from trustgraph.base.pubsub import add_pubsub_args
default_pulsar_admin_url = "http://pulsar:8080" default_pulsar_admin_url = "http://pulsar:8080"
default_pulsar_host = "pulsar://pulsar:6650" subscriber = "tg-init-pubsub"
subscriber = "tg-init-pulsar"
def get_clusters(url): def get_clusters(url):
@ -65,12 +69,11 @@ def ensure_namespace(url, tenant, namespace, config):
print(f"Namespace {tenant}/{namespace} created.", flush=True) print(f"Namespace {tenant}/{namespace} created.", flush=True)
def ensure_config(config, pulsar_host, pulsar_api_key): def ensure_config(config, **pubsub_config):
cli = ConfigClient( cli = ConfigClient(
subscriber=subscriber, subscriber=subscriber,
pulsar_host=pulsar_host, **pubsub_config,
pulsar_api_key=pulsar_api_key,
) )
while True: while True:
@ -115,11 +118,9 @@ def ensure_config(config, pulsar_host, pulsar_api_key):
time.sleep(2) time.sleep(2)
print("Retrying...", flush=True) print("Retrying...", flush=True)
continue continue
def init( def init_pulsar(pulsar_admin_url, tenant):
pulsar_admin_url, pulsar_host, pulsar_api_key, tenant, """Pulsar-specific setup: create tenant, namespaces, retention policies."""
config, config_file,
):
clusters = get_clusters(pulsar_admin_url) clusters = get_clusters(pulsar_admin_url)
@ -145,17 +146,21 @@ def init(
} }
}) })
if config is not None:
def push_config(config_json, config_file, **pubsub_config):
"""Push initial config if provided."""
if config_json is not None:
try: try:
print("Decoding config...", flush=True) print("Decoding config...", flush=True)
dec = json.loads(config) dec = json.loads(config_json)
print("Decoded.", flush=True) print("Decoded.", flush=True)
except Exception as e: except Exception as e:
print("Exception:", e, flush=True) print("Exception:", e, flush=True)
raise e raise e
ensure_config(dec, pulsar_host, pulsar_api_key) ensure_config(dec, **pubsub_config)
elif config_file is not None: elif config_file is not None:
@ -167,11 +172,12 @@ def init(
print("Exception:", e, flush=True) print("Exception:", e, flush=True)
raise e raise e
ensure_config(dec, pulsar_host, pulsar_api_key) ensure_config(dec, **pubsub_config)
else: else:
print("No config to update.", flush=True) print("No config to update.", flush=True)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -180,22 +186,11 @@ def main():
) )
parser.add_argument( parser.add_argument(
'-p', '--pulsar-admin-url', '--pulsar-admin-url',
default=default_pulsar_admin_url, default=default_pulsar_admin_url,
help=f'Pulsar admin URL (default: {default_pulsar_admin_url})', help=f'Pulsar admin URL (default: {default_pulsar_admin_url})',
) )
parser.add_argument(
'--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'--pulsar-api-key',
help=f'Pulsar API key',
)
parser.add_argument( parser.add_argument(
'-c', '--config', '-c', '--config',
help=f'Initial configuration to load', help=f'Initial configuration to load',
@ -212,18 +207,43 @@ def main():
help=f'Tenant (default: tg)', help=f'Tenant (default: tg)',
) )
add_pubsub_args(parser)
args = parser.parse_args() args = parser.parse_args()
backend_type = args.pubsub_backend
# Extract pubsub config from args
pubsub_config = {
k: v for k, v in vars(args).items()
if k not in ('pulsar_admin_url', 'config', 'config_file', 'tenant')
}
while True: while True:
try: try:
print(flush=True) # Pulsar-specific setup (tenants, namespaces)
print( if backend_type == 'pulsar':
f"Initialising with Pulsar {args.pulsar_admin_url}...", print(flush=True)
flush=True print(
f"Initialising Pulsar at {args.pulsar_admin_url}...",
flush=True,
)
init_pulsar(args.pulsar_admin_url, args.tenant)
else:
print(flush=True)
print(
f"Using {backend_type} backend (no admin setup needed).",
flush=True,
)
# Push config (works with any backend)
push_config(
args.config, args.config_file,
**pubsub_config,
) )
init(**vars(args))
print("Initialisation complete.", flush=True) print("Initialisation complete.", flush=True)
break break
@ -236,4 +256,4 @@ def main():
print("Will retry...", flush=True) print("Will retry...", flush=True)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View file

@ -316,10 +316,8 @@ def main():
queue_type=args.queue_type, queue_type=args.queue_type,
max_lines=args.max_lines, max_lines=args.max_lines,
max_width=args.max_width, max_width=args.max_width,
pulsar_host=args.pulsar_host, **{k: v for k, v in vars(args).items()
pulsar_api_key=args.pulsar_api_key, if k not in ('flow', 'queue_type', 'max_lines', 'max_width')},
pulsar_listener=args.pulsar_listener,
pubsub_backend=args.pubsub_backend,
)) ))
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass

View file

@ -133,7 +133,7 @@ class Processor(ChunkingService):
chunk_length = len(chunk.page_content) chunk_length = len(chunk.page_content)
# Save chunk to librarian as child document # Save chunk to librarian as child document
await self.save_child_document( await self.librarian.save_child_document(
doc_id=chunk_doc_id, doc_id=chunk_doc_id,
parent_id=parent_doc_id, parent_id=parent_doc_id,
user=v.metadata.user, user=v.metadata.user,

View file

@ -131,7 +131,7 @@ class Processor(ChunkingService):
chunk_length = len(chunk.page_content) chunk_length = len(chunk.page_content)
# Save chunk to librarian as child document # Save chunk to librarian as child document
await self.save_child_document( await self.librarian.save_child_document(
doc_id=chunk_doc_id, doc_id=chunk_doc_id,
parent_id=parent_doc_id, parent_id=parent_doc_id,
user=v.metadata.user, user=v.metadata.user,

View file

@ -9,20 +9,16 @@ for large documents.
from pypdf import PdfWriter, PdfReader from pypdf import PdfWriter, PdfReader
from io import BytesIO from io import BytesIO
import asyncio
import base64 import base64
import uuid import uuid
import os import os
from mistralai import Mistral from mistralai import Mistral
from mistralai.models import OCRResponse
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples, document_uri, page_uri as make_page_uri, derived_entity_triples,
@ -102,42 +98,10 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client for fetching document content # Librarian client
librarian_request_q = params.get( self.librarian = LibrarianClient(
"librarian_request_queue", default_librarian_request_queue id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor = id, flow = None, name = "librarian-request"
)
self.librarian_request_producer = Producer(
backend = self.pubsub,
topic = librarian_request_q,
schema = LibrarianRequest,
metrics = librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor = id, flow = None, name = "librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub,
flow = None,
topic = librarian_response_q,
subscriber = f"{id}-librarian",
schema = LibrarianResponse,
handler = self.on_librarian_response,
metrics = librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
if api_key is None: if api_key is None:
raise RuntimeError("Mistral API key not specified") raise RuntimeError("Mistral API key not specified")
@ -151,132 +115,7 @@ class Processor(FlowProcessor):
async def start(self): async def start(self):
await super(Processor, self).start() await super(Processor, self).start()
await self.librarian_request_producer.start() await self.librarian.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
async def fetch_document_metadata(self, document_id, user, timeout=120):
"""
Fetch document metadata from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
)
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.document_metadata
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
async def fetch_document_content(self, document_id, user, timeout=120):
"""
Fetch document content from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.content
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching document {document_id}")
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="page", title=None, timeout=120):
"""
Save a child document to the librarian.
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
request = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving child document {doc_id}")
def ocr(self, blob): def ocr(self, blob):
""" """
@ -359,7 +198,7 @@ class Processor(FlowProcessor):
# Check MIME type if fetching from librarian # Check MIME type if fetching from librarian
if v.document_id: if v.document_id:
doc_meta = await self.fetch_document_metadata( doc_meta = await self.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
user=v.metadata.user, user=v.metadata.user,
) )
@ -374,7 +213,7 @@ class Processor(FlowProcessor):
# Get PDF content - fetch from librarian or use inline data # Get PDF content - fetch from librarian or use inline data
if v.document_id: if v.document_id:
logger.info(f"Fetching document {v.document_id} from librarian...") logger.info(f"Fetching document {v.document_id} from librarian...")
content = await self.fetch_document_content( content = await self.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
user=v.metadata.user, user=v.metadata.user,
) )
@ -401,7 +240,7 @@ class Processor(FlowProcessor):
page_content = markdown.encode("utf-8") page_content = markdown.encode("utf-8")
# Save page as child document in librarian # Save page as child document in librarian
await self.save_child_document( await self.librarian.save_child_document(
doc_id=page_doc_id, doc_id=page_doc_id,
parent_id=source_doc_id, parent_id=source_doc_id,
user=v.metadata.user, user=v.metadata.user,

View file

@ -7,20 +7,16 @@ Supports both inline document data and fetching from librarian via Pulsar
for large documents. for large documents.
""" """
import asyncio
import os import os
import tempfile import tempfile
import base64 import base64
import logging import logging
import uuid
from langchain_community.document_loaders import PyPDFLoader from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples, document_uri, page_uri as make_page_uri, derived_entity_triples,
@ -74,187 +70,16 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client for fetching document content # Librarian client
librarian_request_q = params.get( self.librarian = LibrarianClient(
"librarian_request_queue", default_librarian_request_queue id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor = id, flow = None, name = "librarian-request"
)
self.librarian_request_producer = Producer(
backend = self.pubsub,
topic = librarian_request_q,
schema = LibrarianRequest,
metrics = librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor = id, flow = None, name = "librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub,
flow = None,
topic = librarian_response_q,
subscriber = f"{id}-librarian",
schema = LibrarianResponse,
handler = self.on_librarian_response,
metrics = librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
logger.info("PDF decoder initialized") logger.info("PDF decoder initialized")
async def start(self): async def start(self):
await super(Processor, self).start() await super(Processor, self).start()
await self.librarian_request_producer.start() await self.librarian.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
async def fetch_document_metadata(self, document_id, user, timeout=120):
"""
Fetch document metadata from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
)
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.document_metadata
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
async def fetch_document_content(self, document_id, user, timeout=120):
"""
Fetch document content from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.content
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching document {document_id}")
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="page", title=None, timeout=120):
"""
Save a child document to the librarian.
Args:
doc_id: ID for the new child document
parent_id: ID of the parent document
user: User ID
content: Document content (bytes)
document_type: Type of document ("page", "chunk", etc.)
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
import base64
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
request = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving child document {doc_id}")
async def on_message(self, msg, consumer, flow): async def on_message(self, msg, consumer, flow):
@ -266,7 +91,7 @@ class Processor(FlowProcessor):
# Check MIME type if fetching from librarian # Check MIME type if fetching from librarian
if v.document_id: if v.document_id:
doc_meta = await self.fetch_document_metadata( doc_meta = await self.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
user=v.metadata.user, user=v.metadata.user,
) )
@ -287,7 +112,7 @@ class Processor(FlowProcessor):
logger.info(f"Fetching document {v.document_id} from librarian...") logger.info(f"Fetching document {v.document_id} from librarian...")
fp.close() fp.close()
content = await self.fetch_document_content( content = await self.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
user=v.metadata.user, user=v.metadata.user,
) )
@ -323,7 +148,7 @@ class Processor(FlowProcessor):
page_content = page.page_content.encode("utf-8") page_content = page.page_content.encode("utf-8")
# Save page as child document in librarian # Save page as child document in librarian
await self.save_child_document( await self.librarian.save_child_document(
doc_id=page_doc_id, doc_id=page_doc_id,
parent_id=source_doc_id, parent_id=source_doc_id,
user=v.metadata.user, user=v.metadata.user,

View file

@ -10,7 +10,7 @@ import logging
import os import os
from trustgraph.base.logging import setup_logging from trustgraph.base.logging import setup_logging
from trustgraph.base.pubsub import get_pubsub from trustgraph.base.pubsub import get_pubsub, add_pubsub_args
from . auth import Authenticator from . auth import Authenticator
from . config.receiver import ConfigReceiver from . config.receiver import ConfigReceiver
@ -167,30 +167,7 @@ def run():
help='Service identifier for logging and metrics (default: api-gateway)', help='Service identifier for logging and metrics (default: api-gateway)',
) )
# Pub/sub backend selection add_pubsub_args(parser)
parser.add_argument(
'--pubsub-backend',
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
choices=['pulsar', 'mqtt'],
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
)
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'--pulsar-api-key',
default=default_pulsar_api_key,
help=f'Pulsar API key',
)
parser.add_argument(
'--pulsar-listener',
help=f'Pulsar listener (default: none)',
)
parser.add_argument( parser.add_argument(
'-m', '--prometheus-url', '-m', '--prometheus-url',

View file

@ -12,22 +12,18 @@ import uuid
from ... schema import DocumentRagQuery, DocumentRagResponse, Error from ... schema import DocumentRagQuery, DocumentRagResponse, Error
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... provenance import GRAPH_RETRIEVAL from ... provenance import GRAPH_RETRIEVAL
from . document_rag import DocumentRag from . document_rag import DocumentRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import DocumentEmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec
from ... base import Consumer, Producer from ... base import LibrarianClient
from ... base import ConsumerMetrics, ProducerMetrics
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
default_ident = "document-rag" default_ident = "document-rag"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor): class Processor(FlowProcessor):
@ -89,111 +85,26 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client for fetching chunk content from Garage # Librarian client
librarian_request_q = params.get( self.librarian = LibrarianClient(
"librarian_request_queue", default_librarian_request_queue id=id,
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub, backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup, taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
) )
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
async def start(self): async def start(self):
await super(Processor, self).start() await super(Processor, self).start()
await self.librarian_request_producer.start() await self.librarian.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
async def fetch_chunk_content(self, chunk_id, user, timeout=120): async def fetch_chunk_content(self, chunk_id, user, timeout=120):
"""Fetch chunk content from librarian/Garage.""" """Fetch chunk content from librarian. Chunks are small so
import uuid single request-response is fine."""
request_id = str(uuid.uuid4()) return await self.librarian.fetch_document_text(
document_id=chunk_id, user=user, timeout=timeout,
request = LibrarianRequest(
operation="get-document-content",
document_id=chunk_id,
user=user,
) )
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
# Content is base64 encoded
content = response.content
if isinstance(content, str):
content = content.encode('utf-8')
return base64.b64decode(content).decode("utf-8")
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching chunk {chunk_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
""" """Save answer content to the librarian."""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata( doc_metadata = DocumentMetadata(
id=doc_id, id=doc_id,
@ -211,29 +122,8 @@ class Processor(FlowProcessor):
user=user, user=user,
) )
# Create future for response await self.librarian.request(request, timeout=timeout)
future = asyncio.get_event_loop().create_future() return doc_id
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_request(self, msg, consumer, flow): async def on_request(self, msg, consumer, flow):
@ -390,4 +280,3 @@ class Processor(FlowProcessor):
def run(): def run():
Processor.launch(default_ident, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -7,19 +7,15 @@ Supports both inline document data and fetching from librarian via Pulsar
for large documents. for large documents.
""" """
import asyncio
import base64 import base64
import logging import logging
import uuid
import pytesseract import pytesseract
from pdf2image import convert_from_bytes from pdf2image import convert_from_bytes
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples, document_uri, page_uri as make_page_uri, derived_entity_triples,
@ -72,173 +68,16 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client for fetching document content # Librarian client
librarian_request_q = params.get( self.librarian = LibrarianClient(
"librarian_request_queue", default_librarian_request_queue id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor = id, flow = None, name = "librarian-request"
)
self.librarian_request_producer = Producer(
backend = self.pubsub,
topic = librarian_request_q,
schema = LibrarianRequest,
metrics = librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor = id, flow = None, name = "librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub,
flow = None,
topic = librarian_response_q,
subscriber = f"{id}-librarian",
schema = LibrarianResponse,
handler = self.on_librarian_response,
metrics = librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
logger.info("PDF OCR processor initialized") logger.info("PDF OCR processor initialized")
async def start(self): async def start(self):
await super(Processor, self).start() await super(Processor, self).start()
await self.librarian_request_producer.start() await self.librarian.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
async def fetch_document_metadata(self, document_id, user, timeout=120):
"""
Fetch document metadata from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
)
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.document_metadata
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching metadata for {document_id}")
async def fetch_document_content(self, document_id, user, timeout=120):
"""
Fetch document content from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.content
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching document {document_id}")
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="page", title=None, timeout=120):
"""
Save a child document to the librarian.
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
request = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving child document {doc_id}")
async def on_message(self, msg, consumer, flow): async def on_message(self, msg, consumer, flow):
@ -250,7 +89,7 @@ class Processor(FlowProcessor):
# Check MIME type if fetching from librarian # Check MIME type if fetching from librarian
if v.document_id: if v.document_id:
doc_meta = await self.fetch_document_metadata( doc_meta = await self.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
user=v.metadata.user, user=v.metadata.user,
) )
@ -265,7 +104,7 @@ class Processor(FlowProcessor):
# Get PDF content - fetch from librarian or use inline data # Get PDF content - fetch from librarian or use inline data
if v.document_id: if v.document_id:
logger.info(f"Fetching document {v.document_id} from librarian...") logger.info(f"Fetching document {v.document_id} from librarian...")
content = await self.fetch_document_content( content = await self.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
user=v.metadata.user, user=v.metadata.user,
) )
@ -299,7 +138,7 @@ class Processor(FlowProcessor):
page_content = text.encode("utf-8") page_content = text.encode("utf-8")
# Save page as child document in librarian # Save page as child document in librarian
await self.save_child_document( await self.librarian.save_child_document(
doc_id=page_doc_id, doc_id=page_doc_id,
parent_id=source_doc_id, parent_id=source_doc_id,
user=v.metadata.user, user=v.metadata.user,

View file

@ -14,22 +14,18 @@ Tables are preserved as HTML markup for better downstream extraction.
Images are stored in the librarian but not sent to the text pipeline. Images are stored in the librarian but not sent to the text pipeline.
""" """
import asyncio
import base64 import base64
import logging import logging
import magic import magic
import tempfile import tempfile
import os import os
import uuid
from unstructured.partition.auto import partition from unstructured.partition.auto import partition
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, document_uri, page_uri as make_page_uri,
@ -166,128 +162,16 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client for fetching/storing document content # Librarian client
librarian_request_q = params.get( self.librarian = LibrarianClient(
"librarian_request_queue", default_librarian_request_queue id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
logger.info("Universal decoder initialized") logger.info("Universal decoder initialized")
async def start(self): async def start(self):
await super(Processor, self).start() await super(Processor, self).start()
await self.librarian_request_producer.start() await self.librarian.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
async def _librarian_request(self, request, timeout=120):
"""Send a request to the librarian and wait for response."""
request_id = str(uuid.uuid4())
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: "
f"{response.error.message}"
)
return response
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError("Timeout waiting for librarian response")
async def fetch_document_metadata(self, document_id, user):
"""Fetch document metadata from the librarian."""
request = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
)
response = await self._librarian_request(request)
return response.document_metadata
async def fetch_document_content(self, document_id, user):
"""Fetch document content from the librarian."""
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
response = await self._librarian_request(request)
return response.content
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="page", title=None,
kind="text/plain"):
"""Save a child document to the librarian."""
if isinstance(content, str):
content = content.encode("utf-8")
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind=kind,
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
request = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
await self._librarian_request(request)
return doc_id
def extract_elements(self, blob, mime_type=None): def extract_elements(self, blob, mime_type=None):
""" """
@ -388,7 +272,7 @@ class Processor(FlowProcessor):
page_content = text.encode("utf-8") page_content = text.encode("utf-8")
# Save to librarian # Save to librarian
await self.save_child_document( await self.librarian.save_child_document(
doc_id=doc_id, doc_id=doc_id,
parent_id=parent_doc_id, parent_id=parent_doc_id,
user=metadata.user, user=metadata.user,
@ -469,7 +353,7 @@ class Processor(FlowProcessor):
# Save to librarian # Save to librarian
if img_content: if img_content:
await self.save_child_document( await self.librarian.save_child_document(
doc_id=img_uri, doc_id=img_uri,
parent_id=parent_doc_id, parent_id=parent_doc_id,
user=metadata.user, user=metadata.user,
@ -518,13 +402,13 @@ class Processor(FlowProcessor):
f"Fetching document {v.document_id} from librarian..." f"Fetching document {v.document_id} from librarian..."
) )
doc_meta = await self.fetch_document_metadata( doc_meta = await self.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
user=v.metadata.user, user=v.metadata.user,
) )
mime_type = doc_meta.kind if doc_meta else None mime_type = doc_meta.kind if doc_meta else None
content = await self.fetch_document_content( content = await self.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
user=v.metadata.user, user=v.metadata.user,
) )