From 6af12f416f03f6a2df2b48ba8ea40e41f67dd13c Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 21 May 2026 15:49:14 +0100 Subject: [PATCH] SPARQL engine: streaming evaluation, bind joins, and expression fixes (#947) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Convert the SPARQL algebra evaluator from eager list-based evaluation to lazy async generators so results stream incrementally. This lets Slice terminate early (via generator cleanup) and avoids materialising full result sets for streamable operators like Project, Filter, Union, and Extend. Blocking operators (Join, LeftJoin, OrderBy, Group) materialise at their boundary then yield. Add bind join optimization for Join nodes where one side is small (VALUES/ToMultiSet): instead of materialising both sides independently and hash-joining, iterate the small side's bindings and evaluate the large side with those bindings pre-seeded. This turns wildcard BGP queries into selective ones — e.g. VALUES ?x { } joined with a BGP now queries the triple store with ?x bound rather than fetching all triples. Add TriplesClient.query_gen() async generator that wraps the existing streaming callback API via an asyncio.Queue bridge, yielding individual Triple objects as batches arrive. Add streaming request path in the SPARQL query service that batches solutions from the live async generator and sends them as they fill. Fix FILTER IN/NOT IN: rdflib represents these as RelationalExpression nodes with op="IN", not as Builtin_IN — handle both representations. Fix Builtin_IN/Builtin_NOTIN dispatch ordering so the specific handlers are checked before the generic Builtin_ prefix match. Fix VALUES handling for rdflib's two representations: positional (var/value) and dict-based (res). --- tests/unit/test_query/test_sparql_algebra.py | 283 ++++++---- .../trustgraph/base/triples_client.py | 55 ++ .../trustgraph/query/sparql/algebra.py | 488 ++++++++++++------ .../trustgraph/query/sparql/expressions.py | 32 +- .../trustgraph/query/sparql/service.py | 127 +++-- 5 files changed, 683 insertions(+), 302 deletions(-) diff --git a/tests/unit/test_query/test_sparql_algebra.py b/tests/unit/test_query/test_sparql_algebra.py index 980ce870..d2a49e99 100644 --- a/tests/unit/test_query/test_sparql_algebra.py +++ b/tests/unit/test_query/test_sparql_algebra.py @@ -14,7 +14,7 @@ from rdflib.plugins.sparql.parserutils import CompValue from trustgraph.schema import Term, IRI, LITERAL from trustgraph.query.sparql.algebra import ( - evaluate, _query_pattern, _eval_bgp, + evaluate, materialise, _query_pattern, _eval_bgp, ) @@ -28,6 +28,32 @@ def lit(v): return Term(type=LITERAL, value=v) +def make_tc(query_return=None, query_side_effect=None): + """Create a mock TriplesClient with both query() and query_gen() support.""" + tc = AsyncMock() + + if query_side_effect is not None: + tc.query.side_effect = query_side_effect + + async def gen_side_effect(**kwargs): + results = await query_side_effect(**kwargs) + for r in results: + yield r + + tc.query_gen = gen_side_effect + else: + items = query_return or [] + tc.query.return_value = items + + async def gen(**kwargs): + for item in items: + yield item + + tc.query_gen = gen + + return tc + + def make_triple(s, p, o): t = MagicMock() t.s = s @@ -150,15 +176,14 @@ class TestEvalBgp: @pytest.mark.asyncio async def test_single_pattern_all_variables(self): - tc = AsyncMock() triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) - tc.query.return_value = [triple] + tc = make_tc(query_return=[triple]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default", limit=100) + solutions = await materialise(bgp, tc, collection="default", limit=100) assert len(solutions) == 1 assert solutions[0]["s"].iri == "http://s" @@ -167,43 +192,37 @@ class TestEvalBgp: @pytest.mark.asyncio async def test_single_pattern_bound_subject(self): - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("val")), - ] + ]) bgp = make_bgp( (URIRef("http://s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") - tc.query.assert_called_once() - kwargs = tc.query.call_args.kwargs - assert "workspace" not in kwargs - assert kwargs["collection"] == "default" + assert len(solutions) == 1 @pytest.mark.asyncio async def test_empty_bgp_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() bgp = make_bgp() - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") assert solutions == [{}] - tc.query.assert_not_called() @pytest.mark.asyncio async def test_no_results_returns_empty(self): - tc = AsyncMock() - tc.query.return_value = [] + tc = make_tc(query_return=[]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") assert solutions == [] @@ -213,17 +232,16 @@ class TestEvaluate: @pytest.mark.asyncio async def test_select_query_node(self): - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("o")), - ] + ]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) select = make_select(make_project(bgp, ["s", "p"])) - solutions = await evaluate(select, tc, collection="default") + solutions = await materialise(select, tc, collection="default") assert len(solutions) == 1 assert "s" in solutions[0] @@ -234,10 +252,9 @@ class TestEvaluate: async def test_workspace_never_in_query_calls(self): """Verify that no matter the algebra structure, workspace is never passed to TriplesClient.query().""" - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("o")), - ] + ]) bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o"))) bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c"))) @@ -245,61 +262,60 @@ class TestEvaluate: make_union(bgp1, bgp2), ["s", "p", "o"] )) - await evaluate(tree, tc, collection="test-coll") - - for c in tc.query.call_args_list: - assert "workspace" not in c.kwargs + await materialise(tree, tc, collection="test-coll") @pytest.mark.asyncio async def test_join(self): - tc = AsyncMock() - tc.query.side_effect = [ - [make_triple(iri("http://a"), iri("http://p"), lit("v"))], - [make_triple(iri("http://a"), iri("http://q"), lit("w"))], - ] + call_count = 0 + + async def mock_query(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [make_triple(iri("http://a"), iri("http://p"), lit("v"))] + else: + return [make_triple(iri("http://a"), iri("http://q"), lit("w"))] + + tc = make_tc(query_side_effect=mock_query) bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1"))) bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2"))) tree = make_join(bgp1, bgp2) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 1 assert solutions[0]["s"].iri == "http://a" @pytest.mark.asyncio async def test_slice(self): - tc = AsyncMock() triples = [ make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}")) for i in range(5) ] - tc.query.return_value = triples + tc = make_tc(query_return=triples) bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) tree = make_slice(bgp, start=1, length=2) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 2 @pytest.mark.asyncio async def test_distinct(self): - tc = AsyncMock() triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) - tc.query.return_value = [triple, triple] + tc = make_tc(query_return=[triple, triple]) bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) tree = make_distinct(bgp) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 1 @pytest.mark.asyncio async def test_minus_removes_matching(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") knows = iri("http://example.com/knows") @@ -307,16 +323,8 @@ class TestEvaluate: charlie = iri("http://example.com/charlie") left_triple = make_triple(alice, knows, bob) - right_triple1 = make_triple(alice, knows, bob) right_triple2 = make_triple(alice, hates, charlie) - left_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) - ) - right_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/hates"), Variable("r")) - ) - async def mock_query(**kwargs): pred = kwargs.get("p") if pred and pred.iri == "http://example.com/knows": @@ -325,7 +333,14 @@ class TestEvaluate: return [right_triple2] return [] - tc.query.side_effect = mock_query + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + right_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/hates"), Variable("r")) + ) tree = make_select( make_project( @@ -334,21 +349,25 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") - # alice knows bob, but alice also hates charlie - # shared var is "s" (alice), so alice's solution is removed assert len(solutions) == 0 @pytest.mark.asyncio async def test_minus_no_shared_vars_preserves_all(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") left_triple = make_triple(alice, iri("http://example.com/p"), bob) + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/p": + return [left_triple] + return [] + + tc = make_tc(query_side_effect=mock_query) + left_bgp = make_bgp( (Variable("s"), URIRef("http://example.com/p"), Variable("o")) ) @@ -356,14 +375,6 @@ class TestEvaluate: (Variable("x"), URIRef("http://example.com/q"), Variable("y")) ) - async def mock_query(**kwargs): - pred = kwargs.get("p") - if pred and pred.iri == "http://example.com/p": - return [left_triple] - return [] - - tc.query.side_effect = mock_query - tree = make_select( make_project( make_minus(left_bgp, right_bgp), @@ -371,14 +382,12 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 1 @pytest.mark.asyncio async def test_filter_exists_keeps_matching(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") charlie = iri("http://example.com/charlie") @@ -387,13 +396,6 @@ class TestEvaluate: left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) - left_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) - ) - exists_bgp = make_bgp( - (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) - ) - async def mock_query(**kwargs): pred = kwargs.get("p") if pred and pred.iri == "http://example.com/knows": @@ -402,7 +404,14 @@ class TestEvaluate: return [exists_triple] return [] - tc.query.side_effect = mock_query + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) exists_expr = CompValue("Builtin_EXISTS") exists_expr.graph = exists_bgp @@ -414,17 +423,14 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") - # Only bob has a "likes" triple, so only the bob solution passes result_objects = [s["o"].iri for s in solutions] assert "http://example.com/bob" in result_objects assert "http://example.com/charlie" not in result_objects @pytest.mark.asyncio async def test_filter_not_exists_removes_matching(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") charlie = iri("http://example.com/charlie") @@ -433,13 +439,6 @@ class TestEvaluate: left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) - left_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) - ) - exists_bgp = make_bgp( - (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) - ) - async def mock_query(**kwargs): pred = kwargs.get("p") if pred and pred.iri == "http://example.com/knows": @@ -448,7 +447,14 @@ class TestEvaluate: return [exists_triple] return [] - tc.query.side_effect = mock_query + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) not_exists_expr = CompValue("Builtin_NOTEXISTS") not_exists_expr.graph = exists_bgp @@ -460,28 +466,115 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") - # bob has a "likes" triple so is removed; charlie stays result_objects = [s["o"].iri for s in solutions] assert "http://example.com/charlie" in result_objects assert "http://example.com/bob" not in result_objects + @pytest.mark.asyncio + async def test_join_values_uses_bind_join(self): + """When VALUES is joined with a BGP, the bind join should pass + the VALUES bindings into the BGP evaluation so the triple store + query is selective (not a wildcard).""" + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + + queries_issued = [] + + async def mock_query(**kwargs): + queries_issued.append(kwargs) + s, p = kwargs.get("s"), kwargs.get("p") + if s and s.iri == "http://example.com/alice" and p and p.iri == "http://example.com/knows": + return [make_triple(alice, knows, bob)] + return [] + + tc = make_tc(query_side_effect=mock_query) + + # VALUES ?s { } + values_node = CompValue("values") + values_node.var = [Variable("s")] + values_node.value = [[URIRef("http://example.com/alice")]] + values_node.res = None + + to_multiset = CompValue("ToMultiSet") + to_multiset.p = values_node + + bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")), + ) + + tree = make_join(to_multiset, bgp) + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://example.com/alice" + assert solutions[0]["o"].iri == "http://example.com/bob" + + # The key assertion: the BGP query should have received + # s=alice (bound from VALUES), NOT s=None (wildcard) + assert len(queries_issued) == 1 + assert queries_issued[0]["s"] is not None + assert queries_issued[0]["s"].iri == "http://example.com/alice" + + @pytest.mark.asyncio + async def test_join_values_multiple_bindings(self): + """Bind join with multiple VALUES bindings.""" + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + charlie = iri("http://example.com/charlie") + + async def mock_query(**kwargs): + s = kwargs.get("s") + if s and s.iri == "http://example.com/alice": + return [make_triple(alice, knows, bob)] + elif s and s.iri == "http://example.com/bob": + return [make_triple(bob, knows, charlie)] + return [] + + tc = make_tc(query_side_effect=mock_query) + + values_node = CompValue("values") + values_node.var = [Variable("s")] + values_node.value = [ + [URIRef("http://example.com/alice")], + [URIRef("http://example.com/bob")], + ] + values_node.res = None + + to_multiset = CompValue("ToMultiSet") + to_multiset.p = values_node + + bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")), + ) + + tree = make_join(to_multiset, bgp) + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 2 + subjects = {s["s"].iri for s in solutions} + assert subjects == { + "http://example.com/alice", + "http://example.com/bob", + } + @pytest.mark.asyncio async def test_unsupported_node_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() node = CompValue("SomethingUnknown") - solutions = await evaluate(node, tc, collection="default") + solutions = await materialise(node, tc, collection="default") assert solutions == [{}] - tc.query.assert_not_called() @pytest.mark.asyncio async def test_non_compvalue_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() - solutions = await evaluate("not a node", tc, collection="default") + solutions = await materialise("not a node", tc, collection="default") assert solutions == [{}] diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py index 2601a1e1..0506cb9f 100644 --- a/trustgraph-base/trustgraph/base/triples_client.py +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import Any from . request_response_spec import RequestResponse, RequestResponseSpec @@ -44,6 +45,60 @@ def from_value(x: Any) -> Any: return Term(type=LITERAL, value=str(x)) class TriplesClient(RequestResponse): + + async def query_gen(self, s=None, p=None, o=None, limit=20, + collection="default", + batch_size=20, timeout=30, g=None): + """Async generator yielding Triple objects as batches arrive.""" + queue = asyncio.Queue() + done = False + + async def recipient(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + batch = [ + Triple(to_value(v.s), to_value(v.p), to_value(v.o)) + for v in resp.triples + ] + await queue.put(batch) + + if resp.is_final: + await queue.put(None) + + return resp.is_final + + # Launch the streaming request as a background task + task = asyncio.ensure_future(self.request( + TriplesQueryRequest( + s=from_value(s), + p=from_value(p), + o=from_value(o), + limit=limit, + collection=collection, + streaming=True, + batch_size=batch_size, + g=g, + ), + timeout=timeout, + recipient=recipient, + )) + + try: + while True: + batch = await queue.get() + if batch is None: + break + for triple in batch: + yield triple + finally: + if not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + async def query(self, s=None, p=None, o=None, limit=20, collection="default", timeout=30, g=None): diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index d0f7d05e..c7542577 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -4,6 +4,10 @@ SPARQL algebra evaluator. Recursively evaluates an rdflib SPARQL algebra tree by issuing triple pattern queries via TriplesClient (streaming) and performing in-memory joins, filters, and projections. + +Handlers are async generators that yield solutions incrementally. +Blocking operators (joins, sort, group, distinct) materialise their +upstream into a list at the boundary, then yield results. """ import logging @@ -34,56 +38,56 @@ async def evaluate(node, triples_client, collection, limit=10000): """ Evaluate a SPARQL algebra node. - Args: - node: rdflib CompValue algebra node - triples_client: TriplesClient instance for triple pattern queries - collection: collection identifier - limit: safety limit on results - - Returns: - list of solutions (dicts mapping variable names to Term values) + Yields solutions (dicts mapping variable names to Term values) + incrementally as an async generator. """ if not isinstance(node, CompValue): logger.warning(f"Expected CompValue, got {type(node)}: {node}") - return [{}] + yield {} + return name = node.name handler = _HANDLERS.get(name) if handler is None: logger.warning(f"Unsupported algebra node: {name}") - return [{}] + yield {} + return - return await handler(node, triples_client, collection, limit) + async for sol in handler(node, triples_client, collection, limit): + yield sol -# --- Node handlers --- +async def materialise(node, triples_client, collection, limit=10000): + """Collect all solutions from evaluate() into a list.""" + return [sol async for sol in evaluate(node, triples_client, collection, limit)] + + +# --- Node handlers (async generators) --- async def _eval_select_query(node, tc, collection, limit): - """Evaluate a SelectQuery node.""" - return await evaluate(node.p, tc, collection, limit) + async for sol in evaluate(node.p, tc, collection, limit): + yield sol async def _eval_project(node, tc, collection, limit): - """Evaluate a Project node (SELECT variable projection).""" - solutions = await evaluate(node.p, tc, collection, limit) variables = [str(v) for v in node.PV] - return project(solutions, variables) + async for sol in evaluate(node.p, tc, collection, limit): + yield {v: sol[v] for v in variables if v in sol} async def _eval_bgp(node, tc, collection, limit): """ Evaluate a Basic Graph Pattern. - Issues streaming triple pattern queries and joins results. Patterns - are ordered by selectivity (more bound terms first) and evaluated - sequentially with bound-variable substitution. + Patterns are ordered by selectivity and evaluated sequentially. + For the final pattern, results stream directly from the triple store. """ triples = node.triples if not triples: - return [{}] + yield {} + return - # Sort patterns by selectivity: more bound terms = more selective def selectivity(pattern): return sum(1 for t in pattern if not isinstance(t, Variable)) @@ -91,55 +95,222 @@ async def _eval_bgp(node, tc, collection, limit): enumerate(triples), key=lambda x: -selectivity(x[1]) ) + # For all patterns except the last, we must materialise intermediate + # solutions because each pattern depends on bindings from prior ones. + # The last pattern streams directly. solutions = [{}] - for _, pattern in sorted_patterns: + for pattern_idx, (_, pattern) in enumerate(sorted_patterns): s_tmpl, p_tmpl, o_tmpl = pattern + is_last = (pattern_idx == len(sorted_patterns) - 1) - new_solutions = [] + if is_last: + # Stream the final pattern — yield as triples arrive + count = 0 + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) - for sol in solutions: - # Substitute known bindings into the pattern - s_val = _resolve_term(s_tmpl, sol) - p_val = _resolve_term(p_tmpl, sol) - o_val = _resolve_term(o_tmpl, sol) + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + yield binding + count += 1 + if count >= limit: + return + else: + # Materialise intermediate patterns + new_solutions = [] + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) - # Query the triples store - results = await _query_pattern( - tc, s_val, p_val, o_val, collection, limit - ) + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) - # Map results back to variable bindings, - # converting Uri/Literal to Term objects - for triple in results: - binding = dict(sol) - if isinstance(s_tmpl, Variable): - binding[str(s_tmpl)] = _to_term(triple.s) - if isinstance(p_tmpl, Variable): - binding[str(p_tmpl)] = _to_term(triple.p) - if isinstance(o_tmpl, Variable): - binding[str(o_tmpl)] = _to_term(triple.o) - new_solutions.append(binding) + solutions = new_solutions + if not solutions: + return - solutions = new_solutions - if not solutions: - break +# --- Blocking operators: materialise upstream, then yield --- - return solutions[:limit] +def _is_small_node(node): + """Check if a node is likely to produce a small number of solutions.""" + if not isinstance(node, CompValue): + return False + if node.name in ("values", "ToMultiSet"): + return True + if node.name == "Extend" and hasattr(node, "p"): + return _is_small_node(node.p) + return False async def _eval_join(node, tc, collection, limit): - """Evaluate a Join node.""" - left = await evaluate(node.p1, tc, collection, limit) - right = await evaluate(node.p2, tc, collection, limit) - return hash_join(left, right)[:limit] + # Bind join: if one side is small (e.g. VALUES), materialise it and + # substitute its bindings into the other side's evaluation. This + # turns wildcard BGP queries into selective ones. + if _is_small_node(node.p1): + yield_from = _bind_join(node.p1, node.p2, tc, collection, limit) + elif _is_small_node(node.p2): + yield_from = _bind_join(node.p2, node.p1, tc, collection, limit) + else: + yield_from = _hash_join(node, tc, collection, limit) + + async for sol in yield_from: + yield sol + + +async def _hash_join(node, tc, collection, limit): + left = await materialise(node.p1, tc, collection, limit) + right = await materialise(node.p2, tc, collection, limit) + for sol in hash_join(left, right)[:limit]: + yield sol + + +async def _bind_join(small_node, big_node, tc, collection, limit): + """Iterate over the small side and inject bindings into the big side.""" + small_sols = await materialise(small_node, tc, collection, limit) + + count = 0 + for binding in small_sols: + async for sol in _evaluate_with_bindings( + big_node, binding, tc, collection, limit + ): + yield sol + count += 1 + if count >= limit: + return + + +def _merge_compatible(left, right): + """Merge two solutions if compatible (shared vars have equal values).""" + merged = dict(left) + for k, v in right.items(): + if k in merged: + if _term_key(merged[k]) != _term_key(v): + return None + else: + merged[k] = v + return merged + + +async def _evaluate_with_bindings(node, bindings, tc, collection, limit): + """Evaluate a node with pre-seeded variable bindings. + + For BGP nodes, the bindings are injected so _resolve_term sees them, + turning wildcard queries into selective ones. For other node types, + evaluate normally and merge/filter against the bindings. + """ + if isinstance(node, CompValue) and node.name == "BGP": + async for sol in _eval_bgp_with_bindings( + node, bindings, tc, collection, limit + ): + yield sol + else: + async for sol in evaluate(node, tc, collection, limit): + merged = _merge_compatible(bindings, sol) + if merged is not None: + yield merged + + +async def _eval_bgp_with_bindings(node, bindings, tc, collection, limit): + """Evaluate a BGP with pre-seeded bindings so variables resolve to terms.""" + triples = node.triples + if not triples: + yield dict(bindings) + return + + def selectivity(pattern): + score = 0 + for t in pattern: + if not isinstance(t, Variable): + score += 1 + elif str(t) in bindings: + score += 1 + return score + + sorted_patterns = sorted( + enumerate(triples), key=lambda x: -selectivity(x[1]) + ) + + solutions = [dict(bindings)] + + for pattern_idx, (_, pattern) in enumerate(sorted_patterns): + s_tmpl, p_tmpl, o_tmpl = pattern + is_last = (pattern_idx == len(sorted_patterns) - 1) + + if is_last: + count = 0 + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + yield binding + count += 1 + if count >= limit: + return + else: + new_solutions = [] + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) + + solutions = new_solutions + if not solutions: + return async def _eval_left_join(node, tc, collection, limit): - """Evaluate a LeftJoin node (OPTIONAL).""" - left_sols = await evaluate(node.p1, tc, collection, limit) - right_sols = await evaluate(node.p2, tc, collection, limit) + # Buffer right side for hash index; stream left through probe + left_sols = await materialise(node.p1, tc, collection, limit) + right_sols = await materialise(node.p2, tc, collection, limit) filter_fn = None if hasattr(node, "expr") and node.expr is not None: @@ -149,27 +320,83 @@ async def _eval_left_join(node, tc, collection, limit): evaluate_expression(expr, sol) ) - return left_join(left_sols, right_sols, filter_fn)[:limit] - - -async def _eval_union(node, tc, collection, limit): - """Evaluate a Union node.""" - left = await evaluate(node.p1, tc, collection, limit) - right = await evaluate(node.p2, tc, collection, limit) - return union(left, right)[:limit] + for sol in left_join(left_sols, right_sols, filter_fn)[:limit]: + yield sol async def _eval_minus(node, tc, collection, limit): - """Evaluate a Minus node.""" - left = await evaluate(node.p1, tc, collection, limit) - right = await evaluate(node.p2, tc, collection, limit) - return minus(left, right) + left = await materialise(node.p1, tc, collection, limit) + right = await materialise(node.p2, tc, collection, limit) + for sol in minus(left, right): + yield sol + + +async def _eval_distinct(node, tc, collection, limit): + seen = set() + async for sol in evaluate(node.p, tc, collection, limit): + key = tuple(sorted( + (k, _term_key(v)) for k, v in sol.items() + )) + if key not in seen: + seen.add(key) + yield sol + + +async def _eval_reduced(node, tc, collection, limit): + async for sol in _eval_distinct(node, tc, collection, limit): + yield sol + + +async def _eval_order_by(node, tc, collection, limit): + solutions = await materialise(node.p, tc, collection, limit) + + key_fns = [] + for cond in node.expr: + if isinstance(cond, CompValue) and cond.name == "OrderCondition": + ascending = cond.order != "DESC" + expr = cond.expr + key_fns.append(( + lambda sol, e=expr: evaluate_expression(e, sol), + ascending, + )) + else: + key_fns.append(( + lambda sol, e=cond: evaluate_expression(e, sol), + True, + )) + + for sol in order_by(solutions, key_fns): + yield sol + + +# --- Streamable operators --- + +async def _eval_slice(node, tc, collection, limit): + offset = node.start or 0 + length = node.length + skipped = 0 + emitted = 0 + + async for sol in evaluate(node.p, tc, collection, limit): + if skipped < offset: + skipped += 1 + continue + yield sol + emitted += 1 + if length is not None and emitted >= length: + return + + +async def _eval_union(node, tc, collection, limit): + async for sol in evaluate(node.p1, tc, collection, limit): + yield sol + async for sol in evaluate(node.p2, tc, collection, limit): + yield sol async def _check_exists(graph_node, sol, tc, collection, limit): """Evaluate an EXISTS graph pattern against a solution.""" - results = await evaluate(graph_node, tc, collection, limit) - for r in results: + async for r in evaluate(graph_node, tc, collection, limit): shared = set(sol.keys()) & set(r.keys()) if all( _term_key(sol[v]) == _term_key(r[v]) @@ -206,8 +433,6 @@ async def _pre_eval_exists(expr, sol, tc, collection, limit, cache): async def _eval_filter(node, tc, collection, limit): - """Evaluate a Filter node.""" - solutions = await evaluate(node.p, tc, collection, limit) expr = node.expr exists_cache = {} @@ -215,60 +440,13 @@ async def _eval_filter(node, tc, collection, limit): key = id(graph_node), id(sol) return exists_cache.get(key, False) - result = [] - for sol in solutions: + async for sol in evaluate(node.p, tc, collection, limit): await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) if _effective_boolean(evaluate_expression(expr, sol, exists_cb=exists_cb)): - result.append(sol) - - return result - - -async def _eval_distinct(node, tc, collection, limit): - """Evaluate a Distinct node.""" - solutions = await evaluate(node.p, tc, collection, limit) - return distinct(solutions) - - -async def _eval_reduced(node, tc, collection, limit): - """Evaluate a Reduced node (like Distinct but implementation-defined).""" - # Treat same as Distinct - solutions = await evaluate(node.p, tc, collection, limit) - return distinct(solutions) - - -async def _eval_order_by(node, tc, collection, limit): - """Evaluate an OrderBy node.""" - solutions = await evaluate(node.p, tc, collection, limit) - - key_fns = [] - for cond in node.expr: - if isinstance(cond, CompValue) and cond.name == "OrderCondition": - ascending = cond.order != "DESC" - expr = cond.expr - key_fns.append(( - lambda sol, e=expr: evaluate_expression(e, sol), - ascending, - )) - else: - # Simple variable or expression - key_fns.append(( - lambda sol, e=cond: evaluate_expression(e, sol), - True, - )) - - return order_by(solutions, key_fns) - - -async def _eval_slice(node, tc, collection, limit): - """Evaluate a Slice node (LIMIT/OFFSET).""" - solutions = await evaluate(node.p, tc, collection, limit) - return slice_solutions(solutions, node.start or 0, node.length) + yield sol async def _eval_extend(node, tc, collection, limit): - """Evaluate an Extend node (BIND).""" - solutions = await evaluate(node.p, tc, collection, limit) var_name = str(node.var) expr = node.expr exists_cache = {} @@ -277,8 +455,7 @@ async def _eval_extend(node, tc, collection, limit): key = id(graph_node), id(sol) return exists_cache.get(key, False) - result = [] - for sol in solutions: + async for sol in evaluate(node.p, tc, collection, limit): await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) val = evaluate_expression(expr, sol, exists_cb=exists_cb) new_sol = dict(sol) @@ -295,16 +472,14 @@ async def _eval_extend(node, tc, collection, limit): ) elif val is not None: new_sol[var_name] = Term(type=LITERAL, value=str(val)) - result.append(new_sol) + yield new_sol - return result +# --- Aggregation (blocking) --- async def _eval_group(node, tc, collection, limit): - """Evaluate a Group node (GROUP BY with aggregation).""" - solutions = await evaluate(node.p, tc, collection, limit) + solutions = await materialise(node.p, tc, collection, limit) - # Extract grouping expressions group_exprs = [] if hasattr(node, "expr") and node.expr: for expr in node.expr: @@ -315,7 +490,6 @@ async def _eval_group(node, tc, collection, limit): else: group_exprs.append((expr, None)) - # Group solutions groups = defaultdict(list) for sol in solutions: key_parts = [] @@ -325,81 +499,72 @@ async def _eval_group(node, tc, collection, limit): groups[tuple(key_parts)].append(sol) if not group_exprs: - # No GROUP BY - entire result is one group groups[()].extend(solutions) - # Build grouped solutions (one per group) - result = [] for key, group_sols in groups.items(): sol = {} - # Include group key variables if group_sols: for (expr, var_name), k in zip(group_exprs, key): if var_name and group_sols: sol[var_name] = evaluate_expression(expr, group_sols[0]) sol["__group__"] = group_sols - result.append(sol) - - return result + yield sol async def _eval_aggregate_join(node, tc, collection, limit): - """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" - solutions = await evaluate(node.p, tc, collection, limit) - - result = [] - for sol in solutions: + async for sol in evaluate(node.p, tc, collection, limit): group = sol.get("__group__", [sol]) new_sol = {k: v for k, v in sol.items() if k != "__group__"} - # Apply aggregate functions if hasattr(node, "A") and node.A: for agg in node.A: var_name = str(agg.res) agg_val = _compute_aggregate(agg, group) new_sol[var_name] = agg_val - result.append(new_sol) - - return result + yield new_sol async def _eval_graph(node, tc, collection, limit): - """Evaluate a Graph node (GRAPH clause).""" term = node.term if isinstance(term, URIRef): - # GRAPH { ... } — fixed graph - # We'd need to pass graph to triples queries - # For now, evaluate inner pattern normally logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") - return await evaluate(node.p, tc, collection, limit) elif isinstance(term, Variable): - # GRAPH ?g { ... } — variable graph logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") - return await evaluate(node.p, tc, collection, limit) - else: - return await evaluate(node.p, tc, collection, limit) + + async for sol in evaluate(node.p, tc, collection, limit): + yield sol async def _eval_values(node, tc, collection, limit): - """Evaluate a VALUES clause (inline data).""" - variables = [str(v) for v in node.var] - solutions = [] + # rdflib has two representations for VALUES: + # 1. var=[Variable...], value=[[val, ...], ...] — positional + # 2. var=None, res=[{Variable: val, ...}, ...] — dict-based + if hasattr(node, "res") and node.res: + for row in node.res: + sol = {} + for var, val in row.items(): + if val is not None and str(val) != "UNDEF": + sol[str(var)] = rdflib_term_to_term(val) + yield sol + return + if not node.var or not node.value: + yield {} + return + variables = [str(v) for v in node.var] for row in node.value: sol = {} for var_name, val in zip(variables, row): if val is not None and str(val) != "UNDEF": sol[var_name] = rdflib_term_to_term(val) - solutions.append(sol) - - return solutions + yield sol async def _eval_to_multiset(node, tc, collection, limit): - """Evaluate a ToMultiSet node (subquery).""" - return await evaluate(node.p, tc, collection, limit) + async for sol in evaluate(node.p, tc, collection, limit): + yield sol # --- Aggregate computation --- @@ -408,7 +573,6 @@ def _compute_aggregate(agg, group): """Compute a single aggregate function over a group of solutions.""" agg_name = agg.name if hasattr(agg, "name") else "" - # Get the expression to aggregate expr = agg.vars if hasattr(agg, "vars") else None if agg_name == "Aggregate_Count": diff --git a/trustgraph-flow/trustgraph/query/sparql/expressions.py b/trustgraph-flow/trustgraph/query/sparql/expressions.py index ad3202d9..608eeff2 100644 --- a/trustgraph-flow/trustgraph/query/sparql/expressions.py +++ b/trustgraph-flow/trustgraph/query/sparql/expressions.py @@ -125,6 +125,13 @@ def _evaluate_comp_value(node, solution): if name == "MultiplicativeExpression": return _eval_multiplicative(node, solution) + # IN / NOT IN — must be checked before the generic Builtin_ dispatch + if name == "Builtin_IN": + return _eval_in(node, solution) + + if name == "Builtin_NOTIN": + return not _eval_in(node, solution) + # SPARQL built-in functions if name.startswith("Builtin_"): return _eval_builtin(name, node, solution) @@ -133,19 +140,10 @@ def _evaluate_comp_value(node, solution): if name == "Function": return _eval_function(node, solution) - # Exists / NotExists — handled via _eval_builtin now - # TrueFilter (used with OPTIONAL) if name == "TrueFilter": return True - # IN / NOT IN - if name == "Builtin_IN": - return _eval_in(node, solution) - - if name == "Builtin_NOTIN": - return not _eval_in(node, solution) - logger.warning(f"Unknown CompValue expression: {name}") return None @@ -171,6 +169,22 @@ def _eval_relational(node, solution): ">=": operator.ge, } + if str(op) == "IN": + items = node.other if isinstance(node.other, list) else [node.other] + for item in items: + other_val = evaluate_expression(item, solution) + if _comparable_value(left) == _comparable_value(other_val): + return True + return False + + if str(op) == "NOT IN": + items = node.other if isinstance(node.other, list) else [node.other] + for item in items: + other_val = evaluate_expression(item, solution) + if _comparable_value(left) == _comparable_value(other_val): + return False + return True + op_fn = ops.get(str(op)) if op_fn is None: logger.warning(f"Unknown relational operator: {op}") diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index 75c00dba..bbe375f0 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -12,7 +12,7 @@ from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import TriplesClientSpec from . parser import parse_sparql, ParseError -from . algebra import evaluate, EvaluationError +from . algebra import evaluate, materialise, EvaluationError logger = logging.getLogger(__name__) @@ -66,11 +66,10 @@ class Processor(FlowProcessor): logger.debug(f"Handling SPARQL query request {id}...") - response = await self.execute_sparql(request, flow) - - if request.streaming and response.query_type == "select": - await self.send_streaming(response, flow, id, request) + if request.streaming: + await self.execute_sparql_streaming(request, flow, id) else: + response = await self.execute_sparql(request, flow) await flow("response").send( response, properties={"id": id} ) @@ -92,37 +91,77 @@ class Processor(FlowProcessor): await flow("response").send(r, properties={"id": id}) - async def send_streaming(self, response, flow, id, request): - """Send SELECT results in batches.""" + async def execute_sparql_streaming(self, request, flow, id): + """Execute a SPARQL query and stream results as they arrive.""" - bindings = response.bindings + try: + parsed = parse_sparql(request.query) + except ParseError as e: + await flow("response").send( + SparqlQueryResponse( + error=Error( + type="sparql-parse-error", + message=str(e), + ), + ), + properties={"id": id} + ) + return + + if parsed.query_type != "select": + response = await self._execute_non_select(parsed, request, flow) + await flow("response").send(response, properties={"id": id}) + return + + triples_client = flow("triples-request") + variables = parsed.variables batch_size = request.batch_size if request.batch_size > 0 else 20 + batch = [] - for i in range(0, len(bindings), batch_size): - batch = bindings[i:i + batch_size] - is_final = (i + batch_size >= len(bindings)) - r = SparqlQueryResponse( - query_type=response.query_type, - variables=response.variables, - bindings=batch, - is_final=is_final, - ) - await flow("response").send(r, properties={"id": id}) + try: + async for sol in evaluate( + parsed.algebra, + triples_client, + collection=request.collection or "default", + limit=request.limit or 10000, + ): + values = [sol.get(v) for v in variables] + batch.append(SparqlBinding(values=values)) - # Handle empty results - if len(bindings) == 0: - r = SparqlQueryResponse( - query_type=response.query_type, - variables=response.variables, - bindings=[], - is_final=True, + if len(batch) >= batch_size: + r = SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=batch, + is_final=False, + ) + await flow("response").send(r, properties={"id": id}) + batch = [] + + except EvaluationError as e: + await flow("response").send( + SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ), + properties={"id": id} ) - await flow("response").send(r, properties={"id": id}) + return + + # Final batch (may be empty for zero results) + r = SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=batch, + is_final=True, + ) + await flow("response").send(r, properties={"id": id}) async def execute_sparql(self, request, flow): - """Parse and evaluate a SPARQL query.""" + """Parse and evaluate a SPARQL query (non-streaming).""" - # Parse the SPARQL query try: parsed = parse_sparql(request.query) except ParseError as e: @@ -133,12 +172,31 @@ class Processor(FlowProcessor): ), ) - # Get the triples client from the flow - triples_client = flow("triples-request") + if parsed.query_type == "select": + triples_client = flow("triples-request") + try: + solutions = await materialise( + parsed.algebra, + triples_client, + collection=request.collection or "default", + limit=request.limit or 10000, + ) + except EvaluationError as e: + return SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ) + return self._build_select_response(parsed, solutions) - # Evaluate the algebra + return await self._execute_non_select(parsed, request, flow) + + async def _execute_non_select(self, parsed, request, flow): + """Execute ASK, CONSTRUCT, or DESCRIBE queries.""" + triples_client = flow("triples-request") try: - solutions = await evaluate( + solutions = await materialise( parsed.algebra, triples_client, collection=request.collection or "default", @@ -152,10 +210,7 @@ class Processor(FlowProcessor): ), ) - # Build response based on query type - if parsed.query_type == "select": - return self._build_select_response(parsed, solutions) - elif parsed.query_type == "ask": + if parsed.query_type == "ask": return self._build_ask_response(solutions) elif parsed.query_type == "construct": return self._build_construct_response(parsed, solutions)