Feature/streaming triples (#676)

* Steaming triples

* Also GraphRAG service uses this

* Updated tests
This commit is contained in:
cybermaggedon 2026-03-09 15:46:33 +00:00 committed by GitHub
parent 3c3e11bef5
commit d2d71f859d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 542 additions and 116 deletions

View file

@ -7,6 +7,7 @@ null. Output is a list of quads.
import logging
import json
from cassandra.query import SimpleStatement
from .... direct.cassandra_kg import (
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
@ -144,28 +145,30 @@ class Processor(TriplesQueryService):
self.cassandra_password = password
self.table = None
def ensure_connection(self, user):
"""Ensure we have a connection to the correct keyspace."""
if user != self.table:
KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
)
self.table = user
async def query_triples(self, query):
try:
user = query.user
if user != self.table:
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=query.user,
username=self.cassandra_username, password=self.cassandra_password
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=query.user,
)
self.table = user
self.ensure_connection(query.user)
# Extract values from query
s_val = get_term_value(query.s)
@ -291,6 +294,93 @@ class Processor(TriplesQueryService):
logger.error(f"Exception querying triples: {e}", exc_info=True)
raise e
async def query_triples_stream(self, query):
"""
Streaming query - yields (batch, is_final) tuples.
Uses Cassandra's paging to fetch results incrementally.
"""
try:
self.ensure_connection(query.user)
batch_size = query.batch_size if query.batch_size > 0 else 20
limit = query.limit if query.limit > 0 else 10000
# Extract query pattern
s_val = get_term_value(query.s)
p_val = get_term_value(query.p)
o_val = get_term_value(query.o)
g_val = query.g
# Helper to extract object metadata from result row
def get_o_metadata(t):
otype = getattr(t, 'otype', None)
dtype = getattr(t, 'dtype', None)
lang = getattr(t, 'lang', None)
return otype, dtype, lang
# For streaming, we need to execute with fetch_size
# Use the collection table for get_all queries (most common streaming case)
# Determine which query to use based on pattern
if s_val is None and p_val is None and o_val is None:
# Get all - use collection table with paging
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s"
params = [query.collection]
else:
# For specific patterns, fall back to non-streaming
# (these typically return small result sets anyway)
async for batch, is_final in self._fallback_stream(query, batch_size):
yield batch, is_final
return
# Create statement with fetch_size for true streaming
statement = SimpleStatement(cql, fetch_size=batch_size)
result_set = self.tg.session.execute(statement, params)
batch = []
count = 0
for row in result_set:
if count >= limit:
break
g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(row)
triple = Triple(
s=create_term(row.s),
p=create_term(row.p),
o=create_term(row.o, otype=otype, dtype=dtype, lang=lang),
g=g if g != DEFAULT_GRAPH else None
)
batch.append(triple)
count += 1
# Yield batch when full (never mark as final mid-stream)
if len(batch) >= batch_size:
yield batch, False
batch = []
# Always yield final batch to signal completion
# This handles: remaining rows, empty result set, or exact batch boundary
yield batch, True
except Exception as e:
logger.error(f"Exception in streaming query: {e}", exc_info=True)
raise e
async def _fallback_stream(self, query, batch_size):
"""Fallback to non-streaming query with post-hoc batching."""
triples = await self.query_triples(query)
for i in range(0, len(triples), batch_size):
batch = triples[i:i + batch_size]
is_final = (i + batch_size >= len(triples))
yield batch, is_final
if len(triples) == 0:
yield [], True
@staticmethod
def add_args(parser):

View file

@ -4,11 +4,25 @@ import logging
import time
from collections import OrderedDict
from ... schema import IRI, LITERAL
# Module logger
logger = logging.getLogger(__name__)
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
def term_to_string(term):
"""Extract string value from a Term object."""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
# Fallback
return term.iri or term.value or str(term)
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
@ -93,7 +107,7 @@ class Query:
)
entities = [
str(e.entity)
term_to_string(e.entity)
for e in entity_matches
]
@ -129,26 +143,29 @@ class Query:
return label
async def execute_batch_triple_queries(self, entities, limit_per_entity):
"""Execute triple queries for multiple entities concurrently"""
"""Execute triple queries for multiple entities concurrently using streaming"""
tasks = []
for entity in entities:
# Create concurrent tasks for all 3 query types per entity
# Create concurrent streaming tasks for all 3 query types per entity
tasks.extend([
self.rag.triples_client.query(
self.rag.triples_client.query_stream(
s=entity, p=None, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
user=self.user, collection=self.collection,
batch_size=20,
),
self.rag.triples_client.query(
self.rag.triples_client.query_stream(
s=None, p=entity, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
user=self.user, collection=self.collection,
batch_size=20,
),
self.rag.triples_client.query(
self.rag.triples_client.query_stream(
s=None, p=None, o=entity,
limit=limit_per_entity,
user=self.user, collection=self.collection
user=self.user, collection=self.collection,
batch_size=20,
)
])
@ -158,7 +175,7 @@ class Query:
# Combine all results
all_triples = []
for result in results:
if not isinstance(result, Exception):
if not isinstance(result, Exception) and result is not None:
all_triples.extend(result)
return all_triples