Fix ontology RAG pipeline + add query concurrency (#691)

- Fix ontology RAG pipeline: embeddings API, chunker provenance, and query concurrency

- Fix ontology embeddings to use correct response shape from embed()
  API (returns list of vectors, not list of list of vectors).
- Simplify chunker URI logic to append /c{index} to parent ID
  instead of parsing page/doc URI structure which was fragile.

- Add provenance tracking and librarian integration to token
  chunker, matching recursive chunker capabilities.

- Add configurable concurrency (default 10) to Cassandra, Qdrant,
  and embeddings query services.
This commit is contained in:
cybermaggedon 2026-03-12 11:34:42 +00:00 committed by GitHub
parent 312174eb88
commit 45e6ad4abc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 148 additions and 50 deletions

View file

@ -176,6 +176,9 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config)
# Mock save_child_document to avoid librarian producer interactions
processor.save_child_document = AsyncMock(return_value="chunk-id")
# Mock message with TextDocument
mock_message = MagicMock()
mock_text_doc = MagicMock()
@ -191,11 +194,13 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
# Mock consumer and flow with parameter overrides
mock_consumer = MagicMock()
mock_producer = AsyncMock()
mock_triples_producer = AsyncMock()
mock_flow = MagicMock()
mock_flow.side_effect = lambda param: {
"chunk-size": 400,
"chunk-overlap": 40,
"output": mock_producer
"output": mock_producer,
"triples": mock_triples_producer,
}.get(param)
# Act

View file

@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec
logger = logging.getLogger(__name__)
default_ident = "doc-embeddings-query"
default_concurrency = 10
class DocumentEmbeddingsQueryService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", default_concurrency)
super(DocumentEmbeddingsQueryService, self).__init__(
**params | { "id": id }
@ -32,7 +34,8 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
ConsumerSpec(
name = "request",
schema = DocumentEmbeddingsRequest,
handler = self.on_message
handler = self.on_message,
concurrency = concurrency,
)
)
@ -83,6 +86,13 @@ class DocumentEmbeddingsQueryService(FlowProcessor):
FlowProcessor.add_args(parser)
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Number of concurrent requests (default: {default_concurrency})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -17,12 +17,14 @@ from . producer_spec import ProducerSpec
logger = logging.getLogger(__name__)
default_ident = "graph-embeddings-query"
default_concurrency = 10
class GraphEmbeddingsQueryService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", default_concurrency)
super(GraphEmbeddingsQueryService, self).__init__(
**params | { "id": id }
@ -32,7 +34,8 @@ class GraphEmbeddingsQueryService(FlowProcessor):
ConsumerSpec(
name = "request",
schema = GraphEmbeddingsRequest,
handler = self.on_message
handler = self.on_message,
concurrency = concurrency,
)
)
@ -83,6 +86,13 @@ class GraphEmbeddingsQueryService(FlowProcessor):
FlowProcessor.add_args(parser)
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Number of concurrent requests (default: {default_concurrency})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -12,8 +12,7 @@ from ... schema import TextDocument, Chunk, Metadata, Triples
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
from ... provenance import (
page_uri, chunk_uri_from_page, chunk_uri_from_doc,
derived_entity_triples, document_uri,
derived_entity_triples,
set_graph, GRAPH_SOURCE,
)
@ -114,22 +113,9 @@ class Processor(ChunkingService):
texts = text_splitter.create_documents([text])
# Get parent document ID for provenance linking
# This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it
parent_doc_id = v.document_id or v.metadata.id
# Determine if parent is a page (from PDF) or source document (text)
# Check if parent_doc_id contains "/p" which indicates a page
is_from_page = "/p" in parent_doc_id
# Extract the root document ID for chunk URI generation
if is_from_page:
# Parent is a page like "doc123/p3", extract page number
parts = parent_doc_id.rsplit("/p", 1)
root_doc_id = parts[0]
page_num = int(parts[1]) if len(parts) > 1 else 1
else:
root_doc_id = parent_doc_id
page_num = None
# Track character offset for provenance
char_offset = 0
@ -138,15 +124,11 @@ class Processor(ChunkingService):
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
# Generate chunk document ID
if is_from_page:
chunk_doc_id = f"{root_doc_id}/p{page_num}/c{chunk_index}"
chunk_uri = chunk_uri_from_page(root_doc_id, page_num, chunk_index)
parent_uri = page_uri(root_doc_id, page_num)
else:
chunk_doc_id = f"{root_doc_id}/c{chunk_index}"
chunk_uri = chunk_uri_from_doc(root_doc_id, chunk_index)
parent_uri = document_uri(root_doc_id)
# Generate chunk document ID by appending /c{index} to parent
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
chunk_uri = chunk_doc_id # URI is same as document ID
parent_uri = parent_doc_id
chunk_content = chunk.page_content.encode("utf-8")
chunk_length = len(chunk.page_content)

View file

@ -8,9 +8,18 @@ import logging
from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk
from ... schema import TextDocument, Chunk, Metadata, Triples
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
from ... provenance import (
derived_entity_triples,
set_graph, GRAPH_SOURCE,
)
# Component identification for provenance
COMPONENT_NAME = "token-chunker"
COMPONENT_VERSION = "1.0.0"
# Module logger
logger = logging.getLogger(__name__)
@ -24,7 +33,7 @@ class Processor(ChunkingService):
id = params.get("id", default_ident)
chunk_size = params.get("chunk_size", 250)
chunk_overlap = params.get("chunk_overlap", 15)
super(Processor, self).__init__(
**params | { "id": id }
)
@ -62,6 +71,13 @@ class Processor(ChunkingService):
)
)
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples,
)
)
logger.info("Token chunker initialized")
async def on_message(self, msg, consumer, flow):
@ -94,21 +110,82 @@ class Processor(ChunkingService):
texts = text_splitter.create_documents([text])
# Get parent document ID for provenance linking
# This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it
parent_doc_id = v.document_id or v.metadata.id
# Track token offset for provenance (approximate)
token_offset = 0
for ix, chunk in enumerate(texts):
chunk_index = ix + 1 # 1-indexed
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
# Generate chunk document ID by appending /c{index} to parent
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
chunk_uri = chunk_doc_id # URI is same as document ID
parent_uri = parent_doc_id
chunk_content = chunk.page_content.encode("utf-8")
chunk_length = len(chunk.page_content)
# Save chunk to librarian as child document
await self.save_child_document(
doc_id=chunk_doc_id,
parent_id=parent_doc_id,
user=v.metadata.user,
content=chunk_content,
document_type="chunk",
title=f"Chunk {chunk_index}",
)
# Emit provenance triples (stored in source graph for separation from core knowledge)
prov_triples = derived_entity_triples(
entity_uri=chunk_uri,
parent_uri=parent_uri,
component_name=COMPONENT_NAME,
component_version=COMPONENT_VERSION,
label=f"Chunk {chunk_index}",
chunk_index=chunk_index,
char_offset=token_offset, # Note: this is token offset, not char offset
char_length=chunk_length,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
await flow("triples").send(Triples(
metadata=Metadata(
id=chunk_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
))
# Forward chunk ID + content (post-chunker optimization)
r = Chunk(
metadata=v.metadata,
chunk=chunk.page_content.encode("utf-8"),
metadata=Metadata(
id=chunk_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
chunk=chunk_content,
document_id=chunk_doc_id,
)
__class__.chunk_metric.labels(
id=consumer.id, flow=consumer.flow
).observe(len(chunk.page_content))
).observe(chunk_length)
await flow("output").send(r)
# Update token offset (approximate, doesn't account for overlap)
token_offset += chunk_size - chunk_overlap
logger.debug("Document chunking complete")
@staticmethod
@ -120,17 +197,16 @@ class Processor(ChunkingService):
'-z', '--chunk-size',
type=int,
default=250,
help=f'Chunk size (default: 250)'
help=f'Chunk size in tokens (default: 250)'
)
parser.add_argument(
'-v', '--chunk-overlap',
type=int,
default=15,
help=f'Chunk overlap (default: 15)'
help=f'Chunk overlap in tokens (default: 15)'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -149,7 +149,7 @@ class Processor(FlowProcessor):
# Detect embedding dimension by embedding a test string
logger.info("Detecting embedding dimension from embeddings service...")
test_embedding_response = await embeddings_client.embed(["test"])
test_embedding = test_embedding_response[0][0] # Extract first vector from first text
test_embedding = test_embedding_response[0] # Extract first vector
dimension = len(test_embedding)
logger.info(f"Detected embedding dimension: {dimension}")

View file

@ -153,14 +153,11 @@ class OntologyEmbedder:
# Get embeddings for batch
texts = [elem['text'] for elem in batch]
try:
# Single batch embedding call
# Single batch embedding call - returns list of vectors
embeddings_response = await self.embedding_service.embed(texts)
# Extract first vector from each text's vector set
embeddings_list = [resp[0] for resp in embeddings_response]
# Convert to numpy array
embeddings = np.array(embeddings_list)
embeddings = np.array(embeddings_response)
# Log embedding shape for debugging
logger.debug(f"Embeddings shape: {embeddings.shape}, expected: ({len(batch)}, {self.vector_store.dimension})")
@ -216,9 +213,9 @@ class OntologyEmbedder:
return None
try:
# embed() with single text, extract first vector from first text
# embed() with single text, extract first vector
embedding_response = await self.embedding_service.embed([text])
return np.array(embedding_response[0][0])
return np.array(embedding_response[0])
except Exception as e:
logger.error(f"Failed to embed text: {e}")
return None
@ -237,11 +234,9 @@ class OntologyEmbedder:
return None
try:
# Single batch embedding call
# Single batch embedding call - returns list of vectors
embeddings_response = await self.embedding_service.embed(texts)
# Extract first vector from each text's vector set
embeddings_list = [resp[0] for resp in embeddings_response]
return np.array(embeddings_list)
return np.array(embeddings_response)
except Exception as e:
logger.error(f"Failed to embed texts: {e}")
return None

View file

@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
default_ident = "row-embeddings-query"
default_store_uri = 'http://localhost:6333'
default_concurrency = 10
class Processor(FlowProcessor):
@ -31,6 +32,7 @@ class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
@ -47,7 +49,8 @@ class Processor(FlowProcessor):
ConsumerSpec(
name="request",
schema=RowEmbeddingsRequest,
handler=self.on_message
handler=self.on_message,
concurrency=concurrency,
)
)
@ -205,6 +208,13 @@ class Processor(FlowProcessor):
help='API key for Qdrant (default: None)'
)
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Number of concurrent requests (default: {default_concurrency})'
)
def run():
"""Entry point for row-embeddings-query-qdrant command"""

View file

@ -30,6 +30,7 @@ from ... graphql import GraphQLSchemaBuilder, SortDirection
logger = logging.getLogger(__name__)
default_ident = "rows-query"
default_concurrency = 10
class Processor(FlowProcessor):
@ -37,6 +38,7 @@ class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
@ -69,7 +71,8 @@ class Processor(FlowProcessor):
ConsumerSpec(
name="request",
schema=RowsQueryRequest,
handler=self.on_message
handler=self.on_message,
concurrency=concurrency,
)
)
@ -517,6 +520,13 @@ class Processor(FlowProcessor):
help='Configuration type prefix for schemas (default: schema)'
)
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Number of concurrent requests (default: {default_concurrency})'
)
def run():
"""Entry point for rows-query-cassandra command"""