Feature/streaming triples (#676)

* Steaming triples

* Also GraphRAG service uses this

* Updated tests
This commit is contained in:
cybermaggedon 2026-03-09 15:46:33 +00:00 committed by GitHub
parent 3c3e11bef5
commit d2d71f859d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 542 additions and 116 deletions

View file

@ -48,7 +48,7 @@ class TestGraphRagIntegration:
client = AsyncMock()
# Mock different queries return different triples
async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None):
async def query_stream_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None, batch_size=20):
# Mock label queries
if p == "http://www.w3.org/2000/01/rdf-schema#label":
if s == "http://trustgraph.ai/e/machine-learning":
@ -76,7 +76,9 @@ class TestGraphRagIntegration:
return []
client.query.side_effect = query_side_effect
client.query_stream.side_effect = query_stream_side_effect
# Also mock query for label lookups (maybe_label uses query, not query_stream)
client.query.side_effect = query_stream_side_effect
return client
@pytest.fixture
@ -137,7 +139,7 @@ class TestGraphRagIntegration:
assert call_args.kwargs['collection'] == collection
# 3. Should query triples to build knowledge subgraph
assert mock_triples_client.query.call_count > 0
assert mock_triples_client.query_stream.call_count > 0
# 4. Should call prompt with knowledge graph
mock_prompt_client.kg_prompt.assert_called_once()
@ -202,7 +204,7 @@ class TestGraphRagIntegration:
"""Test GraphRAG handles empty knowledge graph gracefully"""
# Arrange
mock_graph_embeddings_client.query.return_value = [] # No entities found
mock_triples_client.query.return_value = [] # No triples found
mock_triples_client.query_stream.return_value = [] # No triples found
# Act
result = await graph_rag.query(
@ -231,7 +233,7 @@ class TestGraphRagIntegration:
collection="test_collection"
)
first_call_count = mock_triples_client.query.call_count
first_call_count = mock_triples_client.query_stream.call_count
mock_triples_client.reset_mock()
# Second identical query
@ -241,7 +243,7 @@ class TestGraphRagIntegration:
collection="test_collection"
)
second_call_count = mock_triples_client.query.call_count
second_call_count = mock_triples_client.query_stream.call_count
# Assert - Second query should make fewer triple queries due to caching
# Note: This is a weak assertion because caching behavior depends on

View file

@ -193,15 +193,17 @@ class TestQuery:
test_vectors = [[0.1, 0.2, 0.3]]
mock_embeddings_client.embed.return_value = [test_vectors]
# Mock EntityMatch objects with entity that has string representation
# Mock EntityMatch objects with entity as Term-like object
mock_entity1 = MagicMock()
mock_entity1.__str__ = MagicMock(return_value="entity1")
mock_entity1.type = "i" # IRI type
mock_entity1.iri = "entity1"
mock_match1 = MagicMock()
mock_match1.entity = mock_entity1
mock_match1.score = 0.95
mock_entity2 = MagicMock()
mock_entity2.__str__ = MagicMock(return_value="entity2")
mock_entity2.type = "i" # IRI type
mock_entity2.iri = "entity2"
mock_match2 = MagicMock()
mock_match2.entity = mock_entity2
mock_match2.score = 0.85
@ -363,10 +365,10 @@ class TestQuery:
mock_triple3 = MagicMock()
mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1"
# Setup query responses for s=ent, p=ent, o=ent patterns
mock_triples_client.query.side_effect = [
# Setup query_stream responses for s=ent, p=ent, o=ent patterns
mock_triples_client.query_stream.side_effect = [
[mock_triple1], # s=ent, p=None, o=None
[mock_triple2], # s=None, p=ent, o=None
[mock_triple2], # s=None, p=ent, o=None
[mock_triple3], # s=None, p=None, o=ent
]
@ -384,20 +386,20 @@ class TestQuery:
await query.follow_edges("entity1", subgraph, path_length=1)
# Verify all three query patterns were called
assert mock_triples_client.query.call_count == 3
# Verify query calls
mock_triples_client.query.assert_any_call(
assert mock_triples_client.query_stream.call_count == 3
# Verify query_stream calls
mock_triples_client.query_stream.assert_any_call(
s="entity1", p=None, o=None, limit=10,
user="test_user", collection="test_collection"
user="test_user", collection="test_collection", batch_size=20
)
mock_triples_client.query.assert_any_call(
mock_triples_client.query_stream.assert_any_call(
s=None, p="entity1", o=None, limit=10,
user="test_user", collection="test_collection"
user="test_user", collection="test_collection", batch_size=20
)
mock_triples_client.query.assert_any_call(
mock_triples_client.query_stream.assert_any_call(
s=None, p=None, o="entity1", limit=10,
user="test_user", collection="test_collection"
user="test_user", collection="test_collection", batch_size=20
)
# Verify subgraph contains discovered triples
@ -427,9 +429,9 @@ class TestQuery:
# Call follow_edges with path_length=0
subgraph = set()
await query.follow_edges("entity1", subgraph, path_length=0)
# Verify no queries were made
mock_triples_client.query.assert_not_called()
mock_triples_client.query_stream.assert_not_called()
# Verify subgraph remains empty
assert subgraph == set()
@ -456,9 +458,9 @@ class TestQuery:
# Call follow_edges
await query.follow_edges("entity1", subgraph, path_length=1)
# Verify no queries were made due to size limit
mock_triples_client.query.assert_not_called()
mock_triples_client.query_stream.assert_not_called()
# Verify subgraph unchanged
assert len(subgraph) == 3

View file

@ -91,9 +91,18 @@ class SocketClient:
service: str,
flow: Optional[str],
request: Dict[str, Any],
streaming: bool = False
) -> Union[Dict[str, Any], Iterator[StreamingChunk]]:
"""Synchronous wrapper around async WebSocket communication"""
streaming: bool = False,
streaming_raw: bool = False
) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]:
"""Synchronous wrapper around async WebSocket communication.
Args:
service: Service name
flow: Flow ID (optional)
request: Request payload
streaming: Use parsed streaming (for agent/RAG chunk types)
streaming_raw: Use raw streaming (for data batches like triples)
"""
# Create event loop if needed
try:
loop = asyncio.get_event_loop()
@ -105,12 +114,14 @@ class SocketClient:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if streaming:
# For streaming, we need to return an iterator
# Create a generator that runs async code
if streaming_raw:
# Raw streaming for data batches (triples, rows, etc.)
return self._streaming_generator_raw(service, flow, request, loop)
elif streaming:
# Parsed streaming for agent/RAG chunk types
return self._streaming_generator(service, flow, request, loop)
else:
# For non-streaming, just run the async code and return result
# Non-streaming single response
return loop.run_until_complete(self._send_request_async(service, flow, request))
def _streaming_generator(
@ -120,7 +131,7 @@ class SocketClient:
request: Dict[str, Any],
loop: asyncio.AbstractEventLoop
) -> Iterator[StreamingChunk]:
"""Generator that yields streaming chunks"""
"""Generator that yields streaming chunks (for agent/RAG responses)"""
async_gen = self._send_request_async_streaming(service, flow, request)
try:
@ -137,6 +148,74 @@ class SocketClient:
except:
pass
def _streaming_generator_raw(
self,
service: str,
flow: Optional[str],
request: Dict[str, Any],
loop: asyncio.AbstractEventLoop
) -> Iterator[Dict[str, Any]]:
"""Generator that yields raw response dicts (for data streaming like triples)"""
async_gen = self._send_request_async_streaming_raw(service, flow, request)
try:
while True:
try:
data = loop.run_until_complete(async_gen.__anext__())
yield data
except StopAsyncIteration:
break
finally:
try:
loop.run_until_complete(async_gen.aclose())
except:
pass
async def _send_request_async_streaming_raw(
self,
service: str,
flow: Optional[str],
request: Dict[str, Any]
) -> Iterator[Dict[str, Any]]:
"""Async streaming that yields raw response dicts without parsing.
Used for data streaming (triples, rows, etc.) where responses are
just batches of data, not agent/RAG chunk types.
"""
with self._lock:
self._request_counter += 1
request_id = f"req-{self._request_counter}"
ws_url = f"{self.url}/api/v1/socket"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
message = {
"id": request_id,
"service": service,
"request": request
}
if flow:
message["flow"] = flow
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
await websocket.send(json.dumps(message))
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != request_id:
continue
if "error" in response:
raise_from_error_dict(response["error"])
if "response" in response:
yield response["response"]
if response.get("complete"):
break
async def _send_request_async(
self,
service: str,
@ -790,6 +869,74 @@ class SocketFlowInstance:
return self.client._send_request_sync("triples", self.flow_id, request, False)
def triples_query_stream(
self,
s: Optional[str] = None,
p: Optional[str] = None,
o: Optional[str] = None,
user: Optional[str] = None,
collection: Optional[str] = None,
limit: int = 100,
batch_size: int = 20,
**kwargs: Any
) -> Iterator[List[Dict[str, Any]]]:
"""
Query knowledge graph triples with streaming batches.
Yields batches of triples as they arrive, reducing time-to-first-result
and memory overhead for large result sets.
Args:
s: Subject URI (optional, use None for wildcard)
p: Predicate URI (optional, use None for wildcard)
o: Object URI or Literal (optional, use None for wildcard)
user: User/keyspace identifier (optional)
collection: Collection identifier (optional)
limit: Maximum results to return (default: 100)
batch_size: Triples per batch (default: 20)
**kwargs: Additional parameters passed to the service
Yields:
List[Dict]: Batches of triples in wire format
Example:
```python
socket = api.socket()
flow = socket.flow("default")
for batch in flow.triples_query_stream(
user="trustgraph",
collection="default"
):
for triple in batch:
print(triple["s"], triple["p"], triple["o"])
```
"""
request = {
"limit": limit,
"streaming": True,
"batch-size": batch_size,
}
if s is not None:
request["s"] = str(s)
if p is not None:
request["p"] = str(p)
if o is not None:
request["o"] = str(o)
if user is not None:
request["user"] = user
if collection is not None:
request["collection"] = collection
request.update(kwargs)
# Use raw streaming - yields response dicts directly without parsing
for response in self.client._send_request_sync("triples", self.flow_id, request, streaming_raw=True):
# Response is {"response": [...triples...]} from translator
if isinstance(response, dict) and "response" in response:
yield response["response"]
else:
yield response
def rows_query(
self,
query: str,

View file

@ -22,11 +22,19 @@ def to_value(x):
def from_value(x):
"""Convert Uri or Literal to schema Term."""
"""Convert Uri, Literal, or string to schema Term."""
if x is None:
return None
if isinstance(x, Uri):
return Term(type=IRI, iri=str(x))
elif isinstance(x, Literal):
return Term(type=LITERAL, value=str(x))
elif isinstance(x, str):
# Detect IRIs by common prefixes
if x.startswith("http://") or x.startswith("https://") or x.startswith("urn:"):
return Term(type=IRI, iri=x)
else:
return Term(type=LITERAL, value=x)
else:
return Term(type=LITERAL, value=str(x))
@ -57,6 +65,61 @@ class TriplesClient(RequestResponse):
return triples
async def query_stream(self, s=None, p=None, o=None, limit=20,
user="trustgraph", collection="default",
batch_size=20, timeout=30,
batch_callback=None):
"""
Streaming triple query - calls callback for each batch as it arrives.
Args:
s, p, o: Triple pattern (None for wildcard)
limit: Maximum total triples to return
user: User/keyspace
collection: Collection name
batch_size: Triples per batch
timeout: Request timeout in seconds
batch_callback: Async callback(batch, is_final) called for each batch
Returns:
List[Triple]: All triples (flattened) if no callback provided
"""
all_triples = []
async def recipient(resp):
if resp.error:
raise RuntimeError(resp.error.message)
batch = [
Triple(to_value(v.s), to_value(v.p), to_value(v.o))
for v in resp.triples
]
if batch_callback:
await batch_callback(batch, resp.is_final)
else:
all_triples.extend(batch)
return resp.is_final
await self.request(
TriplesQueryRequest(
s=from_value(s),
p=from_value(p),
o=from_value(o),
limit=limit,
user=user,
collection=collection,
streaming=True,
batch_size=batch_size,
),
timeout=timeout,
recipient=recipient,
)
if not batch_callback:
return all_triples
class TriplesClientSpec(RequestResponseSpec):
def __init__(
self, request_name, response_name,

View file

@ -52,13 +52,23 @@ class TriplesQueryService(FlowProcessor):
logger.debug(f"Handling triples query request {id}...")
triples = await self.query_triples(request)
logger.debug("Sending triples query response...")
r = TriplesQueryResponse(triples=triples, error=None)
await flow("response").send(r, properties={"id": id})
logger.debug("Triples query request completed")
if request.streaming:
# Streaming mode: send batches
async for batch, is_final in self.query_triples_stream(request):
r = TriplesQueryResponse(
triples=batch,
error=None,
is_final=is_final,
)
await flow("response").send(r, properties={"id": id})
logger.debug("Triples query streaming completed")
else:
# Non-streaming mode: single response
triples = await self.query_triples(request)
logger.debug("Sending triples query response...")
r = TriplesQueryResponse(triples=triples, error=None)
await flow("response").send(r, properties={"id": id})
logger.debug("Triples query request completed")
except Exception as e:
@ -76,6 +86,24 @@ class TriplesQueryService(FlowProcessor):
await flow("response").send(r, properties={"id": id})
async def query_triples_stream(self, request):
"""
Streaming query - yields (batch, is_final) tuples.
Default implementation batches results from query_triples.
Override for true streaming from backend.
"""
triples = await self.query_triples(request)
batch_size = request.batch_size if request.batch_size > 0 else 20
for i in range(0, len(triples), batch_size):
batch = triples[i:i + batch_size]
is_final = (i + batch_size >= len(triples))
yield batch, is_final
# Handle empty result
if len(triples) == 0:
yield [], True
@staticmethod
def add_args(parser):

View file

@ -23,14 +23,18 @@ class TriplesQueryRequestTranslator(MessageTranslator):
g=g,
limit=int(data.get("limit", 10000)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default")
collection=data.get("collection", "default"),
streaming=data.get("streaming", False),
batch_size=int(data.get("batch-size", 20)),
)
def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]:
result = {
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection
"collection": obj.collection,
"streaming": obj.streaming,
"batch-size": obj.batch_size,
}
if obj.s:
@ -61,4 +65,4 @@ class TriplesQueryResponseTranslator(MessageTranslator):
def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
return self.from_pulsar(obj), obj.is_final

View file

@ -38,11 +38,14 @@ class TriplesQueryRequest:
o: Term | None = None
g: str | None = None # Graph IRI. None=default graph, "*"=all graphs
limit: int = 0
streaming: bool = False # Enable streaming mode (multiple batched responses)
batch_size: int = 20 # Triples per batch in streaming mode
@dataclass
class TriplesQueryResponse:
error: Error | None = None
triples: list[Triple] = field(default_factory=list)
is_final: bool = True # False for intermediate batches in streaming mode
############################################################################

View file

@ -1,6 +1,7 @@
"""
Connects to the graph query service and dumps all graph edges in Turtle
format with RDF-star support for quoted triples.
Uses streaming mode for lower time-to-first-processing.
"""
import rdflib
@ -9,66 +10,82 @@ import sys
import argparse
import os
from trustgraph.api import Api, Uri
from trustgraph.knowledge import QuotedTriple
from trustgraph.api import Api
default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/')
default_user = 'trustgraph'
default_collection = 'default'
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def value_to_rdflib(val):
"""Convert a TrustGraph value to an rdflib term."""
if isinstance(val, Uri):
def term_to_rdflib(term):
"""Convert a wire-format term to an rdflib term."""
if term is None:
return None
t = term.get("t", "")
if t == "i": # IRI
iri = term.get("i", "")
# Skip malformed URLs with spaces
if " " in val:
if " " in iri:
return None
return rdflib.term.URIRef(val)
elif isinstance(val, QuotedTriple):
# RDF-star quoted triple
s_term = value_to_rdflib(val.s)
p_term = value_to_rdflib(val.p)
o_term = value_to_rdflib(val.o)
return rdflib.term.URIRef(iri)
elif t == "l": # Literal
value = term.get("v", "")
datatype = term.get("d")
language = term.get("l")
if language:
return rdflib.term.Literal(value, lang=language)
elif datatype:
return rdflib.term.Literal(value, datatype=rdflib.term.URIRef(datatype))
else:
return rdflib.term.Literal(value)
elif t == "r": # Quoted triple (RDF-star)
triple = term.get("r", {})
s_term = term_to_rdflib(triple.get("s"))
p_term = term_to_rdflib(triple.get("p"))
o_term = term_to_rdflib(triple.get("o"))
if s_term is None or p_term is None or o_term is None:
return None
# rdflib 6.x+ supports Triple as a term type
try:
return rdflib.term.Triple((s_term, p_term, o_term))
except AttributeError:
# Fallback for older rdflib versions - represent as string
return rdflib.term.Literal(f"<<{val.s} {val.p} {val.o}>>")
# Fallback for older rdflib versions
return rdflib.term.Literal(f"<<{s_term} {p_term} {o_term}>>")
else:
return rdflib.term.Literal(str(val))
# Fallback
return rdflib.term.Literal(str(term))
def show_graph(url, flow_id, user, collection):
def show_graph(url, flow_id, user, collection, limit, batch_size, token=None):
api = Api(url).flow().id(flow_id)
rows = api.triples_query(
s=None, p=None, o=None,
user=user, collection=collection,
limit=10_000)
socket = Api(url, token=token).socket()
flow = socket.flow(flow_id)
g = rdflib.Graph()
for row in rows:
try:
for batch in flow.triples_query_stream(
s=None, p=None, o=None,
user=user, collection=collection,
limit=limit,
batch_size=batch_size,
):
for triple in batch:
sv = term_to_rdflib(triple.get("s"))
pv = term_to_rdflib(triple.get("p"))
ov = term_to_rdflib(triple.get("o"))
sv = rdflib.term.URIRef(row.s)
pv = rdflib.term.URIRef(row.p)
ov = value_to_rdflib(row.o)
if sv is None or pv is None or ov is None:
continue
if ov is None:
continue
g.add((sv, pv, ov))
g.serialize(destination="output.ttl", format="turtle")
g.add((sv, pv, ov))
finally:
socket.close()
buf = io.BytesIO()
g.serialize(destination=buf, format="turtle")
sys.stdout.write(buf.getvalue().decode("utf-8"))
@ -103,6 +120,26 @@ def main():
help=f'Collection ID (default: {default_collection})'
)
parser.add_argument(
'-t', '--token',
default=default_token,
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-l', '--limit',
type=int,
default=10000,
help='Maximum number of triples to return (default: 10000)',
)
parser.add_argument(
'-b', '--batch-size',
type=int,
default=20,
help='Triples per streaming batch (default: 20)',
)
args = parser.parse_args()
try:
@ -112,6 +149,9 @@ def main():
flow_id = args.flow_id,
user = args.user,
collection = args.collection,
limit = args.limit,
batch_size = args.batch_size,
token = args.token,
)
except Exception as e:

View file

@ -1,5 +1,6 @@
"""
Connects to the graph query service and dumps all graph edges.
Uses streaming mode for lower time-to-first-result and reduced memory overhead.
"""
import argparse
@ -11,17 +12,30 @@ default_user = 'trustgraph'
default_collection = 'default'
default_token = os.getenv("TRUSTGRAPH_TOKEN", None)
def show_graph(url, flow_id, user, collection, token=None):
def show_graph(url, flow_id, user, collection, limit, batch_size, token=None):
api = Api(url, token=token).flow().id(flow_id)
socket = Api(url, token=token).socket()
flow = socket.flow(flow_id)
rows = api.triples_query(
user=user, collection=collection,
s=None, p=None, o=None, limit=10_000,
)
for row in rows:
print(row.s, row.p, row.o)
try:
for batch in flow.triples_query_stream(
user=user,
collection=collection,
s=None, p=None, o=None,
limit=limit,
batch_size=batch_size,
):
for triple in batch:
s = triple.get("s", {})
p = triple.get("p", {})
o = triple.get("o", {})
# Format terms for display
s_str = s.get("v", s.get("i", str(s)))
p_str = p.get("v", p.get("i", str(p)))
o_str = o.get("v", o.get("i", str(o)))
print(s_str, p_str, o_str)
finally:
socket.close()
def main():
@ -60,6 +74,20 @@ def main():
help='Authentication token (default: $TRUSTGRAPH_TOKEN)',
)
parser.add_argument(
'-l', '--limit',
type=int,
default=10000,
help='Maximum number of triples to return (default: 10000)',
)
parser.add_argument(
'-b', '--batch-size',
type=int,
default=20,
help='Triples per streaming batch (default: 20)',
)
args = parser.parse_args()
try:
@ -69,6 +97,8 @@ def main():
flow_id = args.flow_id,
user = args.user,
collection = args.collection,
limit = args.limit,
batch_size = args.batch_size,
token = args.token,
)

