diff --git a/tests/unit/test_decoding/test_mistral_ocr_processor.py b/tests/unit/test_decoding/test_mistral_ocr_processor.py index 4d7b9937..d23112b4 100644 --- a/tests/unit/test_decoding/test_mistral_ocr_processor.py +++ b/tests/unit/test_decoding/test_mistral_ocr_processor.py @@ -10,159 +10,137 @@ from unittest import IsolatedAsyncioTestCase from io import BytesIO from trustgraph.decoding.mistral_ocr.processor import Processor -from trustgraph.schema import Document, TextDocument, Metadata +from trustgraph.schema import Document, TextDocument, Metadata, Triples + + +class MockAsyncProcessor: + def __init__(self, **params): + self.config_handlers = [] + self.id = params.get('id', 'test-service') + self.specifications = [] + self.pubsub = MagicMock() + self.taskgroup = params.get('taskgroup', MagicMock()) class TestMistralOcrProcessor(IsolatedAsyncioTestCase): """Test Mistral OCR processor functionality""" + @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') + @patch('trustgraph.decoding.mistral_ocr.processor.Producer') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') - @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') - async def test_processor_initialization_with_api_key(self, mock_flow_init, mock_mistral_class): + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_processor_initialization_with_api_key( + self, mock_mistral_class, mock_producer, mock_consumer + ): """Test Mistral OCR processor initialization with API key""" - # Arrange - mock_flow_init.return_value = None - mock_mistral = MagicMock() - mock_mistral_class.return_value = mock_mistral - + mock_mistral_class.return_value = MagicMock() + config = { 'id': 'test-mistral-ocr', 'api_key': 'test-api-key', 'taskgroup': AsyncMock() } - # Act - with patch.object(Processor, 'register_specification') as mock_register: - processor = Processor(**config) + processor = Processor(**config) - # Assert - mock_flow_init.assert_called_once() mock_mistral_class.assert_called_once_with(api_key='test-api-key') - - # Verify register_specification was called twice (consumer and producer) - assert mock_register.call_count == 2 - - # Check consumer spec - consumer_call = mock_register.call_args_list[0] - consumer_spec = consumer_call[0][0] - assert consumer_spec.name == "input" - assert consumer_spec.schema == Document - assert consumer_spec.handler == processor.on_message - - # Check producer spec - producer_call = mock_register.call_args_list[1] - producer_spec = producer_call[0][0] - assert producer_spec.name == "output" - assert producer_spec.schema == TextDocument - @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') - async def test_processor_initialization_without_api_key(self, mock_flow_init): + # Check specs registered: input consumer, output producer, triples producer + consumer_specs = [s for s in processor.specifications if hasattr(s, 'handler')] + assert len(consumer_specs) >= 1 + assert consumer_specs[0].name == "input" + assert consumer_specs[0].schema == Document + + @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') + @patch('trustgraph.decoding.mistral_ocr.processor.Producer') + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_processor_initialization_without_api_key( + self, mock_producer, mock_consumer + ): """Test Mistral OCR processor initialization without API key raises error""" - # Arrange - mock_flow_init.return_value = None - config = { 'id': 'test-mistral-ocr', 'taskgroup': AsyncMock() } - # Act & Assert - with patch.object(Processor, 'register_specification'): - with pytest.raises(RuntimeError, match="Mistral API key not specified"): - processor = Processor(**config) + with pytest.raises(RuntimeError, match="Mistral API key not specified"): + Processor(**config) - @patch('trustgraph.decoding.mistral_ocr.processor.uuid.uuid4') + @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') + @patch('trustgraph.decoding.mistral_ocr.processor.Producer') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') - @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') - async def test_ocr_single_chunk(self, mock_flow_init, mock_mistral_class, mock_uuid): + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_ocr_single_chunk( + self, mock_mistral_class, mock_producer, mock_consumer + ): """Test OCR processing with a single chunk (less than 5 pages)""" - # Arrange - mock_flow_init.return_value = None - mock_uuid.return_value = "test-uuid-1234" - - # Mock Mistral client mock_mistral = MagicMock() mock_mistral_class.return_value = mock_mistral - + # Mock file upload mock_uploaded_file = MagicMock(id="file-123") mock_mistral.files.upload.return_value = mock_uploaded_file - + # Mock signed URL mock_signed_url = MagicMock(url="https://example.com/signed-url") mock_mistral.files.get_signed_url.return_value = mock_signed_url - - # Mock OCR response - mock_page = MagicMock( + + # Mock OCR response with 2 pages + mock_page1 = MagicMock( markdown="# Page 1\nContent ![img1](img1)", images=[MagicMock(id="img1", image_base64="data:image/png;base64,abc123")] ) - mock_ocr_response = MagicMock(pages=[mock_page]) + mock_page2 = MagicMock( + markdown="# Page 2\nMore content", + images=[] + ) + mock_ocr_response = MagicMock(pages=[mock_page1, mock_page2]) mock_mistral.ocr.process.return_value = mock_ocr_response - + # Mock PyPDF mock_pdf_reader = MagicMock() - mock_pdf_reader.pages = [MagicMock(), MagicMock(), MagicMock()] # 3 pages - + mock_pdf_reader.pages = [MagicMock(), MagicMock(), MagicMock()] + config = { 'id': 'test-mistral-ocr', 'api_key': 'test-api-key', 'taskgroup': AsyncMock() } - with patch.object(Processor, 'register_specification'): - with patch('trustgraph.decoding.mistral_ocr.processor.PdfReader', return_value=mock_pdf_reader): - with patch('trustgraph.decoding.mistral_ocr.processor.PdfWriter') as mock_pdf_writer_class: - mock_pdf_writer = MagicMock() - mock_pdf_writer_class.return_value = mock_pdf_writer - - processor = Processor(**config) - - # Act - result = processor.ocr(b"fake pdf content") + with patch('trustgraph.decoding.mistral_ocr.processor.PdfReader', return_value=mock_pdf_reader): + with patch('trustgraph.decoding.mistral_ocr.processor.PdfWriter') as mock_pdf_writer_class: + mock_pdf_writer = MagicMock() + mock_pdf_writer_class.return_value = mock_pdf_writer - # Assert - assert result == "# Page 1\nContent ![img1](data:image/png;base64,abc123)" - - # Verify PDF writer was used to create chunk + processor = Processor(**config) + result = processor.ocr(b"fake pdf content") + + # Returns list of (markdown, page_num) tuples + assert len(result) == 2 + assert result[0] == ("# Page 1\nContent ![img1](data:image/png;base64,abc123)", 1) + assert result[1] == ("# Page 2\nMore content", 2) + + # Verify PDF writer was used assert mock_pdf_writer.add_page.call_count == 3 mock_pdf_writer.write_stream.assert_called_once() - + # Verify Mistral API calls mock_mistral.files.upload.assert_called_once() - upload_call = mock_mistral.files.upload.call_args[1] - assert upload_call['file']['file_name'] == "test-uuid-1234" - assert upload_call['purpose'] == 'ocr' - mock_mistral.files.get_signed_url.assert_called_once_with( file_id="file-123", expiry=1 ) - - mock_mistral.ocr.process.assert_called_once_with( - model="mistral-ocr-latest", - include_image_base64=True, - document={ - "type": "document_url", - "document_url": "https://example.com/signed-url", - } - ) + mock_mistral.ocr.process.assert_called_once() - @patch('trustgraph.decoding.mistral_ocr.processor.uuid.uuid4') + @patch('trustgraph.decoding.mistral_ocr.processor.Consumer') + @patch('trustgraph.decoding.mistral_ocr.processor.Producer') @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') - @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') - async def test_on_message_success(self, mock_flow_init, mock_mistral_class, mock_uuid): + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_on_message_success( + self, mock_mistral_class, mock_producer, mock_consumer + ): """Test successful message processing""" - # Arrange - mock_flow_init.return_value = None - mock_uuid.return_value = "test-uuid-5678" - - # Mock Mistral client with simple OCR response - mock_mistral = MagicMock() - mock_mistral_class.return_value = mock_mistral - - # Mock the ocr method to return simple markdown - ocr_result = "# Document Title\nThis is the OCR content" - + mock_mistral_class.return_value = MagicMock() + # Mock message pdf_content = b"fake pdf content" pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') @@ -170,126 +148,100 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase): mock_document = Document(metadata=mock_metadata, data=pdf_base64) mock_msg = MagicMock() mock_msg.value.return_value = mock_document - - # Mock flow - needs to be a callable that returns an object with send method + + # Mock flow mock_output_flow = AsyncMock() - mock_flow = MagicMock(return_value=mock_output_flow) - + mock_triples_flow = AsyncMock() + mock_flow = MagicMock(side_effect=lambda name: { + "output": mock_output_flow, + "triples": mock_triples_flow, + }.get(name)) + config = { 'id': 'test-mistral-ocr', 'api_key': 'test-api-key', 'taskgroup': AsyncMock() } - with patch.object(Processor, 'register_specification'): - processor = Processor(**config) - - # Mock the ocr method - with patch.object(processor, 'ocr', return_value=ocr_result): - # Act - await processor.on_message(mock_msg, None, mock_flow) + processor = Processor(**config) - # Assert - # Verify output was sent - mock_output_flow.send.assert_called_once() - - # Check output - call_args = mock_output_flow.send.call_args[0][0] + # Mock ocr to return per-page results + ocr_result = [ + ("# Page 1\nContent", 1), + ("# Page 2\nMore content", 2), + ] + + # Mock save_child_document + processor.save_child_document = AsyncMock(return_value="mock-doc-id") + + with patch.object(processor, 'ocr', return_value=ocr_result): + await processor.on_message(mock_msg, None, mock_flow) + + # Verify output was sent for each page + assert mock_output_flow.send.call_count == 2 + # Verify triples were sent for each page + assert mock_triples_flow.send.call_count == 2 + + # Check output uses UUID-based page URNs + call_args = mock_output_flow.send.call_args_list[0][0][0] assert isinstance(call_args, TextDocument) - assert call_args.metadata == mock_metadata - assert call_args.text == ocr_result.encode('utf-8') + assert call_args.document_id.startswith("urn:page:") + assert call_args.text == b"" # Content stored in librarian @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') - @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') - async def test_chunks_function(self, mock_flow_init, mock_mistral_class): + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_chunks_function(self, mock_mistral_class): """Test the chunks utility function""" - # Arrange from trustgraph.decoding.mistral_ocr.processor import chunks - + test_list = list(range(12)) - - # Act result = list(chunks(test_list, 5)) - - # Assert + assert len(result) == 3 assert result[0] == [0, 1, 2, 3, 4] assert result[1] == [5, 6, 7, 8, 9] assert result[2] == [10, 11] @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') - @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') - async def test_replace_images_in_markdown(self, mock_flow_init, mock_mistral_class): + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_replace_images_in_markdown(self, mock_mistral_class): """Test the replace_images_in_markdown function""" - # Arrange from trustgraph.decoding.mistral_ocr.processor import replace_images_in_markdown - + markdown = "# Title\n![image1](image1)\nSome text\n![image2](image2)" images_dict = { "image1": "data:image/png;base64,abc123", "image2": "data:image/png;base64,def456" } - - # Act - result = replace_images_in_markdown(markdown, images_dict) - - # Assert - expected = "# Title\n![image1](data:image/png;base64,abc123)\nSome text\n![image2](data:image/png;base64,def456)" - assert result == expected - @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') - @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') - async def test_get_combined_markdown(self, mock_flow_init, mock_mistral_class): - """Test the get_combined_markdown function""" - # Arrange - from trustgraph.decoding.mistral_ocr.processor import get_combined_markdown - from mistralai.models import OCRResponse - - # Mock OCR response with multiple pages - mock_page1 = MagicMock( - markdown="# Page 1\n![img1](img1)", - images=[MagicMock(id="img1", image_base64="base64_img1")] - ) - mock_page2 = MagicMock( - markdown="# Page 2\n![img2](img2)", - images=[MagicMock(id="img2", image_base64="base64_img2")] - ) - mock_ocr_response = MagicMock(pages=[mock_page1, mock_page2]) - - # Act - result = get_combined_markdown(mock_ocr_response) - - # Assert - expected = "# Page 1\n![img1](base64_img1)\n\n# Page 2\n![img2](base64_img2)" + result = replace_images_in_markdown(markdown, images_dict) + + expected = "# Title\n![image1](data:image/png;base64,abc123)\nSome text\n![image2](data:image/png;base64,def456)" assert result == expected @patch('trustgraph.base.flow_processor.FlowProcessor.add_args') def test_add_args(self, mock_parent_add_args): - """Test add_args adds API key argument""" - # Arrange + """Test add_args adds expected arguments""" mock_parser = MagicMock() - - # Act + Processor.add_args(mock_parser) - - # Assert + mock_parent_add_args.assert_called_once_with(mock_parser) - mock_parser.add_argument.assert_called_once_with( - '-k', '--api-key', - default=None, # default_api_key is None in test environment - help='Mistral API Key' - ) + assert mock_parser.add_argument.call_count == 3 + # Check the API key arg is among them + call_args_list = [c[0] for c in mock_parser.add_argument.call_args_list] + assert ('-k', '--api-key') in call_args_list @patch('trustgraph.decoding.mistral_ocr.processor.Processor.launch') def test_run(self, mock_launch): """Test run function""" - # Act from trustgraph.decoding.mistral_ocr.processor import run run() - - # Assert - mock_launch.assert_called_once_with("pdf-decoder", - "\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n") + + mock_launch.assert_called_once() + args = mock_launch.call_args[0] + assert args[0] == "pdf-decoder" + assert "Mistral OCR decoder" in args[1] if __name__ == '__main__': diff --git a/tests/unit/test_decoding/test_pdf_decoder.py b/tests/unit/test_decoding/test_pdf_decoder.py index a3ca3514..c55201ad 100644 --- a/tests/unit/test_decoding/test_pdf_decoder.py +++ b/tests/unit/test_decoding/test_pdf_decoder.py @@ -171,8 +171,8 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): mock_output_flow.send.assert_called_once() call_args = mock_output_flow.send.call_args[0][0] - # PDF decoder now forwards document_id, chunker fetches content from librarian - assert call_args.document_id == "test-doc/p1" + # PDF decoder now forwards document_id with UUID-based URN + assert call_args.document_id.startswith("urn:page:") assert call_args.text == b"" # Content stored in librarian, not inline @patch('trustgraph.base.flow_processor.FlowProcessor.add_args') diff --git a/tests/unit/test_provenance/test_uris.py b/tests/unit/test_provenance/test_uris.py index 0e69734c..05bb7a1b 100644 --- a/tests/unit/test_provenance/test_uris.py +++ b/tests/unit/test_provenance/test_uris.py @@ -10,8 +10,7 @@ from trustgraph.provenance.uris import ( _encode_id, document_uri, page_uri, - chunk_uri_from_page, - chunk_uri_from_doc, + chunk_uri, activity_uri, subgraph_uri, agent_uri, @@ -60,31 +59,22 @@ class TestDocumentUris: assert document_uri(iri) == iri def test_page_uri_format(self): - result = page_uri("https://example.com/doc/123", 5) - assert result == "https://example.com/doc/123/p5" + result = page_uri() + assert result.startswith("urn:page:") - def test_page_uri_page_zero(self): - result = page_uri("https://example.com/doc/123", 0) - assert result == "https://example.com/doc/123/p0" + def test_page_uri_unique(self): + r1 = page_uri() + r2 = page_uri() + assert r1 != r2 - def test_chunk_uri_from_page_format(self): - result = chunk_uri_from_page("https://example.com/doc/123", 2, 3) - assert result == "https://example.com/doc/123/p2/c3" + def test_chunk_uri_format(self): + result = chunk_uri() + assert result.startswith("urn:chunk:") - def test_chunk_uri_from_doc_format(self): - result = chunk_uri_from_doc("https://example.com/doc/123", 7) - assert result == "https://example.com/doc/123/c7" - - def test_page_uri_preserves_doc_iri(self): - doc = "urn:isbn:978-3-16-148410-0" - result = page_uri(doc, 1) - assert result.startswith(doc) - - def test_chunk_from_page_hierarchy(self): - """Chunk URI should contain both page and chunk identifiers.""" - result = chunk_uri_from_page("https://example.com/doc", 3, 5) - assert "/p3/" in result - assert result.endswith("/c5") + def test_chunk_uri_unique(self): + r1 = chunk_uri() + r2 = chunk_uri() + assert r1 != r2 class TestActivityAndSubgraphUris: diff --git a/trustgraph-base/trustgraph/provenance/__init__.py b/trustgraph-base/trustgraph/provenance/__init__.py index 18ecb0e8..5b9d2129 100644 --- a/trustgraph-base/trustgraph/provenance/__init__.py +++ b/trustgraph-base/trustgraph/provenance/__init__.py @@ -9,14 +9,14 @@ Provides helpers for: Usage example: from trustgraph.provenance import ( - document_uri, page_uri, chunk_uri_from_page, + document_uri, page_uri, chunk_uri, document_triples, derived_entity_triples, get_vocabulary_triples, ) # Generate URIs doc_uri = document_uri("my-doc-123") - page_uri = page_uri("my-doc-123", page_number=1) + pg_uri = page_uri() # Build provenance triples triples = document_triples( @@ -35,8 +35,7 @@ from . uris import ( TRUSTGRAPH_BASE, document_uri, page_uri, - chunk_uri_from_page, - chunk_uri_from_doc, + chunk_uri, activity_uri, subgraph_uri, agent_uri, @@ -138,8 +137,7 @@ __all__ = [ "TRUSTGRAPH_BASE", "document_uri", "page_uri", - "chunk_uri_from_page", - "chunk_uri_from_doc", + "chunk_uri", "activity_uri", "subgraph_uri", "agent_uri", diff --git a/trustgraph-base/trustgraph/provenance/uris.py b/trustgraph-base/trustgraph/provenance/uris.py index 670143df..d851fa0b 100644 --- a/trustgraph-base/trustgraph/provenance/uris.py +++ b/trustgraph-base/trustgraph/provenance/uris.py @@ -1,12 +1,11 @@ """ URI generation for provenance entities. -Document IDs are already IRIs (e.g., https://trustgraph.ai/doc/abc123). -Child entities (pages, chunks) append path segments to the parent IRI: -- Document: {doc_iri} (as provided) -- Page: {doc_iri}/p{page_number} -- Chunk: {page_iri}/c{chunk_index} (from page) - {doc_iri}/c{chunk_index} (from text doc) +Document IDs are externally provided (e.g., https://trustgraph.ai/doc/abc123). +Child entities (pages, chunks) use UUID-based URNs: +- Document: {doc_iri} (as provided, not generated here) +- Page: urn:page:{uuid} +- Chunk: urn:chunk:{uuid} - Activity: https://trustgraph.ai/activity/{uuid} - Subgraph: https://trustgraph.ai/subgraph/{uuid} """ @@ -28,19 +27,14 @@ def document_uri(doc_iri: str) -> str: return doc_iri -def page_uri(doc_iri: str, page_number: int) -> str: - """Generate URI for a page by appending to document IRI.""" - return f"{doc_iri}/p{page_number}" +def page_uri() -> str: + """Generate a unique URI for a page.""" + return f"urn:page:{uuid.uuid4()}" -def chunk_uri_from_page(doc_iri: str, page_number: int, chunk_index: int) -> str: - """Generate URI for a chunk extracted from a page.""" - return f"{doc_iri}/p{page_number}/c{chunk_index}" - - -def chunk_uri_from_doc(doc_iri: str, chunk_index: int) -> str: - """Generate URI for a chunk extracted directly from a text document.""" - return f"{doc_iri}/c{chunk_index}" +def chunk_uri() -> str: + """Generate a unique URI for a chunk.""" + return f"urn:chunk:{uuid.uuid4()}" def activity_uri(activity_id: str = None) -> str: diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index fb84c356..64d58457 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -12,7 +12,7 @@ from ... schema import TextDocument, Chunk, Metadata, Triples from ... base import ChunkingService, ConsumerSpec, ProducerSpec from ... provenance import ( - derived_entity_triples, + chunk_uri as make_chunk_uri, derived_entity_triples, set_graph, GRAPH_SOURCE, ) @@ -124,10 +124,9 @@ class Processor(ChunkingService): logger.debug(f"Created chunk of size {len(chunk.page_content)}") - # Generate chunk document ID by appending /c{index} to parent - # Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1) - chunk_doc_id = f"{parent_doc_id}/c{chunk_index}" - chunk_uri = chunk_doc_id # URI is same as document ID + # Generate unique chunk ID + c_uri = make_chunk_uri() + chunk_doc_id = c_uri parent_uri = parent_doc_id chunk_content = chunk.page_content.encode("utf-8") @@ -145,7 +144,7 @@ class Processor(ChunkingService): # Emit provenance triples (stored in source graph for separation from core knowledge) prov_triples = derived_entity_triples( - entity_uri=chunk_uri, + entity_uri=c_uri, parent_uri=parent_uri, component_name=COMPONENT_NAME, component_version=COMPONENT_VERSION, @@ -159,7 +158,7 @@ class Processor(ChunkingService): await flow("triples").send(Triples( metadata=Metadata( - id=chunk_uri, + id=c_uri, root=v.metadata.root, user=v.metadata.user, collection=v.metadata.collection, @@ -170,7 +169,7 @@ class Processor(ChunkingService): # Forward chunk ID + content (post-chunker optimization) r = Chunk( metadata=Metadata( - id=chunk_uri, + id=c_uri, root=v.metadata.root, user=v.metadata.user, collection=v.metadata.collection, diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index 909396c6..4302250e 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -12,7 +12,7 @@ from ... schema import TextDocument, Chunk, Metadata, Triples from ... base import ChunkingService, ConsumerSpec, ProducerSpec from ... provenance import ( - derived_entity_triples, + chunk_uri as make_chunk_uri, derived_entity_triples, set_graph, GRAPH_SOURCE, ) @@ -122,10 +122,9 @@ class Processor(ChunkingService): logger.debug(f"Created chunk of size {len(chunk.page_content)}") - # Generate chunk document ID by appending /c{index} to parent - # Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1) - chunk_doc_id = f"{parent_doc_id}/c{chunk_index}" - chunk_uri = chunk_doc_id # URI is same as document ID + # Generate unique chunk ID + c_uri = make_chunk_uri() + chunk_doc_id = c_uri parent_uri = parent_doc_id chunk_content = chunk.page_content.encode("utf-8") @@ -143,7 +142,7 @@ class Processor(ChunkingService): # Emit provenance triples (stored in source graph for separation from core knowledge) prov_triples = derived_entity_triples( - entity_uri=chunk_uri, + entity_uri=c_uri, parent_uri=parent_uri, component_name=COMPONENT_NAME, component_version=COMPONENT_VERSION, @@ -157,7 +156,7 @@ class Processor(ChunkingService): await flow("triples").send(Triples( metadata=Metadata( - id=chunk_uri, + id=c_uri, root=v.metadata.root, user=v.metadata.user, collection=v.metadata.collection, @@ -168,7 +167,7 @@ class Processor(ChunkingService): # Forward chunk ID + content (post-chunker optimization) r = Chunk( metadata=Metadata( - id=chunk_uri, + id=c_uri, root=v.metadata.root, user=v.metadata.user, collection=v.metadata.collection, diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py index 3cacb16c..6207d659 100755 --- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py @@ -1,29 +1,48 @@ """ -Simple decoder, accepts PDF documents on input, outputs pages from the -PDF document as text as separate output objects. +Mistral OCR decoder, accepts PDF documents on input, outputs pages from the +PDF document as markdown text as separate output objects. + +Supports both inline document data and fetching from librarian via Pulsar +for large documents. """ from pypdf import PdfWriter, PdfReader from io import BytesIO +import asyncio import base64 import uuid import os from mistralai import Mistral -from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk from mistralai.models import OCRResponse from ... schema import Document, TextDocument, Metadata +from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ... schema import librarian_request_queue, librarian_response_queue +from ... schema import Triples from ... base import FlowProcessor, ConsumerSpec, ProducerSpec +from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics + +from ... provenance import ( + document_uri, page_uri as make_page_uri, derived_entity_triples, + set_graph, GRAPH_SOURCE, +) import logging logger = logging.getLogger(__name__) +# Component identification for provenance +COMPONENT_NAME = "mistral-ocr-decoder" +COMPONENT_VERSION = "1.0.0" + default_ident = "pdf-decoder" default_api_key = os.getenv("MISTRAL_TOKEN") +default_librarian_request_queue = librarian_request_queue +default_librarian_response_queue = librarian_response_queue + pages_per_chunk = 5 def chunks(lst, n): @@ -48,27 +67,6 @@ def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str: ) return markdown_str -def get_combined_markdown(ocr_response: OCRResponse) -> str: - """ - Combine OCR text and images into a single markdown document. - - Args: - ocr_response: Response from OCR processing containing text and images - - Returns: - Combined markdown string with embedded images - """ - markdowns: list[str] = [] - # Extract images from page - for page in ocr_response.pages: - image_data = {} - for img in page.images: - image_data[img.id] = img.image_base64 - # Replace image placeholders with actual images - markdowns.append(replace_images_in_markdown(page.markdown, image_data)) - - return "\n\n".join(markdowns) - class Processor(FlowProcessor): def __init__(self, **params): @@ -97,6 +95,50 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + ProducerSpec( + name = "triples", + schema = Triples, + ) + ) + + # Librarian client for fetching document content + librarian_request_q = params.get( + "librarian_request_queue", default_librarian_request_queue + ) + librarian_response_q = params.get( + "librarian_response_queue", default_librarian_response_queue + ) + + librarian_request_metrics = ProducerMetrics( + processor = id, flow = None, name = "librarian-request" + ) + + self.librarian_request_producer = Producer( + backend = self.pubsub, + topic = librarian_request_q, + schema = LibrarianRequest, + metrics = librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor = id, flow = None, name = "librarian-response" + ) + + self.librarian_response_consumer = Consumer( + taskgroup = self.taskgroup, + backend = self.pubsub, + flow = None, + topic = librarian_response_q, + subscriber = f"{id}-librarian", + schema = LibrarianResponse, + handler = self.on_librarian_response, + metrics = librarian_response_metrics, + ) + + # Pending librarian requests: request_id -> asyncio.Future + self.pending_requests = {} + if api_key is None: raise RuntimeError("Mistral API key not specified") @@ -107,15 +149,125 @@ class Processor(FlowProcessor): logger.info("Mistral OCR processor initialized") + async def start(self): + await super(Processor, self).start() + await self.librarian_request_producer.start() + await self.librarian_response_consumer.start() + + async def on_librarian_response(self, msg, consumer, flow): + """Handle responses from the librarian service.""" + response = msg.value() + request_id = msg.properties().get("id") + + if request_id and request_id in self.pending_requests: + future = self.pending_requests.pop(request_id) + future.set_result(response) + else: + logger.warning(f"Received unexpected librarian response: {request_id}") + + async def fetch_document_content(self, document_id, user, timeout=120): + """ + Fetch document content from librarian via Pulsar. + """ + request_id = str(uuid.uuid4()) + + request = LibrarianRequest( + operation="get-document-content", + document_id=document_id, + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error: {response.error.type}: {response.error.message}" + ) + + return response.content + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout fetching document {document_id}") + + async def save_child_document(self, doc_id, parent_id, user, content, + document_type="page", title=None, timeout=120): + """ + Save a child document to the librarian. + """ + request_id = str(uuid.uuid4()) + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind="text/plain", + title=title or doc_id, + parent_id=parent_id, + document_type=document_type, + ) + + request = LibrarianRequest( + operation="add-child-document", + document_metadata=doc_metadata, + content=base64.b64encode(content).decode("utf-8"), + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error saving child document: {response.error.type}: {response.error.message}" + ) + + return doc_id + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout saving child document {doc_id}") + def ocr(self, blob): + """ + Run Mistral OCR on a PDF blob, returning per-page markdown strings. + + Args: + blob: Raw PDF bytes + + Returns: + List of (page_markdown, page_number) tuples, 1-indexed + """ logger.debug("Parse PDF...") pdfbuf = BytesIO(blob) pdf = PdfReader(pdfbuf) + pages = [] + global_page_num = 0 + for chunk in chunks(pdf.pages, pages_per_chunk): - + logger.debug("Get next pages...") part = PdfWriter() @@ -152,11 +304,19 @@ class Processor(FlowProcessor): logger.debug("Extract markdown...") - markdown = get_combined_markdown(processed) + for page in processed.pages: + global_page_num += 1 + image_data = {} + for img in page.images: + image_data[img.id] = img.image_base64 + markdown = replace_images_in_markdown( + page.markdown, image_data + ) + pages.append((markdown, global_page_num)) - logger.info("OCR complete.") + logger.info(f"OCR complete, {len(pages)} pages.") - return markdown + return pages async def on_message(self, msg, consumer, flow): @@ -166,16 +326,83 @@ class Processor(FlowProcessor): logger.info(f"Decoding {v.metadata.id}...") - markdown = self.ocr(base64.b64decode(v.data)) + # Get PDF content - fetch from librarian or use inline data + if v.document_id: + logger.info(f"Fetching document {v.document_id} from librarian...") + content = await self.fetch_document_content( + document_id=v.document_id, + user=v.metadata.user, + ) + if isinstance(content, str): + content = content.encode('utf-8') + blob = base64.b64decode(content) + logger.info(f"Fetched {len(blob)} bytes from librarian") + else: + blob = base64.b64decode(v.data) - r = TextDocument( - metadata=v.metadata, - text=markdown.encode("utf-8"), - ) + # Get the source document ID + source_doc_id = v.document_id or v.metadata.id - await flow("output").send(r) + # Run OCR, get per-page markdown + pages = self.ocr(blob) - logger.info("Done.") + for markdown, page_num in pages: + + logger.debug(f"Processing page {page_num}") + + # Generate unique page ID + pg_uri = make_page_uri() + page_doc_id = pg_uri + page_content = markdown.encode("utf-8") + + # Save page as child document in librarian + await self.save_child_document( + doc_id=page_doc_id, + parent_id=source_doc_id, + user=v.metadata.user, + content=page_content, + document_type="page", + title=f"Page {page_num}", + ) + + # Emit provenance triples + doc_uri = document_uri(source_doc_id) + + prov_triples = derived_entity_triples( + entity_uri=pg_uri, + parent_uri=doc_uri, + component_name=COMPONENT_NAME, + component_version=COMPONENT_VERSION, + label=f"Page {page_num}", + page_number=page_num, + ) + + await flow("triples").send(Triples( + metadata=Metadata( + id=pg_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + triples=set_graph(prov_triples, GRAPH_SOURCE), + )) + + # Forward page document ID to chunker + # Chunker will fetch content from librarian + r = TextDocument( + metadata=Metadata( + id=pg_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + document_id=page_doc_id, + text=b"", # Empty, chunker will fetch from librarian + ) + + await flow("output").send(r) + + logger.debug("PDF decoding complete") @staticmethod def add_args(parser): @@ -188,7 +415,18 @@ class Processor(FlowProcessor): help=f'Mistral API Key' ) + parser.add_argument( + '--librarian-request-queue', + default=default_librarian_request_queue, + help=f'Librarian request queue (default: {default_librarian_request_queue})', + ) + + parser.add_argument( + '--librarian-response-queue', + default=default_librarian_response_queue, + help=f'Librarian response queue (default: {default_librarian_response_queue})', + ) + def run(): Processor.launch(default_ident, __doc__) - diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index 550948fe..865b984e 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -23,7 +23,7 @@ from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics from ... provenance import ( - document_uri, page_uri, derived_entity_triples, + document_uri, page_uri as make_page_uri, derived_entity_triples, set_graph, GRAPH_SOURCE, ) @@ -272,8 +272,9 @@ class Processor(FlowProcessor): logger.debug(f"Processing page {page_num}") - # Generate page document ID - page_doc_id = f"{source_doc_id}/p{page_num}" + # Generate unique page ID + pg_uri = make_page_uri() + page_doc_id = pg_uri page_content = page.page_content.encode("utf-8") # Save page as child document in librarian @@ -288,7 +289,6 @@ class Processor(FlowProcessor): # Emit provenance triples (stored in source graph for separation from core knowledge) doc_uri = document_uri(source_doc_id) - pg_uri = page_uri(source_doc_id, page_num) prov_triples = derived_entity_triples( entity_uri=pg_uri, diff --git a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py index b5aac3c2..0c94039d 100755 --- a/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py +++ b/trustgraph-ocr/trustgraph/decoding/ocr/pdf_decoder.py @@ -2,22 +2,42 @@ """ Simple decoder, accepts PDF documents on input, outputs pages from the PDF document as text as separate output objects. + +Supports both inline document data and fetching from librarian via Pulsar +for large documents. """ -import tempfile +import asyncio import base64 import logging +import uuid import pytesseract from pdf2image import convert_from_bytes from ... schema import Document, TextDocument, Metadata +from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata +from ... schema import librarian_request_queue, librarian_response_queue +from ... schema import Triples from ... base import FlowProcessor, ConsumerSpec, ProducerSpec +from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics + +from ... provenance import ( + document_uri, page_uri as make_page_uri, derived_entity_triples, + set_graph, GRAPH_SOURCE, +) + +# Component identification for provenance +COMPONENT_NAME = "tesseract-ocr-decoder" +COMPONENT_VERSION = "1.0.0" # Module logger logger = logging.getLogger(__name__) default_ident = "pdf-decoder" +default_librarian_request_queue = librarian_request_queue +default_librarian_response_queue = librarian_response_queue + class Processor(FlowProcessor): def __init__(self, **params): @@ -45,8 +65,150 @@ class Processor(FlowProcessor): ) ) + self.register_specification( + ProducerSpec( + name = "triples", + schema = Triples, + ) + ) + + # Librarian client for fetching document content + librarian_request_q = params.get( + "librarian_request_queue", default_librarian_request_queue + ) + librarian_response_q = params.get( + "librarian_response_queue", default_librarian_response_queue + ) + + librarian_request_metrics = ProducerMetrics( + processor = id, flow = None, name = "librarian-request" + ) + + self.librarian_request_producer = Producer( + backend = self.pubsub, + topic = librarian_request_q, + schema = LibrarianRequest, + metrics = librarian_request_metrics, + ) + + librarian_response_metrics = ConsumerMetrics( + processor = id, flow = None, name = "librarian-response" + ) + + self.librarian_response_consumer = Consumer( + taskgroup = self.taskgroup, + backend = self.pubsub, + flow = None, + topic = librarian_response_q, + subscriber = f"{id}-librarian", + schema = LibrarianResponse, + handler = self.on_librarian_response, + metrics = librarian_response_metrics, + ) + + # Pending librarian requests: request_id -> asyncio.Future + self.pending_requests = {} + logger.info("PDF OCR processor initialized") + async def start(self): + await super(Processor, self).start() + await self.librarian_request_producer.start() + await self.librarian_response_consumer.start() + + async def on_librarian_response(self, msg, consumer, flow): + """Handle responses from the librarian service.""" + response = msg.value() + request_id = msg.properties().get("id") + + if request_id and request_id in self.pending_requests: + future = self.pending_requests.pop(request_id) + future.set_result(response) + else: + logger.warning(f"Received unexpected librarian response: {request_id}") + + async def fetch_document_content(self, document_id, user, timeout=120): + """ + Fetch document content from librarian via Pulsar. + """ + request_id = str(uuid.uuid4()) + + request = LibrarianRequest( + operation="get-document-content", + document_id=document_id, + user=user, + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error: {response.error.type}: {response.error.message}" + ) + + return response.content + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout fetching document {document_id}") + + async def save_child_document(self, doc_id, parent_id, user, content, + document_type="page", title=None, timeout=120): + """ + Save a child document to the librarian. + """ + request_id = str(uuid.uuid4()) + + doc_metadata = DocumentMetadata( + id=doc_id, + user=user, + kind="text/plain", + title=title or doc_id, + parent_id=parent_id, + document_type=document_type, + ) + + request = LibrarianRequest( + operation="add-child-document", + document_metadata=doc_metadata, + content=base64.b64encode(content).decode("utf-8"), + ) + + # Create future for response + future = asyncio.get_event_loop().create_future() + self.pending_requests[request_id] = future + + try: + # Send request + await self.librarian_request_producer.send( + request, properties={"id": request_id} + ) + + # Wait for response + response = await asyncio.wait_for(future, timeout=timeout) + + if response.error: + raise RuntimeError( + f"Librarian error saving child document: {response.error.type}: {response.error.message}" + ) + + return doc_id + + except asyncio.TimeoutError: + self.pending_requests.pop(request_id, None) + raise RuntimeError(f"Timeout saving child document {doc_id}") + async def on_message(self, msg, consumer, flow): logger.info("PDF message received") @@ -55,21 +217,85 @@ class Processor(FlowProcessor): logger.info(f"Decoding {v.metadata.id}...") - blob = base64.b64decode(v.data) + # Get PDF content - fetch from librarian or use inline data + if v.document_id: + logger.info(f"Fetching document {v.document_id} from librarian...") + content = await self.fetch_document_content( + document_id=v.document_id, + user=v.metadata.user, + ) + if isinstance(content, str): + content = content.encode('utf-8') + blob = base64.b64decode(content) + logger.info(f"Fetched {len(blob)} bytes from librarian") + else: + blob = base64.b64decode(v.data) + + # Get the source document ID + source_doc_id = v.document_id or v.metadata.id pages = convert_from_bytes(blob) for ix, page in enumerate(pages): + page_num = ix + 1 # 1-indexed + try: text = pytesseract.image_to_string(page, lang='eng') except Exception as e: - logger.warning(f"Page did not OCR: {e}") + logger.warning(f"Page {page_num} did not OCR: {e}") continue + logger.debug(f"Processing page {page_num}") + + # Generate unique page ID + pg_uri = make_page_uri() + page_doc_id = pg_uri + page_content = text.encode("utf-8") + + # Save page as child document in librarian + await self.save_child_document( + doc_id=page_doc_id, + parent_id=source_doc_id, + user=v.metadata.user, + content=page_content, + document_type="page", + title=f"Page {page_num}", + ) + + # Emit provenance triples + doc_uri = document_uri(source_doc_id) + + prov_triples = derived_entity_triples( + entity_uri=pg_uri, + parent_uri=doc_uri, + component_name=COMPONENT_NAME, + component_version=COMPONENT_VERSION, + label=f"Page {page_num}", + page_number=page_num, + ) + + await flow("triples").send(Triples( + metadata=Metadata( + id=pg_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + triples=set_graph(prov_triples, GRAPH_SOURCE), + )) + + # Forward page document ID to chunker + # Chunker will fetch content from librarian r = TextDocument( - metadata=v.metadata, - text=text.encode("utf-8"), + metadata=Metadata( + id=pg_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + document_id=page_doc_id, + text=b"", # Empty, chunker will fetch from librarian ) await flow("output").send(r) @@ -78,9 +304,21 @@ class Processor(FlowProcessor): @staticmethod def add_args(parser): + FlowProcessor.add_args(parser) + parser.add_argument( + '--librarian-request-queue', + default=default_librarian_request_queue, + help=f'Librarian request queue (default: {default_librarian_request_queue})', + ) + + parser.add_argument( + '--librarian-response-queue', + default=default_librarian_response_queue, + help=f'Librarian response queue (default: {default_librarian_response_queue})', + ) + def run(): Processor.launch(default_ident, __doc__) -