diff --git a/tests/unit/test_chunking/test_token_chunker.py b/tests/unit/test_chunking/test_token_chunker.py index 45fab919..2ed37391 100644 --- a/tests/unit/test_chunking/test_token_chunker.py +++ b/tests/unit/test_chunking/test_token_chunker.py @@ -176,6 +176,9 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): processor = Processor(**config) + # Mock save_child_document to avoid librarian producer interactions + processor.save_child_document = AsyncMock(return_value="chunk-id") + # Mock message with TextDocument mock_message = MagicMock() mock_text_doc = MagicMock() @@ -191,11 +194,13 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase): # Mock consumer and flow with parameter overrides mock_consumer = MagicMock() mock_producer = AsyncMock() + mock_triples_producer = AsyncMock() mock_flow = MagicMock() mock_flow.side_effect = lambda param: { "chunk-size": 400, "chunk-overlap": 40, - "output": mock_producer + "output": mock_producer, + "triples": mock_triples_producer, }.get(param) # Act diff --git a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py index b8979776..c7aef104 100644 --- a/trustgraph-base/trustgraph/base/document_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/document_embeddings_query_service.py @@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec logger = logging.getLogger(__name__) default_ident = "doc-embeddings-query" +default_concurrency = 10 class DocumentEmbeddingsQueryService(FlowProcessor): def __init__(self, **params): id = params.get("id") + concurrency = params.get("concurrency", default_concurrency) super(DocumentEmbeddingsQueryService, self).__init__( **params | { "id": id } @@ -32,7 +34,8 @@ class DocumentEmbeddingsQueryService(FlowProcessor): ConsumerSpec( name = "request", schema = DocumentEmbeddingsRequest, - handler = self.on_message + handler = self.on_message, + concurrency = concurrency, ) ) @@ -83,6 +86,13 @@ class DocumentEmbeddingsQueryService(FlowProcessor): FlowProcessor.add_args(parser) + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests (default: {default_concurrency})' + ) + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py index d429b3a5..cbbef4f2 100644 --- a/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py +++ b/trustgraph-base/trustgraph/base/graph_embeddings_query_service.py @@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec logger = logging.getLogger(__name__) default_ident = "graph-embeddings-query" +default_concurrency = 10 class GraphEmbeddingsQueryService(FlowProcessor): def __init__(self, **params): id = params.get("id") + concurrency = params.get("concurrency", default_concurrency) super(GraphEmbeddingsQueryService, self).__init__( **params | { "id": id } @@ -32,7 +34,8 @@ class GraphEmbeddingsQueryService(FlowProcessor): ConsumerSpec( name = "request", schema = GraphEmbeddingsRequest, - handler = self.on_message + handler = self.on_message, + concurrency = concurrency, ) ) @@ -83,6 +86,13 @@ class GraphEmbeddingsQueryService(FlowProcessor): FlowProcessor.add_args(parser) + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests (default: {default_concurrency})' + ) + def run(): Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index 3438093c..fb84c356 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -12,8 +12,7 @@ from ... schema import TextDocument, Chunk, Metadata, Triples from ... base import ChunkingService, ConsumerSpec, ProducerSpec from ... provenance import ( - page_uri, chunk_uri_from_page, chunk_uri_from_doc, - derived_entity_triples, document_uri, + derived_entity_triples, set_graph, GRAPH_SOURCE, ) @@ -114,22 +113,9 @@ class Processor(ChunkingService): texts = text_splitter.create_documents([text]) # Get parent document ID for provenance linking + # This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it parent_doc_id = v.document_id or v.metadata.id - # Determine if parent is a page (from PDF) or source document (text) - # Check if parent_doc_id contains "/p" which indicates a page - is_from_page = "/p" in parent_doc_id - - # Extract the root document ID for chunk URI generation - if is_from_page: - # Parent is a page like "doc123/p3", extract page number - parts = parent_doc_id.rsplit("/p", 1) - root_doc_id = parts[0] - page_num = int(parts[1]) if len(parts) > 1 else 1 - else: - root_doc_id = parent_doc_id - page_num = None - # Track character offset for provenance char_offset = 0 @@ -138,15 +124,11 @@ class Processor(ChunkingService): logger.debug(f"Created chunk of size {len(chunk.page_content)}") - # Generate chunk document ID - if is_from_page: - chunk_doc_id = f"{root_doc_id}/p{page_num}/c{chunk_index}" - chunk_uri = chunk_uri_from_page(root_doc_id, page_num, chunk_index) - parent_uri = page_uri(root_doc_id, page_num) - else: - chunk_doc_id = f"{root_doc_id}/c{chunk_index}" - chunk_uri = chunk_uri_from_doc(root_doc_id, chunk_index) - parent_uri = document_uri(root_doc_id) + # Generate chunk document ID by appending /c{index} to parent + # Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1) + chunk_doc_id = f"{parent_doc_id}/c{chunk_index}" + chunk_uri = chunk_doc_id # URI is same as document ID + parent_uri = parent_doc_id chunk_content = chunk.page_content.encode("utf-8") chunk_length = len(chunk.page_content) diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index cfb068f0..909396c6 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -8,9 +8,18 @@ import logging from langchain_text_splitters import TokenTextSplitter from prometheus_client import Histogram -from ... schema import TextDocument, Chunk +from ... schema import TextDocument, Chunk, Metadata, Triples from ... base import ChunkingService, ConsumerSpec, ProducerSpec +from ... provenance import ( + derived_entity_triples, + set_graph, GRAPH_SOURCE, +) + +# Component identification for provenance +COMPONENT_NAME = "token-chunker" +COMPONENT_VERSION = "1.0.0" + # Module logger logger = logging.getLogger(__name__) @@ -24,7 +33,7 @@ class Processor(ChunkingService): id = params.get("id", default_ident) chunk_size = params.get("chunk_size", 250) chunk_overlap = params.get("chunk_overlap", 15) - + super(Processor, self).__init__( **params | { "id": id } ) @@ -62,6 +71,13 @@ class Processor(ChunkingService): ) ) + self.register_specification( + ProducerSpec( + name = "triples", + schema = Triples, + ) + ) + logger.info("Token chunker initialized") async def on_message(self, msg, consumer, flow): @@ -94,21 +110,82 @@ class Processor(ChunkingService): texts = text_splitter.create_documents([text]) + # Get parent document ID for provenance linking + # This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it + parent_doc_id = v.document_id or v.metadata.id + + # Track token offset for provenance (approximate) + token_offset = 0 + for ix, chunk in enumerate(texts): + chunk_index = ix + 1 # 1-indexed logger.debug(f"Created chunk of size {len(chunk.page_content)}") + # Generate chunk document ID by appending /c{index} to parent + # Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1) + chunk_doc_id = f"{parent_doc_id}/c{chunk_index}" + chunk_uri = chunk_doc_id # URI is same as document ID + parent_uri = parent_doc_id + + chunk_content = chunk.page_content.encode("utf-8") + chunk_length = len(chunk.page_content) + + # Save chunk to librarian as child document + await self.save_child_document( + doc_id=chunk_doc_id, + parent_id=parent_doc_id, + user=v.metadata.user, + content=chunk_content, + document_type="chunk", + title=f"Chunk {chunk_index}", + ) + + # Emit provenance triples (stored in source graph for separation from core knowledge) + prov_triples = derived_entity_triples( + entity_uri=chunk_uri, + parent_uri=parent_uri, + component_name=COMPONENT_NAME, + component_version=COMPONENT_VERSION, + label=f"Chunk {chunk_index}", + chunk_index=chunk_index, + char_offset=token_offset, # Note: this is token offset, not char offset + char_length=chunk_length, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + await flow("triples").send(Triples( + metadata=Metadata( + id=chunk_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + triples=set_graph(prov_triples, GRAPH_SOURCE), + )) + + # Forward chunk ID + content (post-chunker optimization) r = Chunk( - metadata=v.metadata, - chunk=chunk.page_content.encode("utf-8"), + metadata=Metadata( + id=chunk_uri, + root=v.metadata.root, + user=v.metadata.user, + collection=v.metadata.collection, + ), + chunk=chunk_content, + document_id=chunk_doc_id, ) __class__.chunk_metric.labels( id=consumer.id, flow=consumer.flow - ).observe(len(chunk.page_content)) + ).observe(chunk_length) await flow("output").send(r) + # Update token offset (approximate, doesn't account for overlap) + token_offset += chunk_size - chunk_overlap + logger.debug("Document chunking complete") @staticmethod @@ -120,17 +197,16 @@ class Processor(ChunkingService): '-z', '--chunk-size', type=int, default=250, - help=f'Chunk size (default: 250)' + help=f'Chunk size in tokens (default: 250)' ) parser.add_argument( '-v', '--chunk-overlap', type=int, default=15, - help=f'Chunk overlap (default: 15)' + help=f'Chunk overlap in tokens (default: 15)' ) def run(): Processor.launch(default_ident, __doc__) - diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py index eec6face..16383752 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/extract.py @@ -149,7 +149,7 @@ class Processor(FlowProcessor): # Detect embedding dimension by embedding a test string logger.info("Detecting embedding dimension from embeddings service...") test_embedding_response = await embeddings_client.embed(["test"]) - test_embedding = test_embedding_response[0][0] # Extract first vector from first text + test_embedding = test_embedding_response[0] # Extract first vector dimension = len(test_embedding) logger.info(f"Detected embedding dimension: {dimension}") diff --git a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py index 4bff6551..64127487 100644 --- a/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py +++ b/trustgraph-flow/trustgraph/extract/kg/ontology/ontology_embedder.py @@ -153,14 +153,11 @@ class OntologyEmbedder: # Get embeddings for batch texts = [elem['text'] for elem in batch] try: - # Single batch embedding call + # Single batch embedding call - returns list of vectors embeddings_response = await self.embedding_service.embed(texts) - # Extract first vector from each text's vector set - embeddings_list = [resp[0] for resp in embeddings_response] - # Convert to numpy array - embeddings = np.array(embeddings_list) + embeddings = np.array(embeddings_response) # Log embedding shape for debugging logger.debug(f"Embeddings shape: {embeddings.shape}, expected: ({len(batch)}, {self.vector_store.dimension})") @@ -216,9 +213,9 @@ class OntologyEmbedder: return None try: - # embed() with single text, extract first vector from first text + # embed() with single text, extract first vector embedding_response = await self.embedding_service.embed([text]) - return np.array(embedding_response[0][0]) + return np.array(embedding_response[0]) except Exception as e: logger.error(f"Failed to embed text: {e}") return None @@ -237,11 +234,9 @@ class OntologyEmbedder: return None try: - # Single batch embedding call + # Single batch embedding call - returns list of vectors embeddings_response = await self.embedding_service.embed(texts) - # Extract first vector from each text's vector set - embeddings_list = [resp[0] for resp in embeddings_response] - return np.array(embeddings_list) + return np.array(embeddings_response) except Exception as e: logger.error(f"Failed to embed texts: {e}") return None diff --git a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py index 307899d6..7fc20303 100644 --- a/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/row_embeddings/qdrant/service.py @@ -24,6 +24,7 @@ logger = logging.getLogger(__name__) default_ident = "row-embeddings-query" default_store_uri = 'http://localhost:6333' +default_concurrency = 10 class Processor(FlowProcessor): @@ -31,6 +32,7 @@ class Processor(FlowProcessor): def __init__(self, **params): id = params.get("id", default_ident) + concurrency = params.get("concurrency", default_concurrency) store_uri = params.get("store_uri", default_store_uri) api_key = params.get("api_key", None) @@ -47,7 +49,8 @@ class Processor(FlowProcessor): ConsumerSpec( name="request", schema=RowEmbeddingsRequest, - handler=self.on_message + handler=self.on_message, + concurrency=concurrency, ) ) @@ -205,6 +208,13 @@ class Processor(FlowProcessor): help='API key for Qdrant (default: None)' ) + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests (default: {default_concurrency})' + ) + def run(): """Entry point for row-embeddings-query-qdrant command""" diff --git a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py index 3808cdb0..2337642f 100644 --- a/trustgraph-flow/trustgraph/query/rows/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/rows/cassandra/service.py @@ -30,6 +30,7 @@ from ... graphql import GraphQLSchemaBuilder, SortDirection logger = logging.getLogger(__name__) default_ident = "rows-query" +default_concurrency = 10 class Processor(FlowProcessor): @@ -37,6 +38,7 @@ class Processor(FlowProcessor): def __init__(self, **params): id = params.get("id", default_ident) + concurrency = params.get("concurrency", default_concurrency) # Get Cassandra parameters cassandra_host = params.get("cassandra_host") @@ -69,7 +71,8 @@ class Processor(FlowProcessor): ConsumerSpec( name="request", schema=RowsQueryRequest, - handler=self.on_message + handler=self.on_message, + concurrency=concurrency, ) ) @@ -517,6 +520,13 @@ class Processor(FlowProcessor): help='Configuration type prefix for schemas (default: schema)' ) + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests (default: {default_concurrency})' + ) + def run(): """Entry point for rows-query-cassandra command"""