View file

@ -7,6 +7,7 @@ null. Output is a list of quads.
import logging
import json
from cassandra.query import SimpleStatement
from .... direct.cassandra_kg import (
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
@ -144,28 +145,30 @@ class Processor(TriplesQueryService):
self.cassandra_password = password
self.table = None
def ensure_connection(self, user):
"""Ensure we have a connection to the correct keyspace."""
if user != self.table:
KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
)
self.table = user
async def query_triples(self, query):
try:
user = query.user
if user != self.table:
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=query.user,
username=self.cassandra_username, password=self.cassandra_password
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=query.user,
)
self.table = user
self.ensure_connection(query.user)
# Extract values from query
s_val = get_term_value(query.s)
@ -291,6 +294,93 @@ class Processor(TriplesQueryService):
logger.error(f"Exception querying triples: {e}", exc_info=True)
raise e
async def query_triples_stream(self, query):
"""
Streaming query - yields (batch, is_final) tuples.
Uses Cassandra's paging to fetch results incrementally.
"""
try:
self.ensure_connection(query.user)
batch_size = query.batch_size if query.batch_size > 0 else 20
limit = query.limit if query.limit > 0 else 10000
# Extract query pattern
s_val = get_term_value(query.s)
p_val = get_term_value(query.p)
o_val = get_term_value(query.o)
g_val = query.g
# Helper to extract object metadata from result row
def get_o_metadata(t):
otype = getattr(t, 'otype', None)
dtype = getattr(t, 'dtype', None)
lang = getattr(t, 'lang', None)
return otype, dtype, lang
# For streaming, we need to execute with fetch_size
# Use the collection table for get_all queries (most common streaming case)
# Determine which query to use based on pattern
if s_val is None and p_val is None and o_val is None:
# Get all - use collection table with paging
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s"
params = [query.collection]
else:
# For specific patterns, fall back to non-streaming
# (these typically return small result sets anyway)
async for batch, is_final in self._fallback_stream(query, batch_size):
yield batch, is_final
return
# Create statement with fetch_size for true streaming
statement = SimpleStatement(cql, fetch_size=batch_size)
result_set = self.tg.session.execute(statement, params)
batch = []
count = 0
for row in result_set:
if count >= limit:
break
g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(row)
triple = Triple(
s=create_term(row.s),
p=create_term(row.p),
o=create_term(row.o, otype=otype, dtype=dtype, lang=lang),
g=g if g != DEFAULT_GRAPH else None
)
batch.append(triple)
count += 1
# Yield batch when full (never mark as final mid-stream)
if len(batch) >= batch_size:
yield batch, False
batch = []
# Always yield final batch to signal completion
# This handles: remaining rows, empty result set, or exact batch boundary
yield batch, True
except Exception as e:
logger.error(f"Exception in streaming query: {e}", exc_info=True)
raise e
async def _fallback_stream(self, query, batch_size):
"""Fallback to non-streaming query with post-hoc batching."""
triples = await self.query_triples(query)
for i in range(0, len(triples), batch_size):
batch = triples[i:i + batch_size]
is_final = (i + batch_size >= len(triples))
yield batch, is_final
if len(triples) == 0:
yield [], True
@staticmethod
def add_args(parser):

View file

@ -4,11 +4,25 @@ import logging
import time
from collections import OrderedDict
from ... schema import IRI, LITERAL
# Module logger
logger = logging.getLogger(__name__)
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
def term_to_string(term):
"""Extract string value from a Term object."""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
# Fallback
return term.iri or term.value or str(term)
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
@ -93,7 +107,7 @@ class Query:
)
entities = [
str(e.entity)
term_to_string(e.entity)
for e in entity_matches
]
@ -129,26 +143,29 @@ class Query:
return label
async def execute_batch_triple_queries(self, entities, limit_per_entity):
"""Execute triple queries for multiple entities concurrently"""
"""Execute triple queries for multiple entities concurrently using streaming"""
tasks = []
for entity in entities:
# Create concurrent tasks for all 3 query types per entity
# Create concurrent streaming tasks for all 3 query types per entity
tasks.extend([
self.rag.triples_client.query(
self.rag.triples_client.query_stream(
s=entity, p=None, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
user=self.user, collection=self.collection,
batch_size=20,
),
self.rag.triples_client.query(
self.rag.triples_client.query_stream(
s=None, p=entity, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
user=self.user, collection=self.collection,
batch_size=20,
),
self.rag.triples_client.query(
self.rag.triples_client.query_stream(
s=None, p=None, o=entity,
limit=limit_per_entity,
user=self.user, collection=self.collection
user=self.user, collection=self.collection,
batch_size=20,
)
])
@ -158,7 +175,7 @@ class Query:
# Combine all results
all_triples = []
for result in results:
if not isinstance(result, Exception):
if not isinstance(result, Exception) and result is not None:
all_triples.extend(result)
return all_triples