mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Fix ontology RAG pipeline + add query concurrency (#691)
- Fix ontology RAG pipeline: embeddings API, chunker provenance, and query concurrency
- Fix ontology embeddings to use correct response shape from embed()
API (returns list of vectors, not list of list of vectors).
- Simplify chunker URI logic to append /c{index} to parent ID
instead of parsing page/doc URI structure which was fragile.
- Add provenance tracking and librarian integration to token
chunker, matching recursive chunker capabilities.
- Add configurable concurrency (default 10) to Cassandra, Qdrant,
and embeddings query services.
This commit is contained in:
parent
312174eb88
commit
45e6ad4abc
9 changed files with 148 additions and 50 deletions
|
|
@ -176,6 +176,9 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
|
||||
processor = Processor(**config)
|
||||
|
||||
# Mock save_child_document to avoid librarian producer interactions
|
||||
processor.save_child_document = AsyncMock(return_value="chunk-id")
|
||||
|
||||
# Mock message with TextDocument
|
||||
mock_message = MagicMock()
|
||||
mock_text_doc = MagicMock()
|
||||
|
|
@ -191,11 +194,13 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
|
|||
# Mock consumer and flow with parameter overrides
|
||||
mock_consumer = MagicMock()
|
||||
mock_producer = AsyncMock()
|
||||
mock_triples_producer = AsyncMock()
|
||||
mock_flow = MagicMock()
|
||||
mock_flow.side_effect = lambda param: {
|
||||
"chunk-size": 400,
|
||||
"chunk-overlap": 40,
|
||||
"output": mock_producer
|
||||
"output": mock_producer,
|
||||
"triples": mock_triples_producer,
|
||||
}.get(param)
|
||||
|
||||
# Act
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "doc-embeddings-query"
|
||||
default_concurrency = 10
|
||||
|
||||
class DocumentEmbeddingsQueryService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
|
||||
super(DocumentEmbeddingsQueryService, self).__init__(
|
||||
**params | { "id": id }
|
||||
|
|
@ -32,7 +34,8 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
|||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = DocumentEmbeddingsRequest,
|
||||
handler = self.on_message
|
||||
handler = self.on_message,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -83,6 +86,13 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
|
|||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Number of concurrent requests (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "graph-embeddings-query"
|
||||
default_concurrency = 10
|
||||
|
||||
class GraphEmbeddingsQueryService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
|
||||
super(GraphEmbeddingsQueryService, self).__init__(
|
||||
**params | { "id": id }
|
||||
|
|
@ -32,7 +34,8 @@ class GraphEmbeddingsQueryService(FlowProcessor):
|
|||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = GraphEmbeddingsRequest,
|
||||
handler = self.on_message
|
||||
handler = self.on_message,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -83,6 +86,13 @@ class GraphEmbeddingsQueryService(FlowProcessor):
|
|||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Number of concurrent requests (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ from ... schema import TextDocument, Chunk, Metadata, Triples
|
|||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
from ... provenance import (
|
||||
page_uri, chunk_uri_from_page, chunk_uri_from_doc,
|
||||
derived_entity_triples, document_uri,
|
||||
derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
)
|
||||
|
||||
|
|
@ -114,22 +113,9 @@ class Processor(ChunkingService):
|
|||
texts = text_splitter.create_documents([text])
|
||||
|
||||
# Get parent document ID for provenance linking
|
||||
# This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it
|
||||
parent_doc_id = v.document_id or v.metadata.id
|
||||
|
||||
# Determine if parent is a page (from PDF) or source document (text)
|
||||
# Check if parent_doc_id contains "/p" which indicates a page
|
||||
is_from_page = "/p" in parent_doc_id
|
||||
|
||||
# Extract the root document ID for chunk URI generation
|
||||
if is_from_page:
|
||||
# Parent is a page like "doc123/p3", extract page number
|
||||
parts = parent_doc_id.rsplit("/p", 1)
|
||||
root_doc_id = parts[0]
|
||||
page_num = int(parts[1]) if len(parts) > 1 else 1
|
||||
else:
|
||||
root_doc_id = parent_doc_id
|
||||
page_num = None
|
||||
|
||||
# Track character offset for provenance
|
||||
char_offset = 0
|
||||
|
||||
|
|
@ -138,15 +124,11 @@ class Processor(ChunkingService):
|
|||
|
||||
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
|
||||
|
||||
# Generate chunk document ID
|
||||
if is_from_page:
|
||||
chunk_doc_id = f"{root_doc_id}/p{page_num}/c{chunk_index}"
|
||||
chunk_uri = chunk_uri_from_page(root_doc_id, page_num, chunk_index)
|
||||
parent_uri = page_uri(root_doc_id, page_num)
|
||||
else:
|
||||
chunk_doc_id = f"{root_doc_id}/c{chunk_index}"
|
||||
chunk_uri = chunk_uri_from_doc(root_doc_id, chunk_index)
|
||||
parent_uri = document_uri(root_doc_id)
|
||||
# Generate chunk document ID by appending /c{index} to parent
|
||||
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
|
||||
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
|
||||
chunk_uri = chunk_doc_id # URI is same as document ID
|
||||
parent_uri = parent_doc_id
|
||||
|
||||
chunk_content = chunk.page_content.encode("utf-8")
|
||||
chunk_length = len(chunk.page_content)
|
||||
|
|
|
|||
|
|
@ -8,9 +8,18 @@ import logging
|
|||
from langchain_text_splitters import TokenTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk
|
||||
from ... schema import TextDocument, Chunk, Metadata, Triples
|
||||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
from ... provenance import (
|
||||
derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
)
|
||||
|
||||
# Component identification for provenance
|
||||
COMPONENT_NAME = "token-chunker"
|
||||
COMPONENT_VERSION = "1.0.0"
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -24,7 +33,7 @@ class Processor(ChunkingService):
|
|||
id = params.get("id", default_ident)
|
||||
chunk_size = params.get("chunk_size", 250)
|
||||
chunk_overlap = params.get("chunk_overlap", 15)
|
||||
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | { "id": id }
|
||||
)
|
||||
|
|
@ -62,6 +71,13 @@ class Processor(ChunkingService):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "triples",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Token chunker initialized")
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
|
|
@ -94,21 +110,82 @@ class Processor(ChunkingService):
|
|||
|
||||
texts = text_splitter.create_documents([text])
|
||||
|
||||
# Get parent document ID for provenance linking
|
||||
# This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it
|
||||
parent_doc_id = v.document_id or v.metadata.id
|
||||
|
||||
# Track token offset for provenance (approximate)
|
||||
token_offset = 0
|
||||
|
||||
for ix, chunk in enumerate(texts):
|
||||
chunk_index = ix + 1 # 1-indexed
|
||||
|
||||
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
|
||||
|
||||
# Generate chunk document ID by appending /c{index} to parent
|
||||
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
|
||||
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
|
||||
chunk_uri = chunk_doc_id # URI is same as document ID
|
||||
parent_uri = parent_doc_id
|
||||
|
||||
chunk_content = chunk.page_content.encode("utf-8")
|
||||
chunk_length = len(chunk.page_content)
|
||||
|
||||
# Save chunk to librarian as child document
|
||||
await self.save_child_document(
|
||||
doc_id=chunk_doc_id,
|
||||
parent_id=parent_doc_id,
|
||||
user=v.metadata.user,
|
||||
content=chunk_content,
|
||||
document_type="chunk",
|
||||
title=f"Chunk {chunk_index}",
|
||||
)
|
||||
|
||||
# Emit provenance triples (stored in source graph for separation from core knowledge)
|
||||
prov_triples = derived_entity_triples(
|
||||
entity_uri=chunk_uri,
|
||||
parent_uri=parent_uri,
|
||||
component_name=COMPONENT_NAME,
|
||||
component_version=COMPONENT_VERSION,
|
||||
label=f"Chunk {chunk_index}",
|
||||
chunk_index=chunk_index,
|
||||
char_offset=token_offset, # Note: this is token offset, not char offset
|
||||
char_length=chunk_length,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
await flow("triples").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=chunk_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
triples=set_graph(prov_triples, GRAPH_SOURCE),
|
||||
))
|
||||
|
||||
# Forward chunk ID + content (post-chunker optimization)
|
||||
r = Chunk(
|
||||
metadata=v.metadata,
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
metadata=Metadata(
|
||||
id=chunk_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
chunk=chunk_content,
|
||||
document_id=chunk_doc_id,
|
||||
)
|
||||
|
||||
__class__.chunk_metric.labels(
|
||||
id=consumer.id, flow=consumer.flow
|
||||
).observe(len(chunk.page_content))
|
||||
).observe(chunk_length)
|
||||
|
||||
await flow("output").send(r)
|
||||
|
||||
# Update token offset (approximate, doesn't account for overlap)
|
||||
token_offset += chunk_size - chunk_overlap
|
||||
|
||||
logger.debug("Document chunking complete")
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -120,17 +197,16 @@ class Processor(ChunkingService):
|
|||
'-z', '--chunk-size',
|
||||
type=int,
|
||||
default=250,
|
||||
help=f'Chunk size (default: 250)'
|
||||
help=f'Chunk size in tokens (default: 250)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-v', '--chunk-overlap',
|
||||
type=int,
|
||||
default=15,
|
||||
help=f'Chunk overlap (default: 15)'
|
||||
help=f'Chunk overlap in tokens (default: 15)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class Processor(FlowProcessor):
|
|||
# Detect embedding dimension by embedding a test string
|
||||
logger.info("Detecting embedding dimension from embeddings service...")
|
||||
test_embedding_response = await embeddings_client.embed(["test"])
|
||||
test_embedding = test_embedding_response[0][0] # Extract first vector from first text
|
||||
test_embedding = test_embedding_response[0] # Extract first vector
|
||||
dimension = len(test_embedding)
|
||||
logger.info(f"Detected embedding dimension: {dimension}")
|
||||
|
||||
|
|
|
|||
|
|
@ -153,14 +153,11 @@ class OntologyEmbedder:
|
|||
# Get embeddings for batch
|
||||
texts = [elem['text'] for elem in batch]
|
||||
try:
|
||||
# Single batch embedding call
|
||||
# Single batch embedding call - returns list of vectors
|
||||
embeddings_response = await self.embedding_service.embed(texts)
|
||||
|
||||
# Extract first vector from each text's vector set
|
||||
embeddings_list = [resp[0] for resp in embeddings_response]
|
||||
|
||||
# Convert to numpy array
|
||||
embeddings = np.array(embeddings_list)
|
||||
embeddings = np.array(embeddings_response)
|
||||
|
||||
# Log embedding shape for debugging
|
||||
logger.debug(f"Embeddings shape: {embeddings.shape}, expected: ({len(batch)}, {self.vector_store.dimension})")
|
||||
|
|
@ -216,9 +213,9 @@ class OntologyEmbedder:
|
|||
return None
|
||||
|
||||
try:
|
||||
# embed() with single text, extract first vector from first text
|
||||
# embed() with single text, extract first vector
|
||||
embedding_response = await self.embedding_service.embed([text])
|
||||
return np.array(embedding_response[0][0])
|
||||
return np.array(embedding_response[0])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed text: {e}")
|
||||
return None
|
||||
|
|
@ -237,11 +234,9 @@ class OntologyEmbedder:
|
|||
return None
|
||||
|
||||
try:
|
||||
# Single batch embedding call
|
||||
# Single batch embedding call - returns list of vectors
|
||||
embeddings_response = await self.embedding_service.embed(texts)
|
||||
# Extract first vector from each text's vector set
|
||||
embeddings_list = [resp[0] for resp in embeddings_response]
|
||||
return np.array(embeddings_list)
|
||||
return np.array(embeddings_response)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed texts: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
default_ident = "row-embeddings-query"
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
default_concurrency = 10
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
|
@ -31,6 +32,7 @@ class Processor(FlowProcessor):
|
|||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
api_key = params.get("api_key", None)
|
||||
|
|
@ -47,7 +49,8 @@ class Processor(FlowProcessor):
|
|||
ConsumerSpec(
|
||||
name="request",
|
||||
schema=RowEmbeddingsRequest,
|
||||
handler=self.on_message
|
||||
handler=self.on_message,
|
||||
concurrency=concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -205,6 +208,13 @@ class Processor(FlowProcessor):
|
|||
help='API key for Qdrant (default: None)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Number of concurrent requests (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for row-embeddings-query-qdrant command"""
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from ... graphql import GraphQLSchemaBuilder, SortDirection
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "rows-query"
|
||||
default_concurrency = 10
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
|
@ -37,6 +38,7 @@ class Processor(FlowProcessor):
|
|||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
|
|
@ -69,7 +71,8 @@ class Processor(FlowProcessor):
|
|||
ConsumerSpec(
|
||||
name="request",
|
||||
schema=RowsQueryRequest,
|
||||
handler=self.on_message
|
||||
handler=self.on_message,
|
||||
concurrency=concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -517,6 +520,13 @@ class Processor(FlowProcessor):
|
|||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Number of concurrent requests (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for rows-query-cassandra command"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue