diff --git a/Makefile b/Makefile index 197a6c63..4d79f554 100644 --- a/Makefile +++ b/Makefile @@ -77,8 +77,8 @@ some-containers: -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . ${DOCKER} build -f containers/Containerfile.flow \ -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . -# ${DOCKER} build -f containers/Containerfile.unstructured \ -# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . + ${DOCKER} build -f containers/Containerfile.unstructured \ + -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . # ${DOCKER} build -f containers/Containerfile.vertexai \ # -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . # ${DOCKER} build -f containers/Containerfile.mcp \ diff --git a/specs/api/components/schemas/librarian/LibrarianRequest.yaml b/specs/api/components/schemas/librarian/LibrarianRequest.yaml index 18aa94b1..eed999f0 100644 --- a/specs/api/components/schemas/librarian/LibrarianRequest.yaml +++ b/specs/api/components/schemas/librarian/LibrarianRequest.yaml @@ -3,6 +3,9 @@ description: | Librarian service request for document library management. Operations: add-document, remove-document, list-documents, + get-document-metadata, stream-document, add-child-document, + list-children, begin-upload, upload-chunk, complete-upload, + abort-upload, get-upload-status, list-uploads, start-processing, stop-processing, list-processing required: - operation @@ -13,6 +16,17 @@ properties: - add-document - remove-document - list-documents + - get-document-metadata + - get-document-content + - stream-document + - add-child-document + - list-children + - begin-upload + - upload-chunk + - complete-upload + - abort-upload + - get-upload-status + - list-uploads - start-processing - stop-processing - list-processing @@ -21,6 +35,21 @@ properties: - `add-document`: Add document to library - `remove-document`: Remove document from library - `list-documents`: List documents in library + - `get-document-metadata`: Get document metadata + - `get-document-content`: Get full document content in a single response. + **Deprecated** — use `stream-document` instead. Fails for documents + exceeding the broker's max message size. + - `stream-document`: Stream document content in chunks. Each response + includes `chunk_index` and `is_final`. Preferred over `get-document-content` + for all document sizes. + - `add-child-document`: Add a child document (e.g. page, chunk) + - `list-children`: List child documents of a parent + - `begin-upload`: Start a chunked upload session + - `upload-chunk`: Upload a chunk of data + - `complete-upload`: Finalize a chunked upload + - `abort-upload`: Cancel a chunked upload + - `get-upload-status`: Check upload progress + - `list-uploads`: List active upload sessions - `start-processing`: Start processing library documents - `stop-processing`: Stop library processing - `list-processing`: List processing status diff --git a/tests/unit/test_chunking/test_recursive_chunker.py b/tests/unit/test_chunking/test_recursive_chunker.py index ae05d22c..a5ec59c8 100644 --- a/tests/unit/test_chunking/test_recursive_chunker.py +++ b/tests/unit/test_chunking/test_recursive_chunker.py @@ -24,8 +24,8 @@ class MockAsyncProcessor: class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): """Test Recursive chunker functionality""" - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) def test_processor_initialization_basic(self, mock_producer, mock_consumer): """Test basic processor initialization""" @@ -51,8 +51,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']] assert len(param_specs) == 2 - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-size parameter override""" @@ -71,7 +71,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 2000, # Override chunk size "chunk-overlap": None # Use default chunk overlap }.get(param) @@ -85,8 +85,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 2000 # Should use overridden value assert chunk_overlap == 100 # Should use default value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-overlap parameter override""" @@ -105,7 +105,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": None, # Use default chunk size "chunk-overlap": 200 # Override chunk overlap }.get(param) @@ -119,8 +119,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 1000 # Should use default value assert chunk_overlap == 200 # Should use overridden value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer): """Test chunk_document with both chunk-size and chunk-overlap overrides""" @@ -139,7 +139,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 1500, # Override chunk size "chunk-overlap": 150 # Override chunk overlap }.get(param) @@ -153,8 +153,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 1500 # Should use overridden value assert chunk_overlap == 150 # Should use overridden value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer): @@ -177,7 +177,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Mock save_child_document to avoid waiting for librarian response - processor.save_child_document = AsyncMock(return_value="mock-doc-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id") # Mock message with TextDocument mock_message = MagicMock() @@ -196,12 +196,14 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_producer = AsyncMock() mock_triples_producer = AsyncMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 1500, "chunk-overlap": 150, + }.get(param) + mock_flow.side_effect = lambda name: { "output": mock_producer, "triples": mock_triples_producer, - }.get(param) + }.get(name) # Act await processor.on_message(mock_message, mock_consumer, mock_flow) @@ -219,8 +221,8 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): sent_chunk = mock_producer.send.call_args[0][0] assert isinstance(sent_chunk, Chunk) - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer): """Test chunk_document when no parameters are overridden (flow returns None)""" @@ -239,7 +241,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.return_value = None # No overrides + mock_flow.parameters.get.return_value = None # No overrides # Act chunk_size, chunk_overlap = await processor.chunk_document( diff --git a/tests/unit/test_chunking/test_token_chunker.py b/tests/unit/test_chunking/test_token_chunker.py index 2ed37391..f3f83904 100644 --- a/tests/unit/test_chunking/test_token_chunker.py +++ b/tests/unit/test_chunking/test_token_chunker.py @@ -24,8 +24,8 @@ class MockAsyncProcessor: class TestTokenChunkerSimple(IsolatedAsyncioTestCase): """Test Token chunker functionality""" - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) def test_processor_initialization_basic(self, mock_producer, mock_consumer): """Test basic processor initialization""" @@ -51,8 +51,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']] assert len(param_specs) == 2 - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_chunk_size_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-size parameter override""" @@ -71,7 +71,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 400, # Override chunk size "chunk-overlap": None # Use default chunk overlap }.get(param) @@ -85,8 +85,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 400 # Should use overridden value assert chunk_overlap == 15 # Should use default value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_chunk_overlap_override(self, mock_producer, mock_consumer): """Test chunk_document with chunk-overlap parameter override""" @@ -105,7 +105,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": None, # Use default chunk size "chunk-overlap": 25 # Override chunk overlap }.get(param) @@ -119,8 +119,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 250 # Should use default value assert chunk_overlap == 25 # Should use overridden value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_both_parameters_override(self, mock_producer, mock_consumer): """Test chunk_document with both chunk-size and chunk-overlap overrides""" @@ -139,7 +139,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 350, # Override chunk size "chunk-overlap": 30 # Override chunk overlap }.get(param) @@ -153,8 +153,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 350 # Should use overridden value assert chunk_overlap == 30 # Should use overridden value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.chunking.token.chunker.TokenTextSplitter') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_uses_flow_parameters(self, mock_splitter_class, mock_producer, mock_consumer): @@ -177,7 +177,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): processor = Processor(**config) # Mock save_child_document to avoid librarian producer interactions - processor.save_child_document = AsyncMock(return_value="chunk-id") + processor.librarian.save_child_document = AsyncMock(return_value="chunk-id") # Mock message with TextDocument mock_message = MagicMock() @@ -196,12 +196,14 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_producer = AsyncMock() mock_triples_producer = AsyncMock() mock_flow = MagicMock() - mock_flow.side_effect = lambda param: { + mock_flow.parameters.get.side_effect = lambda param: { "chunk-size": 400, "chunk-overlap": 40, + }.get(param) + mock_flow.side_effect = lambda name: { "output": mock_producer, "triples": mock_triples_producer, - }.get(param) + }.get(name) # Act await processor.on_message(mock_message, mock_consumer, mock_flow) @@ -223,8 +225,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): sent_chunk = mock_producer.send.call_args[0][0] assert isinstance(sent_chunk, Chunk) - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_chunk_document_with_no_overrides(self, mock_producer, mock_consumer): """Test chunk_document when no parameters are overridden (flow returns None)""" @@ -243,7 +245,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): mock_message = MagicMock() mock_consumer = MagicMock() mock_flow = MagicMock() - mock_flow.return_value = None # No overrides + mock_flow.parameters.get.return_value = None # No overrides # Act chunk_size, chunk_overlap = await processor.chunk_document( @@ -254,8 +256,8 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): assert chunk_size == 250 # Should use default value assert chunk_overlap == 15 # Should use default value - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) def test_token_chunker_uses_different_defaults(self, mock_producer, mock_consumer): """Test that token chunker has different defaults than recursive chunker""" diff --git a/tests/unit/test_concurrency/test_consumer_concurrency.py b/tests/unit/test_concurrency/test_consumer_concurrency.py index 32a6559b..3869aaf3 100644 --- a/tests/unit/test_concurrency/test_consumer_concurrency.py +++ b/tests/unit/test_concurrency/test_consumer_concurrency.py @@ -83,7 +83,7 @@ class TestTaskGroupConcurrency: call_count = 0 original_running = True - async def mock_consume(): + async def mock_consume(backend_consumer): nonlocal call_count call_count += 1 # Wait a bit to let all tasks start, then signal stop @@ -107,7 +107,7 @@ class TestTaskGroupConcurrency: consumer = _make_consumer(concurrency=1) call_count = 0 - async def mock_consume(): + async def mock_consume(backend_consumer): nonlocal call_count call_count += 1 await asyncio.sleep(0.01) @@ -147,7 +147,7 @@ class TestRateLimitRetry: mock_msg = _make_msg() consumer.consumer = MagicMock() - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) assert call_count == 2 consumer.consumer.acknowledge.assert_called_once_with(mock_msg) @@ -166,7 +166,7 @@ class TestRateLimitRetry: mock_msg = _make_msg() consumer.consumer = MagicMock() - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) consumer.consumer.negative_acknowledge.assert_called_with(mock_msg) consumer.consumer.acknowledge.assert_not_called() @@ -185,7 +185,7 @@ class TestRateLimitRetry: mock_msg = _make_msg() consumer.consumer = MagicMock() - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) assert call_count == 1 consumer.consumer.negative_acknowledge.assert_called_once_with(mock_msg) @@ -197,7 +197,7 @@ class TestRateLimitRetry: mock_msg = _make_msg() consumer.consumer = MagicMock() - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) consumer.consumer.acknowledge.assert_called_once_with(mock_msg) @@ -219,7 +219,7 @@ class TestMetricsIntegration: mock_metrics.record_time.return_value.__exit__ = MagicMock() consumer.metrics = mock_metrics - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) mock_metrics.process.assert_called_once_with("success") @@ -235,7 +235,7 @@ class TestMetricsIntegration: mock_metrics = MagicMock() consumer.metrics = mock_metrics - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) mock_metrics.process.assert_called_once_with("error") @@ -261,7 +261,7 @@ class TestMetricsIntegration: mock_metrics.record_time.return_value.__exit__ = MagicMock(return_value=False) consumer.metrics = mock_metrics - await consumer.handle_one_from_queue(mock_msg) + await consumer.handle_one_from_queue(mock_msg, consumer.consumer) mock_metrics.rate_limit.assert_called_once() diff --git a/tests/unit/test_decoding/test_mistral_ocr_processor.py b/tests/unit/test_decoding/test_mistral_ocr_processor.py index 3243666c..2b8c25e2 100644 --- a/tests/unit/test_decoding/test_mistral_ocr_processor.py +++ b/tests/unit/test_decoding/test_mistral_ocr_processor.py @@ -25,8 +25,8 @@ class MockAsyncProcessor: class TestMistralOcrProcessor(IsolatedAsyncioTestCase): """Test Mistral OCR processor functionality""" - @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') - @patch('trustgraph.decoding.mistral_ocr.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_processor_initialization_with_api_key( @@ -51,8 +51,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase): assert consumer_specs[0].name == "input" assert consumer_specs[0].schema == Document - @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') - @patch('trustgraph.decoding.mistral_ocr.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_processor_initialization_without_api_key( self, mock_producer, mock_consumer @@ -66,8 +66,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase): with pytest.raises(RuntimeError, match="Mistral API key not specified"): Processor(**config) - @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') - @patch('trustgraph.decoding.mistral_ocr.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_ocr_single_chunk( @@ -131,8 +131,8 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase): ) mock_mistral.ocr.process.assert_called_once() - @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') - @patch('trustgraph.decoding.mistral_ocr.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_success( @@ -172,7 +172,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase): ] # Mock save_child_document - processor.save_child_document = AsyncMock(return_value="mock-doc-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id") with patch.object(processor, 'ocr', return_value=ocr_result): await processor.on_message(mock_msg, None, mock_flow) diff --git a/tests/unit/test_decoding/test_pdf_decoder.py b/tests/unit/test_decoding/test_pdf_decoder.py index 22659479..d2183c0c 100644 --- a/tests/unit/test_decoding/test_pdf_decoder.py +++ b/tests/unit/test_decoding/test_pdf_decoder.py @@ -24,12 +24,10 @@ class MockAsyncProcessor: class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): """Test PDF decoder processor functionality""" - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_processor_initialization(self, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): + async def test_processor_initialization(self, mock_producer, mock_consumer): """Test PDF decoder processor initialization""" config = { 'id': 'test-pdf-decoder', @@ -44,13 +42,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): assert consumer_specs[0].name == "input" assert consumer_specs[0].schema == Document - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): + async def test_on_message_success(self, mock_pdf_loader_class, mock_producer, mock_consumer): """Test successful PDF processing""" # Mock PDF content pdf_content = b"fake pdf content" @@ -85,7 +81,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): processor = Processor(**config) # Mock save_child_document to avoid waiting for librarian response - processor.save_child_document = AsyncMock(return_value="mock-doc-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id") await processor.on_message(mock_msg, None, mock_flow) @@ -94,13 +90,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): # Verify triples were sent for each page (provenance) assert mock_triples_flow.send.call_count == 2 - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): + async def test_on_message_empty_pdf(self, mock_pdf_loader_class, mock_producer, mock_consumer): """Test handling of empty PDF""" pdf_content = b"fake pdf content" pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') @@ -128,13 +122,11 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): mock_output_flow.send.assert_not_called() - @patch('trustgraph.base.chunking_service.Consumer') - @patch('trustgraph.base.chunking_service.Producer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Consumer') - @patch('trustgraph.decoding.pdf.pdf_decoder.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) - async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer, mock_cs_producer, mock_cs_consumer): + async def test_on_message_unicode_content(self, mock_pdf_loader_class, mock_producer, mock_consumer): """Test handling of unicode content in PDF""" pdf_content = b"fake pdf content" pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') @@ -165,7 +157,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): processor = Processor(**config) # Mock save_child_document to avoid waiting for librarian response - processor.save_child_document = AsyncMock(return_value="mock-doc-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id") await processor.on_message(mock_msg, None, mock_flow) diff --git a/tests/unit/test_decoding/test_universal_processor.py b/tests/unit/test_decoding/test_universal_processor.py index 8d2e116e..4daa9b68 100644 --- a/tests/unit/test_decoding/test_universal_processor.py +++ b/tests/unit/test_decoding/test_universal_processor.py @@ -142,8 +142,8 @@ class TestPageBasedFormats: class TestUniversalProcessor(IsolatedAsyncioTestCase): """Test universal decoder processor.""" - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_processor_initialization( self, mock_producer, mock_consumer @@ -169,8 +169,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): assert consumer_specs[0].name == "input" assert consumer_specs[0].schema == Document - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_processor_custom_strategy( self, mock_producer, mock_consumer @@ -188,8 +188,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): assert processor.partition_strategy == "hi_res" assert processor.section_strategy_name == "heading" - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_group_by_page(self, mock_producer, mock_consumer): """Test page grouping of elements.""" @@ -214,8 +214,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): assert result[1][0] == 2 assert len(result[1][1]) == 1 - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.universal.processor.partition') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_inline_non_page( @@ -255,7 +255,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): }.get(name)) # Mock save_child_document and magic - processor.save_child_document = AsyncMock(return_value="mock-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-id") with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: mock_magic.from_buffer.return_value = "text/markdown" @@ -271,8 +271,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): assert call_args.document_id.startswith("urn:section:") assert call_args.text == b"" - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.universal.processor.partition') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_on_message_page_based( @@ -310,7 +310,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): "triples": mock_triples_flow, }.get(name)) - processor.save_child_document = AsyncMock(return_value="mock-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-id") with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: mock_magic.from_buffer.return_value = "application/pdf" @@ -323,8 +323,8 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): call_args = mock_output_flow.send.call_args_list[0][0][0] assert call_args.document_id.startswith("urn:page:") - @patch('trustgraph.decoding.universal.processor.Consumer') - @patch('trustgraph.decoding.universal.processor.Producer') + @patch('trustgraph.base.librarian_client.Consumer') + @patch('trustgraph.base.librarian_client.Producer') @patch('trustgraph.decoding.universal.processor.partition') @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) async def test_images_stored_not_emitted( @@ -361,7 +361,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): "triples": mock_triples_flow, }.get(name)) - processor.save_child_document = AsyncMock(return_value="mock-id") + processor.librarian.save_child_document = AsyncMock(return_value="mock-id") with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: mock_magic.from_buffer.return_value = "application/pdf" @@ -374,7 +374,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase): assert mock_triples_flow.send.call_count == 2 # save_child_document called twice (page + image) - assert processor.save_child_document.call_count == 2 + assert processor.librarian.save_child_document.call_count == 2 @patch('trustgraph.base.flow_processor.FlowProcessor.add_args') def test_add_args(self, mock_parent_add_args): diff --git a/tests/unit/test_pubsub/test_queue_naming.py b/tests/unit/test_pubsub/test_queue_naming.py index 1ee781d9..edd3dfca 100644 --- a/tests/unit/test_pubsub/test_queue_naming.py +++ b/tests/unit/test_pubsub/test_queue_naming.py @@ -109,6 +109,37 @@ class TestAddPubsubArgs: assert args.pubsub_backend == 'pulsar' +class TestAddPubsubArgsRabbitMQ: + + def test_rabbitmq_args_present(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser) + args = parser.parse_args([ + '--pubsub-backend', 'rabbitmq', + '--rabbitmq-host', 'myhost', + '--rabbitmq-port', '5673', + ]) + assert args.pubsub_backend == 'rabbitmq' + assert args.rabbitmq_host == 'myhost' + assert args.rabbitmq_port == 5673 + + def test_rabbitmq_defaults_container(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser) + args = parser.parse_args([]) + assert args.rabbitmq_host == 'rabbitmq' + assert args.rabbitmq_port == 5672 + assert args.rabbitmq_username == 'guest' + assert args.rabbitmq_password == 'guest' + assert args.rabbitmq_vhost == '/' + + def test_rabbitmq_standalone_defaults_to_localhost(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser, standalone=True) + args = parser.parse_args([]) + assert args.rabbitmq_host == 'localhost' + + class TestQueueDefinitions: """Verify the actual queue constants produce correct names.""" @@ -124,9 +155,9 @@ class TestQueueDefinitions: from trustgraph.schema.services.config import config_push_queue assert config_push_queue == 'state:tg:config' - def test_librarian_request_is_persistent(self): + def test_librarian_request(self): from trustgraph.schema.services.library import librarian_request_queue - assert librarian_request_queue.startswith('flow:') + assert librarian_request_queue == 'request:tg:librarian' def test_knowledge_request(self): from trustgraph.schema.knowledge.knowledge import knowledge_request_queue diff --git a/tests/unit/test_pubsub/test_rabbitmq_backend.py b/tests/unit/test_pubsub/test_rabbitmq_backend.py new file mode 100644 index 00000000..578db3b6 --- /dev/null +++ b/tests/unit/test_pubsub/test_rabbitmq_backend.py @@ -0,0 +1,107 @@ +""" +Unit tests for RabbitMQ backend — queue name mapping and factory dispatch. +Does not require a running RabbitMQ instance. +""" + +import pytest +import argparse + +pika = pytest.importorskip("pika", reason="pika not installed") + +from trustgraph.base.rabbitmq_backend import RabbitMQBackend +from trustgraph.base.pubsub import get_pubsub, add_pubsub_args + + +class TestRabbitMQMapQueueName: + + @pytest.fixture + def backend(self): + b = object.__new__(RabbitMQBackend) + return b + + def test_flow_is_durable(self, backend): + name, durable = backend.map_queue_name('flow:tg:text-completion-request') + assert durable is True + assert name == 'tg.flow.text-completion-request' + + def test_state_is_durable(self, backend): + name, durable = backend.map_queue_name('state:tg:config') + assert durable is True + assert name == 'tg.state.config' + + def test_request_is_not_durable(self, backend): + name, durable = backend.map_queue_name('request:tg:config') + assert durable is False + assert name == 'tg.request.config' + + def test_response_is_not_durable(self, backend): + name, durable = backend.map_queue_name('response:tg:librarian') + assert durable is False + assert name == 'tg.response.librarian' + + def test_custom_topicspace(self, backend): + name, durable = backend.map_queue_name('flow:prod:my-queue') + assert name == 'prod.flow.my-queue' + assert durable is True + + def test_no_colon_defaults_to_flow(self, backend): + name, durable = backend.map_queue_name('simple-queue') + assert name == 'tg.simple-queue' + assert durable is False + + def test_invalid_class_raises(self, backend): + with pytest.raises(ValueError, match="Invalid queue class"): + backend.map_queue_name('unknown:tg:topic') + + def test_flow_with_flow_suffix(self, backend): + """Queue names with flow suffix (e.g. :default) are preserved.""" + name, durable = backend.map_queue_name('request:tg:prompt:default') + assert name == 'tg.request.prompt:default' + + +class TestGetPubsubRabbitMQ: + + def test_factory_creates_rabbitmq_backend(self): + backend = get_pubsub(pubsub_backend='rabbitmq') + assert isinstance(backend, RabbitMQBackend) + + def test_factory_passes_config(self): + backend = get_pubsub( + pubsub_backend='rabbitmq', + rabbitmq_host='myhost', + rabbitmq_port=5673, + rabbitmq_username='user', + rabbitmq_password='pass', + rabbitmq_vhost='/test', + ) + assert isinstance(backend, RabbitMQBackend) + # Verify connection params were set + params = backend._connection_params + assert params.host == 'myhost' + assert params.port == 5673 + assert params.virtual_host == '/test' + + +class TestAddPubsubArgsRabbitMQ: + + def test_rabbitmq_args_present(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser) + args = parser.parse_args([ + '--pubsub-backend', 'rabbitmq', + '--rabbitmq-host', 'myhost', + '--rabbitmq-port', '5673', + ]) + assert args.pubsub_backend == 'rabbitmq' + assert args.rabbitmq_host == 'myhost' + assert args.rabbitmq_port == 5673 + + def test_rabbitmq_defaults_container(self): + parser = argparse.ArgumentParser() + add_pubsub_args(parser) + args = parser.parse_args([]) + assert args.rabbitmq_host == 'rabbitmq' + assert args.rabbitmq_port == 5672 + assert args.rabbitmq_username == 'guest' + assert args.rabbitmq_password == 'guest' + assert args.rabbitmq_vhost == '/' diff --git a/trustgraph-base/pyproject.toml b/trustgraph-base/pyproject.toml index 7d9f9219..b7b9757c 100644 --- a/trustgraph-base/pyproject.toml +++ b/trustgraph-base/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "prometheus-client", "requests", "python-logging-loki", + "pika", ] classifiers = [ "Programming Language :: Python :: 3", diff --git a/trustgraph-base/trustgraph/api/library.py b/trustgraph-base/trustgraph/api/library.py index 396d64e0..c66598aa 100644 --- a/trustgraph-base/trustgraph/api/library.py +++ b/trustgraph-base/trustgraph/api/library.py @@ -22,8 +22,9 @@ logger = logging.getLogger(__name__) # Lower threshold provides progress feedback and resumability on slower connections CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024 -# Default chunk size (5MB - S3 multipart minimum) -DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024 +# Default chunk size (3MB - stays under broker message size limits +# after base64 encoding ~4MB) +DEFAULT_CHUNK_SIZE = 3 * 1024 * 1024 def to_value(x): diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 5a454279..24b6c1f0 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -14,6 +14,7 @@ from . producer_spec import ProducerSpec from . subscriber_spec import SubscriberSpec from . request_response_spec import RequestResponseSpec from . llm_service import LlmService, LlmResult, LlmChunk +from . librarian_client import LibrarianClient from . chunking_service import ChunkingService from . embeddings_service import EmbeddingsService from . embeddings_client import EmbeddingsClientSpec diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index 94bab278..7f7dbdcd 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -68,11 +68,12 @@ class AsyncProcessor: processor = self.id, flow = None, name = "config", ) - # Subscribe to config queue + # Subscribe to config queue — exclusive so every processor + # gets its own copy of config pushes (broadcast pattern) self.config_sub_task = Consumer( taskgroup = self.taskgroup, - backend = self.pubsub_backend, # Changed from client to backend + backend = self.pubsub_backend, subscriber = config_subscriber_id, flow = None, @@ -83,9 +84,8 @@ class AsyncProcessor: metrics = config_consumer_metrics, - # This causes new subscriptions to view the entire history of - # configuration - start_of_messages = True + start_of_messages = True, + consumer_type = 'exclusive', ) self.running = True diff --git a/trustgraph-base/trustgraph/base/chunking_service.py b/trustgraph-base/trustgraph/base/chunking_service.py index 753378d4..d4bf4cd4 100644 --- a/trustgraph-base/trustgraph/base/chunking_service.py +++ b/trustgraph-base/trustgraph/base/chunking_service.py @@ -7,23 +7,14 @@ fetching large document content. import asyncio import base64 import logging -import uuid from .flow_processor import FlowProcessor from .parameter_spec import ParameterSpec -from .consumer import Consumer -from .producer import Producer -from .metrics import ConsumerMetrics, ProducerMetrics - -from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata -from ..schema import librarian_request_queue, librarian_response_queue +from .librarian_client import LibrarianClient # Module logger logger = logging.getLogger(__name__) -default_librarian_request_queue = librarian_request_queue -default_librarian_response_queue = librarian_response_queue - class ChunkingService(FlowProcessor): """Base service for chunking processors with parameter specification support""" @@ -44,155 +35,18 @@ class ChunkingService(FlowProcessor): ParameterSpec(name="chunk-overlap") ) - # Librarian client for fetching document content - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue - ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor=id, flow=None, name="librarian-request" - ) - - self.librarian_request_producer = Producer( + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, - topic=librarian_request_q, - schema=LibrarianRequest, - metrics=librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor=id, flow=None, name="librarian-response" - ) - - self.librarian_response_consumer = Consumer( taskgroup=self.taskgroup, - backend=self.pubsub, - flow=None, - topic=librarian_response_q, - subscriber=f"{id}-librarian", - schema=LibrarianResponse, - handler=self.on_librarian_response, - metrics=librarian_response_metrics, ) - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} - logger.debug("ChunkingService initialized with parameter specifications") async def start(self): await super(ChunkingService, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id and request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) - - async def fetch_document_content(self, document_id, user, timeout=120): - """ - Fetch document content from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-content", - document_id=document_id, - user=user, - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.content - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching document {document_id}") - - async def save_child_document(self, doc_id, parent_id, user, content, - document_type="chunk", title=None, timeout=120): - """ - Save a child document (chunk) to the librarian. - - Args: - doc_id: ID for the new child document - parent_id: ID of the parent document - user: User ID - content: Document content (bytes or str) - document_type: Type of document ("chunk", etc.) - title: Optional title - timeout: Request timeout in seconds - - Returns: - The document ID on success - """ - request_id = str(uuid.uuid4()) - - if isinstance(content, str): - content = content.encode("utf-8") - - doc_metadata = DocumentMetadata( - id=doc_id, - user=user, - kind="text/plain", - title=title or doc_id, - parent_id=parent_id, - document_type=document_type, - ) - - request = LibrarianRequest( - operation="add-child-document", - document_metadata=doc_metadata, - content=base64.b64encode(content).decode("utf-8"), - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error saving chunk: {response.error.type}: {response.error.message}" - ) - - return doc_id - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout saving chunk {doc_id}") + await self.librarian.start() async def get_document_text(self, doc): """ @@ -206,14 +60,10 @@ class ChunkingService(FlowProcessor): """ if doc.document_id and not doc.text: logger.info(f"Fetching document {doc.document_id} from librarian...") - content = await self.fetch_document_content( + text = await self.librarian.fetch_document_text( document_id=doc.document_id, user=doc.metadata.user, ) - # Content is base64 encoded - if isinstance(content, str): - content = content.encode('utf-8') - text = base64.b64decode(content).decode("utf-8") logger.info(f"Fetched {len(text)} characters from librarian") return text else: @@ -224,41 +74,31 @@ class ChunkingService(FlowProcessor): Extract chunk parameters from flow and return effective values Args: - msg: The message containing the document to chunk - consumer: The consumer spec - flow: The flow context - default_chunk_size: Default chunk size from processor config - default_chunk_overlap: Default chunk overlap from processor config + msg: The message being processed + consumer: The consumer instance + flow: The flow object containing parameters + default_chunk_size: Default chunk size if not configured + default_chunk_overlap: Default chunk overlap if not configured Returns: - tuple: (chunk_size, chunk_overlap) - effective values to use + tuple: (chunk_size, chunk_overlap) effective values """ - # Extract parameters from flow (flow-configurable parameters) - chunk_size = flow("chunk-size") - chunk_overlap = flow("chunk-overlap") - # Use provided values or fall back to defaults - effective_chunk_size = chunk_size if chunk_size is not None else default_chunk_size - effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else default_chunk_overlap + chunk_size = default_chunk_size + chunk_overlap = default_chunk_overlap - logger.debug(f"Using chunk-size: {effective_chunk_size}") - logger.debug(f"Using chunk-overlap: {effective_chunk_overlap}") + try: + cs = flow.parameters.get("chunk-size") + if cs is not None: + chunk_size = int(cs) + except Exception as e: + logger.warning(f"Could not parse chunk-size parameter: {e}") - return effective_chunk_size, effective_chunk_overlap + try: + co = flow.parameters.get("chunk-overlap") + if co is not None: + chunk_overlap = int(co) + except Exception as e: + logger.warning(f"Could not parse chunk-overlap parameter: {e}") - @staticmethod - def add_args(parser): - """Add chunking service arguments to parser""" - FlowProcessor.add_args(parser) - - parser.add_argument( - '--librarian-request-queue', - default=default_librarian_request_queue, - help=f'Librarian request queue (default: {default_librarian_request_queue})', - ) - - parser.add_argument( - '--librarian-response-queue', - default=default_librarian_response_queue, - help=f'Librarian response queue (default: {default_librarian_response_queue})', - ) \ No newline at end of file + return chunk_size, chunk_overlap diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 2a220312..9ae35d49 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -32,6 +32,7 @@ class Consumer: rate_limit_retry_time = 10, rate_limit_timeout = 7200, reconnect_time = 5, concurrency = 1, # Number of concurrent requests to handle + consumer_type = 'shared', ): self.taskgroup = taskgroup @@ -42,6 +43,8 @@ class Consumer: self.schema = schema self.handler = handler + self.consumer_type = consumer_type + self.rate_limit_retry_time = rate_limit_retry_time self.rate_limit_timeout = rate_limit_timeout @@ -93,33 +96,11 @@ class Consumer: if self.metrics: self.metrics.state("stopped") - try: - - logger.info(f"Subscribing to topic: {self.topic}") - - # Determine initial position - if self.start_of_messages: - initial_pos = 'earliest' - else: - initial_pos = 'latest' - - # Create consumer via backend - self.consumer = await asyncio.to_thread( - self.backend.create_consumer, - topic = self.topic, - subscription = self.subscriber, - schema = self.schema, - initial_position = initial_pos, - consumer_type = 'shared', - ) - - except Exception as e: - - logger.error(f"Consumer subscription exception: {e}", exc_info=True) - await asyncio.sleep(self.reconnect_time) - continue - - logger.info(f"Successfully subscribed to topic: {self.topic}") + # Determine initial position + if self.start_of_messages: + initial_pos = 'earliest' + else: + initial_pos = 'latest' if self.metrics: self.metrics.state("running") @@ -128,14 +109,30 @@ class Consumer: logger.info(f"Starting {self.concurrency} receiver threads") - async with asyncio.TaskGroup() as tg: - - tasks = [] - - for i in range(0, self.concurrency): - tasks.append( - tg.create_task(self.consume_from_queue()) + # Create one backend consumer per concurrent task. + # Each gets its own connection — required for backends + # like RabbitMQ where connections are not thread-safe. + consumers = [] + for i in range(self.concurrency): + try: + logger.info(f"Subscribing to topic: {self.topic} (worker {i})") + c = await asyncio.to_thread( + self.backend.create_consumer, + topic = self.topic, + subscription = self.subscriber, + schema = self.schema, + initial_position = initial_pos, + consumer_type = self.consumer_type, ) + consumers.append(c) + logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})") + except Exception as e: + logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True) + raise + + async with asyncio.TaskGroup() as tg: + for c in consumers: + tg.create_task(self.consume_from_queue(c)) if self.metrics: self.metrics.state("stopped") @@ -143,23 +140,31 @@ class Consumer: except Exception as e: logger.error(f"Consumer loop exception: {e}", exc_info=True) - self.consumer.unsubscribe() - self.consumer.close() - self.consumer = None + for c in consumers: + try: + c.unsubscribe() + c.close() + except Exception: + pass + consumers = [] await asyncio.sleep(self.reconnect_time) continue - if self.consumer: - self.consumer.unsubscribe() - self.consumer.close() + finally: + for c in consumers: + try: + c.unsubscribe() + c.close() + except Exception: + pass - async def consume_from_queue(self): + async def consume_from_queue(self, consumer): while self.running: try: msg = await asyncio.to_thread( - self.consumer.receive, + consumer.receive, timeout_millis=2000 ) except Exception as e: @@ -168,9 +173,9 @@ class Consumer: continue raise e - await self.handle_one_from_queue(msg) + await self.handle_one_from_queue(msg, consumer) - async def handle_one_from_queue(self, msg): + async def handle_one_from_queue(self, msg, consumer): expiry = time.time() + self.rate_limit_timeout @@ -183,7 +188,7 @@ class Consumer: # Message failed to be processed, this causes it to # be retried - self.consumer.negative_acknowledge(msg) + consumer.negative_acknowledge(msg) if self.metrics: self.metrics.process("error") @@ -206,7 +211,7 @@ class Consumer: logger.debug("Message processed successfully") # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) + consumer.acknowledge(msg) if self.metrics: self.metrics.process("success") @@ -233,7 +238,7 @@ class Consumer: # Message failed to be processed, this causes it to # be retried - self.consumer.negative_acknowledge(msg) + consumer.negative_acknowledge(msg) if self.metrics: self.metrics.process("error") diff --git a/trustgraph-base/trustgraph/base/librarian_client.py b/trustgraph-base/trustgraph/base/librarian_client.py new file mode 100644 index 00000000..6191cff8 --- /dev/null +++ b/trustgraph-base/trustgraph/base/librarian_client.py @@ -0,0 +1,246 @@ +""" +Shared librarian client for services that need to communicate +with the librarian via pub/sub. + +Provides request-response and streaming operations over the message +broker, with proper support for large documents via stream-document. + +Usage: + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params + ) + await self.librarian.start() + content = await self.librarian.fetch_document_content(doc_id, user) +""" + +import asyncio +import base64 +import logging +import uuid + +from .consumer import Consumer +from .producer import Producer +from .metrics import ConsumerMetrics, ProducerMetrics + +from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ..schema import librarian_request_queue, librarian_response_queue + +logger = logging.getLogger(__name__) + + +class LibrarianClient: + """Client for librarian request-response over the message broker.""" + + def __init__(self, id, backend, taskgroup, **params): + + librarian_request_q = params.get( + "librarian_request_queue", librarian_request_queue, + ) + librarian_response_q = params.get( + "librarian_response_queue", librarian_response_queue, + ) + + librarian_request_metrics = ProducerMetrics( + processor=id, flow=None, name="librarian-request", + ) + + self._producer = Producer( + backend=backend, + topic=librarian_request_q, + schema=LibrarianRequest, + metrics=librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor=id, flow=None, name="librarian-response", + ) + + self._consumer = Consumer( + taskgroup=taskgroup, + backend=backend, + flow=None, + topic=librarian_response_q, + subscriber=f"{id}-librarian", + schema=LibrarianResponse, + handler=self._on_response, + metrics=librarian_response_metrics, + consumer_type='exclusive', + ) + + # Single-response requests: request_id -> asyncio.Future + self._pending = {} + # Streaming requests: request_id -> asyncio.Queue + self._streams = {} + + async def start(self): + """Start the librarian producer and consumer.""" + await self._producer.start() + await self._consumer.start() + + async def _on_response(self, msg, consumer, flow): + """Route librarian responses to the right waiter.""" + response = msg.value() + request_id = msg.properties().get("id") + + if not request_id: + return + + if request_id in self._pending: + future = self._pending.pop(request_id) + future.set_result(response) + elif request_id in self._streams: + await self._streams[request_id].put(response) + + async def request(self, request, timeout=120): + """Send a request to the librarian and wait for a single response.""" + request_id = str(uuid.uuid4()) + + future = asyncio.get_event_loop().create_future() + self._pending[request_id] = future + + try: + await self._producer.send( + request, properties={"id": request_id}, + ) + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error: {response.error.type}: " + f"{response.error.message}" + ) + + return response + + except asyncio.TimeoutError: + self._pending.pop(request_id, None) + raise RuntimeError("Timeout waiting for librarian response") + + async def stream(self, request, timeout=120): + """Send a request and collect streamed response chunks.""" + request_id = str(uuid.uuid4()) + + q = asyncio.Queue() + self._streams[request_id] = q + + try: + await self._producer.send( + request, properties={"id": request_id}, + ) + + chunks = [] + while True: + response = await asyncio.wait_for(q.get(), timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error: {response.error.type}: " + f"{response.error.message}" + ) + + chunks.append(response) + + if response.is_final: + break + + return chunks + + except asyncio.TimeoutError: + self._streams.pop(request_id, None) + raise RuntimeError("Timeout waiting for librarian stream") + finally: + self._streams.pop(request_id, None) + + async def fetch_document_content(self, document_id, user, timeout=120): + """Fetch document content using streaming. + + Returns base64-encoded content. Caller is responsible for decoding. + """ + req = LibrarianRequest( + operation="stream-document", + document_id=document_id, + user=user, + ) + chunks = await self.stream(req, timeout=timeout) + + # Decode each chunk's base64 to raw bytes, concatenate, + # re-encode for the caller. + raw = b"" + for chunk in chunks: + if chunk.content: + if isinstance(chunk.content, bytes): + raw += base64.b64decode(chunk.content) + else: + raw += base64.b64decode( + chunk.content.encode("utf-8") + ) + + return base64.b64encode(raw) + + async def fetch_document_text(self, document_id, user, timeout=120): + """Fetch document content and decode as UTF-8 text.""" + content = await self.fetch_document_content( + document_id, user, timeout=timeout, + ) + return base64.b64decode(content).decode("utf-8") + + async def fetch_document_metadata(self, document_id, user, timeout=120): + """Fetch document metadata from the librarian.""" + req = LibrarianRequest( + operation="get-document-metadata", + document_id=document_id, + user=user, + ) + response = await self.request(req, timeout=timeout) + return response.document_metadata + + async def save_child_document(self, doc_id, parent_id, user, content, + document_type="chunk", title=None, + kind="text/plain", timeout=120): + """Save a child document to the librarian.""" + if isinstance(content, str): + content = content.encode("utf-8") + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind=kind, + title=title or doc_id, + parent_id=parent_id, + document_type=document_type, + ) + + req = LibrarianRequest( + operation="add-child-document", + document_metadata=doc_metadata, + content=base64.b64encode(content).decode("utf-8"), + ) + + await self.request(req, timeout=timeout) + return doc_id + + async def save_document(self, doc_id, user, content, title=None, + document_type="answer", kind="text/plain", + timeout=120): + """Save a document to the librarian.""" + if isinstance(content, str): + content = content.encode("utf-8") + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind=kind, + title=title or doc_id, + document_type=document_type, + ) + + req = LibrarianRequest( + operation="add-document", + document_id=doc_id, + document_metadata=doc_metadata, + content=base64.b64encode(content).decode("utf-8"), + user=user, + ) + + await self.request(req, timeout=timeout) + return doc_id diff --git a/trustgraph-base/trustgraph/base/pubsub.py b/trustgraph-base/trustgraph/base/pubsub.py index 04734f28..8fe532d8 100644 --- a/trustgraph-base/trustgraph/base/pubsub.py +++ b/trustgraph-base/trustgraph/base/pubsub.py @@ -8,6 +8,12 @@ logger = logging.getLogger(__name__) DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650') DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None) +DEFAULT_RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", 'rabbitmq') +DEFAULT_RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", '5672')) +DEFAULT_RABBITMQ_USERNAME = os.getenv("RABBITMQ_USERNAME", 'guest') +DEFAULT_RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", 'guest') +DEFAULT_RABBITMQ_VHOST = os.getenv("RABBITMQ_VHOST", '/') + def get_pubsub(**config): """ @@ -29,6 +35,15 @@ def get_pubsub(**config): api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY), listener=config.get('pulsar_listener'), ) + elif backend_type == 'rabbitmq': + from .rabbitmq_backend import RabbitMQBackend + return RabbitMQBackend( + host=config.get('rabbitmq_host', DEFAULT_RABBITMQ_HOST), + port=config.get('rabbitmq_port', DEFAULT_RABBITMQ_PORT), + username=config.get('rabbitmq_username', DEFAULT_RABBITMQ_USERNAME), + password=config.get('rabbitmq_password', DEFAULT_RABBITMQ_PASSWORD), + vhost=config.get('rabbitmq_vhost', DEFAULT_RABBITMQ_VHOST), + ) else: raise ValueError(f"Unknown pub/sub backend: {backend_type}") @@ -44,8 +59,9 @@ def add_pubsub_args(parser, standalone=False): standalone: If True, default host is localhost (for CLI tools that run outside containers) """ - host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST - listener_default = 'localhost' if standalone else None + pulsar_host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST + pulsar_listener = 'localhost' if standalone else None + rabbitmq_host = 'localhost' if standalone else DEFAULT_RABBITMQ_HOST parser.add_argument( '--pubsub-backend', @@ -53,10 +69,11 @@ def add_pubsub_args(parser, standalone=False): help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)', ) + # Pulsar options parser.add_argument( '-p', '--pulsar-host', - default=host, - help=f'Pulsar host (default: {host})', + default=pulsar_host, + help=f'Pulsar host (default: {pulsar_host})', ) parser.add_argument( @@ -67,6 +84,38 @@ def add_pubsub_args(parser, standalone=False): parser.add_argument( '--pulsar-listener', - default=listener_default, - help=f'Pulsar listener (default: {listener_default or "none"})', + default=pulsar_listener, + help=f'Pulsar listener (default: {pulsar_listener or "none"})', + ) + + # RabbitMQ options + parser.add_argument( + '--rabbitmq-host', + default=rabbitmq_host, + help=f'RabbitMQ host (default: {rabbitmq_host})', + ) + + parser.add_argument( + '--rabbitmq-port', + type=int, + default=DEFAULT_RABBITMQ_PORT, + help=f'RabbitMQ port (default: {DEFAULT_RABBITMQ_PORT})', + ) + + parser.add_argument( + '--rabbitmq-username', + default=DEFAULT_RABBITMQ_USERNAME, + help='RabbitMQ username', + ) + + parser.add_argument( + '--rabbitmq-password', + default=DEFAULT_RABBITMQ_PASSWORD, + help='RabbitMQ password', + ) + + parser.add_argument( + '--rabbitmq-vhost', + default=DEFAULT_RABBITMQ_VHOST, + help=f'RabbitMQ vhost (default: {DEFAULT_RABBITMQ_VHOST})', ) diff --git a/trustgraph-base/trustgraph/base/pulsar_backend.py b/trustgraph-base/trustgraph/base/pulsar_backend.py index 677f2527..9480243e 100644 --- a/trustgraph-base/trustgraph/base/pulsar_backend.py +++ b/trustgraph-base/trustgraph/base/pulsar_backend.py @@ -9,122 +9,14 @@ import pulsar import _pulsar import json import logging -import base64 -import types -from dataclasses import asdict, is_dataclass -from typing import Any, get_type_hints +from typing import Any from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message +from .serialization import dataclass_to_dict, dict_to_dataclass logger = logging.getLogger(__name__) -def dataclass_to_dict(obj: Any) -> dict: - """ - Recursively convert a dataclass to a dictionary, handling None values and bytes. - - None values are excluded from the dictionary (not serialized). - Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior). - Handles nested dataclasses, lists, and dictionaries recursively. - """ - if obj is None: - return None - - # Handle bytes - decode to UTF-8 for JSON serialization - if isinstance(obj, bytes): - return obj.decode('utf-8') - - # Handle dataclass - convert to dict then recursively process all values - if is_dataclass(obj): - result = {} - for key, value in asdict(obj).items(): - result[key] = dataclass_to_dict(value) if value is not None else None - return result - - # Handle list - recursively process all items - if isinstance(obj, list): - return [dataclass_to_dict(item) for item in obj] - - # Handle dict - recursively process all values - if isinstance(obj, dict): - return {k: dataclass_to_dict(v) for k, v in obj.items()} - - # Return primitive types as-is - return obj - - -def dict_to_dataclass(data: dict, cls: type) -> Any: - """ - Convert a dictionary back to a dataclass instance. - - Handles nested dataclasses and missing fields. - Uses get_type_hints() to resolve forward references (string annotations). - """ - if data is None: - return None - - if not is_dataclass(cls): - return data - - # Get field types from the dataclass, resolving forward references - # get_type_hints() evaluates string annotations like "Triple | None" - try: - field_types = get_type_hints(cls) - except Exception: - # Fallback if get_type_hints fails (shouldn't happen normally) - field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} - kwargs = {} - - for key, value in data.items(): - if key in field_types: - field_type = field_types[key] - - # Handle modern union types (X | Y) - if isinstance(field_type, types.UnionType): - # Check if it's Optional (X | None) - if type(None) in field_type.__args__: - # Get the non-None type - actual_type = next((t for t in field_type.__args__ if t is not type(None)), None) - if actual_type and is_dataclass(actual_type) and isinstance(value, dict): - kwargs[key] = dict_to_dataclass(value, actual_type) - else: - kwargs[key] = value - else: - kwargs[key] = value - # Check if this is a generic type (list, dict, etc.) - elif hasattr(field_type, '__origin__'): - # Handle list[T] - if field_type.__origin__ == list: - item_type = field_type.__args__[0] if field_type.__args__ else None - if item_type and is_dataclass(item_type) and isinstance(value, list): - kwargs[key] = [ - dict_to_dataclass(item, item_type) if isinstance(item, dict) else item - for item in value - ] - else: - kwargs[key] = value - # Handle old-style Optional[T] (which is Union[T, None]) - elif hasattr(field_type, '__args__') and type(None) in field_type.__args__: - # Get the non-None type from Union - actual_type = next((t for t in field_type.__args__ if t is not type(None)), None) - if actual_type and is_dataclass(actual_type) and isinstance(value, dict): - kwargs[key] = dict_to_dataclass(value, actual_type) - else: - kwargs[key] = value - else: - kwargs[key] = value - # Handle direct dataclass fields - elif is_dataclass(field_type) and isinstance(value, dict): - kwargs[key] = dict_to_dataclass(value, field_type) - # Handle bytes fields (UTF-8 encoded strings from JSON) - elif field_type == bytes and isinstance(value, str): - kwargs[key] = value.encode('utf-8') - else: - kwargs[key] = value - - return cls(**kwargs) - - class PulsarMessage: """Wrapper for Pulsar messages to match Message protocol.""" diff --git a/trustgraph-base/trustgraph/base/rabbitmq_backend.py b/trustgraph-base/trustgraph/base/rabbitmq_backend.py new file mode 100644 index 00000000..a80efbaf --- /dev/null +++ b/trustgraph-base/trustgraph/base/rabbitmq_backend.py @@ -0,0 +1,390 @@ +""" +RabbitMQ backend implementation for pub/sub abstraction. + +Uses a single topic exchange per topicspace. The logical queue name +becomes the routing key. Consumer behavior is determined by the +subscription name: + +- Same subscription + same topic = shared queue (competing consumers) +- Different subscriptions = separate queues (broadcast / fan-out) + +This mirrors Pulsar's subscription model using idiomatic RabbitMQ. + +Architecture: + Producer --> [tg exchange] --routing key--> [named queue] --> Consumer + --routing key--> [named queue] --> Consumer + --routing key--> [exclusive q] --> Subscriber + +Uses basic_consume (push) instead of basic_get (polling) for +efficient message delivery. +""" + +import json +import time +import logging +import queue +import threading +import pika +from typing import Any + +from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message +from .serialization import dataclass_to_dict, dict_to_dataclass + +logger = logging.getLogger(__name__) + + +class RabbitMQMessage: + """Wrapper for RabbitMQ messages to match Message protocol.""" + + def __init__(self, method, properties, body, schema_cls): + self._method = method + self._properties = properties + self._body = body + self._schema_cls = schema_cls + self._value = None + + def value(self) -> Any: + """Deserialize and return the message value as a dataclass.""" + if self._value is None: + data_dict = json.loads(self._body.decode('utf-8')) + self._value = dict_to_dataclass(data_dict, self._schema_cls) + return self._value + + def properties(self) -> dict: + """Return message properties from AMQP headers.""" + headers = self._properties.headers or {} + return dict(headers) + + +class RabbitMQBackendProducer: + """Publishes messages to a topic exchange with a routing key. + + Uses thread-local connections so each thread gets its own + connection/channel. This avoids wire corruption from concurrent + threads writing to the same socket (pika is not thread-safe). + """ + + def __init__(self, connection_params, exchange_name, routing_key, + durable): + self._connection_params = connection_params + self._exchange_name = exchange_name + self._routing_key = routing_key + self._durable = durable + self._local = threading.local() + + def _get_channel(self): + """Get or create a thread-local connection and channel.""" + conn = getattr(self._local, 'connection', None) + chan = getattr(self._local, 'channel', None) + + if conn is None or not conn.is_open or chan is None or not chan.is_open: + # Close stale connection if any + if conn is not None: + try: + conn.close() + except Exception: + pass + + conn = pika.BlockingConnection(self._connection_params) + chan = conn.channel() + chan.exchange_declare( + exchange=self._exchange_name, + exchange_type='topic', + durable=True, + ) + self._local.connection = conn + self._local.channel = chan + + return chan + + def send(self, message: Any, properties: dict = {}) -> None: + data_dict = dataclass_to_dict(message) + json_data = json.dumps(data_dict) + + amqp_properties = pika.BasicProperties( + delivery_mode=2 if self._durable else 1, + content_type='application/json', + headers=properties if properties else None, + ) + + for attempt in range(2): + try: + channel = self._get_channel() + channel.basic_publish( + exchange=self._exchange_name, + routing_key=self._routing_key, + body=json_data.encode('utf-8'), + properties=amqp_properties, + ) + return + except Exception as e: + logger.warning( + f"RabbitMQ send failed (attempt {attempt + 1}): {e}" + ) + # Force reconnect on next attempt + self._local.connection = None + self._local.channel = None + if attempt == 1: + raise + + def flush(self) -> None: + pass + + def close(self) -> None: + """Close the thread-local connection if any.""" + conn = getattr(self._local, 'connection', None) + if conn is not None: + try: + conn.close() + except Exception: + pass + self._local.connection = None + self._local.channel = None + + +class RabbitMQBackendConsumer: + """Consumes from a queue bound to a topic exchange. + + Uses basic_consume (push model) with messages delivered to an + internal thread-safe queue. process_data_events() drives both + message delivery and heartbeat processing. + """ + + def __init__(self, connection_params, exchange_name, routing_key, + queue_name, schema_cls, durable, exclusive=False, + auto_delete=False): + self._connection_params = connection_params + self._exchange_name = exchange_name + self._routing_key = routing_key + self._queue_name = queue_name + self._schema_cls = schema_cls + self._durable = durable + self._exclusive = exclusive + self._auto_delete = auto_delete + self._connection = None + self._channel = None + self._consumer_tag = None + self._incoming = queue.Queue() + + def _connect(self): + self._connection = pika.BlockingConnection(self._connection_params) + self._channel = self._connection.channel() + + # Declare the topic exchange + self._channel.exchange_declare( + exchange=self._exchange_name, + exchange_type='topic', + durable=True, + ) + + # Declare the queue — anonymous if exclusive + result = self._channel.queue_declare( + queue=self._queue_name, + durable=self._durable, + exclusive=self._exclusive, + auto_delete=self._auto_delete, + ) + # Capture actual name (important for anonymous queues where name='') + self._queue_name = result.method.queue + + self._channel.queue_bind( + queue=self._queue_name, + exchange=self._exchange_name, + routing_key=self._routing_key, + ) + + self._channel.basic_qos(prefetch_count=1) + + # Register push-based consumer + self._consumer_tag = self._channel.basic_consume( + queue=self._queue_name, + on_message_callback=self._on_message, + auto_ack=False, + ) + + def _on_message(self, channel, method, properties, body): + """Callback invoked by pika when a message arrives.""" + self._incoming.put((method, properties, body)) + + def _is_alive(self): + return ( + self._connection is not None + and self._connection.is_open + and self._channel is not None + and self._channel.is_open + ) + + def receive(self, timeout_millis: int = 2000) -> Message: + """Receive a message. Raises TimeoutError if none available.""" + if not self._is_alive(): + self._connect() + + timeout_seconds = timeout_millis / 1000.0 + deadline = time.monotonic() + timeout_seconds + + while time.monotonic() < deadline: + # Check if a message was already delivered + try: + method, properties, body = self._incoming.get_nowait() + return RabbitMQMessage( + method, properties, body, self._schema_cls, + ) + except queue.Empty: + pass + + # Drive pika's I/O — delivers messages and processes heartbeats + remaining = deadline - time.monotonic() + if remaining > 0: + self._connection.process_data_events( + time_limit=min(0.1, remaining), + ) + + raise TimeoutError("No message received within timeout") + + def acknowledge(self, message: Message) -> None: + if isinstance(message, RabbitMQMessage) and message._method: + self._channel.basic_ack( + delivery_tag=message._method.delivery_tag, + ) + + def negative_acknowledge(self, message: Message) -> None: + if isinstance(message, RabbitMQMessage) and message._method: + self._channel.basic_nack( + delivery_tag=message._method.delivery_tag, + requeue=True, + ) + + def unsubscribe(self) -> None: + if self._consumer_tag and self._channel and self._channel.is_open: + try: + self._channel.basic_cancel(self._consumer_tag) + except Exception: + pass + self._consumer_tag = None + + def close(self) -> None: + self.unsubscribe() + try: + if self._channel and self._channel.is_open: + self._channel.close() + except Exception: + pass + try: + if self._connection and self._connection.is_open: + self._connection.close() + except Exception: + pass + self._channel = None + self._connection = None + + +class RabbitMQBackend: + """RabbitMQ pub/sub backend using a topic exchange per topicspace.""" + + def __init__(self, host='localhost', port=5672, username='guest', + password='guest', vhost='/'): + self._connection_params = pika.ConnectionParameters( + host=host, + port=port, + virtual_host=vhost, + credentials=pika.PlainCredentials(username, password), + ) + logger.info(f"RabbitMQ backend: {host}:{port} vhost={vhost}") + + def _parse_queue_id(self, queue_id: str) -> tuple[str, str, str, bool]: + """ + Parse queue identifier into exchange, routing key, and durability. + + Format: class:topicspace:topic + Returns: (exchange_name, routing_key, class, durable) + """ + if ':' not in queue_id: + return 'tg', queue_id, 'flow', False + + parts = queue_id.split(':', 2) + if len(parts) != 3: + raise ValueError( + f"Invalid queue format: {queue_id}, " + f"expected class:topicspace:topic" + ) + + cls, topicspace, topic = parts + + if cls in ('flow', 'state'): + durable = True + elif cls in ('request', 'response'): + durable = False + else: + raise ValueError( + f"Invalid queue class: {cls}, " + f"expected flow, request, response, or state" + ) + + # Exchange per topicspace, routing key includes class + exchange_name = topicspace + routing_key = f"{cls}.{topic}" + + return exchange_name, routing_key, cls, durable + + # Keep map_queue_name for backward compatibility with tests + def map_queue_name(self, queue_id: str) -> tuple[str, bool]: + exchange, routing_key, cls, durable = self._parse_queue_id(queue_id) + return f"{exchange}.{routing_key}", durable + + def create_producer(self, topic: str, schema: type, + **options) -> BackendProducer: + exchange, routing_key, cls, durable = self._parse_queue_id(topic) + logger.debug( + f"Creating producer: exchange={exchange}, " + f"routing_key={routing_key}" + ) + return RabbitMQBackendProducer( + self._connection_params, exchange, routing_key, durable, + ) + + def create_consumer(self, topic: str, subscription: str, schema: type, + initial_position: str = 'latest', + consumer_type: str = 'shared', + **options) -> BackendConsumer: + """Create a consumer with a queue bound to the topic exchange. + + consumer_type='shared': Named durable queue. Multiple consumers + with the same subscription compete (round-robin). + consumer_type='exclusive': Anonymous ephemeral queue. Each + consumer gets its own copy of every message (broadcast). + """ + exchange, routing_key, cls, durable = self._parse_queue_id(topic) + + if consumer_type == 'exclusive' and cls == 'state': + # State broadcast: named durable queue per subscriber. + # Retains messages so late-starting processors see current state. + queue_name = f"{exchange}.{routing_key}.{subscription}" + queue_durable = True + exclusive = False + auto_delete = False + elif consumer_type == 'exclusive': + # Broadcast: anonymous queue, auto-deleted on disconnect + queue_name = '' + queue_durable = False + exclusive = True + auto_delete = True + else: + # Shared: named queue, competing consumers + queue_name = f"{exchange}.{routing_key}.{subscription}" + queue_durable = durable + exclusive = False + auto_delete = False + + logger.debug( + f"Creating consumer: exchange={exchange}, " + f"routing_key={routing_key}, queue={queue_name or '(anonymous)'}, " + f"type={consumer_type}" + ) + + return RabbitMQBackendConsumer( + self._connection_params, exchange, routing_key, + queue_name, schema, queue_durable, exclusive, auto_delete, + ) + + def close(self) -> None: + pass diff --git a/trustgraph-base/trustgraph/base/serialization.py b/trustgraph-base/trustgraph/base/serialization.py new file mode 100644 index 00000000..6fd3ca62 --- /dev/null +++ b/trustgraph-base/trustgraph/base/serialization.py @@ -0,0 +1,115 @@ +""" +JSON serialization helpers for dataclass ↔ dict conversion. + +Used by pub/sub backends that use JSON as their wire format. +""" + +import types +from dataclasses import asdict, is_dataclass +from typing import Any, get_type_hints + + +def dataclass_to_dict(obj: Any) -> dict: + """ + Recursively convert a dataclass to a dictionary, handling None values and bytes. + + None values are excluded from the dictionary (not serialized). + Bytes values are decoded as UTF-8 strings for JSON serialization. + Handles nested dataclasses, lists, and dictionaries recursively. + """ + if obj is None: + return None + + # Handle bytes - decode to UTF-8 for JSON serialization + if isinstance(obj, bytes): + return obj.decode('utf-8') + + # Handle dataclass - convert to dict then recursively process all values + if is_dataclass(obj): + result = {} + for key, value in asdict(obj).items(): + result[key] = dataclass_to_dict(value) if value is not None else None + return result + + # Handle list - recursively process all items + if isinstance(obj, list): + return [dataclass_to_dict(item) for item in obj] + + # Handle dict - recursively process all values + if isinstance(obj, dict): + return {k: dataclass_to_dict(v) for k, v in obj.items()} + + # Return primitive types as-is + return obj + + +def dict_to_dataclass(data: dict, cls: type) -> Any: + """ + Convert a dictionary back to a dataclass instance. + + Handles nested dataclasses and missing fields. + Uses get_type_hints() to resolve forward references (string annotations). + """ + if data is None: + return None + + if not is_dataclass(cls): + return data + + # Get field types from the dataclass, resolving forward references + # get_type_hints() evaluates string annotations like "Triple | None" + try: + field_types = get_type_hints(cls) + except Exception: + # Fallback if get_type_hints fails (shouldn't happen normally) + field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()} + kwargs = {} + + for key, value in data.items(): + if key in field_types: + field_type = field_types[key] + + # Handle modern union types (X | Y) + if isinstance(field_type, types.UnionType): + # Check if it's Optional (X | None) + if type(None) in field_type.__args__: + # Get the non-None type + actual_type = next((t for t in field_type.__args__ if t is not type(None)), None) + if actual_type and is_dataclass(actual_type) and isinstance(value, dict): + kwargs[key] = dict_to_dataclass(value, actual_type) + else: + kwargs[key] = value + else: + kwargs[key] = value + # Check if this is a generic type (list, dict, etc.) + elif hasattr(field_type, '__origin__'): + # Handle list[T] + if field_type.__origin__ == list: + item_type = field_type.__args__[0] if field_type.__args__ else None + if item_type and is_dataclass(item_type) and isinstance(value, list): + kwargs[key] = [ + dict_to_dataclass(item, item_type) if isinstance(item, dict) else item + for item in value + ] + else: + kwargs[key] = value + # Handle old-style Optional[T] (which is Union[T, None]) + elif hasattr(field_type, '__args__') and type(None) in field_type.__args__: + # Get the non-None type from Union + actual_type = next((t for t in field_type.__args__ if t is not type(None)), None) + if actual_type and is_dataclass(actual_type) and isinstance(value, dict): + kwargs[key] = dict_to_dataclass(value, actual_type) + else: + kwargs[key] = value + else: + kwargs[key] = value + # Handle direct dataclass fields + elif is_dataclass(field_type) and isinstance(value, dict): + kwargs[key] = dict_to_dataclass(value, field_type) + # Handle bytes fields (UTF-8 encoded strings from JSON) + elif field_type == bytes and isinstance(value, str): + kwargs[key] = value.encode('utf-8') + else: + kwargs[key] = value + + return cls(**kwargs) diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py index b0d90507..36948131 100644 --- a/trustgraph-base/trustgraph/base/subscriber.py +++ b/trustgraph-base/trustgraph/base/subscriber.py @@ -51,7 +51,7 @@ class Subscriber: topic=self.topic, subscription=self.subscription, schema=self.schema, - consumer_type='shared', + consumer_type='exclusive', ) self.task = asyncio.create_task(self.run()) diff --git a/trustgraph-base/trustgraph/clients/base.py b/trustgraph-base/trustgraph/clients/base.py index a71ba84e..cd4ad72e 100644 --- a/trustgraph-base/trustgraph/clients/base.py +++ b/trustgraph-base/trustgraph/clients/base.py @@ -18,9 +18,7 @@ class BaseClient: output_queue=None, input_schema=None, output_schema=None, - pulsar_host="pulsar://pulsar:6650", - pulsar_api_key=None, - listener=None, + **pubsub_config, ): if input_queue == None: raise RuntimeError("Need input_queue") @@ -32,12 +30,7 @@ class BaseClient: subscriber = str(uuid.uuid4()) # Create backend using factory - self.backend = get_pubsub( - pulsar_host=pulsar_host, - pulsar_api_key=pulsar_api_key, - pulsar_listener=listener, - pubsub_backend='pulsar' - ) + self.backend = get_pubsub(**pubsub_config) self.producer = self.backend.create_producer( topic=input_queue, diff --git a/trustgraph-base/trustgraph/clients/config_client.py b/trustgraph-base/trustgraph/clients/config_client.py index daadf652..78b62688 100644 --- a/trustgraph-base/trustgraph/clients/config_client.py +++ b/trustgraph-base/trustgraph/clients/config_client.py @@ -33,9 +33,7 @@ class ConfigClient(BaseClient): subscriber=None, input_queue=None, output_queue=None, - pulsar_host="pulsar://pulsar:6650", - listener=None, - pulsar_api_key=None, + **pubsub_config, ): if input_queue == None: @@ -48,11 +46,9 @@ class ConfigClient(BaseClient): subscriber=subscriber, input_queue=input_queue, output_queue=output_queue, - pulsar_host=pulsar_host, - pulsar_api_key=pulsar_api_key, input_schema=ConfigRequest, output_schema=ConfigResponse, - listener=listener, + **pubsub_config, ) def get(self, keys, timeout=300): diff --git a/trustgraph-base/trustgraph/schema/services/library.py b/trustgraph-base/trustgraph/schema/services/library.py index 51d0d5a5..f5d4592c 100644 --- a/trustgraph-base/trustgraph/schema/services/library.py +++ b/trustgraph-base/trustgraph/schema/services/library.py @@ -24,10 +24,13 @@ from ..core.metadata import Metadata # <- (document_metadata) # <- (error) -# get-document-content +# get-document-content [DEPRECATED — use stream-document instead] # -> (document_id) # <- (content) # <- (error) +# NOTE: Returns entire document in a single message. Fails for documents +# exceeding the broker's max message size. Use stream-document which +# returns content in chunks. # add-processing # -> (processing_id, processing_metadata) @@ -220,5 +223,5 @@ class LibrarianResponse: # FIXME: Is this right? Using persistence on librarian so that # message chunking works -librarian_request_queue = queue('librarian-request', cls='flow') -librarian_response_queue = queue('librarian-response', cls='flow') +librarian_request_queue = queue('librarian', cls='request') +librarian_response_queue = queue('librarian', cls='response') diff --git a/trustgraph-cli/trustgraph/cli/dump_queues.py b/trustgraph-cli/trustgraph/cli/dump_queues.py index eb7898c2..95be8529 100644 --- a/trustgraph-cli/trustgraph/cli/dump_queues.py +++ b/trustgraph-cli/trustgraph/cli/dump_queues.py @@ -354,10 +354,8 @@ IMPORTANT: output_file=args.output, subscriber_name=args.subscriber, append_mode=args.append, - pubsub_backend=args.pubsub_backend, - pulsar_host=args.pulsar_host, - pulsar_api_key=args.pulsar_api_key, - pulsar_listener=args.pulsar_listener, + **{k: v for k, v in vars(args).items() + if k not in ('queues', 'output', 'subscriber', 'append')}, )) except KeyboardInterrupt: # Already handled in async_main diff --git a/trustgraph-cli/trustgraph/cli/init_trustgraph.py b/trustgraph-cli/trustgraph/cli/init_trustgraph.py index 02456b1c..514dc75b 100644 --- a/trustgraph-cli/trustgraph/cli/init_trustgraph.py +++ b/trustgraph-cli/trustgraph/cli/init_trustgraph.py @@ -1,5 +1,8 @@ """ -Initialises Pulsar with Trustgraph tenant / namespaces & policy. +Initialises TrustGraph pub/sub infrastructure and pushes initial config. + +For Pulsar: creates tenant, namespaces, and retention policies. +For RabbitMQ: queues are auto-declared, so only config push is needed. """ import requests @@ -8,10 +11,11 @@ import argparse import json from trustgraph.clients.config_client import ConfigClient +from trustgraph.base.pubsub import add_pubsub_args default_pulsar_admin_url = "http://pulsar:8080" -default_pulsar_host = "pulsar://pulsar:6650" -subscriber = "tg-init-pulsar" +subscriber = "tg-init-pubsub" + def get_clusters(url): @@ -65,12 +69,11 @@ def ensure_namespace(url, tenant, namespace, config): print(f"Namespace {tenant}/{namespace} created.", flush=True) -def ensure_config(config, pulsar_host, pulsar_api_key): +def ensure_config(config, **pubsub_config): cli = ConfigClient( subscriber=subscriber, - pulsar_host=pulsar_host, - pulsar_api_key=pulsar_api_key, + **pubsub_config, ) while True: @@ -115,11 +118,9 @@ def ensure_config(config, pulsar_host, pulsar_api_key): time.sleep(2) print("Retrying...", flush=True) continue - -def init( - pulsar_admin_url, pulsar_host, pulsar_api_key, tenant, - config, config_file, -): + +def init_pulsar(pulsar_admin_url, tenant): + """Pulsar-specific setup: create tenant, namespaces, retention policies.""" clusters = get_clusters(pulsar_admin_url) @@ -145,17 +146,21 @@ def init( } }) - if config is not None: + +def push_config(config_json, config_file, **pubsub_config): + """Push initial config if provided.""" + + if config_json is not None: try: print("Decoding config...", flush=True) - dec = json.loads(config) + dec = json.loads(config_json) print("Decoded.", flush=True) except Exception as e: print("Exception:", e, flush=True) raise e - ensure_config(dec, pulsar_host, pulsar_api_key) + ensure_config(dec, **pubsub_config) elif config_file is not None: @@ -167,11 +172,12 @@ def init( print("Exception:", e, flush=True) raise e - ensure_config(dec, pulsar_host, pulsar_api_key) + ensure_config(dec, **pubsub_config) else: print("No config to update.", flush=True) + def main(): parser = argparse.ArgumentParser( @@ -180,22 +186,11 @@ def main(): ) parser.add_argument( - '-p', '--pulsar-admin-url', + '--pulsar-admin-url', default=default_pulsar_admin_url, help=f'Pulsar admin URL (default: {default_pulsar_admin_url})', ) - parser.add_argument( - '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '--pulsar-api-key', - help=f'Pulsar API key', - ) - parser.add_argument( '-c', '--config', help=f'Initial configuration to load', @@ -212,18 +207,43 @@ def main(): help=f'Tenant (default: tg)', ) + add_pubsub_args(parser) + args = parser.parse_args() + backend_type = args.pubsub_backend + + # Extract pubsub config from args + pubsub_config = { + k: v for k, v in vars(args).items() + if k not in ('pulsar_admin_url', 'config', 'config_file', 'tenant') + } + while True: try: - print(flush=True) - print( - f"Initialising with Pulsar {args.pulsar_admin_url}...", - flush=True + # Pulsar-specific setup (tenants, namespaces) + if backend_type == 'pulsar': + print(flush=True) + print( + f"Initialising Pulsar at {args.pulsar_admin_url}...", + flush=True, + ) + init_pulsar(args.pulsar_admin_url, args.tenant) + else: + print(flush=True) + print( + f"Using {backend_type} backend (no admin setup needed).", + flush=True, + ) + + # Push config (works with any backend) + push_config( + args.config, args.config_file, + **pubsub_config, ) - init(**vars(args)) + print("Initialisation complete.", flush=True) break @@ -236,4 +256,4 @@ def main(): print("Will retry...", flush=True) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/trustgraph-cli/trustgraph/cli/monitor_prompts.py b/trustgraph-cli/trustgraph/cli/monitor_prompts.py index c3b71afb..0cfe68ac 100644 --- a/trustgraph-cli/trustgraph/cli/monitor_prompts.py +++ b/trustgraph-cli/trustgraph/cli/monitor_prompts.py @@ -316,10 +316,8 @@ def main(): queue_type=args.queue_type, max_lines=args.max_lines, max_width=args.max_width, - pulsar_host=args.pulsar_host, - pulsar_api_key=args.pulsar_api_key, - pulsar_listener=args.pulsar_listener, - pubsub_backend=args.pubsub_backend, + **{k: v for k, v in vars(args).items() + if k not in ('flow', 'queue_type', 'max_lines', 'max_width')}, )) except KeyboardInterrupt: pass diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index 64d58457..df2c58bd 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -133,7 +133,7 @@ class Processor(ChunkingService): chunk_length = len(chunk.page_content) # Save chunk to librarian as child document - await self.save_child_document( + await self.librarian.save_child_document( doc_id=chunk_doc_id, parent_id=parent_doc_id, user=v.metadata.user, diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index 4302250e..3e1161bf 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -131,7 +131,7 @@ class Processor(ChunkingService): chunk_length = len(chunk.page_content) # Save chunk to librarian as child document - await self.save_child_document( + await self.librarian.save_child_document( doc_id=chunk_doc_id, parent_id=parent_doc_id, user=v.metadata.user, diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py index 8685aa61..40b8c566 100755 --- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py @@ -9,20 +9,16 @@ for large documents. from pypdf import PdfWriter, PdfReader from io import BytesIO -import asyncio import base64 import uuid import os from mistralai import Mistral -from mistralai.models import OCRResponse from ... schema import Document, TextDocument, Metadata -from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue from ... schema import Triples -from ... base import FlowProcessor, ConsumerSpec, ProducerSpec -from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... provenance import ( document_uri, page_uri as make_page_uri, derived_entity_triples, @@ -102,42 +98,10 @@ class Processor(FlowProcessor): ) ) - # Librarian client for fetching document content - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, taskgroup=self.taskgroup, ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor = id, flow = None, name = "librarian-request" - ) - - self.librarian_request_producer = Producer( - backend = self.pubsub, - topic = librarian_request_q, - schema = LibrarianRequest, - metrics = librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor = id, flow = None, name = "librarian-response" - ) - - self.librarian_response_consumer = Consumer( - taskgroup = self.taskgroup, - backend = self.pubsub, - flow = None, - topic = librarian_response_q, - subscriber = f"{id}-librarian", - schema = LibrarianResponse, - handler = self.on_librarian_response, - metrics = librarian_response_metrics, - ) - - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} if api_key is None: raise RuntimeError("Mistral API key not specified") @@ -151,132 +115,7 @@ class Processor(FlowProcessor): async def start(self): await super(Processor, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id and request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) - - async def fetch_document_metadata(self, document_id, user, timeout=120): - """ - Fetch document metadata from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-metadata", - document_id=document_id, - user=user, - ) - - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.document_metadata - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching metadata for {document_id}") - - async def fetch_document_content(self, document_id, user, timeout=120): - """ - Fetch document content from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-content", - document_id=document_id, - user=user, - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.content - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching document {document_id}") - - async def save_child_document(self, doc_id, parent_id, user, content, - document_type="page", title=None, timeout=120): - """ - Save a child document to the librarian. - """ - request_id = str(uuid.uuid4()) - - doc_metadata = DocumentMetadata( - id=doc_id, - user=user, - kind="text/plain", - title=title or doc_id, - parent_id=parent_id, - document_type=document_type, - ) - - request = LibrarianRequest( - operation="add-child-document", - document_metadata=doc_metadata, - content=base64.b64encode(content).decode("utf-8"), - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error saving child document: {response.error.type}: {response.error.message}" - ) - - return doc_id - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout saving child document {doc_id}") + await self.librarian.start() def ocr(self, blob): """ @@ -359,7 +198,7 @@ class Processor(FlowProcessor): # Check MIME type if fetching from librarian if v.document_id: - doc_meta = await self.fetch_document_metadata( + doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, user=v.metadata.user, ) @@ -374,7 +213,7 @@ class Processor(FlowProcessor): # Get PDF content - fetch from librarian or use inline data if v.document_id: logger.info(f"Fetching document {v.document_id} from librarian...") - content = await self.fetch_document_content( + content = await self.librarian.fetch_document_content( document_id=v.document_id, user=v.metadata.user, ) @@ -401,7 +240,7 @@ class Processor(FlowProcessor): page_content = markdown.encode("utf-8") # Save page as child document in librarian - await self.save_child_document( + await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, user=v.metadata.user, diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index 38ca0603..d0061afd 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -7,20 +7,16 @@ Supports both inline document data and fetching from librarian via Pulsar for large documents. """ -import asyncio import os import tempfile import base64 import logging -import uuid from langchain_community.document_loaders import PyPDFLoader from ... schema import Document, TextDocument, Metadata -from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue from ... schema import Triples -from ... base import FlowProcessor, ConsumerSpec, ProducerSpec -from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... provenance import ( document_uri, page_uri as make_page_uri, derived_entity_triples, @@ -74,187 +70,16 @@ class Processor(FlowProcessor): ) ) - # Librarian client for fetching document content - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, taskgroup=self.taskgroup, ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor = id, flow = None, name = "librarian-request" - ) - - self.librarian_request_producer = Producer( - backend = self.pubsub, - topic = librarian_request_q, - schema = LibrarianRequest, - metrics = librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor = id, flow = None, name = "librarian-response" - ) - - self.librarian_response_consumer = Consumer( - taskgroup = self.taskgroup, - backend = self.pubsub, - flow = None, - topic = librarian_response_q, - subscriber = f"{id}-librarian", - schema = LibrarianResponse, - handler = self.on_librarian_response, - metrics = librarian_response_metrics, - ) - - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} logger.info("PDF decoder initialized") async def start(self): await super(Processor, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id and request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) - - async def fetch_document_metadata(self, document_id, user, timeout=120): - """ - Fetch document metadata from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-metadata", - document_id=document_id, - user=user, - ) - - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.document_metadata - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching metadata for {document_id}") - - async def fetch_document_content(self, document_id, user, timeout=120): - """ - Fetch document content from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-content", - document_id=document_id, - user=user, - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.content - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching document {document_id}") - - async def save_child_document(self, doc_id, parent_id, user, content, - document_type="page", title=None, timeout=120): - """ - Save a child document to the librarian. - - Args: - doc_id: ID for the new child document - parent_id: ID of the parent document - user: User ID - content: Document content (bytes) - document_type: Type of document ("page", "chunk", etc.) - title: Optional title - timeout: Request timeout in seconds - - Returns: - The document ID on success - """ - import base64 - - request_id = str(uuid.uuid4()) - - doc_metadata = DocumentMetadata( - id=doc_id, - user=user, - kind="text/plain", - title=title or doc_id, - parent_id=parent_id, - document_type=document_type, - ) - - request = LibrarianRequest( - operation="add-child-document", - document_metadata=doc_metadata, - content=base64.b64encode(content).decode("utf-8"), - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error saving child document: {response.error.type}: {response.error.message}" - ) - - return doc_id - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout saving child document {doc_id}") + await self.librarian.start() async def on_message(self, msg, consumer, flow): @@ -266,7 +91,7 @@ class Processor(FlowProcessor): # Check MIME type if fetching from librarian if v.document_id: - doc_meta = await self.fetch_document_metadata( + doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, user=v.metadata.user, ) @@ -287,7 +112,7 @@ class Processor(FlowProcessor): logger.info(f"Fetching document {v.document_id} from librarian...") fp.close() - content = await self.fetch_document_content( + content = await self.librarian.fetch_document_content( document_id=v.document_id, user=v.metadata.user, ) @@ -323,7 +148,7 @@ class Processor(FlowProcessor): page_content = page.page_content.encode("utf-8") # Save page as child document in librarian - await self.save_child_document( + await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, user=v.metadata.user, diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index cdf5daba..8d1aca9e 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -10,7 +10,7 @@ import logging import os from trustgraph.base.logging import setup_logging -from trustgraph.base.pubsub import get_pubsub +from trustgraph.base.pubsub import get_pubsub, add_pubsub_args from . auth import Authenticator from . config.receiver import ConfigReceiver @@ -167,30 +167,7 @@ def run(): help='Service identifier for logging and metrics (default: api-gateway)', ) - # Pub/sub backend selection - parser.add_argument( - '--pubsub-backend', - default=os.getenv('PUBSUB_BACKEND', 'pulsar'), - choices=['pulsar', 'mqtt'], - help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)', - ) - - parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', - ) - - parser.add_argument( - '--pulsar-api-key', - default=default_pulsar_api_key, - help=f'Pulsar API key', - ) - - parser.add_argument( - '--pulsar-listener', - help=f'Pulsar listener (default: none)', - ) + add_pubsub_args(parser) parser.add_argument( '-m', '--prometheus-url', diff --git a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py index b81c6321..c0e55d84 100755 --- a/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/document_rag/rag.py @@ -12,22 +12,18 @@ import uuid from ... schema import DocumentRagQuery, DocumentRagResponse, Error from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata -from ... schema import librarian_request_queue, librarian_response_queue from ... schema import Triples, Metadata from ... provenance import GRAPH_RETRIEVAL from . document_rag import DocumentRag from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec -from ... base import Consumer, Producer -from ... base import ConsumerMetrics, ProducerMetrics +from ... base import LibrarianClient # Module logger logger = logging.getLogger(__name__) default_ident = "document-rag" -default_librarian_request_queue = librarian_request_queue -default_librarian_response_queue = librarian_response_queue class Processor(FlowProcessor): @@ -89,111 +85,26 @@ class Processor(FlowProcessor): ) ) - # Librarian client for fetching chunk content from Garage - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue - ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor=id, flow=None, name="librarian-request" - ) - - self.librarian_request_producer = Producer( + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, - topic=librarian_request_q, - schema=LibrarianRequest, - metrics=librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor=id, flow=None, name="librarian-response" - ) - - self.librarian_response_consumer = Consumer( taskgroup=self.taskgroup, - backend=self.pubsub, - flow=None, - topic=librarian_response_q, - subscriber=f"{id}-librarian", - schema=LibrarianResponse, - handler=self.on_librarian_response, - metrics=librarian_response_metrics, ) - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} - async def start(self): await super(Processor, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) + await self.librarian.start() async def fetch_chunk_content(self, chunk_id, user, timeout=120): - """Fetch chunk content from librarian/Garage.""" - import uuid - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-content", - document_id=chunk_id, - user=user, + """Fetch chunk content from librarian. Chunks are small so + single request-response is fine.""" + return await self.librarian.fetch_document_text( + document_id=chunk_id, user=user, timeout=timeout, ) - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - # Content is base64 encoded - content = response.content - if isinstance(content, str): - content = content.encode('utf-8') - return base64.b64decode(content).decode("utf-8") - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching chunk {chunk_id}") - async def save_answer_content(self, doc_id, user, content, title=None, timeout=120): - """ - Save answer content to the librarian. - - Args: - doc_id: ID for the answer document - user: User ID - content: Answer text content - title: Optional title - timeout: Request timeout in seconds - - Returns: - The document ID on success - """ - request_id = str(uuid.uuid4()) + """Save answer content to the librarian.""" doc_metadata = DocumentMetadata( id=doc_id, @@ -211,29 +122,8 @@ class Processor(FlowProcessor): user=user, ) - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error saving answer: {response.error.type}: {response.error.message}" - ) - - return doc_id - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout saving answer document {doc_id}") + await self.librarian.request(request, timeout=timeout) + return doc_id async def on_request(self, msg, consumer, flow): @@ -390,4 +280,3 @@ class Processor(FlowProcessor): def run(): Processor.launch(default_ident, __doc__) - diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py index dd410d90..4844b104 100755 --- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py +++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py @@ -7,19 +7,15 @@ Supports both inline document data and fetching from librarian via Pulsar for large documents. """ -import asyncio import base64 import logging -import uuid import pytesseract from pdf2image import convert_from_bytes from ... schema import Document, TextDocument, Metadata -from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue from ... schema import Triples -from ... base import FlowProcessor, ConsumerSpec, ProducerSpec -from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... provenance import ( document_uri, page_uri as make_page_uri, derived_entity_triples, @@ -72,173 +68,16 @@ class Processor(FlowProcessor): ) ) - # Librarian client for fetching document content - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, taskgroup=self.taskgroup, ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor = id, flow = None, name = "librarian-request" - ) - - self.librarian_request_producer = Producer( - backend = self.pubsub, - topic = librarian_request_q, - schema = LibrarianRequest, - metrics = librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor = id, flow = None, name = "librarian-response" - ) - - self.librarian_response_consumer = Consumer( - taskgroup = self.taskgroup, - backend = self.pubsub, - flow = None, - topic = librarian_response_q, - subscriber = f"{id}-librarian", - schema = LibrarianResponse, - handler = self.on_librarian_response, - metrics = librarian_response_metrics, - ) - - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} logger.info("PDF OCR processor initialized") async def start(self): await super(Processor, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id and request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) - - async def fetch_document_metadata(self, document_id, user, timeout=120): - """ - Fetch document metadata from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-metadata", - document_id=document_id, - user=user, - ) - - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.document_metadata - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching metadata for {document_id}") - - async def fetch_document_content(self, document_id, user, timeout=120): - """ - Fetch document content from librarian via Pulsar. - """ - request_id = str(uuid.uuid4()) - - request = LibrarianRequest( - operation="get-document-content", - document_id=document_id, - user=user, - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: {response.error.message}" - ) - - return response.content - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout fetching document {document_id}") - - async def save_child_document(self, doc_id, parent_id, user, content, - document_type="page", title=None, timeout=120): - """ - Save a child document to the librarian. - """ - request_id = str(uuid.uuid4()) - - doc_metadata = DocumentMetadata( - id=doc_id, - user=user, - kind="text/plain", - title=title or doc_id, - parent_id=parent_id, - document_type=document_type, - ) - - request = LibrarianRequest( - operation="add-child-document", - document_metadata=doc_metadata, - content=base64.b64encode(content).decode("utf-8"), - ) - - # Create future for response - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - # Send request - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - - # Wait for response - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error saving child document: {response.error.type}: {response.error.message}" - ) - - return doc_id - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError(f"Timeout saving child document {doc_id}") + await self.librarian.start() async def on_message(self, msg, consumer, flow): @@ -250,7 +89,7 @@ class Processor(FlowProcessor): # Check MIME type if fetching from librarian if v.document_id: - doc_meta = await self.fetch_document_metadata( + doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, user=v.metadata.user, ) @@ -265,7 +104,7 @@ class Processor(FlowProcessor): # Get PDF content - fetch from librarian or use inline data if v.document_id: logger.info(f"Fetching document {v.document_id} from librarian...") - content = await self.fetch_document_content( + content = await self.librarian.fetch_document_content( document_id=v.document_id, user=v.metadata.user, ) @@ -299,7 +138,7 @@ class Processor(FlowProcessor): page_content = text.encode("utf-8") # Save page as child document in librarian - await self.save_child_document( + await self.librarian.save_child_document( doc_id=page_doc_id, parent_id=source_doc_id, user=v.metadata.user, diff --git a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py index b8d05158..6b7d0246 100644 --- a/trustgraph-unstructured/trustgraph/decoding/universal/processor.py +++ b/trustgraph-unstructured/trustgraph/decoding/universal/processor.py @@ -14,22 +14,18 @@ Tables are preserved as HTML markup for better downstream extraction. Images are stored in the librarian but not sent to the text pipeline. """ -import asyncio import base64 import logging import magic import tempfile import os -import uuid from unstructured.partition.auto import partition from ... schema import Document, TextDocument, Metadata -from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata from ... schema import librarian_request_queue, librarian_response_queue from ... schema import Triples -from ... base import FlowProcessor, ConsumerSpec, ProducerSpec -from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... provenance import ( document_uri, page_uri as make_page_uri, @@ -166,128 +162,16 @@ class Processor(FlowProcessor): ) ) - # Librarian client for fetching/storing document content - librarian_request_q = params.get( - "librarian_request_queue", default_librarian_request_queue + # Librarian client + self.librarian = LibrarianClient( + id=id, backend=self.pubsub, taskgroup=self.taskgroup, ) - librarian_response_q = params.get( - "librarian_response_queue", default_librarian_response_queue - ) - - librarian_request_metrics = ProducerMetrics( - processor=id, flow=None, name="librarian-request" - ) - - self.librarian_request_producer = Producer( - backend=self.pubsub, - topic=librarian_request_q, - schema=LibrarianRequest, - metrics=librarian_request_metrics, - ) - - librarian_response_metrics = ConsumerMetrics( - processor=id, flow=None, name="librarian-response" - ) - - self.librarian_response_consumer = Consumer( - taskgroup=self.taskgroup, - backend=self.pubsub, - flow=None, - topic=librarian_response_q, - subscriber=f"{id}-librarian", - schema=LibrarianResponse, - handler=self.on_librarian_response, - metrics=librarian_response_metrics, - ) - - # Pending librarian requests: request_id -> asyncio.Future - self.pending_requests = {} logger.info("Universal decoder initialized") async def start(self): await super(Processor, self).start() - await self.librarian_request_producer.start() - await self.librarian_response_consumer.start() - - async def on_librarian_response(self, msg, consumer, flow): - """Handle responses from the librarian service.""" - response = msg.value() - request_id = msg.properties().get("id") - - if request_id and request_id in self.pending_requests: - future = self.pending_requests.pop(request_id) - future.set_result(response) - - async def _librarian_request(self, request, timeout=120): - """Send a request to the librarian and wait for response.""" - request_id = str(uuid.uuid4()) - - future = asyncio.get_event_loop().create_future() - self.pending_requests[request_id] = future - - try: - await self.librarian_request_producer.send( - request, properties={"id": request_id} - ) - response = await asyncio.wait_for(future, timeout=timeout) - - if response.error: - raise RuntimeError( - f"Librarian error: {response.error.type}: " - f"{response.error.message}" - ) - - return response - - except asyncio.TimeoutError: - self.pending_requests.pop(request_id, None) - raise RuntimeError("Timeout waiting for librarian response") - - async def fetch_document_metadata(self, document_id, user): - """Fetch document metadata from the librarian.""" - request = LibrarianRequest( - operation="get-document-metadata", - document_id=document_id, - user=user, - ) - response = await self._librarian_request(request) - return response.document_metadata - - async def fetch_document_content(self, document_id, user): - """Fetch document content from the librarian.""" - request = LibrarianRequest( - operation="get-document-content", - document_id=document_id, - user=user, - ) - response = await self._librarian_request(request) - return response.content - - async def save_child_document(self, doc_id, parent_id, user, content, - document_type="page", title=None, - kind="text/plain"): - """Save a child document to the librarian.""" - if isinstance(content, str): - content = content.encode("utf-8") - - doc_metadata = DocumentMetadata( - id=doc_id, - user=user, - kind=kind, - title=title or doc_id, - parent_id=parent_id, - document_type=document_type, - ) - - request = LibrarianRequest( - operation="add-child-document", - document_metadata=doc_metadata, - content=base64.b64encode(content).decode("utf-8"), - ) - - await self._librarian_request(request) - return doc_id + await self.librarian.start() def extract_elements(self, blob, mime_type=None): """ @@ -388,7 +272,7 @@ class Processor(FlowProcessor): page_content = text.encode("utf-8") # Save to librarian - await self.save_child_document( + await self.librarian.save_child_document( doc_id=doc_id, parent_id=parent_doc_id, user=metadata.user, @@ -469,7 +353,7 @@ class Processor(FlowProcessor): # Save to librarian if img_content: - await self.save_child_document( + await self.librarian.save_child_document( doc_id=img_uri, parent_id=parent_doc_id, user=metadata.user, @@ -518,13 +402,13 @@ class Processor(FlowProcessor): f"Fetching document {v.document_id} from librarian..." ) - doc_meta = await self.fetch_document_metadata( + doc_meta = await self.librarian.fetch_document_metadata( document_id=v.document_id, user=v.metadata.user, ) mime_type = doc_meta.kind if doc_meta else None - content = await self.fetch_document_content( + content = await self.librarian.fetch_document_content( document_id=v.document_id, user=v.metadata.user, )