diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 847513d3..9c37a9b1 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -812,6 +812,31 @@ class SocketFlowInstance: else: yield response + def sparql_query_stream( + self, + query: str, + user: str = "trustgraph", + collection: str = "default", + limit: int = 10000, + batch_size: int = 20, + **kwargs: Any + ) -> Iterator[Dict[str, Any]]: + """Execute a SPARQL query with streaming batches.""" + request = { + "query": query, + "user": user, + "collection": collection, + "limit": limit, + "streaming": True, + "batch-size": batch_size, + } + request.update(kwargs) + + for response in self.client._send_request_sync( + "sparql", self.flow_id, request, streaming_raw=True + ): + yield response + def rows_query( self, query: str, diff --git a/trustgraph-base/trustgraph/messaging/translators/sparql_query.py b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py index d1912429..a8b13865 100644 --- a/trustgraph-base/trustgraph/messaging/translators/sparql_query.py +++ b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py @@ -16,6 +16,8 @@ class SparqlQueryRequestTranslator(MessageTranslator): collection=data.get("collection", "default"), query=data.get("query", ""), limit=int(data.get("limit", 10000)), + streaming=data.get("streaming", False), + batch_size=int(data.get("batch-size", 20)), ) def encode(self, obj: SparqlQueryRequest) -> Dict[str, Any]: @@ -24,6 +26,8 @@ class SparqlQueryRequestTranslator(MessageTranslator): "collection": obj.collection, "query": obj.query, "limit": obj.limit, + "streaming": obj.streaming, + "batch-size": obj.batch_size, } @@ -108,4 +112,4 @@ class SparqlQueryResponseTranslator(MessageTranslator): def encode_with_completion( self, obj: SparqlQueryResponse ) -> Tuple[Dict[str, Any], bool]: - return self.encode(obj), True + return self.encode(obj), obj.is_final diff --git a/trustgraph-base/trustgraph/schema/services/sparql_query.py b/trustgraph-base/trustgraph/schema/services/sparql_query.py index 105cc753..62c02c93 100644 --- a/trustgraph-base/trustgraph/schema/services/sparql_query.py +++ b/trustgraph-base/trustgraph/schema/services/sparql_query.py @@ -20,6 +20,8 @@ class SparqlQueryRequest: collection: str = "" query: str = "" # SPARQL query string limit: int = 10000 # Safety limit on results + streaming: bool = False # Enable streaming mode + batch_size: int = 20 # Bindings per batch in streaming mode @dataclass class SparqlQueryResponse: @@ -36,5 +38,7 @@ class SparqlQueryResponse: # For CONSTRUCT/DESCRIBE queries triples: list[Triple] = field(default_factory=list) + is_final: bool = True # False for intermediate batches in streaming + sparql_query_request_queue = queue('sparql-query', cls='request') sparql_query_response_queue = queue('sparql-query', cls='response') diff --git a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py index 9547193d..82e48456 100644 --- a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py +++ b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py @@ -13,85 +13,20 @@ default_user = 'trustgraph' default_collection = 'default' -def format_select(response, output_format): - """Format SELECT query results.""" - variables = response.get("variables", []) - bindings = response.get("bindings", []) - - if not bindings: - return "No results." - - if output_format == "json": - rows = [] - for binding in bindings: - row = {} - for var, val in zip(variables, binding.get("values", [])): - if val is None: - row[var] = None - elif val.get("t") == "i": - row[var] = val.get("i", "") - elif val.get("t") == "l": - row[var] = val.get("v", "") - else: - row[var] = val.get("v", val.get("i", "")) - rows.append(row) - return json.dumps(rows, indent=2) - - # Table format - col_widths = [len(v) for v in variables] - rows = [] - for binding in bindings: - row = [] - for i, val in enumerate(binding.get("values", [])): - if val is None: - cell = "" - elif val.get("t") == "i": - cell = val.get("i", "") - elif val.get("t") == "l": - cell = val.get("v", "") - else: - cell = val.get("v", val.get("i", "")) - row.append(cell) - if i < len(col_widths): - col_widths[i] = max(col_widths[i], len(cell)) - rows.append(row) - - # Build table - header = " | ".join( - v.ljust(col_widths[i]) for i, v in enumerate(variables) - ) - separator = "-+-".join("-" * w for w in col_widths) - lines = [header, separator] - for row in rows: - line = " | ".join( - cell.ljust(col_widths[i]) if i < len(col_widths) else cell - for i, cell in enumerate(row) - ) - lines.append(line) - return "\n".join(lines) - - -def format_triples(response, output_format): - """Format CONSTRUCT/DESCRIBE results.""" - triples = response.get("triples", []) - - if not triples: - return "No triples." - - if output_format == "json": - return json.dumps(triples, indent=2) - - lines = [] - for t in triples: - s = _term_str(t.get("s")) - p = _term_str(t.get("p")) - o = _term_str(t.get("o")) - lines.append(f"{s} {p} {o} .") - return "\n".join(lines) +def _term_cell(val): + """Extract display string from a wire-format term.""" + if val is None: + return "" + t = val.get("t", "") + if t == "i": + return val.get("i", "") + elif t == "l": + return val.get("v", "") + return val.get("v", val.get("i", "")) def _term_str(val): - """Convert a wire-format term to a display string.""" + """Convert a wire-format term to a Turtle-style display string.""" if val is None: return "?" t = val.get("t", "") @@ -110,27 +45,93 @@ def _term_str(val): def sparql_query(url, token, flow_id, query, user, collection, limit, - output_format): + batch_size, output_format): - api = Api(url=url, token=token).flow().id(flow_id) + socket = Api(url=url, token=token).socket() + flow = socket.flow(flow_id) - resp = api.sparql_query( - query=query, - user=user, - collection=collection, - limit=limit, - ) + variables = None + all_rows = [] - query_type = resp.get("query-type", "select") + try: - if query_type == "select": - print(format_select(resp, output_format)) - elif query_type == "ask": - print("true" if resp.get("ask-result") else "false") - elif query_type in ("construct", "describe"): - print(format_triples(resp, output_format)) - else: - print(json.dumps(resp, indent=2)) + for response in flow.sparql_query_stream( + query=query, + user=user, + collection=collection, + limit=limit, + batch_size=batch_size, + ): + query_type = response.get("query-type", "select") + + # ASK queries - just print and return + if query_type == "ask": + print("true" if response.get("ask-result") else "false") + return + + # CONSTRUCT/DESCRIBE - print triples + if query_type in ("construct", "describe"): + triples = response.get("triples", []) + if not triples: + print("No triples.") + elif output_format == "json": + print(json.dumps(triples, indent=2)) + else: + for t in triples: + s = _term_str(t.get("s")) + p = _term_str(t.get("p")) + o = _term_str(t.get("o")) + print(f"{s} {p} {o} .") + return + + # SELECT - accumulate bindings across batches + if variables is None: + variables = response.get("variables", []) + + bindings = response.get("bindings", []) + for binding in bindings: + values = binding.get("values", []) + all_rows.append([_term_cell(v) for v in values]) + + # Output SELECT results + if variables is None: + print("No results.") + return + + if not all_rows: + print("No results.") + return + + if output_format == "json": + rows = [] + for row in all_rows: + rows.append({ + var: cell for var, cell in zip(variables, row) + }) + print(json.dumps(rows, indent=2)) + else: + # Table format + col_widths = [len(v) for v in variables] + for row in all_rows: + for i, cell in enumerate(row): + if i < len(col_widths): + col_widths[i] = max(col_widths[i], len(cell)) + + header = " | ".join( + v.ljust(col_widths[i]) for i, v in enumerate(variables) + ) + separator = "-+-".join("-" * w for w in col_widths) + print(header) + print(separator) + for row in all_rows: + line = " | ".join( + cell.ljust(col_widths[i]) if i < len(col_widths) else cell + for i, cell in enumerate(row) + ) + print(line) + + finally: + socket.close() def main(): @@ -187,6 +188,13 @@ def main(): help='Result limit (default: 10000)', ) + parser.add_argument( + '-b', '--batch-size', + type=int, + default=20, + help='Streaming batch size (default: 20)', + ) + parser.add_argument( '--format', choices=['table', 'json'], @@ -218,6 +226,7 @@ def main(): user=args.user, collection=args.collection, limit=args.limit, + batch_size=args.batch_size, output_format=args.format, ) diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index e815540f..38488032 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -68,7 +68,12 @@ class Processor(FlowProcessor): response = await self.execute_sparql(request, flow) - await flow("response").send(response, properties={"id": id}) + if request.streaming and response.query_type == "select": + await self.send_streaming(response, flow, id, request) + else: + await flow("response").send( + response, properties={"id": id} + ) logger.debug("SPARQL query request completed") @@ -87,6 +92,33 @@ class Processor(FlowProcessor): await flow("response").send(r, properties={"id": id}) + async def send_streaming(self, response, flow, id, request): + """Send SELECT results in batches.""" + + bindings = response.bindings + batch_size = request.batch_size if request.batch_size > 0 else 20 + + for i in range(0, len(bindings), batch_size): + batch = bindings[i:i + batch_size] + is_final = (i + batch_size >= len(bindings)) + r = SparqlQueryResponse( + query_type=response.query_type, + variables=response.variables, + bindings=batch, + is_final=is_final, + ) + await flow("response").send(r, properties={"id": id}) + + # Handle empty results + if len(bindings) == 0: + r = SparqlQueryResponse( + query_type=response.query_type, + variables=response.variables, + bindings=[], + is_final=True, + ) + await flow("response").send(r, properties={"id": id}) + async def execute_sparql(self, request, flow): """Parse and evaluate a SPARQL query."""