From d2d71f859def7a709ef8b453dfce11d7187df22e Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Mon, 9 Mar 2026 15:46:33 +0000 Subject: [PATCH] Feature/streaming triples (#676) * Steaming triples * Also GraphRAG service uses this * Updated tests --- .../integration/test_graph_rag_integration.py | 14 +- tests/unit/test_retrieval/test_graph_rag.py | 40 +++-- .../trustgraph/api/socket_client.py | 163 +++++++++++++++++- .../trustgraph/base/triples_client.py | 65 ++++++- .../trustgraph/base/triples_query_service.py | 42 ++++- .../messaging/translators/triples.py | 10 +- .../trustgraph/schema/services/query.py | 3 + .../trustgraph/cli/graph_to_turtle.py | 110 ++++++++---- trustgraph-cli/trustgraph/cli/show_graph.py | 48 +++++- .../query/triples/cassandra/service.py | 126 ++++++++++++-- .../retrieval/graph_rag/graph_rag.py | 37 ++-- 11 files changed, 542 insertions(+), 116 deletions(-) diff --git a/tests/integration/test_graph_rag_integration.py b/tests/integration/test_graph_rag_integration.py index 25a572c0..d7e39a2e 100644 --- a/tests/integration/test_graph_rag_integration.py +++ b/tests/integration/test_graph_rag_integration.py @@ -48,7 +48,7 @@ class TestGraphRagIntegration: client = AsyncMock() # Mock different queries return different triples - async def query_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None): + async def query_stream_side_effect(s=None, p=None, o=None, limit=None, user=None, collection=None, batch_size=20): # Mock label queries if p == "http://www.w3.org/2000/01/rdf-schema#label": if s == "http://trustgraph.ai/e/machine-learning": @@ -76,7 +76,9 @@ class TestGraphRagIntegration: return [] - client.query.side_effect = query_side_effect + client.query_stream.side_effect = query_stream_side_effect + # Also mock query for label lookups (maybe_label uses query, not query_stream) + client.query.side_effect = query_stream_side_effect return client @pytest.fixture @@ -137,7 +139,7 @@ class TestGraphRagIntegration: assert call_args.kwargs['collection'] == collection # 3. Should query triples to build knowledge subgraph - assert mock_triples_client.query.call_count > 0 + assert mock_triples_client.query_stream.call_count > 0 # 4. Should call prompt with knowledge graph mock_prompt_client.kg_prompt.assert_called_once() @@ -202,7 +204,7 @@ class TestGraphRagIntegration: """Test GraphRAG handles empty knowledge graph gracefully""" # Arrange mock_graph_embeddings_client.query.return_value = [] # No entities found - mock_triples_client.query.return_value = [] # No triples found + mock_triples_client.query_stream.return_value = [] # No triples found # Act result = await graph_rag.query( @@ -231,7 +233,7 @@ class TestGraphRagIntegration: collection="test_collection" ) - first_call_count = mock_triples_client.query.call_count + first_call_count = mock_triples_client.query_stream.call_count mock_triples_client.reset_mock() # Second identical query @@ -241,7 +243,7 @@ class TestGraphRagIntegration: collection="test_collection" ) - second_call_count = mock_triples_client.query.call_count + second_call_count = mock_triples_client.query_stream.call_count # Assert - Second query should make fewer triple queries due to caching # Note: This is a weak assertion because caching behavior depends on diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index e763d089..eddc1e12 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -193,15 +193,17 @@ class TestQuery: test_vectors = [[0.1, 0.2, 0.3]] mock_embeddings_client.embed.return_value = [test_vectors] - # Mock EntityMatch objects with entity that has string representation + # Mock EntityMatch objects with entity as Term-like object mock_entity1 = MagicMock() - mock_entity1.__str__ = MagicMock(return_value="entity1") + mock_entity1.type = "i" # IRI type + mock_entity1.iri = "entity1" mock_match1 = MagicMock() mock_match1.entity = mock_entity1 mock_match1.score = 0.95 mock_entity2 = MagicMock() - mock_entity2.__str__ = MagicMock(return_value="entity2") + mock_entity2.type = "i" # IRI type + mock_entity2.iri = "entity2" mock_match2 = MagicMock() mock_match2.entity = mock_entity2 mock_match2.score = 0.85 @@ -363,10 +365,10 @@ class TestQuery: mock_triple3 = MagicMock() mock_triple3.s, mock_triple3.p, mock_triple3.o = "subject3", "predicate3", "entity1" - # Setup query responses for s=ent, p=ent, o=ent patterns - mock_triples_client.query.side_effect = [ + # Setup query_stream responses for s=ent, p=ent, o=ent patterns + mock_triples_client.query_stream.side_effect = [ [mock_triple1], # s=ent, p=None, o=None - [mock_triple2], # s=None, p=ent, o=None + [mock_triple2], # s=None, p=ent, o=None [mock_triple3], # s=None, p=None, o=ent ] @@ -384,20 +386,20 @@ class TestQuery: await query.follow_edges("entity1", subgraph, path_length=1) # Verify all three query patterns were called - assert mock_triples_client.query.call_count == 3 - - # Verify query calls - mock_triples_client.query.assert_any_call( + assert mock_triples_client.query_stream.call_count == 3 + + # Verify query_stream calls + mock_triples_client.query_stream.assert_any_call( s="entity1", p=None, o=None, limit=10, - user="test_user", collection="test_collection" + user="test_user", collection="test_collection", batch_size=20 ) - mock_triples_client.query.assert_any_call( + mock_triples_client.query_stream.assert_any_call( s=None, p="entity1", o=None, limit=10, - user="test_user", collection="test_collection" + user="test_user", collection="test_collection", batch_size=20 ) - mock_triples_client.query.assert_any_call( + mock_triples_client.query_stream.assert_any_call( s=None, p=None, o="entity1", limit=10, - user="test_user", collection="test_collection" + user="test_user", collection="test_collection", batch_size=20 ) # Verify subgraph contains discovered triples @@ -427,9 +429,9 @@ class TestQuery: # Call follow_edges with path_length=0 subgraph = set() await query.follow_edges("entity1", subgraph, path_length=0) - + # Verify no queries were made - mock_triples_client.query.assert_not_called() + mock_triples_client.query_stream.assert_not_called() # Verify subgraph remains empty assert subgraph == set() @@ -456,9 +458,9 @@ class TestQuery: # Call follow_edges await query.follow_edges("entity1", subgraph, path_length=1) - + # Verify no queries were made due to size limit - mock_triples_client.query.assert_not_called() + mock_triples_client.query_stream.assert_not_called() # Verify subgraph unchanged assert len(subgraph) == 3 diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 26c241a7..700a4531 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -91,9 +91,18 @@ class SocketClient: service: str, flow: Optional[str], request: Dict[str, Any], - streaming: bool = False - ) -> Union[Dict[str, Any], Iterator[StreamingChunk]]: - """Synchronous wrapper around async WebSocket communication""" + streaming: bool = False, + streaming_raw: bool = False + ) -> Union[Dict[str, Any], Iterator[StreamingChunk], Iterator[Dict[str, Any]]]: + """Synchronous wrapper around async WebSocket communication. + + Args: + service: Service name + flow: Flow ID (optional) + request: Request payload + streaming: Use parsed streaming (for agent/RAG chunk types) + streaming_raw: Use raw streaming (for data batches like triples) + """ # Create event loop if needed try: loop = asyncio.get_event_loop() @@ -105,12 +114,14 @@ class SocketClient: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - if streaming: - # For streaming, we need to return an iterator - # Create a generator that runs async code + if streaming_raw: + # Raw streaming for data batches (triples, rows, etc.) + return self._streaming_generator_raw(service, flow, request, loop) + elif streaming: + # Parsed streaming for agent/RAG chunk types return self._streaming_generator(service, flow, request, loop) else: - # For non-streaming, just run the async code and return result + # Non-streaming single response return loop.run_until_complete(self._send_request_async(service, flow, request)) def _streaming_generator( @@ -120,7 +131,7 @@ class SocketClient: request: Dict[str, Any], loop: asyncio.AbstractEventLoop ) -> Iterator[StreamingChunk]: - """Generator that yields streaming chunks""" + """Generator that yields streaming chunks (for agent/RAG responses)""" async_gen = self._send_request_async_streaming(service, flow, request) try: @@ -137,6 +148,74 @@ class SocketClient: except: pass + def _streaming_generator_raw( + self, + service: str, + flow: Optional[str], + request: Dict[str, Any], + loop: asyncio.AbstractEventLoop + ) -> Iterator[Dict[str, Any]]: + """Generator that yields raw response dicts (for data streaming like triples)""" + async_gen = self._send_request_async_streaming_raw(service, flow, request) + + try: + while True: + try: + data = loop.run_until_complete(async_gen.__anext__()) + yield data + except StopAsyncIteration: + break + finally: + try: + loop.run_until_complete(async_gen.aclose()) + except: + pass + + async def _send_request_async_streaming_raw( + self, + service: str, + flow: Optional[str], + request: Dict[str, Any] + ) -> Iterator[Dict[str, Any]]: + """Async streaming that yields raw response dicts without parsing. + + Used for data streaming (triples, rows, etc.) where responses are + just batches of data, not agent/RAG chunk types. + """ + with self._lock: + self._request_counter += 1 + request_id = f"req-{self._request_counter}" + + ws_url = f"{self.url}/api/v1/socket" + if self.token: + ws_url = f"{ws_url}?token={self.token}" + + message = { + "id": request_id, + "service": service, + "request": request + } + if flow: + message["flow"] = flow + + async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket: + await websocket.send(json.dumps(message)) + + async for raw_message in websocket: + response = json.loads(raw_message) + + if response.get("id") != request_id: + continue + + if "error" in response: + raise_from_error_dict(response["error"]) + + if "response" in response: + yield response["response"] + + if response.get("complete"): + break + async def _send_request_async( self, service: str, @@ -790,6 +869,74 @@ class SocketFlowInstance: return self.client._send_request_sync("triples", self.flow_id, request, False) + def triples_query_stream( + self, + s: Optional[str] = None, + p: Optional[str] = None, + o: Optional[str] = None, + user: Optional[str] = None, + collection: Optional[str] = None, + limit: int = 100, + batch_size: int = 20, + **kwargs: Any + ) -> Iterator[List[Dict[str, Any]]]: + """ + Query knowledge graph triples with streaming batches. + + Yields batches of triples as they arrive, reducing time-to-first-result + and memory overhead for large result sets. + + Args: + s: Subject URI (optional, use None for wildcard) + p: Predicate URI (optional, use None for wildcard) + o: Object URI or Literal (optional, use None for wildcard) + user: User/keyspace identifier (optional) + collection: Collection identifier (optional) + limit: Maximum results to return (default: 100) + batch_size: Triples per batch (default: 20) + **kwargs: Additional parameters passed to the service + + Yields: + List[Dict]: Batches of triples in wire format + + Example: + ```python + socket = api.socket() + flow = socket.flow("default") + + for batch in flow.triples_query_stream( + user="trustgraph", + collection="default" + ): + for triple in batch: + print(triple["s"], triple["p"], triple["o"]) + ``` + """ + request = { + "limit": limit, + "streaming": True, + "batch-size": batch_size, + } + if s is not None: + request["s"] = str(s) + if p is not None: + request["p"] = str(p) + if o is not None: + request["o"] = str(o) + if user is not None: + request["user"] = user + if collection is not None: + request["collection"] = collection + request.update(kwargs) + + # Use raw streaming - yields response dicts directly without parsing + for response in self.client._send_request_sync("triples", self.flow_id, request, streaming_raw=True): + # Response is {"response": [...triples...]} from translator + if isinstance(response, dict) and "response" in response: + yield response["response"] + else: + yield response + def rows_query( self, query: str, diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py index 7258d3ca..84e95ebe 100644 --- a/trustgraph-base/trustgraph/base/triples_client.py +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -22,11 +22,19 @@ def to_value(x): def from_value(x): - """Convert Uri or Literal to schema Term.""" + """Convert Uri, Literal, or string to schema Term.""" if x is None: return None if isinstance(x, Uri): return Term(type=IRI, iri=str(x)) + elif isinstance(x, Literal): + return Term(type=LITERAL, value=str(x)) + elif isinstance(x, str): + # Detect IRIs by common prefixes + if x.startswith("http://") or x.startswith("https://") or x.startswith("urn:"): + return Term(type=IRI, iri=x) + else: + return Term(type=LITERAL, value=x) else: return Term(type=LITERAL, value=str(x)) @@ -57,6 +65,61 @@ class TriplesClient(RequestResponse): return triples + async def query_stream(self, s=None, p=None, o=None, limit=20, + user="trustgraph", collection="default", + batch_size=20, timeout=30, + batch_callback=None): + """ + Streaming triple query - calls callback for each batch as it arrives. + + Args: + s, p, o: Triple pattern (None for wildcard) + limit: Maximum total triples to return + user: User/keyspace + collection: Collection name + batch_size: Triples per batch + timeout: Request timeout in seconds + batch_callback: Async callback(batch, is_final) called for each batch + + Returns: + List[Triple]: All triples (flattened) if no callback provided + """ + all_triples = [] + + async def recipient(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + batch = [ + Triple(to_value(v.s), to_value(v.p), to_value(v.o)) + for v in resp.triples + ] + + if batch_callback: + await batch_callback(batch, resp.is_final) + else: + all_triples.extend(batch) + + return resp.is_final + + await self.request( + TriplesQueryRequest( + s=from_value(s), + p=from_value(p), + o=from_value(o), + limit=limit, + user=user, + collection=collection, + streaming=True, + batch_size=batch_size, + ), + timeout=timeout, + recipient=recipient, + ) + + if not batch_callback: + return all_triples + class TriplesClientSpec(RequestResponseSpec): def __init__( self, request_name, response_name, diff --git a/trustgraph-base/trustgraph/base/triples_query_service.py b/trustgraph-base/trustgraph/base/triples_query_service.py index b156ef55..b8053b01 100644 --- a/trustgraph-base/trustgraph/base/triples_query_service.py +++ b/trustgraph-base/trustgraph/base/triples_query_service.py @@ -52,13 +52,23 @@ class TriplesQueryService(FlowProcessor): logger.debug(f"Handling triples query request {id}...") - triples = await self.query_triples(request) - - logger.debug("Sending triples query response...") - r = TriplesQueryResponse(triples=triples, error=None) - await flow("response").send(r, properties={"id": id}) - - logger.debug("Triples query request completed") + if request.streaming: + # Streaming mode: send batches + async for batch, is_final in self.query_triples_stream(request): + r = TriplesQueryResponse( + triples=batch, + error=None, + is_final=is_final, + ) + await flow("response").send(r, properties={"id": id}) + logger.debug("Triples query streaming completed") + else: + # Non-streaming mode: single response + triples = await self.query_triples(request) + logger.debug("Sending triples query response...") + r = TriplesQueryResponse(triples=triples, error=None) + await flow("response").send(r, properties={"id": id}) + logger.debug("Triples query request completed") except Exception as e: @@ -76,6 +86,24 @@ class TriplesQueryService(FlowProcessor): await flow("response").send(r, properties={"id": id}) + async def query_triples_stream(self, request): + """ + Streaming query - yields (batch, is_final) tuples. + Default implementation batches results from query_triples. + Override for true streaming from backend. + """ + triples = await self.query_triples(request) + batch_size = request.batch_size if request.batch_size > 0 else 20 + + for i in range(0, len(triples), batch_size): + batch = triples[i:i + batch_size] + is_final = (i + batch_size >= len(triples)) + yield batch, is_final + + # Handle empty result + if len(triples) == 0: + yield [], True + @staticmethod def add_args(parser): diff --git a/trustgraph-base/trustgraph/messaging/translators/triples.py b/trustgraph-base/trustgraph/messaging/translators/triples.py index 2b01b1bc..2f29aa56 100644 --- a/trustgraph-base/trustgraph/messaging/translators/triples.py +++ b/trustgraph-base/trustgraph/messaging/translators/triples.py @@ -23,14 +23,18 @@ class TriplesQueryRequestTranslator(MessageTranslator): g=g, limit=int(data.get("limit", 10000)), user=data.get("user", "trustgraph"), - collection=data.get("collection", "default") + collection=data.get("collection", "default"), + streaming=data.get("streaming", False), + batch_size=int(data.get("batch-size", 20)), ) def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]: result = { "limit": obj.limit, "user": obj.user, - "collection": obj.collection + "collection": obj.collection, + "streaming": obj.streaming, + "batch-size": obj.batch_size, } if obj.s: @@ -61,4 +65,4 @@ class TriplesQueryResponseTranslator(MessageTranslator): def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]: """Returns (response_dict, is_final)""" - return self.from_pulsar(obj), True \ No newline at end of file + return self.from_pulsar(obj), obj.is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/query.py b/trustgraph-base/trustgraph/schema/services/query.py index 67caa2be..7a65f775 100644 --- a/trustgraph-base/trustgraph/schema/services/query.py +++ b/trustgraph-base/trustgraph/schema/services/query.py @@ -38,11 +38,14 @@ class TriplesQueryRequest: o: Term | None = None g: str | None = None # Graph IRI. None=default graph, "*"=all graphs limit: int = 0 + streaming: bool = False # Enable streaming mode (multiple batched responses) + batch_size: int = 20 # Triples per batch in streaming mode @dataclass class TriplesQueryResponse: error: Error | None = None triples: list[Triple] = field(default_factory=list) + is_final: bool = True # False for intermediate batches in streaming mode ############################################################################ diff --git a/trustgraph-cli/trustgraph/cli/graph_to_turtle.py b/trustgraph-cli/trustgraph/cli/graph_to_turtle.py index f42fe140..840f8574 100644 --- a/trustgraph-cli/trustgraph/cli/graph_to_turtle.py +++ b/trustgraph-cli/trustgraph/cli/graph_to_turtle.py @@ -1,6 +1,7 @@ """ Connects to the graph query service and dumps all graph edges in Turtle format with RDF-star support for quoted triples. +Uses streaming mode for lower time-to-first-processing. """ import rdflib @@ -9,66 +10,82 @@ import sys import argparse import os -from trustgraph.api import Api, Uri -from trustgraph.knowledge import QuotedTriple +from trustgraph.api import Api default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') default_user = 'trustgraph' default_collection = 'default' +default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def value_to_rdflib(val): - """Convert a TrustGraph value to an rdflib term.""" - if isinstance(val, Uri): +def term_to_rdflib(term): + """Convert a wire-format term to an rdflib term.""" + if term is None: + return None + + t = term.get("t", "") + + if t == "i": # IRI + iri = term.get("i", "") # Skip malformed URLs with spaces - if " " in val: + if " " in iri: return None - return rdflib.term.URIRef(val) - elif isinstance(val, QuotedTriple): - # RDF-star quoted triple - s_term = value_to_rdflib(val.s) - p_term = value_to_rdflib(val.p) - o_term = value_to_rdflib(val.o) + return rdflib.term.URIRef(iri) + elif t == "l": # Literal + value = term.get("v", "") + datatype = term.get("d") + language = term.get("l") + if language: + return rdflib.term.Literal(value, lang=language) + elif datatype: + return rdflib.term.Literal(value, datatype=rdflib.term.URIRef(datatype)) + else: + return rdflib.term.Literal(value) + elif t == "r": # Quoted triple (RDF-star) + triple = term.get("r", {}) + s_term = term_to_rdflib(triple.get("s")) + p_term = term_to_rdflib(triple.get("p")) + o_term = term_to_rdflib(triple.get("o")) if s_term is None or p_term is None or o_term is None: return None - # rdflib 6.x+ supports Triple as a term type try: return rdflib.term.Triple((s_term, p_term, o_term)) except AttributeError: - # Fallback for older rdflib versions - represent as string - return rdflib.term.Literal(f"<<{val.s} {val.p} {val.o}>>") + # Fallback for older rdflib versions + return rdflib.term.Literal(f"<<{s_term} {p_term} {o_term}>>") else: - return rdflib.term.Literal(str(val)) + # Fallback + return rdflib.term.Literal(str(term)) -def show_graph(url, flow_id, user, collection): +def show_graph(url, flow_id, user, collection, limit, batch_size, token=None): - api = Api(url).flow().id(flow_id) - - rows = api.triples_query( - s=None, p=None, o=None, - user=user, collection=collection, - limit=10_000) + socket = Api(url, token=token).socket() + flow = socket.flow(flow_id) g = rdflib.Graph() - for row in rows: + try: + for batch in flow.triples_query_stream( + s=None, p=None, o=None, + user=user, collection=collection, + limit=limit, + batch_size=batch_size, + ): + for triple in batch: + sv = term_to_rdflib(triple.get("s")) + pv = term_to_rdflib(triple.get("p")) + ov = term_to_rdflib(triple.get("o")) - sv = rdflib.term.URIRef(row.s) - pv = rdflib.term.URIRef(row.p) - ov = value_to_rdflib(row.o) + if sv is None or pv is None or ov is None: + continue - if ov is None: - continue - - g.add((sv, pv, ov)) - - g.serialize(destination="output.ttl", format="turtle") + g.add((sv, pv, ov)) + finally: + socket.close() buf = io.BytesIO() - g.serialize(destination=buf, format="turtle") - sys.stdout.write(buf.getvalue().decode("utf-8")) @@ -103,6 +120,26 @@ def main(): help=f'Collection ID (default: {default_collection})' ) + parser.add_argument( + '-t', '--token', + default=default_token, + help='Authentication token (default: $TRUSTGRAPH_TOKEN)', + ) + + parser.add_argument( + '-l', '--limit', + type=int, + default=10000, + help='Maximum number of triples to return (default: 10000)', + ) + + parser.add_argument( + '-b', '--batch-size', + type=int, + default=20, + help='Triples per streaming batch (default: 20)', + ) + args = parser.parse_args() try: @@ -112,6 +149,9 @@ def main(): flow_id = args.flow_id, user = args.user, collection = args.collection, + limit = args.limit, + batch_size = args.batch_size, + token = args.token, ) except Exception as e: diff --git a/trustgraph-cli/trustgraph/cli/show_graph.py b/trustgraph-cli/trustgraph/cli/show_graph.py index b5b15e3c..105fe604 100644 --- a/trustgraph-cli/trustgraph/cli/show_graph.py +++ b/trustgraph-cli/trustgraph/cli/show_graph.py @@ -1,5 +1,6 @@ """ Connects to the graph query service and dumps all graph edges. +Uses streaming mode for lower time-to-first-result and reduced memory overhead. """ import argparse @@ -11,17 +12,30 @@ default_user = 'trustgraph' default_collection = 'default' default_token = os.getenv("TRUSTGRAPH_TOKEN", None) -def show_graph(url, flow_id, user, collection, token=None): +def show_graph(url, flow_id, user, collection, limit, batch_size, token=None): - api = Api(url, token=token).flow().id(flow_id) + socket = Api(url, token=token).socket() + flow = socket.flow(flow_id) - rows = api.triples_query( - user=user, collection=collection, - s=None, p=None, o=None, limit=10_000, - ) - - for row in rows: - print(row.s, row.p, row.o) + try: + for batch in flow.triples_query_stream( + user=user, + collection=collection, + s=None, p=None, o=None, + limit=limit, + batch_size=batch_size, + ): + for triple in batch: + s = triple.get("s", {}) + p = triple.get("p", {}) + o = triple.get("o", {}) + # Format terms for display + s_str = s.get("v", s.get("i", str(s))) + p_str = p.get("v", p.get("i", str(p))) + o_str = o.get("v", o.get("i", str(o))) + print(s_str, p_str, o_str) + finally: + socket.close() def main(): @@ -60,6 +74,20 @@ def main(): help='Authentication token (default: $TRUSTGRAPH_TOKEN)', ) + parser.add_argument( + '-l', '--limit', + type=int, + default=10000, + help='Maximum number of triples to return (default: 10000)', + ) + + parser.add_argument( + '-b', '--batch-size', + type=int, + default=20, + help='Triples per streaming batch (default: 20)', + ) + args = parser.parse_args() try: @@ -69,6 +97,8 @@ def main(): flow_id = args.flow_id, user = args.user, collection = args.collection, + limit = args.limit, + batch_size = args.batch_size, token = args.token, ) diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index 4d4290b1..9cea4f48 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -7,6 +7,7 @@ null. Output is a list of quads. import logging import json +from cassandra.query import SimpleStatement from .... direct.cassandra_kg import ( EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH @@ -144,28 +145,30 @@ class Processor(TriplesQueryService): self.cassandra_password = password self.table = None + def ensure_connection(self, user): + """Ensure we have a connection to the correct keyspace.""" + if user != self.table: + KGClass = EntityCentricKnowledgeGraph + + if self.cassandra_username and self.cassandra_password: + self.tg = KGClass( + hosts=self.cassandra_host, + keyspace=user, + username=self.cassandra_username, + password=self.cassandra_password + ) + else: + self.tg = KGClass( + hosts=self.cassandra_host, + keyspace=user, + ) + self.table = user + async def query_triples(self, query): try: - user = query.user - - if user != self.table: - # Use factory function to select implementation - KGClass = EntityCentricKnowledgeGraph - - if self.cassandra_username and self.cassandra_password: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=query.user, - username=self.cassandra_username, password=self.cassandra_password - ) - else: - self.tg = KGClass( - hosts=self.cassandra_host, - keyspace=query.user, - ) - self.table = user + self.ensure_connection(query.user) # Extract values from query s_val = get_term_value(query.s) @@ -291,6 +294,93 @@ class Processor(TriplesQueryService): logger.error(f"Exception querying triples: {e}", exc_info=True) raise e + async def query_triples_stream(self, query): + """ + Streaming query - yields (batch, is_final) tuples. + Uses Cassandra's paging to fetch results incrementally. + """ + try: + self.ensure_connection(query.user) + + batch_size = query.batch_size if query.batch_size > 0 else 20 + limit = query.limit if query.limit > 0 else 10000 + + # Extract query pattern + s_val = get_term_value(query.s) + p_val = get_term_value(query.p) + o_val = get_term_value(query.o) + g_val = query.g + + # Helper to extract object metadata from result row + def get_o_metadata(t): + otype = getattr(t, 'otype', None) + dtype = getattr(t, 'dtype', None) + lang = getattr(t, 'lang', None) + return otype, dtype, lang + + # For streaming, we need to execute with fetch_size + # Use the collection table for get_all queries (most common streaming case) + + # Determine which query to use based on pattern + if s_val is None and p_val is None and o_val is None: + # Get all - use collection table with paging + cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s" + params = [query.collection] + else: + # For specific patterns, fall back to non-streaming + # (these typically return small result sets anyway) + async for batch, is_final in self._fallback_stream(query, batch_size): + yield batch, is_final + return + + # Create statement with fetch_size for true streaming + statement = SimpleStatement(cql, fetch_size=batch_size) + result_set = self.tg.session.execute(statement, params) + + batch = [] + count = 0 + + for row in result_set: + if count >= limit: + break + + g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH + otype, dtype, lang = get_o_metadata(row) + + triple = Triple( + s=create_term(row.s), + p=create_term(row.p), + o=create_term(row.o, otype=otype, dtype=dtype, lang=lang), + g=g if g != DEFAULT_GRAPH else None + ) + batch.append(triple) + count += 1 + + # Yield batch when full (never mark as final mid-stream) + if len(batch) >= batch_size: + yield batch, False + batch = [] + + # Always yield final batch to signal completion + # This handles: remaining rows, empty result set, or exact batch boundary + yield batch, True + + except Exception as e: + logger.error(f"Exception in streaming query: {e}", exc_info=True) + raise e + + async def _fallback_stream(self, query, batch_size): + """Fallback to non-streaming query with post-hoc batching.""" + triples = await self.query_triples(query) + + for i in range(0, len(triples), batch_size): + batch = triples[i:i + batch_size] + is_final = (i + batch_size >= len(triples)) + yield batch, is_final + + if len(triples) == 0: + yield [], True + @staticmethod def add_args(parser): diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py index 2bf6b2ea..8dbeb41b 100644 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/graph_rag.py @@ -4,11 +4,25 @@ import logging import time from collections import OrderedDict +from ... schema import IRI, LITERAL + # Module logger logger = logging.getLogger(__name__) LABEL="http://www.w3.org/2000/01/rdf-schema#label" + +def term_to_string(term): + """Extract string value from a Term object.""" + if term is None: + return None + if term.type == IRI: + return term.iri + elif term.type == LITERAL: + return term.value + # Fallback + return term.iri or term.value or str(term) + class LRUCacheWithTTL: """LRU cache with TTL for label caching @@ -93,7 +107,7 @@ class Query: ) entities = [ - str(e.entity) + term_to_string(e.entity) for e in entity_matches ] @@ -129,26 +143,29 @@ class Query: return label async def execute_batch_triple_queries(self, entities, limit_per_entity): - """Execute triple queries for multiple entities concurrently""" + """Execute triple queries for multiple entities concurrently using streaming""" tasks = [] for entity in entities: - # Create concurrent tasks for all 3 query types per entity + # Create concurrent streaming tasks for all 3 query types per entity tasks.extend([ - self.rag.triples_client.query( + self.rag.triples_client.query_stream( s=entity, p=None, o=None, limit=limit_per_entity, - user=self.user, collection=self.collection + user=self.user, collection=self.collection, + batch_size=20, ), - self.rag.triples_client.query( + self.rag.triples_client.query_stream( s=None, p=entity, o=None, limit=limit_per_entity, - user=self.user, collection=self.collection + user=self.user, collection=self.collection, + batch_size=20, ), - self.rag.triples_client.query( + self.rag.triples_client.query_stream( s=None, p=None, o=entity, limit=limit_per_entity, - user=self.user, collection=self.collection + user=self.user, collection=self.collection, + batch_size=20, ) ]) @@ -158,7 +175,7 @@ class Query: # Combine all results all_triples = [] for result in results: - if not isinstance(result, Exception): + if not isinstance(result, Exception) and result is not None: all_triples.extend(result) return all_triples