mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-04 04:42:36 +02:00
Feature/streaming triples (#676)
* Steaming triples * Also GraphRAG service uses this * Updated tests
This commit is contained in:
parent
3c3e11bef5
commit
d2d71f859d
11 changed files with 542 additions and 116 deletions
|
|
@ -7,6 +7,7 @@ null. Output is a list of quads.
|
|||
import logging
|
||||
|
||||
import json
|
||||
from cassandra.query import SimpleStatement
|
||||
|
||||
from .... direct.cassandra_kg import (
|
||||
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
|
||||
|
|
@ -144,28 +145,30 @@ class Processor(TriplesQueryService):
|
|||
self.cassandra_password = password
|
||||
self.table = None
|
||||
|
||||
def ensure_connection(self, user):
|
||||
"""Ensure we have a connection to the correct keyspace."""
|
||||
if user != self.table:
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=user,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=user,
|
||||
)
|
||||
self.table = user
|
||||
|
||||
async def query_triples(self, query):
|
||||
|
||||
try:
|
||||
|
||||
user = query.user
|
||||
|
||||
if user != self.table:
|
||||
# Use factory function to select implementation
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=query.user,
|
||||
username=self.cassandra_username, password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=query.user,
|
||||
)
|
||||
self.table = user
|
||||
self.ensure_connection(query.user)
|
||||
|
||||
# Extract values from query
|
||||
s_val = get_term_value(query.s)
|
||||
|
|
@ -291,6 +294,93 @@ class Processor(TriplesQueryService):
|
|||
logger.error(f"Exception querying triples: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
async def query_triples_stream(self, query):
|
||||
"""
|
||||
Streaming query - yields (batch, is_final) tuples.
|
||||
Uses Cassandra's paging to fetch results incrementally.
|
||||
"""
|
||||
try:
|
||||
self.ensure_connection(query.user)
|
||||
|
||||
batch_size = query.batch_size if query.batch_size > 0 else 20
|
||||
limit = query.limit if query.limit > 0 else 10000
|
||||
|
||||
# Extract query pattern
|
||||
s_val = get_term_value(query.s)
|
||||
p_val = get_term_value(query.p)
|
||||
o_val = get_term_value(query.o)
|
||||
g_val = query.g
|
||||
|
||||
# Helper to extract object metadata from result row
|
||||
def get_o_metadata(t):
|
||||
otype = getattr(t, 'otype', None)
|
||||
dtype = getattr(t, 'dtype', None)
|
||||
lang = getattr(t, 'lang', None)
|
||||
return otype, dtype, lang
|
||||
|
||||
# For streaming, we need to execute with fetch_size
|
||||
# Use the collection table for get_all queries (most common streaming case)
|
||||
|
||||
# Determine which query to use based on pattern
|
||||
if s_val is None and p_val is None and o_val is None:
|
||||
# Get all - use collection table with paging
|
||||
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s"
|
||||
params = [query.collection]
|
||||
else:
|
||||
# For specific patterns, fall back to non-streaming
|
||||
# (these typically return small result sets anyway)
|
||||
async for batch, is_final in self._fallback_stream(query, batch_size):
|
||||
yield batch, is_final
|
||||
return
|
||||
|
||||
# Create statement with fetch_size for true streaming
|
||||
statement = SimpleStatement(cql, fetch_size=batch_size)
|
||||
result_set = self.tg.session.execute(statement, params)
|
||||
|
||||
batch = []
|
||||
count = 0
|
||||
|
||||
for row in result_set:
|
||||
if count >= limit:
|
||||
break
|
||||
|
||||
g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(row)
|
||||
|
||||
triple = Triple(
|
||||
s=create_term(row.s),
|
||||
p=create_term(row.p),
|
||||
o=create_term(row.o, otype=otype, dtype=dtype, lang=lang),
|
||||
g=g if g != DEFAULT_GRAPH else None
|
||||
)
|
||||
batch.append(triple)
|
||||
count += 1
|
||||
|
||||
# Yield batch when full (never mark as final mid-stream)
|
||||
if len(batch) >= batch_size:
|
||||
yield batch, False
|
||||
batch = []
|
||||
|
||||
# Always yield final batch to signal completion
|
||||
# This handles: remaining rows, empty result set, or exact batch boundary
|
||||
yield batch, True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in streaming query: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
async def _fallback_stream(self, query, batch_size):
|
||||
"""Fallback to non-streaming query with post-hoc batching."""
|
||||
triples = await self.query_triples(query)
|
||||
|
||||
for i in range(0, len(triples), batch_size):
|
||||
batch = triples[i:i + batch_size]
|
||||
is_final = (i + batch_size >= len(triples))
|
||||
yield batch, is_final
|
||||
|
||||
if len(triples) == 0:
|
||||
yield [], True
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,25 @@ import logging
|
|||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
from ... schema import IRI, LITERAL
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
|
||||
def term_to_string(term):
|
||||
"""Extract string value from a Term object."""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
# Fallback
|
||||
return term.iri or term.value or str(term)
|
||||
|
||||
class LRUCacheWithTTL:
|
||||
"""LRU cache with TTL for label caching
|
||||
|
||||
|
|
@ -93,7 +107,7 @@ class Query:
|
|||
)
|
||||
|
||||
entities = [
|
||||
str(e.entity)
|
||||
term_to_string(e.entity)
|
||||
for e in entity_matches
|
||||
]
|
||||
|
||||
|
|
@ -129,26 +143,29 @@ class Query:
|
|||
return label
|
||||
|
||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||
"""Execute triple queries for multiple entities concurrently"""
|
||||
"""Execute triple queries for multiple entities concurrently using streaming"""
|
||||
tasks = []
|
||||
|
||||
for entity in entities:
|
||||
# Create concurrent tasks for all 3 query types per entity
|
||||
# Create concurrent streaming tasks for all 3 query types per entity
|
||||
tasks.extend([
|
||||
self.rag.triples_client.query(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=entity, p=None, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=entity, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=None, o=entity,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20,
|
||||
)
|
||||
])
|
||||
|
||||
|
|
@ -158,7 +175,7 @@ class Query:
|
|||
# Combine all results
|
||||
all_triples = []
|
||||
for result in results:
|
||||
if not isinstance(result, Exception):
|
||||
if not isinstance(result, Exception) and result is not None:
|
||||
all_triples.extend(result)
|
||||
|
||||
return all_triples
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue