diff --git a/tests/unit/test_decoding/__init__.py b/tests/unit/test_decoding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_decoding/test_mistral_ocr_processor.py b/tests/unit/test_decoding/test_mistral_ocr_processor.py new file mode 100644 index 00000000..cb8362b7 --- /dev/null +++ b/tests/unit/test_decoding/test_mistral_ocr_processor.py @@ -0,0 +1,296 @@ +""" +Unit tests for trustgraph.decoding.mistral_ocr.processor +""" + +import pytest +import base64 +import uuid +from unittest.mock import AsyncMock, MagicMock, patch, Mock +from unittest import IsolatedAsyncioTestCase +from io import BytesIO + +from trustgraph.decoding.mistral_ocr.processor import Processor +from trustgraph.schema import Document, TextDocument, Metadata + + +class TestMistralOcrProcessor(IsolatedAsyncioTestCase): + """Test Mistral OCR processor functionality""" + + @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') + @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') + async def test_processor_initialization_with_api_key(self, mock_flow_init, mock_mistral_class): + """Test Mistral OCR processor initialization with API key""" + # Arrange + mock_flow_init.return_value = None + mock_mistral = MagicMock() + mock_mistral_class.return_value = mock_mistral + + config = { + 'id': 'test-mistral-ocr', + 'api_key': 'test-api-key', + 'taskgroup': AsyncMock() + } + + # Act + with patch.object(Processor, 'register_specification') as mock_register: + processor = Processor(**config) + + # Assert + mock_flow_init.assert_called_once() + mock_mistral_class.assert_called_once_with(api_key='test-api-key') + + # Verify register_specification was called twice (consumer and producer) + assert mock_register.call_count == 2 + + # Check consumer spec + consumer_call = mock_register.call_args_list[0] + consumer_spec = consumer_call[0][0] + assert consumer_spec.name == "input" + assert consumer_spec.schema == Document + assert consumer_spec.handler == processor.on_message + + # Check producer spec + producer_call = mock_register.call_args_list[1] + producer_spec = producer_call[0][0] + assert producer_spec.name == "output" + assert producer_spec.schema == TextDocument + + @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') + async def test_processor_initialization_without_api_key(self, mock_flow_init): + """Test Mistral OCR processor initialization without API key raises error""" + # Arrange + mock_flow_init.return_value = None + + config = { + 'id': 'test-mistral-ocr', + 'taskgroup': AsyncMock() + } + + # Act & Assert + with patch.object(Processor, 'register_specification'): + with pytest.raises(RuntimeError, match="Mistral API key not specified"): + processor = Processor(**config) + + @patch('trustgraph.decoding.mistral_ocr.processor.uuid.uuid4') + @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') + @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') + async def test_ocr_single_chunk(self, mock_flow_init, mock_mistral_class, mock_uuid): + """Test OCR processing with a single chunk (less than 5 pages)""" + # Arrange + mock_flow_init.return_value = None + mock_uuid.return_value = "test-uuid-1234" + + # Mock Mistral client + mock_mistral = MagicMock() + mock_mistral_class.return_value = mock_mistral + + # Mock file upload + mock_uploaded_file = MagicMock(id="file-123") + mock_mistral.files.upload.return_value = mock_uploaded_file + + # Mock signed URL + mock_signed_url = MagicMock(url="https://example.com/signed-url") + mock_mistral.files.get_signed_url.return_value = mock_signed_url + + # Mock OCR response + mock_page = MagicMock( + markdown="# Page 1\nContent ![img1](img1)", + images=[MagicMock(id="img1", image_base64="data:image/png;base64,abc123")] + ) + mock_ocr_response = MagicMock(pages=[mock_page]) + mock_mistral.ocr.process.return_value = mock_ocr_response + + # Mock PyPDF + mock_pdf_reader = MagicMock() + mock_pdf_reader.pages = [MagicMock(), MagicMock(), MagicMock()] # 3 pages + + config = { + 'id': 'test-mistral-ocr', + 'api_key': 'test-api-key', + 'taskgroup': AsyncMock() + } + + with patch.object(Processor, 'register_specification'): + with patch('trustgraph.decoding.mistral_ocr.processor.PdfReader', return_value=mock_pdf_reader): + with patch('trustgraph.decoding.mistral_ocr.processor.PdfWriter') as mock_pdf_writer_class: + mock_pdf_writer = MagicMock() + mock_pdf_writer_class.return_value = mock_pdf_writer + + processor = Processor(**config) + + # Act + result = processor.ocr(b"fake pdf content") + + # Assert + assert result == "# Page 1\nContent ![img1](data:image/png;base64,abc123)" + + # Verify PDF writer was used to create chunk + assert mock_pdf_writer.add_page.call_count == 3 + mock_pdf_writer.write_stream.assert_called_once() + + # Verify Mistral API calls + mock_mistral.files.upload.assert_called_once() + upload_call = mock_mistral.files.upload.call_args[1] + assert upload_call['file']['file_name'] == "test-uuid-1234" + assert upload_call['purpose'] == 'ocr' + + mock_mistral.files.get_signed_url.assert_called_once_with( + file_id="file-123", expiry=1 + ) + + mock_mistral.ocr.process.assert_called_once_with( + model="mistral-ocr-latest", + include_image_base64=True, + document={ + "type": "document_url", + "document_url": "https://example.com/signed-url", + } + ) + + @patch('trustgraph.decoding.mistral_ocr.processor.uuid.uuid4') + @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') + @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') + async def test_on_message_success(self, mock_flow_init, mock_mistral_class, mock_uuid): + """Test successful message processing""" + # Arrange + mock_flow_init.return_value = None + mock_uuid.return_value = "test-uuid-5678" + + # Mock Mistral client with simple OCR response + mock_mistral = MagicMock() + mock_mistral_class.return_value = mock_mistral + + # Mock the ocr method to return simple markdown + ocr_result = "# Document Title\nThis is the OCR content" + + # Mock message + pdf_content = b"fake pdf content" + pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') + mock_metadata = Metadata(id="test-doc") + mock_document = Document(metadata=mock_metadata, data=pdf_base64) + mock_msg = MagicMock() + mock_msg.value.return_value = mock_document + + # Mock flow - needs to be a callable that returns an object with send method + mock_output_flow = AsyncMock() + mock_flow = MagicMock(return_value=mock_output_flow) + + config = { + 'id': 'test-mistral-ocr', + 'api_key': 'test-api-key', + 'taskgroup': AsyncMock() + } + + with patch.object(Processor, 'register_specification'): + processor = Processor(**config) + + # Mock the ocr method + with patch.object(processor, 'ocr', return_value=ocr_result): + # Act + await processor.on_message(mock_msg, None, mock_flow) + + # Assert + # Verify output was sent + mock_output_flow.send.assert_called_once() + + # Check output + call_args = mock_output_flow.send.call_args[0][0] + assert isinstance(call_args, TextDocument) + assert call_args.metadata == mock_metadata + assert call_args.text == ocr_result.encode('utf-8') + + @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') + @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') + async def test_chunks_function(self, mock_flow_init, mock_mistral_class): + """Test the chunks utility function""" + # Arrange + from trustgraph.decoding.mistral_ocr.processor import chunks + + test_list = list(range(12)) + + # Act + result = list(chunks(test_list, 5)) + + # Assert + assert len(result) == 3 + assert result[0] == [0, 1, 2, 3, 4] + assert result[1] == [5, 6, 7, 8, 9] + assert result[2] == [10, 11] + + @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') + @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') + async def test_replace_images_in_markdown(self, mock_flow_init, mock_mistral_class): + """Test the replace_images_in_markdown function""" + # Arrange + from trustgraph.decoding.mistral_ocr.processor import replace_images_in_markdown + + markdown = "# Title\n![image1](image1)\nSome text\n![image2](image2)" + images_dict = { + "image1": "data:image/png;base64,abc123", + "image2": "data:image/png;base64,def456" + } + + # Act + result = replace_images_in_markdown(markdown, images_dict) + + # Assert + expected = "# Title\n![image1](data:image/png;base64,abc123)\nSome text\n![image2](data:image/png;base64,def456)" + assert result == expected + + @patch('trustgraph.decoding.mistral_ocr.processor.Mistral') + @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') + async def test_get_combined_markdown(self, mock_flow_init, mock_mistral_class): + """Test the get_combined_markdown function""" + # Arrange + from trustgraph.decoding.mistral_ocr.processor import get_combined_markdown + from mistralai.models import OCRResponse + + # Mock OCR response with multiple pages + mock_page1 = MagicMock( + markdown="# Page 1\n![img1](img1)", + images=[MagicMock(id="img1", image_base64="base64_img1")] + ) + mock_page2 = MagicMock( + markdown="# Page 2\n![img2](img2)", + images=[MagicMock(id="img2", image_base64="base64_img2")] + ) + mock_ocr_response = MagicMock(pages=[mock_page1, mock_page2]) + + # Act + result = get_combined_markdown(mock_ocr_response) + + # Assert + expected = "# Page 1\n![img1](base64_img1)\n\n# Page 2\n![img2](base64_img2)" + assert result == expected + + @patch('trustgraph.base.flow_processor.FlowProcessor.add_args') + def test_add_args(self, mock_parent_add_args): + """Test add_args adds API key argument""" + # Arrange + mock_parser = MagicMock() + + # Act + Processor.add_args(mock_parser) + + # Assert + mock_parent_add_args.assert_called_once_with(mock_parser) + mock_parser.add_argument.assert_called_once_with( + '-k', '--api-key', + default=None, # default_api_key is None in test environment + help='Mistral API Key' + ) + + @patch('trustgraph.decoding.mistral_ocr.processor.Processor.launch') + def test_run(self, mock_launch): + """Test run function""" + # Act + from trustgraph.decoding.mistral_ocr.processor import run + run() + + # Assert + mock_launch.assert_called_once_with("mistral-ocr", + "\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n") + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_decoding/test_pdf_decoder.py b/tests/unit/test_decoding/test_pdf_decoder.py new file mode 100644 index 00000000..b40accdf --- /dev/null +++ b/tests/unit/test_decoding/test_pdf_decoder.py @@ -0,0 +1,229 @@ +""" +Unit tests for trustgraph.decoding.pdf.pdf_decoder +""" + +import pytest +import base64 +import tempfile +from unittest.mock import AsyncMock, MagicMock, patch, call +from unittest import IsolatedAsyncioTestCase + +from trustgraph.decoding.pdf.pdf_decoder import Processor +from trustgraph.schema import Document, TextDocument, Metadata + + +class TestPdfDecoderProcessor(IsolatedAsyncioTestCase): + """Test PDF decoder processor functionality""" + + @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') + async def test_processor_initialization(self, mock_flow_init): + """Test PDF decoder processor initialization""" + # Arrange + mock_flow_init.return_value = None + + config = { + 'id': 'test-pdf-decoder', + 'taskgroup': AsyncMock() + } + + # Act + with patch.object(Processor, 'register_specification') as mock_register: + processor = Processor(**config) + + # Assert + mock_flow_init.assert_called_once() + # Verify register_specification was called twice (consumer and producer) + assert mock_register.call_count == 2 + + # Check consumer spec + consumer_call = mock_register.call_args_list[0] + consumer_spec = consumer_call[0][0] + assert consumer_spec.name == "input" + assert consumer_spec.schema == Document + assert consumer_spec.handler == processor.on_message + + # Check producer spec + producer_call = mock_register.call_args_list[1] + producer_spec = producer_call[0][0] + assert producer_spec.name == "output" + assert producer_spec.schema == TextDocument + + @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') + @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') + async def test_on_message_success(self, mock_flow_init, mock_pdf_loader_class): + """Test successful PDF processing""" + # Arrange + mock_flow_init.return_value = None + + # Mock PDF content + pdf_content = b"fake pdf content" + pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') + + # Mock PyPDFLoader + mock_loader = MagicMock() + mock_page1 = MagicMock(page_content="Page 1 content") + mock_page2 = MagicMock(page_content="Page 2 content") + mock_loader.load.return_value = [mock_page1, mock_page2] + mock_pdf_loader_class.return_value = mock_loader + + # Mock message + mock_metadata = Metadata(id="test-doc") + mock_document = Document(metadata=mock_metadata, data=pdf_base64) + mock_msg = MagicMock() + mock_msg.value.return_value = mock_document + + # Mock flow - needs to be a callable that returns an object with send method + mock_output_flow = AsyncMock() + mock_flow = MagicMock(return_value=mock_output_flow) + + config = { + 'id': 'test-pdf-decoder', + 'taskgroup': AsyncMock() + } + + with patch.object(Processor, 'register_specification'): + processor = Processor(**config) + + # Act + await processor.on_message(mock_msg, None, mock_flow) + + # Assert + # Verify PyPDFLoader was called + mock_pdf_loader_class.assert_called_once() + mock_loader.load.assert_called_once() + + # Verify output was sent for each page + assert mock_output_flow.send.call_count == 2 + + # Check first page output + first_call = mock_output_flow.send.call_args_list[0] + first_output = first_call[0][0] + assert isinstance(first_output, TextDocument) + assert first_output.metadata == mock_metadata + assert first_output.text == b"Page 1 content" + + # Check second page output + second_call = mock_output_flow.send.call_args_list[1] + second_output = second_call[0][0] + assert isinstance(second_output, TextDocument) + assert second_output.metadata == mock_metadata + assert second_output.text == b"Page 2 content" + + @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') + @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') + async def test_on_message_empty_pdf(self, mock_flow_init, mock_pdf_loader_class): + """Test handling of empty PDF""" + # Arrange + mock_flow_init.return_value = None + + # Mock PDF content + pdf_content = b"fake pdf content" + pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') + + # Mock PyPDFLoader with no pages + mock_loader = MagicMock() + mock_loader.load.return_value = [] + mock_pdf_loader_class.return_value = mock_loader + + # Mock message + mock_metadata = Metadata(id="test-doc") + mock_document = Document(metadata=mock_metadata, data=pdf_base64) + mock_msg = MagicMock() + mock_msg.value.return_value = mock_document + + # Mock flow - needs to be a callable that returns an object with send method + mock_output_flow = AsyncMock() + mock_flow = MagicMock(return_value=mock_output_flow) + + config = { + 'id': 'test-pdf-decoder', + 'taskgroup': AsyncMock() + } + + with patch.object(Processor, 'register_specification'): + processor = Processor(**config) + + # Act + await processor.on_message(mock_msg, None, mock_flow) + + # Assert + # Verify PyPDFLoader was called + mock_pdf_loader_class.assert_called_once() + mock_loader.load.assert_called_once() + + # Verify no output was sent + mock_output_flow.send.assert_not_called() + + @patch('trustgraph.decoding.pdf.pdf_decoder.PyPDFLoader') + @patch('trustgraph.base.flow_processor.FlowProcessor.__init__') + async def test_on_message_unicode_content(self, mock_flow_init, mock_pdf_loader_class): + """Test handling of unicode content in PDF""" + # Arrange + mock_flow_init.return_value = None + + # Mock PDF content + pdf_content = b"fake pdf content" + pdf_base64 = base64.b64encode(pdf_content).decode('utf-8') + + # Mock PyPDFLoader with unicode content + mock_loader = MagicMock() + mock_page = MagicMock(page_content="Page with unicode: 你好世界 🌍") + mock_loader.load.return_value = [mock_page] + mock_pdf_loader_class.return_value = mock_loader + + # Mock message + mock_metadata = Metadata(id="test-doc") + mock_document = Document(metadata=mock_metadata, data=pdf_base64) + mock_msg = MagicMock() + mock_msg.value.return_value = mock_document + + # Mock flow - needs to be a callable that returns an object with send method + mock_output_flow = AsyncMock() + mock_flow = MagicMock(return_value=mock_output_flow) + + config = { + 'id': 'test-pdf-decoder', + 'taskgroup': AsyncMock() + } + + with patch.object(Processor, 'register_specification'): + processor = Processor(**config) + + # Act + await processor.on_message(mock_msg, None, mock_flow) + + # Assert + # Verify output was sent + mock_output_flow.send.assert_called_once() + + # Check output + call_args = mock_output_flow.send.call_args[0][0] + assert isinstance(call_args, TextDocument) + assert call_args.text == "Page with unicode: 你好世界 🌍".encode('utf-8') + + @patch('trustgraph.base.flow_processor.FlowProcessor.add_args') + def test_add_args(self, mock_parent_add_args): + """Test add_args calls parent method""" + # Arrange + mock_parser = MagicMock() + + # Act + Processor.add_args(mock_parser) + + # Assert + mock_parent_add_args.assert_called_once_with(mock_parser) + + @patch('trustgraph.decoding.pdf.pdf_decoder.Processor.launch') + def test_run(self, mock_launch): + """Test run function""" + # Act + from trustgraph.decoding.pdf.pdf_decoder import run + run() + + # Assert + mock_launch.assert_called_once_with("pdf-decoder", + "\nSimple decoder, accepts PDF documents on input, outputs pages from the\nPDF document as text as separate output objects.\n") + + +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index c7eef10b..911c91a0 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -36,7 +36,6 @@ dependencies = [ "pulsar-client", "pymilvus", "pypdf", - "mistralai", "pyyaml", "qdrant-client", "rdflib", diff --git a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py index 4bacd278..9532fa0f 100755 --- a/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py +++ b/trustgraph-flow/trustgraph/decoding/mistral_ocr/processor.py @@ -15,17 +15,13 @@ from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk from mistralai.models import OCRResponse from ... schema import Document, TextDocument, Metadata -from ... schema import document_ingest_queue, text_ingest_queue -from ... log_level import LogLevel -from ... base import InputOutputProcessor +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec import logging logger = logging.getLogger(__name__) -module = "ocr" - -default_subscriber = module +default_ident = "mistral-ocr" default_api_key = os.getenv("MISTRAL_TOKEN") pages_per_chunk = 5 @@ -73,23 +69,34 @@ def get_combined_markdown(ocr_response: OCRResponse) -> str: return "\n\n".join(markdowns) -class Processor(InputOutputProcessor): +class Processor(FlowProcessor): def __init__(self, **params): - id = params.get("id") - subscriber = params.get("subscriber", default_subscriber) + id = params.get("id", default_ident) api_key = params.get("api_key", default_api_key) super(Processor, self).__init__( **params | { "id": id, - "subscriber": subscriber, - "input_schema": Document, - "output_schema": TextDocument, } ) + self.register_specification( + ConsumerSpec( + name = "input", + schema = Document, + handler = self.on_message, + ) + ) + + self.register_specification( + ProducerSpec( + name = "output", + schema = TextDocument, + ) + ) + if api_key is None: raise RuntimeError("Mistral API key not specified") @@ -98,7 +105,7 @@ class Processor(InputOutputProcessor): # Used with Mistral doc upload self.unique_id = str(uuid.uuid4()) - logger.info("PDF inited") + logger.info("Mistral OCR processor initialized") def ocr(self, blob): @@ -151,7 +158,7 @@ class Processor(InputOutputProcessor): return markdown - async def on_message(self, msg, consumer): + async def on_message(self, msg, consumer, flow): logger.debug("PDF message received") @@ -166,14 +173,14 @@ class Processor(InputOutputProcessor): text=markdown.encode("utf-8"), ) - await consumer.q.output.send(r) + await flow("output").send(r) logger.info("Done.") @staticmethod def add_args(parser): - InputOutputProcessor.add_args(parser, default_subscriber) + FlowProcessor.add_args(parser) parser.add_argument( '-k', '--api-key', @@ -183,5 +190,5 @@ class Processor(InputOutputProcessor): def run(): - Processor.launch(module, __doc__) + Processor.launch(default_ident, __doc__)