From 8929a680a18ff40c2ba77b93248d9dbea6643fae Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Fri, 26 Sep 2025 10:53:32 +0100 Subject: [PATCH] Chunking dynamic params (#536) * Chunking params are dynamic * Update tests --- .../test_chunking/test_recursive_chunker.py | 431 ++++++++------- .../unit/test_chunking/test_token_chunker.py | 515 +++++++++--------- trustgraph-base/trustgraph/base/__init__.py | 1 + .../trustgraph/base/chunking_service.py | 62 +++ .../trustgraph/chunking/recursive/chunker.py | 27 +- .../trustgraph/chunking/token/chunker.py | 26 +- 6 files changed, 584 insertions(+), 478 deletions(-) create mode 100644 trustgraph-base/trustgraph/base/chunking_service.py diff --git a/tests/unit/test_chunking/test_recursive_chunker.py b/tests/unit/test_chunking/test_recursive_chunker.py index 045133cd..8f91d95f 100644 --- a/tests/unit/test_chunking/test_recursive_chunker.py +++ b/tests/unit/test_chunking/test_recursive_chunker.py @@ -1,211 +1,236 @@ +""" +Unit tests for trustgraph.chunking.recursive +Testing parameter override functionality for chunk-size and chunk-overlap +""" + import pytest -import asyncio -from unittest.mock import AsyncMock, Mock, patch, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +# Import the service under test +from trustgraph.chunking.recursive.chunker import Processor from trustgraph.schema import TextDocument, Chunk, Metadata -from trustgraph.chunking.recursive.chunker import Processor as RecursiveChunker -@pytest.fixture -def mock_flow(): - output_mock = AsyncMock() - flow_mock = Mock(return_value=output_mock) - return flow_mock, output_mock +class MockAsyncProcessor: + def __init__(self, **params): + self.config_handlers = [] + self.id = params.get('id', 'test-service') + self.specifications = [] -@pytest.fixture -def mock_consumer(): - consumer = Mock() - consumer.id = "test-consumer" - consumer.flow = "test-flow" - return consumer +class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase): + """Test Recursive chunker functionality""" + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + def test_processor_initialization_basic(self): + """Test basic processor initialization""" + # Arrange + config = { + 'id': 'test-chunker', + 'chunk_size': 1500, + 'chunk_overlap': 150, + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + # Act + processor = Processor(**config) + + # Assert + assert processor.default_chunk_size == 1500 + assert processor.default_chunk_overlap == 150 + assert hasattr(processor, 'text_splitter') + + # Verify parameter specs are registered + param_specs = [spec for spec in processor.specifications + if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']] + assert len(param_specs) == 2 + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_chunk_document_with_chunk_size_override(self): + """Test chunk_document with chunk-size parameter override""" + # Arrange + config = { + 'id': 'test-chunker', + 'chunk_size': 1000, # Default chunk size + 'chunk_overlap': 100, + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + # Mock message and flow + mock_message = MagicMock() + mock_consumer = MagicMock() + mock_flow = MagicMock() + mock_flow.side_effect = lambda param: { + "chunk-size": 2000, # Override chunk size + "chunk-overlap": None # Use default chunk overlap + }.get(param) + + # Act + chunk_size, chunk_overlap = await processor.chunk_document( + mock_message, mock_consumer, mock_flow, 1000, 100 + ) + + # Assert + assert chunk_size == 2000 # Should use overridden value + assert chunk_overlap == 100 # Should use default value + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_chunk_document_with_chunk_overlap_override(self): + """Test chunk_document with chunk-overlap parameter override""" + # Arrange + config = { + 'id': 'test-chunker', + 'chunk_size': 1000, + 'chunk_overlap': 100, # Default chunk overlap + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + # Mock message and flow + mock_message = MagicMock() + mock_consumer = MagicMock() + mock_flow = MagicMock() + mock_flow.side_effect = lambda param: { + "chunk-size": None, # Use default chunk size + "chunk-overlap": 200 # Override chunk overlap + }.get(param) + + # Act + chunk_size, chunk_overlap = await processor.chunk_document( + mock_message, mock_consumer, mock_flow, 1000, 100 + ) + + # Assert + assert chunk_size == 1000 # Should use default value + assert chunk_overlap == 200 # Should use overridden value + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_chunk_document_with_both_parameters_override(self): + """Test chunk_document with both chunk-size and chunk-overlap overrides""" + # Arrange + config = { + 'id': 'test-chunker', + 'chunk_size': 1000, + 'chunk_overlap': 100, + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + # Mock message and flow + mock_message = MagicMock() + mock_consumer = MagicMock() + mock_flow = MagicMock() + mock_flow.side_effect = lambda param: { + "chunk-size": 1500, # Override chunk size + "chunk-overlap": 150 # Override chunk overlap + }.get(param) + + # Act + chunk_size, chunk_overlap = await processor.chunk_document( + mock_message, mock_consumer, mock_flow, 1000, 100 + ) + + # Assert + assert chunk_size == 1500 # Should use overridden value + assert chunk_overlap == 150 # Should use overridden value + + @patch('trustgraph.chunking.recursive.chunker.RecursiveCharacterTextSplitter') + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_on_message_uses_flow_parameters(self, mock_splitter_class): + """Test that on_message method uses parameters from flow""" + # Arrange + mock_splitter = MagicMock() + mock_document = MagicMock() + mock_document.page_content = "Test chunk content" + mock_splitter.create_documents.return_value = [mock_document] + mock_splitter_class.return_value = mock_splitter + + config = { + 'id': 'test-chunker', + 'chunk_size': 1000, + 'chunk_overlap': 100, + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + # Mock message with TextDocument + mock_message = MagicMock() + mock_text_doc = MagicMock() + mock_text_doc.metadata = Metadata( + id="test-doc-123", + metadata=[], + user="test-user", + collection="test-collection" + ) + mock_text_doc.text = b"This is test document content" + mock_message.value.return_value = mock_text_doc + + # Mock consumer and flow with parameter overrides + mock_consumer = MagicMock() + mock_producer = AsyncMock() + mock_flow = MagicMock() + mock_flow.side_effect = lambda param: { + "chunk-size": 1500, + "chunk-overlap": 150, + "output": mock_producer + }.get(param) + + # Act + await processor.on_message(mock_message, mock_consumer, mock_flow) + + # Assert + # Verify RecursiveCharacterTextSplitter was called with overridden parameters (last call) + actual_last_call = mock_splitter_class.call_args_list[-1] + assert actual_last_call.kwargs['chunk_size'] == 1500 + assert actual_last_call.kwargs['chunk_overlap'] == 150 + assert actual_last_call.kwargs['length_function'] == len + assert actual_last_call.kwargs['is_separator_regex'] == False + + # Verify chunk was sent to output + mock_producer.send.assert_called_once() + sent_chunk = mock_producer.send.call_args[0][0] + assert isinstance(sent_chunk, Chunk) + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_chunk_document_with_no_overrides(self): + """Test chunk_document when no parameters are overridden (flow returns None)""" + # Arrange + config = { + 'id': 'test-chunker', + 'chunk_size': 1000, + 'chunk_overlap': 100, + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + # Mock message and flow that returns None for all parameters + mock_message = MagicMock() + mock_consumer = MagicMock() + mock_flow = MagicMock() + mock_flow.return_value = None # No overrides + + # Act + chunk_size, chunk_overlap = await processor.chunk_document( + mock_message, mock_consumer, mock_flow, 1000, 100 + ) + + # Assert + assert chunk_size == 1000 # Should use default value + assert chunk_overlap == 100 # Should use default value -@pytest.fixture -def sample_document(): - metadata = Metadata( - id="test-doc-1", - metadata=[], - user="test-user", - collection="test-collection" - ) - text = "This is a test document. " * 100 # Create text long enough to be chunked - return TextDocument( - metadata=metadata, - text=text.encode("utf-8") - ) - - -@pytest.fixture -def short_document(): - metadata = Metadata( - id="test-doc-2", - metadata=[], - user="test-user", - collection="test-collection" - ) - text = "This is a very short document." - return TextDocument( - metadata=metadata, - text=text.encode("utf-8") - ) - - -class TestRecursiveChunker: - - def test_init_default_params(self, mock_async_processor_init): - processor = RecursiveChunker() - assert processor.text_splitter._chunk_size == 2000 - assert processor.text_splitter._chunk_overlap == 100 - - def test_init_custom_params(self, mock_async_processor_init): - processor = RecursiveChunker(chunk_size=500, chunk_overlap=50) - assert processor.text_splitter._chunk_size == 500 - assert processor.text_splitter._chunk_overlap == 50 - - def test_init_with_id(self, mock_async_processor_init): - processor = RecursiveChunker(id="custom-chunker") - assert processor.id == "custom-chunker" - - @pytest.mark.asyncio - async def test_on_message_single_chunk(self, mock_async_processor_init, mock_flow, mock_consumer, short_document): - flow_mock, output_mock = mock_flow - processor = RecursiveChunker(chunk_size=2000, chunk_overlap=100) - - msg = Mock() - msg.value.return_value = short_document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Should produce exactly one chunk for short text - assert output_mock.send.call_count == 1 - - # Verify the chunk was created correctly - chunk_call = output_mock.send.call_args[0][0] - assert isinstance(chunk_call, Chunk) - assert chunk_call.metadata == short_document.metadata - assert chunk_call.chunk.decode("utf-8") == short_document.text.decode("utf-8") - - @pytest.mark.asyncio - async def test_on_message_multiple_chunks(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document): - flow_mock, output_mock = mock_flow - processor = RecursiveChunker(chunk_size=100, chunk_overlap=20) - - msg = Mock() - msg.value.return_value = sample_document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Should produce multiple chunks - assert output_mock.send.call_count > 1 - - # Verify all chunks have correct metadata - for call in output_mock.send.call_args_list: - chunk = call[0][0] - assert isinstance(chunk, Chunk) - assert chunk.metadata == sample_document.metadata - assert len(chunk.chunk) > 0 - - @pytest.mark.asyncio - async def test_on_message_chunk_overlap(self, mock_async_processor_init, mock_flow, mock_consumer): - flow_mock, output_mock = mock_flow - processor = RecursiveChunker(chunk_size=50, chunk_overlap=10) - - # Create a document with predictable content - metadata = Metadata(id="test", metadata=[], user="test-user", collection="test-collection") - text = "ABCDEFGHIJ" * 10 # 100 characters - document = TextDocument(metadata=metadata, text=text.encode("utf-8")) - - msg = Mock() - msg.value.return_value = document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Collect all chunks - chunks = [] - for call in output_mock.send.call_args_list: - chunk_text = call[0][0].chunk.decode("utf-8") - chunks.append(chunk_text) - - # Verify chunks have expected overlap - for i in range(len(chunks) - 1): - # The end of chunk i should overlap with the beginning of chunk i+1 - # Check if there's some overlap (exact overlap depends on text splitter logic) - assert len(chunks[i]) <= 50 + 10 # chunk_size + some tolerance - - @pytest.mark.asyncio - async def test_on_message_empty_document(self, mock_async_processor_init, mock_flow, mock_consumer): - flow_mock, output_mock = mock_flow - processor = RecursiveChunker() - - metadata = Metadata(id="empty", metadata=[], user="test-user", collection="test-collection") - document = TextDocument(metadata=metadata, text=b"") - - msg = Mock() - msg.value.return_value = document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Empty documents typically don't produce chunks with langchain splitters - # This behavior is expected - no chunks should be produced - assert output_mock.send.call_count == 0 - - @pytest.mark.asyncio - async def test_on_message_unicode_handling(self, mock_async_processor_init, mock_flow, mock_consumer): - flow_mock, output_mock = mock_flow - processor = RecursiveChunker(chunk_size=500, chunk_overlap=20) # Fixed overlap < chunk_size - - metadata = Metadata(id="unicode", metadata=[], user="test-user", collection="test-collection") - text = "Hello 世界! 🌍 This is a test with émojis and spëcial characters." - document = TextDocument(metadata=metadata, text=text.encode("utf-8")) - - msg = Mock() - msg.value.return_value = document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Verify unicode is preserved correctly - all_chunks = [] - for call in output_mock.send.call_args_list: - chunk_text = call[0][0].chunk.decode("utf-8") - all_chunks.append(chunk_text) - - # Reconstruct text (approximately, due to overlap) - reconstructed = "".join(all_chunks) - assert "世界" in reconstructed - assert "🌍" in reconstructed - assert "émojis" in reconstructed - - @pytest.mark.asyncio - async def test_metrics_recorded(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document): - flow_mock, output_mock = mock_flow - processor = RecursiveChunker(chunk_size=100) - - msg = Mock() - msg.value.return_value = sample_document - - # Mock the metric - with patch.object(RecursiveChunker.chunk_metric, 'labels') as mock_labels: - mock_observe = Mock() - mock_labels.return_value.observe = mock_observe - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Verify metrics were recorded - mock_labels.assert_called_with(id="test-consumer", flow="test-flow") - assert mock_observe.call_count > 0 - - # Verify chunk sizes were observed - for call in mock_observe.call_args_list: - chunk_size = call[0][0] - assert chunk_size > 0 - - def test_add_args(self): - parser = Mock() - RecursiveChunker.add_args(parser) - - # Verify arguments were added - calls = parser.add_argument.call_args_list - arg_names = [call[0][0] for call in calls] - - assert '-z' in arg_names or '--chunk-size' in arg_names - assert '-v' in arg_names or '--chunk-overlap' in arg_names \ No newline at end of file +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/tests/unit/test_chunking/test_token_chunker.py b/tests/unit/test_chunking/test_token_chunker.py index 31dcc0c3..600df930 100644 --- a/tests/unit/test_chunking/test_token_chunker.py +++ b/tests/unit/test_chunking/test_token_chunker.py @@ -1,275 +1,256 @@ +""" +Unit tests for trustgraph.chunking.token +Testing parameter override functionality for chunk-size and chunk-overlap +""" + import pytest -import asyncio -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, patch +from unittest import IsolatedAsyncioTestCase + +# Import the service under test +from trustgraph.chunking.token.chunker import Processor from trustgraph.schema import TextDocument, Chunk, Metadata -from trustgraph.chunking.token.chunker import Processor as TokenChunker -@pytest.fixture -def mock_flow(): - output_mock = AsyncMock() - flow_mock = Mock(return_value=output_mock) - return flow_mock, output_mock +class MockAsyncProcessor: + def __init__(self, **params): + self.config_handlers = [] + self.id = params.get('id', 'test-service') + self.specifications = [] -@pytest.fixture -def mock_consumer(): - consumer = Mock() - consumer.id = "test-consumer" - consumer.flow = "test-flow" - return consumer +class TestTokenChunkerSimple(IsolatedAsyncioTestCase): + """Test Token chunker functionality""" + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + def test_processor_initialization_basic(self): + """Test basic processor initialization""" + # Arrange + config = { + 'id': 'test-chunker', + 'chunk_size': 300, + 'chunk_overlap': 20, + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + # Act + processor = Processor(**config) + + # Assert + assert processor.default_chunk_size == 300 + assert processor.default_chunk_overlap == 20 + assert hasattr(processor, 'text_splitter') + + # Verify parameter specs are registered + param_specs = [spec for spec in processor.specifications + if hasattr(spec, 'name') and spec.name in ['chunk-size', 'chunk-overlap']] + assert len(param_specs) == 2 + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_chunk_document_with_chunk_size_override(self): + """Test chunk_document with chunk-size parameter override""" + # Arrange + config = { + 'id': 'test-chunker', + 'chunk_size': 250, # Default chunk size + 'chunk_overlap': 15, + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + # Mock message and flow + mock_message = MagicMock() + mock_consumer = MagicMock() + mock_flow = MagicMock() + mock_flow.side_effect = lambda param: { + "chunk-size": 400, # Override chunk size + "chunk-overlap": None # Use default chunk overlap + }.get(param) + + # Act + chunk_size, chunk_overlap = await processor.chunk_document( + mock_message, mock_consumer, mock_flow, 250, 15 + ) + + # Assert + assert chunk_size == 400 # Should use overridden value + assert chunk_overlap == 15 # Should use default value + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_chunk_document_with_chunk_overlap_override(self): + """Test chunk_document with chunk-overlap parameter override""" + # Arrange + config = { + 'id': 'test-chunker', + 'chunk_size': 250, + 'chunk_overlap': 15, # Default chunk overlap + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + # Mock message and flow + mock_message = MagicMock() + mock_consumer = MagicMock() + mock_flow = MagicMock() + mock_flow.side_effect = lambda param: { + "chunk-size": None, # Use default chunk size + "chunk-overlap": 25 # Override chunk overlap + }.get(param) + + # Act + chunk_size, chunk_overlap = await processor.chunk_document( + mock_message, mock_consumer, mock_flow, 250, 15 + ) + + # Assert + assert chunk_size == 250 # Should use default value + assert chunk_overlap == 25 # Should use overridden value + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_chunk_document_with_both_parameters_override(self): + """Test chunk_document with both chunk-size and chunk-overlap overrides""" + # Arrange + config = { + 'id': 'test-chunker', + 'chunk_size': 250, + 'chunk_overlap': 15, + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + # Mock message and flow + mock_message = MagicMock() + mock_consumer = MagicMock() + mock_flow = MagicMock() + mock_flow.side_effect = lambda param: { + "chunk-size": 350, # Override chunk size + "chunk-overlap": 30 # Override chunk overlap + }.get(param) + + # Act + chunk_size, chunk_overlap = await processor.chunk_document( + mock_message, mock_consumer, mock_flow, 250, 15 + ) + + # Assert + assert chunk_size == 350 # Should use overridden value + assert chunk_overlap == 30 # Should use overridden value + + @patch('trustgraph.chunking.token.chunker.TokenTextSplitter') + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_on_message_uses_flow_parameters(self, mock_splitter_class): + """Test that on_message method uses parameters from flow""" + # Arrange + mock_splitter = MagicMock() + mock_document = MagicMock() + mock_document.page_content = "Test token chunk content" + mock_splitter.create_documents.return_value = [mock_document] + mock_splitter_class.return_value = mock_splitter + + config = { + 'id': 'test-chunker', + 'chunk_size': 250, + 'chunk_overlap': 15, + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + # Mock message with TextDocument + mock_message = MagicMock() + mock_text_doc = MagicMock() + mock_text_doc.metadata = Metadata( + id="test-doc-456", + metadata=[], + user="test-user", + collection="test-collection" + ) + mock_text_doc.text = b"This is test document content for token chunking" + mock_message.value.return_value = mock_text_doc + + # Mock consumer and flow with parameter overrides + mock_consumer = MagicMock() + mock_producer = AsyncMock() + mock_flow = MagicMock() + mock_flow.side_effect = lambda param: { + "chunk-size": 400, + "chunk-overlap": 40, + "output": mock_producer + }.get(param) + + # Act + await processor.on_message(mock_message, mock_consumer, mock_flow) + + # Assert + # Verify TokenTextSplitter was called with overridden parameters (last call) + expected_call = [ + ('encoding_name', 'cl100k_base'), + ('chunk_size', 400), + ('chunk_overlap', 40) + ] + actual_last_call = mock_splitter_class.call_args_list[-1] + assert actual_last_call.kwargs['encoding_name'] == "cl100k_base" + assert actual_last_call.kwargs['chunk_size'] == 400 + assert actual_last_call.kwargs['chunk_overlap'] == 40 + + # Verify chunk was sent to output + mock_producer.send.assert_called_once() + sent_chunk = mock_producer.send.call_args[0][0] + assert isinstance(sent_chunk, Chunk) + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + async def test_chunk_document_with_no_overrides(self): + """Test chunk_document when no parameters are overridden (flow returns None)""" + # Arrange + config = { + 'id': 'test-chunker', + 'chunk_size': 250, + 'chunk_overlap': 15, + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + # Mock message and flow that returns None for all parameters + mock_message = MagicMock() + mock_consumer = MagicMock() + mock_flow = MagicMock() + mock_flow.return_value = None # No overrides + + # Act + chunk_size, chunk_overlap = await processor.chunk_document( + mock_message, mock_consumer, mock_flow, 250, 15 + ) + + # Assert + assert chunk_size == 250 # Should use default value + assert chunk_overlap == 15 # Should use default value + + @patch('trustgraph.base.async_processor.AsyncProcessor', MockAsyncProcessor) + def test_token_chunker_uses_different_defaults(self): + """Test that token chunker has different defaults than recursive chunker""" + # Arrange & Act + config = { + 'id': 'test-chunker', + 'concurrency': 1, + 'taskgroup': AsyncMock() + } + + processor = Processor(**config) + + # Assert - Token chunker should have different defaults + assert processor.default_chunk_size == 250 # Token chunker default + assert processor.default_chunk_overlap == 15 # Token chunker default -@pytest.fixture -def sample_document(): - metadata = Metadata( - id="test-doc-1", - metadata=[], - user="test-user", - collection="test-collection" - ) - # Create text that will result in multiple token chunks - text = "The quick brown fox jumps over the lazy dog. " * 50 - return TextDocument( - metadata=metadata, - text=text.encode("utf-8") - ) - - -@pytest.fixture -def short_document(): - metadata = Metadata( - id="test-doc-2", - metadata=[], - user="test-user", - collection="test-collection" - ) - text = "Short text." - return TextDocument( - metadata=metadata, - text=text.encode("utf-8") - ) - - -class TestTokenChunker: - - def test_init_default_params(self, mock_async_processor_init): - processor = TokenChunker() - assert processor.text_splitter._chunk_size == 250 - assert processor.text_splitter._chunk_overlap == 15 - # Just verify the text splitter was created (encoding verification is complex) - assert processor.text_splitter is not None - assert hasattr(processor.text_splitter, 'split_text') - - def test_init_custom_params(self, mock_async_processor_init): - processor = TokenChunker(chunk_size=100, chunk_overlap=10) - assert processor.text_splitter._chunk_size == 100 - assert processor.text_splitter._chunk_overlap == 10 - - def test_init_with_id(self, mock_async_processor_init): - processor = TokenChunker(id="custom-token-chunker") - assert processor.id == "custom-token-chunker" - - @pytest.mark.asyncio - async def test_on_message_single_chunk(self, mock_async_processor_init, mock_flow, mock_consumer, short_document): - flow_mock, output_mock = mock_flow - processor = TokenChunker(chunk_size=250, chunk_overlap=15) - - msg = Mock() - msg.value.return_value = short_document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Short text should produce exactly one chunk - assert output_mock.send.call_count == 1 - - # Verify the chunk was created correctly - chunk_call = output_mock.send.call_args[0][0] - assert isinstance(chunk_call, Chunk) - assert chunk_call.metadata == short_document.metadata - assert chunk_call.chunk.decode("utf-8") == short_document.text.decode("utf-8") - - @pytest.mark.asyncio - async def test_on_message_multiple_chunks(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document): - flow_mock, output_mock = mock_flow - processor = TokenChunker(chunk_size=50, chunk_overlap=5) - - msg = Mock() - msg.value.return_value = sample_document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Should produce multiple chunks - assert output_mock.send.call_count > 1 - - # Verify all chunks have correct metadata - for call in output_mock.send.call_args_list: - chunk = call[0][0] - assert isinstance(chunk, Chunk) - assert chunk.metadata == sample_document.metadata - assert len(chunk.chunk) > 0 - - @pytest.mark.asyncio - async def test_on_message_token_overlap(self, mock_async_processor_init, mock_flow, mock_consumer): - flow_mock, output_mock = mock_flow - processor = TokenChunker(chunk_size=20, chunk_overlap=5) - - # Create a document with repeated pattern - metadata = Metadata(id="test", metadata=[], user="test-user", collection="test-collection") - text = "one two three four five six seven eight nine ten " * 5 - document = TextDocument(metadata=metadata, text=text.encode("utf-8")) - - msg = Mock() - msg.value.return_value = document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Collect all chunks - chunks = [] - for call in output_mock.send.call_args_list: - chunk_text = call[0][0].chunk.decode("utf-8") - chunks.append(chunk_text) - - # Should have multiple chunks - assert len(chunks) > 1 - - # Verify chunks are not empty - for chunk in chunks: - assert len(chunk) > 0 - - @pytest.mark.asyncio - async def test_on_message_empty_document(self, mock_async_processor_init, mock_flow, mock_consumer): - flow_mock, output_mock = mock_flow - processor = TokenChunker() - - metadata = Metadata(id="empty", metadata=[], user="test-user", collection="test-collection") - document = TextDocument(metadata=metadata, text=b"") - - msg = Mock() - msg.value.return_value = document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Empty documents typically don't produce chunks with langchain splitters - # This behavior is expected - no chunks should be produced - assert output_mock.send.call_count == 0 - - @pytest.mark.asyncio - async def test_on_message_unicode_handling(self, mock_async_processor_init, mock_flow, mock_consumer): - flow_mock, output_mock = mock_flow - processor = TokenChunker(chunk_size=50) - - metadata = Metadata(id="unicode", metadata=[], user="test-user", collection="test-collection") - # Test with various unicode characters - text = "Hello 世界! 🌍 Test émojis café naïve résumé. Greek: αβγδε Hebrew: אבגדה" - document = TextDocument(metadata=metadata, text=text.encode("utf-8")) - - msg = Mock() - msg.value.return_value = document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Verify unicode is preserved correctly - all_chunks = [] - for call in output_mock.send.call_args_list: - chunk_text = call[0][0].chunk.decode("utf-8") - all_chunks.append(chunk_text) - - # Reconstruct text - reconstructed = "".join(all_chunks) - assert "世界" in reconstructed - assert "🌍" in reconstructed - assert "émojis" in reconstructed - assert "αβγδε" in reconstructed - assert "אבגדה" in reconstructed - - @pytest.mark.asyncio - async def test_on_message_token_boundary_preservation(self, mock_async_processor_init, mock_flow, mock_consumer): - flow_mock, output_mock = mock_flow - processor = TokenChunker(chunk_size=10, chunk_overlap=2) - - metadata = Metadata(id="boundary", metadata=[], user="test-user", collection="test-collection") - # Text with clear word boundaries - text = "This is a test of token boundaries and proper splitting." - document = TextDocument(metadata=metadata, text=text.encode("utf-8")) - - msg = Mock() - msg.value.return_value = document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Collect all chunks - chunks = [] - for call in output_mock.send.call_args_list: - chunk_text = call[0][0].chunk.decode("utf-8") - chunks.append(chunk_text) - - # Token chunker should respect token boundaries - for chunk in chunks: - # Chunks should not start or end with partial words (in most cases) - assert len(chunk.strip()) > 0 - - @pytest.mark.asyncio - async def test_metrics_recorded(self, mock_async_processor_init, mock_flow, mock_consumer, sample_document): - flow_mock, output_mock = mock_flow - processor = TokenChunker(chunk_size=50) - - msg = Mock() - msg.value.return_value = sample_document - - # Mock the metric - with patch.object(TokenChunker.chunk_metric, 'labels') as mock_labels: - mock_observe = Mock() - mock_labels.return_value.observe = mock_observe - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Verify metrics were recorded - mock_labels.assert_called_with(id="test-consumer", flow="test-flow") - assert mock_observe.call_count > 0 - - # Verify chunk sizes were observed - for call in mock_observe.call_args_list: - chunk_size = call[0][0] - assert chunk_size > 0 - - def test_add_args(self): - parser = Mock() - TokenChunker.add_args(parser) - - # Verify arguments were added - calls = parser.add_argument.call_args_list - arg_names = [call[0][0] for call in calls] - - assert '-z' in arg_names or '--chunk-size' in arg_names - assert '-v' in arg_names or '--chunk-overlap' in arg_names - - @pytest.mark.asyncio - async def test_encoding_specific_behavior(self, mock_async_processor_init, mock_flow, mock_consumer): - flow_mock, output_mock = mock_flow - processor = TokenChunker(chunk_size=10, chunk_overlap=0) - - metadata = Metadata(id="encoding", metadata=[], user="test-user", collection="test-collection") - # Test text that might tokenize differently with cl100k_base encoding - text = "GPT-4 is an AI model. It uses tokens." - document = TextDocument(metadata=metadata, text=text.encode("utf-8")) - - msg = Mock() - msg.value.return_value = document - - await processor.on_message(msg, mock_consumer, flow_mock) - - # Verify chunking happened - assert output_mock.send.call_count >= 1 - - # Collect all chunks - chunks = [] - for call in output_mock.send.call_args_list: - chunk_text = call[0][0].chunk.decode("utf-8") - chunks.append(chunk_text) - - # Verify all text is preserved (allowing for overlap) - all_text = " ".join(chunks) - assert "GPT-4" in all_text - assert "AI model" in all_text - assert "tokens" in all_text \ No newline at end of file +if __name__ == '__main__': + pytest.main([__file__]) \ No newline at end of file diff --git a/trustgraph-base/trustgraph/base/__init__.py b/trustgraph-base/trustgraph/base/__init__.py index 0bdf1f7a..5a97c220 100644 --- a/trustgraph-base/trustgraph/base/__init__.py +++ b/trustgraph-base/trustgraph/base/__init__.py @@ -13,6 +13,7 @@ from . producer_spec import ProducerSpec from . subscriber_spec import SubscriberSpec from . request_response_spec import RequestResponseSpec from . llm_service import LlmService, LlmResult +from . chunking_service import ChunkingService from . embeddings_service import EmbeddingsService from . embeddings_client import EmbeddingsClientSpec from . text_completion_client import TextCompletionClientSpec diff --git a/trustgraph-base/trustgraph/base/chunking_service.py b/trustgraph-base/trustgraph/base/chunking_service.py new file mode 100644 index 00000000..2e18a933 --- /dev/null +++ b/trustgraph-base/trustgraph/base/chunking_service.py @@ -0,0 +1,62 @@ +""" +Base chunking service that provides parameter specification functionality +for chunk-size and chunk-overlap parameters +""" + +import logging +from .flow_processor import FlowProcessor +from .parameter_spec import ParameterSpec + +# Module logger +logger = logging.getLogger(__name__) + +class ChunkingService(FlowProcessor): + """Base service for chunking processors with parameter specification support""" + + def __init__(self, **params): + + # Call parent constructor + super(ChunkingService, self).__init__(**params) + + # Register parameter specifications for chunk-size and chunk-overlap + self.register_specification( + ParameterSpec(name="chunk-size") + ) + + self.register_specification( + ParameterSpec(name="chunk-overlap") + ) + + logger.debug("ChunkingService initialized with parameter specifications") + + async def chunk_document(self, msg, consumer, flow, default_chunk_size, default_chunk_overlap): + """ + Extract chunk parameters from flow and return effective values + + Args: + msg: The message containing the document to chunk + consumer: The consumer spec + flow: The flow context + default_chunk_size: Default chunk size from processor config + default_chunk_overlap: Default chunk overlap from processor config + + Returns: + tuple: (chunk_size, chunk_overlap) - effective values to use + """ + # Extract parameters from flow (flow-configurable parameters) + chunk_size = flow("chunk-size") + chunk_overlap = flow("chunk-overlap") + + # Use provided values or fall back to defaults + effective_chunk_size = chunk_size if chunk_size is not None else default_chunk_size + effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else default_chunk_overlap + + logger.debug(f"Using chunk-size: {effective_chunk_size}") + logger.debug(f"Using chunk-overlap: {effective_chunk_overlap}") + + return effective_chunk_size, effective_chunk_overlap + + @staticmethod + def add_args(parser): + """Add chunking service arguments to parser""" + FlowProcessor.add_args(parser) \ No newline at end of file diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index fe182b14..8604f4fa 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -9,14 +9,14 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter from prometheus_client import Histogram from ... schema import TextDocument, Chunk -from ... base import FlowProcessor, ConsumerSpec, ProducerSpec +from ... base import ChunkingService, ConsumerSpec, ProducerSpec # Module logger logger = logging.getLogger(__name__) default_ident = "chunker" -class Processor(FlowProcessor): +class Processor(ChunkingService): def __init__(self, **params): @@ -28,6 +28,10 @@ class Processor(FlowProcessor): **params | { "id": id } ) + # Store default values for parameter override + self.default_chunk_size = chunk_size + self.default_chunk_overlap = chunk_overlap + if not hasattr(__class__, "chunk_metric"): __class__.chunk_metric = Histogram( 'chunk_size', 'Chunk size', @@ -65,7 +69,22 @@ class Processor(FlowProcessor): v = msg.value() logger.info(f"Chunking document {v.metadata.id}...") - texts = self.text_splitter.create_documents( + # Extract chunk parameters from flow (allows runtime override) + chunk_size, chunk_overlap = await self.chunk_document( + msg, consumer, flow, + self.default_chunk_size, + self.default_chunk_overlap + ) + + # Create text splitter with effective parameters + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + length_function=len, + is_separator_regex=False, + ) + + texts = text_splitter.create_documents( [v.text.decode("utf-8")] ) @@ -89,7 +108,7 @@ class Processor(FlowProcessor): @staticmethod def add_args(parser): - FlowProcessor.add_args(parser) + ChunkingService.add_args(parser) parser.add_argument( '-z', '--chunk-size', diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index a1f43a35..b4e55038 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -9,14 +9,14 @@ from langchain_text_splitters import TokenTextSplitter from prometheus_client import Histogram from ... schema import TextDocument, Chunk -from ... base import FlowProcessor, ConsumerSpec, ProducerSpec +from ... base import ChunkingService, ConsumerSpec, ProducerSpec # Module logger logger = logging.getLogger(__name__) default_ident = "chunker" -class Processor(FlowProcessor): +class Processor(ChunkingService): def __init__(self, **params): @@ -28,6 +28,10 @@ class Processor(FlowProcessor): **params | { "id": id } ) + # Store default values for parameter override + self.default_chunk_size = chunk_size + self.default_chunk_overlap = chunk_overlap + if not hasattr(__class__, "chunk_metric"): __class__.chunk_metric = Histogram( 'chunk_size', 'Chunk size', @@ -64,7 +68,21 @@ class Processor(FlowProcessor): v = msg.value() logger.info(f"Chunking document {v.metadata.id}...") - texts = self.text_splitter.create_documents( + # Extract chunk parameters from flow (allows runtime override) + chunk_size, chunk_overlap = await self.chunk_document( + msg, consumer, flow, + self.default_chunk_size, + self.default_chunk_overlap + ) + + # Create text splitter with effective parameters + text_splitter = TokenTextSplitter( + encoding_name="cl100k_base", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + texts = text_splitter.create_documents( [v.text.decode("utf-8")] ) @@ -88,7 +106,7 @@ class Processor(FlowProcessor): @staticmethod def add_args(parser): - FlowProcessor.add_args(parser) + ChunkingService.add_args(parser) parser.add_argument( '-z', '--chunk-size',