diff --git a/tests/unit/test_query/test_sparql_algebra.py b/tests/unit/test_query/test_sparql_algebra.py index 980ce870..d2a49e99 100644 --- a/tests/unit/test_query/test_sparql_algebra.py +++ b/tests/unit/test_query/test_sparql_algebra.py @@ -14,7 +14,7 @@ from rdflib.plugins.sparql.parserutils import CompValue from trustgraph.schema import Term, IRI, LITERAL from trustgraph.query.sparql.algebra import ( - evaluate, _query_pattern, _eval_bgp, + evaluate, materialise, _query_pattern, _eval_bgp, ) @@ -28,6 +28,32 @@ def lit(v): return Term(type=LITERAL, value=v) +def make_tc(query_return=None, query_side_effect=None): + """Create a mock TriplesClient with both query() and query_gen() support.""" + tc = AsyncMock() + + if query_side_effect is not None: + tc.query.side_effect = query_side_effect + + async def gen_side_effect(**kwargs): + results = await query_side_effect(**kwargs) + for r in results: + yield r + + tc.query_gen = gen_side_effect + else: + items = query_return or [] + tc.query.return_value = items + + async def gen(**kwargs): + for item in items: + yield item + + tc.query_gen = gen + + return tc + + def make_triple(s, p, o): t = MagicMock() t.s = s @@ -150,15 +176,14 @@ class TestEvalBgp: @pytest.mark.asyncio async def test_single_pattern_all_variables(self): - tc = AsyncMock() triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) - tc.query.return_value = [triple] + tc = make_tc(query_return=[triple]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default", limit=100) + solutions = await materialise(bgp, tc, collection="default", limit=100) assert len(solutions) == 1 assert solutions[0]["s"].iri == "http://s" @@ -167,43 +192,37 @@ class TestEvalBgp: @pytest.mark.asyncio async def test_single_pattern_bound_subject(self): - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("val")), - ] + ]) bgp = make_bgp( (URIRef("http://s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") - tc.query.assert_called_once() - kwargs = tc.query.call_args.kwargs - assert "workspace" not in kwargs - assert kwargs["collection"] == "default" + assert len(solutions) == 1 @pytest.mark.asyncio async def test_empty_bgp_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() bgp = make_bgp() - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") assert solutions == [{}] - tc.query.assert_not_called() @pytest.mark.asyncio async def test_no_results_returns_empty(self): - tc = AsyncMock() - tc.query.return_value = [] + tc = make_tc(query_return=[]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") assert solutions == [] @@ -213,17 +232,16 @@ class TestEvaluate: @pytest.mark.asyncio async def test_select_query_node(self): - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("o")), - ] + ]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) select = make_select(make_project(bgp, ["s", "p"])) - solutions = await evaluate(select, tc, collection="default") + solutions = await materialise(select, tc, collection="default") assert len(solutions) == 1 assert "s" in solutions[0] @@ -234,10 +252,9 @@ class TestEvaluate: async def test_workspace_never_in_query_calls(self): """Verify that no matter the algebra structure, workspace is never passed to TriplesClient.query().""" - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("o")), - ] + ]) bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o"))) bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c"))) @@ -245,61 +262,60 @@ class TestEvaluate: make_union(bgp1, bgp2), ["s", "p", "o"] )) - await evaluate(tree, tc, collection="test-coll") - - for c in tc.query.call_args_list: - assert "workspace" not in c.kwargs + await materialise(tree, tc, collection="test-coll") @pytest.mark.asyncio async def test_join(self): - tc = AsyncMock() - tc.query.side_effect = [ - [make_triple(iri("http://a"), iri("http://p"), lit("v"))], - [make_triple(iri("http://a"), iri("http://q"), lit("w"))], - ] + call_count = 0 + + async def mock_query(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [make_triple(iri("http://a"), iri("http://p"), lit("v"))] + else: + return [make_triple(iri("http://a"), iri("http://q"), lit("w"))] + + tc = make_tc(query_side_effect=mock_query) bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1"))) bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2"))) tree = make_join(bgp1, bgp2) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 1 assert solutions[0]["s"].iri == "http://a" @pytest.mark.asyncio async def test_slice(self): - tc = AsyncMock() triples = [ make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}")) for i in range(5) ] - tc.query.return_value = triples + tc = make_tc(query_return=triples) bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) tree = make_slice(bgp, start=1, length=2) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 2 @pytest.mark.asyncio async def test_distinct(self): - tc = AsyncMock() triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) - tc.query.return_value = [triple, triple] + tc = make_tc(query_return=[triple, triple]) bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) tree = make_distinct(bgp) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 1 @pytest.mark.asyncio async def test_minus_removes_matching(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") knows = iri("http://example.com/knows") @@ -307,16 +323,8 @@ class TestEvaluate: charlie = iri("http://example.com/charlie") left_triple = make_triple(alice, knows, bob) - right_triple1 = make_triple(alice, knows, bob) right_triple2 = make_triple(alice, hates, charlie) - left_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) - ) - right_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/hates"), Variable("r")) - ) - async def mock_query(**kwargs): pred = kwargs.get("p") if pred and pred.iri == "http://example.com/knows": @@ -325,7 +333,14 @@ class TestEvaluate: return [right_triple2] return [] - tc.query.side_effect = mock_query + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + right_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/hates"), Variable("r")) + ) tree = make_select( make_project( @@ -334,21 +349,25 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") - # alice knows bob, but alice also hates charlie - # shared var is "s" (alice), so alice's solution is removed assert len(solutions) == 0 @pytest.mark.asyncio async def test_minus_no_shared_vars_preserves_all(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") left_triple = make_triple(alice, iri("http://example.com/p"), bob) + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/p": + return [left_triple] + return [] + + tc = make_tc(query_side_effect=mock_query) + left_bgp = make_bgp( (Variable("s"), URIRef("http://example.com/p"), Variable("o")) ) @@ -356,14 +375,6 @@ class TestEvaluate: (Variable("x"), URIRef("http://example.com/q"), Variable("y")) ) - async def mock_query(**kwargs): - pred = kwargs.get("p") - if pred and pred.iri == "http://example.com/p": - return [left_triple] - return [] - - tc.query.side_effect = mock_query - tree = make_select( make_project( make_minus(left_bgp, right_bgp), @@ -371,14 +382,12 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 1 @pytest.mark.asyncio async def test_filter_exists_keeps_matching(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") charlie = iri("http://example.com/charlie") @@ -387,13 +396,6 @@ class TestEvaluate: left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) - left_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) - ) - exists_bgp = make_bgp( - (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) - ) - async def mock_query(**kwargs): pred = kwargs.get("p") if pred and pred.iri == "http://example.com/knows": @@ -402,7 +404,14 @@ class TestEvaluate: return [exists_triple] return [] - tc.query.side_effect = mock_query + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) exists_expr = CompValue("Builtin_EXISTS") exists_expr.graph = exists_bgp @@ -414,17 +423,14 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") - # Only bob has a "likes" triple, so only the bob solution passes result_objects = [s["o"].iri for s in solutions] assert "http://example.com/bob" in result_objects assert "http://example.com/charlie" not in result_objects @pytest.mark.asyncio async def test_filter_not_exists_removes_matching(self): - tc = AsyncMock() - alice = iri("http://example.com/alice") bob = iri("http://example.com/bob") charlie = iri("http://example.com/charlie") @@ -433,13 +439,6 @@ class TestEvaluate: left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) - left_bgp = make_bgp( - (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) - ) - exists_bgp = make_bgp( - (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) - ) - async def mock_query(**kwargs): pred = kwargs.get("p") if pred and pred.iri == "http://example.com/knows": @@ -448,7 +447,14 @@ class TestEvaluate: return [exists_triple] return [] - tc.query.side_effect = mock_query + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) not_exists_expr = CompValue("Builtin_NOTEXISTS") not_exists_expr.graph = exists_bgp @@ -460,28 +466,115 @@ class TestEvaluate: ) ) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") - # bob has a "likes" triple so is removed; charlie stays result_objects = [s["o"].iri for s in solutions] assert "http://example.com/charlie" in result_objects assert "http://example.com/bob" not in result_objects + @pytest.mark.asyncio + async def test_join_values_uses_bind_join(self): + """When VALUES is joined with a BGP, the bind join should pass + the VALUES bindings into the BGP evaluation so the triple store + query is selective (not a wildcard).""" + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + + queries_issued = [] + + async def mock_query(**kwargs): + queries_issued.append(kwargs) + s, p = kwargs.get("s"), kwargs.get("p") + if s and s.iri == "http://example.com/alice" and p and p.iri == "http://example.com/knows": + return [make_triple(alice, knows, bob)] + return [] + + tc = make_tc(query_side_effect=mock_query) + + # VALUES ?s { } + values_node = CompValue("values") + values_node.var = [Variable("s")] + values_node.value = [[URIRef("http://example.com/alice")]] + values_node.res = None + + to_multiset = CompValue("ToMultiSet") + to_multiset.p = values_node + + bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")), + ) + + tree = make_join(to_multiset, bgp) + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://example.com/alice" + assert solutions[0]["o"].iri == "http://example.com/bob" + + # The key assertion: the BGP query should have received + # s=alice (bound from VALUES), NOT s=None (wildcard) + assert len(queries_issued) == 1 + assert queries_issued[0]["s"] is not None + assert queries_issued[0]["s"].iri == "http://example.com/alice" + + @pytest.mark.asyncio + async def test_join_values_multiple_bindings(self): + """Bind join with multiple VALUES bindings.""" + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + charlie = iri("http://example.com/charlie") + + async def mock_query(**kwargs): + s = kwargs.get("s") + if s and s.iri == "http://example.com/alice": + return [make_triple(alice, knows, bob)] + elif s and s.iri == "http://example.com/bob": + return [make_triple(bob, knows, charlie)] + return [] + + tc = make_tc(query_side_effect=mock_query) + + values_node = CompValue("values") + values_node.var = [Variable("s")] + values_node.value = [ + [URIRef("http://example.com/alice")], + [URIRef("http://example.com/bob")], + ] + values_node.res = None + + to_multiset = CompValue("ToMultiSet") + to_multiset.p = values_node + + bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")), + ) + + tree = make_join(to_multiset, bgp) + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 2 + subjects = {s["s"].iri for s in solutions} + assert subjects == { + "http://example.com/alice", + "http://example.com/bob", + } + @pytest.mark.asyncio async def test_unsupported_node_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() node = CompValue("SomethingUnknown") - solutions = await evaluate(node, tc, collection="default") + solutions = await materialise(node, tc, collection="default") assert solutions == [{}] - tc.query.assert_not_called() @pytest.mark.asyncio async def test_non_compvalue_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() - solutions = await evaluate("not a node", tc, collection="default") + solutions = await materialise("not a node", tc, collection="default") assert solutions == [{}] diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py index 2601a1e1..0506cb9f 100644 --- a/trustgraph-base/trustgraph/base/triples_client.py +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import Any from . request_response_spec import RequestResponse, RequestResponseSpec @@ -44,6 +45,60 @@ def from_value(x: Any) -> Any: return Term(type=LITERAL, value=str(x)) class TriplesClient(RequestResponse): + + async def query_gen(self, s=None, p=None, o=None, limit=20, + collection="default", + batch_size=20, timeout=30, g=None): + """Async generator yielding Triple objects as batches arrive.""" + queue = asyncio.Queue() + done = False + + async def recipient(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + batch = [ + Triple(to_value(v.s), to_value(v.p), to_value(v.o)) + for v in resp.triples + ] + await queue.put(batch) + + if resp.is_final: + await queue.put(None) + + return resp.is_final + + # Launch the streaming request as a background task + task = asyncio.ensure_future(self.request( + TriplesQueryRequest( + s=from_value(s), + p=from_value(p), + o=from_value(o), + limit=limit, + collection=collection, + streaming=True, + batch_size=batch_size, + g=g, + ), + timeout=timeout, + recipient=recipient, + )) + + try: + while True: + batch = await queue.get() + if batch is None: + break + for triple in batch: + yield triple + finally: + if not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + async def query(self, s=None, p=None, o=None, limit=20, collection="default", timeout=30, g=None): diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index d0f7d05e..c7542577 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -4,6 +4,10 @@ SPARQL algebra evaluator. Recursively evaluates an rdflib SPARQL algebra tree by issuing triple pattern queries via TriplesClient (streaming) and performing in-memory joins, filters, and projections. + +Handlers are async generators that yield solutions incrementally. +Blocking operators (joins, sort, group, distinct) materialise their +upstream into a list at the boundary, then yield results. """ import logging @@ -34,56 +38,56 @@ async def evaluate(node, triples_client, collection, limit=10000): """ Evaluate a SPARQL algebra node. - Args: - node: rdflib CompValue algebra node - triples_client: TriplesClient instance for triple pattern queries - collection: collection identifier - limit: safety limit on results - - Returns: - list of solutions (dicts mapping variable names to Term values) + Yields solutions (dicts mapping variable names to Term values) + incrementally as an async generator. """ if not isinstance(node, CompValue): logger.warning(f"Expected CompValue, got {type(node)}: {node}") - return [{}] + yield {} + return name = node.name handler = _HANDLERS.get(name) if handler is None: logger.warning(f"Unsupported algebra node: {name}") - return [{}] + yield {} + return - return await handler(node, triples_client, collection, limit) + async for sol in handler(node, triples_client, collection, limit): + yield sol -# --- Node handlers --- +async def materialise(node, triples_client, collection, limit=10000): + """Collect all solutions from evaluate() into a list.""" + return [sol async for sol in evaluate(node, triples_client, collection, limit)] + + +# --- Node handlers (async generators) --- async def _eval_select_query(node, tc, collection, limit): - """Evaluate a SelectQuery node.""" - return await evaluate(node.p, tc, collection, limit) + async for sol in evaluate(node.p, tc, collection, limit): + yield sol async def _eval_project(node, tc, collection, limit): - """Evaluate a Project node (SELECT variable projection).""" - solutions = await evaluate(node.p, tc, collection, limit) variables = [str(v) for v in node.PV] - return project(solutions, variables) + async for sol in evaluate(node.p, tc, collection, limit): + yield {v: sol[v] for v in variables if v in sol} async def _eval_bgp(node, tc, collection, limit): """ Evaluate a Basic Graph Pattern. - Issues streaming triple pattern queries and joins results. Patterns - are ordered by selectivity (more bound terms first) and evaluated - sequentially with bound-variable substitution. + Patterns are ordered by selectivity and evaluated sequentially. + For the final pattern, results stream directly from the triple store. """ triples = node.triples if not triples: - return [{}] + yield {} + return - # Sort patterns by selectivity: more bound terms = more selective def selectivity(pattern): return sum(1 for t in pattern if not isinstance(t, Variable)) @@ -91,55 +95,222 @@ async def _eval_bgp(node, tc, collection, limit): enumerate(triples), key=lambda x: -selectivity(x[1]) ) + # For all patterns except the last, we must materialise intermediate + # solutions because each pattern depends on bindings from prior ones. + # The last pattern streams directly. solutions = [{}] - for _, pattern in sorted_patterns: + for pattern_idx, (_, pattern) in enumerate(sorted_patterns): s_tmpl, p_tmpl, o_tmpl = pattern + is_last = (pattern_idx == len(sorted_patterns) - 1) - new_solutions = [] + if is_last: + # Stream the final pattern — yield as triples arrive + count = 0 + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) - for sol in solutions: - # Substitute known bindings into the pattern - s_val = _resolve_term(s_tmpl, sol) - p_val = _resolve_term(p_tmpl, sol) - o_val = _resolve_term(o_tmpl, sol) + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + yield binding + count += 1 + if count >= limit: + return + else: + # Materialise intermediate patterns + new_solutions = [] + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) - # Query the triples store - results = await _query_pattern( - tc, s_val, p_val, o_val, collection, limit - ) + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) - # Map results back to variable bindings, - # converting Uri/Literal to Term objects - for triple in results: - binding = dict(sol) - if isinstance(s_tmpl, Variable): - binding[str(s_tmpl)] = _to_term(triple.s) - if isinstance(p_tmpl, Variable): - binding[str(p_tmpl)] = _to_term(triple.p) - if isinstance(o_tmpl, Variable): - binding[str(o_tmpl)] = _to_term(triple.o) - new_solutions.append(binding) + solutions = new_solutions + if not solutions: + return - solutions = new_solutions - if not solutions: - break +# --- Blocking operators: materialise upstream, then yield --- - return solutions[:limit] +def _is_small_node(node): + """Check if a node is likely to produce a small number of solutions.""" + if not isinstance(node, CompValue): + return False + if node.name in ("values", "ToMultiSet"): + return True + if node.name == "Extend" and hasattr(node, "p"): + return _is_small_node(node.p) + return False async def _eval_join(node, tc, collection, limit): - """Evaluate a Join node.""" - left = await evaluate(node.p1, tc, collection, limit) - right = await evaluate(node.p2, tc, collection, limit) - return hash_join(left, right)[:limit] + # Bind join: if one side is small (e.g. VALUES), materialise it and + # substitute its bindings into the other side's evaluation. This + # turns wildcard BGP queries into selective ones. + if _is_small_node(node.p1): + yield_from = _bind_join(node.p1, node.p2, tc, collection, limit) + elif _is_small_node(node.p2): + yield_from = _bind_join(node.p2, node.p1, tc, collection, limit) + else: + yield_from = _hash_join(node, tc, collection, limit) + + async for sol in yield_from: + yield sol + + +async def _hash_join(node, tc, collection, limit): + left = await materialise(node.p1, tc, collection, limit) + right = await materialise(node.p2, tc, collection, limit) + for sol in hash_join(left, right)[:limit]: + yield sol + + +async def _bind_join(small_node, big_node, tc, collection, limit): + """Iterate over the small side and inject bindings into the big side.""" + small_sols = await materialise(small_node, tc, collection, limit) + + count = 0 + for binding in small_sols: + async for sol in _evaluate_with_bindings( + big_node, binding, tc, collection, limit + ): + yield sol + count += 1 + if count >= limit: + return + + +def _merge_compatible(left, right): + """Merge two solutions if compatible (shared vars have equal values).""" + merged = dict(left) + for k, v in right.items(): + if k in merged: + if _term_key(merged[k]) != _term_key(v): + return None + else: + merged[k] = v + return merged + + +async def _evaluate_with_bindings(node, bindings, tc, collection, limit): + """Evaluate a node with pre-seeded variable bindings. + + For BGP nodes, the bindings are injected so _resolve_term sees them, + turning wildcard queries into selective ones. For other node types, + evaluate normally and merge/filter against the bindings. + """ + if isinstance(node, CompValue) and node.name == "BGP": + async for sol in _eval_bgp_with_bindings( + node, bindings, tc, collection, limit + ): + yield sol + else: + async for sol in evaluate(node, tc, collection, limit): + merged = _merge_compatible(bindings, sol) + if merged is not None: + yield merged + + +async def _eval_bgp_with_bindings(node, bindings, tc, collection, limit): + """Evaluate a BGP with pre-seeded bindings so variables resolve to terms.""" + triples = node.triples + if not triples: + yield dict(bindings) + return + + def selectivity(pattern): + score = 0 + for t in pattern: + if not isinstance(t, Variable): + score += 1 + elif str(t) in bindings: + score += 1 + return score + + sorted_patterns = sorted( + enumerate(triples), key=lambda x: -selectivity(x[1]) + ) + + solutions = [dict(bindings)] + + for pattern_idx, (_, pattern) in enumerate(sorted_patterns): + s_tmpl, p_tmpl, o_tmpl = pattern + is_last = (pattern_idx == len(sorted_patterns) - 1) + + if is_last: + count = 0 + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + yield binding + count += 1 + if count >= limit: + return + else: + new_solutions = [] + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) + + solutions = new_solutions + if not solutions: + return async def _eval_left_join(node, tc, collection, limit): - """Evaluate a LeftJoin node (OPTIONAL).""" - left_sols = await evaluate(node.p1, tc, collection, limit) - right_sols = await evaluate(node.p2, tc, collection, limit) + # Buffer right side for hash index; stream left through probe + left_sols = await materialise(node.p1, tc, collection, limit) + right_sols = await materialise(node.p2, tc, collection, limit) filter_fn = None if hasattr(node, "expr") and node.expr is not None: @@ -149,27 +320,83 @@ async def _eval_left_join(node, tc, collection, limit): evaluate_expression(expr, sol) ) - return left_join(left_sols, right_sols, filter_fn)[:limit] - - -async def _eval_union(node, tc, collection, limit): - """Evaluate a Union node.""" - left = await evaluate(node.p1, tc, collection, limit) - right = await evaluate(node.p2, tc, collection, limit) - return union(left, right)[:limit] + for sol in left_join(left_sols, right_sols, filter_fn)[:limit]: + yield sol async def _eval_minus(node, tc, collection, limit): - """Evaluate a Minus node.""" - left = await evaluate(node.p1, tc, collection, limit) - right = await evaluate(node.p2, tc, collection, limit) - return minus(left, right) + left = await materialise(node.p1, tc, collection, limit) + right = await materialise(node.p2, tc, collection, limit) + for sol in minus(left, right): + yield sol + + +async def _eval_distinct(node, tc, collection, limit): + seen = set() + async for sol in evaluate(node.p, tc, collection, limit): + key = tuple(sorted( + (k, _term_key(v)) for k, v in sol.items() + )) + if key not in seen: + seen.add(key) + yield sol + + +async def _eval_reduced(node, tc, collection, limit): + async for sol in _eval_distinct(node, tc, collection, limit): + yield sol + + +async def _eval_order_by(node, tc, collection, limit): + solutions = await materialise(node.p, tc, collection, limit) + + key_fns = [] + for cond in node.expr: + if isinstance(cond, CompValue) and cond.name == "OrderCondition": + ascending = cond.order != "DESC" + expr = cond.expr + key_fns.append(( + lambda sol, e=expr: evaluate_expression(e, sol), + ascending, + )) + else: + key_fns.append(( + lambda sol, e=cond: evaluate_expression(e, sol), + True, + )) + + for sol in order_by(solutions, key_fns): + yield sol + + +# --- Streamable operators --- + +async def _eval_slice(node, tc, collection, limit): + offset = node.start or 0 + length = node.length + skipped = 0 + emitted = 0 + + async for sol in evaluate(node.p, tc, collection, limit): + if skipped < offset: + skipped += 1 + continue + yield sol + emitted += 1 + if length is not None and emitted >= length: + return + + +async def _eval_union(node, tc, collection, limit): + async for sol in evaluate(node.p1, tc, collection, limit): + yield sol + async for sol in evaluate(node.p2, tc, collection, limit): + yield sol async def _check_exists(graph_node, sol, tc, collection, limit): """Evaluate an EXISTS graph pattern against a solution.""" - results = await evaluate(graph_node, tc, collection, limit) - for r in results: + async for r in evaluate(graph_node, tc, collection, limit): shared = set(sol.keys()) & set(r.keys()) if all( _term_key(sol[v]) == _term_key(r[v]) @@ -206,8 +433,6 @@ async def _pre_eval_exists(expr, sol, tc, collection, limit, cache): async def _eval_filter(node, tc, collection, limit): - """Evaluate a Filter node.""" - solutions = await evaluate(node.p, tc, collection, limit) expr = node.expr exists_cache = {} @@ -215,60 +440,13 @@ async def _eval_filter(node, tc, collection, limit): key = id(graph_node), id(sol) return exists_cache.get(key, False) - result = [] - for sol in solutions: + async for sol in evaluate(node.p, tc, collection, limit): await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) if _effective_boolean(evaluate_expression(expr, sol, exists_cb=exists_cb)): - result.append(sol) - - return result - - -async def _eval_distinct(node, tc, collection, limit): - """Evaluate a Distinct node.""" - solutions = await evaluate(node.p, tc, collection, limit) - return distinct(solutions) - - -async def _eval_reduced(node, tc, collection, limit): - """Evaluate a Reduced node (like Distinct but implementation-defined).""" - # Treat same as Distinct - solutions = await evaluate(node.p, tc, collection, limit) - return distinct(solutions) - - -async def _eval_order_by(node, tc, collection, limit): - """Evaluate an OrderBy node.""" - solutions = await evaluate(node.p, tc, collection, limit) - - key_fns = [] - for cond in node.expr: - if isinstance(cond, CompValue) and cond.name == "OrderCondition": - ascending = cond.order != "DESC" - expr = cond.expr - key_fns.append(( - lambda sol, e=expr: evaluate_expression(e, sol), - ascending, - )) - else: - # Simple variable or expression - key_fns.append(( - lambda sol, e=cond: evaluate_expression(e, sol), - True, - )) - - return order_by(solutions, key_fns) - - -async def _eval_slice(node, tc, collection, limit): - """Evaluate a Slice node (LIMIT/OFFSET).""" - solutions = await evaluate(node.p, tc, collection, limit) - return slice_solutions(solutions, node.start or 0, node.length) + yield sol async def _eval_extend(node, tc, collection, limit): - """Evaluate an Extend node (BIND).""" - solutions = await evaluate(node.p, tc, collection, limit) var_name = str(node.var) expr = node.expr exists_cache = {} @@ -277,8 +455,7 @@ async def _eval_extend(node, tc, collection, limit): key = id(graph_node), id(sol) return exists_cache.get(key, False) - result = [] - for sol in solutions: + async for sol in evaluate(node.p, tc, collection, limit): await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) val = evaluate_expression(expr, sol, exists_cb=exists_cb) new_sol = dict(sol) @@ -295,16 +472,14 @@ async def _eval_extend(node, tc, collection, limit): ) elif val is not None: new_sol[var_name] = Term(type=LITERAL, value=str(val)) - result.append(new_sol) + yield new_sol - return result +# --- Aggregation (blocking) --- async def _eval_group(node, tc, collection, limit): - """Evaluate a Group node (GROUP BY with aggregation).""" - solutions = await evaluate(node.p, tc, collection, limit) + solutions = await materialise(node.p, tc, collection, limit) - # Extract grouping expressions group_exprs = [] if hasattr(node, "expr") and node.expr: for expr in node.expr: @@ -315,7 +490,6 @@ async def _eval_group(node, tc, collection, limit): else: group_exprs.append((expr, None)) - # Group solutions groups = defaultdict(list) for sol in solutions: key_parts = [] @@ -325,81 +499,72 @@ async def _eval_group(node, tc, collection, limit): groups[tuple(key_parts)].append(sol) if not group_exprs: - # No GROUP BY - entire result is one group groups[()].extend(solutions) - # Build grouped solutions (one per group) - result = [] for key, group_sols in groups.items(): sol = {} - # Include group key variables if group_sols: for (expr, var_name), k in zip(group_exprs, key): if var_name and group_sols: sol[var_name] = evaluate_expression(expr, group_sols[0]) sol["__group__"] = group_sols - result.append(sol) - - return result + yield sol async def _eval_aggregate_join(node, tc, collection, limit): - """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" - solutions = await evaluate(node.p, tc, collection, limit) - - result = [] - for sol in solutions: + async for sol in evaluate(node.p, tc, collection, limit): group = sol.get("__group__", [sol]) new_sol = {k: v for k, v in sol.items() if k != "__group__"} - # Apply aggregate functions if hasattr(node, "A") and node.A: for agg in node.A: var_name = str(agg.res) agg_val = _compute_aggregate(agg, group) new_sol[var_name] = agg_val - result.append(new_sol) - - return result + yield new_sol async def _eval_graph(node, tc, collection, limit): - """Evaluate a Graph node (GRAPH clause).""" term = node.term if isinstance(term, URIRef): - # GRAPH { ... } — fixed graph - # We'd need to pass graph to triples queries - # For now, evaluate inner pattern normally logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") - return await evaluate(node.p, tc, collection, limit) elif isinstance(term, Variable): - # GRAPH ?g { ... } — variable graph logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") - return await evaluate(node.p, tc, collection, limit) - else: - return await evaluate(node.p, tc, collection, limit) + + async for sol in evaluate(node.p, tc, collection, limit): + yield sol async def _eval_values(node, tc, collection, limit): - """Evaluate a VALUES clause (inline data).""" - variables = [str(v) for v in node.var] - solutions = [] + # rdflib has two representations for VALUES: + # 1. var=[Variable...], value=[[val, ...], ...] — positional + # 2. var=None, res=[{Variable: val, ...}, ...] — dict-based + if hasattr(node, "res") and node.res: + for row in node.res: + sol = {} + for var, val in row.items(): + if val is not None and str(val) != "UNDEF": + sol[str(var)] = rdflib_term_to_term(val) + yield sol + return + if not node.var or not node.value: + yield {} + return + variables = [str(v) for v in node.var] for row in node.value: sol = {} for var_name, val in zip(variables, row): if val is not None and str(val) != "UNDEF": sol[var_name] = rdflib_term_to_term(val) - solutions.append(sol) - - return solutions + yield sol async def _eval_to_multiset(node, tc, collection, limit): - """Evaluate a ToMultiSet node (subquery).""" - return await evaluate(node.p, tc, collection, limit) + async for sol in evaluate(node.p, tc, collection, limit): + yield sol # --- Aggregate computation --- @@ -408,7 +573,6 @@ def _compute_aggregate(agg, group): """Compute a single aggregate function over a group of solutions.""" agg_name = agg.name if hasattr(agg, "name") else "" - # Get the expression to aggregate expr = agg.vars if hasattr(agg, "vars") else None if agg_name == "Aggregate_Count": diff --git a/trustgraph-flow/trustgraph/query/sparql/expressions.py b/trustgraph-flow/trustgraph/query/sparql/expressions.py index ad3202d9..608eeff2 100644 --- a/trustgraph-flow/trustgraph/query/sparql/expressions.py +++ b/trustgraph-flow/trustgraph/query/sparql/expressions.py @@ -125,6 +125,13 @@ def _evaluate_comp_value(node, solution): if name == "MultiplicativeExpression": return _eval_multiplicative(node, solution) + # IN / NOT IN — must be checked before the generic Builtin_ dispatch + if name == "Builtin_IN": + return _eval_in(node, solution) + + if name == "Builtin_NOTIN": + return not _eval_in(node, solution) + # SPARQL built-in functions if name.startswith("Builtin_"): return _eval_builtin(name, node, solution) @@ -133,19 +140,10 @@ def _evaluate_comp_value(node, solution): if name == "Function": return _eval_function(node, solution) - # Exists / NotExists — handled via _eval_builtin now - # TrueFilter (used with OPTIONAL) if name == "TrueFilter": return True - # IN / NOT IN - if name == "Builtin_IN": - return _eval_in(node, solution) - - if name == "Builtin_NOTIN": - return not _eval_in(node, solution) - logger.warning(f"Unknown CompValue expression: {name}") return None @@ -171,6 +169,22 @@ def _eval_relational(node, solution): ">=": operator.ge, } + if str(op) == "IN": + items = node.other if isinstance(node.other, list) else [node.other] + for item in items: + other_val = evaluate_expression(item, solution) + if _comparable_value(left) == _comparable_value(other_val): + return True + return False + + if str(op) == "NOT IN": + items = node.other if isinstance(node.other, list) else [node.other] + for item in items: + other_val = evaluate_expression(item, solution) + if _comparable_value(left) == _comparable_value(other_val): + return False + return True + op_fn = ops.get(str(op)) if op_fn is None: logger.warning(f"Unknown relational operator: {op}") diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index 75c00dba..bbe375f0 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -12,7 +12,7 @@ from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import TriplesClientSpec from . parser import parse_sparql, ParseError -from . algebra import evaluate, EvaluationError +from . algebra import evaluate, materialise, EvaluationError logger = logging.getLogger(__name__) @@ -66,11 +66,10 @@ class Processor(FlowProcessor): logger.debug(f"Handling SPARQL query request {id}...") - response = await self.execute_sparql(request, flow) - - if request.streaming and response.query_type == "select": - await self.send_streaming(response, flow, id, request) + if request.streaming: + await self.execute_sparql_streaming(request, flow, id) else: + response = await self.execute_sparql(request, flow) await flow("response").send( response, properties={"id": id} ) @@ -92,37 +91,77 @@ class Processor(FlowProcessor): await flow("response").send(r, properties={"id": id}) - async def send_streaming(self, response, flow, id, request): - """Send SELECT results in batches.""" + async def execute_sparql_streaming(self, request, flow, id): + """Execute a SPARQL query and stream results as they arrive.""" - bindings = response.bindings + try: + parsed = parse_sparql(request.query) + except ParseError as e: + await flow("response").send( + SparqlQueryResponse( + error=Error( + type="sparql-parse-error", + message=str(e), + ), + ), + properties={"id": id} + ) + return + + if parsed.query_type != "select": + response = await self._execute_non_select(parsed, request, flow) + await flow("response").send(response, properties={"id": id}) + return + + triples_client = flow("triples-request") + variables = parsed.variables batch_size = request.batch_size if request.batch_size > 0 else 20 + batch = [] - for i in range(0, len(bindings), batch_size): - batch = bindings[i:i + batch_size] - is_final = (i + batch_size >= len(bindings)) - r = SparqlQueryResponse( - query_type=response.query_type, - variables=response.variables, - bindings=batch, - is_final=is_final, - ) - await flow("response").send(r, properties={"id": id}) + try: + async for sol in evaluate( + parsed.algebra, + triples_client, + collection=request.collection or "default", + limit=request.limit or 10000, + ): + values = [sol.get(v) for v in variables] + batch.append(SparqlBinding(values=values)) - # Handle empty results - if len(bindings) == 0: - r = SparqlQueryResponse( - query_type=response.query_type, - variables=response.variables, - bindings=[], - is_final=True, + if len(batch) >= batch_size: + r = SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=batch, + is_final=False, + ) + await flow("response").send(r, properties={"id": id}) + batch = [] + + except EvaluationError as e: + await flow("response").send( + SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ), + properties={"id": id} ) - await flow("response").send(r, properties={"id": id}) + return + + # Final batch (may be empty for zero results) + r = SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=batch, + is_final=True, + ) + await flow("response").send(r, properties={"id": id}) async def execute_sparql(self, request, flow): - """Parse and evaluate a SPARQL query.""" + """Parse and evaluate a SPARQL query (non-streaming).""" - # Parse the SPARQL query try: parsed = parse_sparql(request.query) except ParseError as e: @@ -133,12 +172,31 @@ class Processor(FlowProcessor): ), ) - # Get the triples client from the flow - triples_client = flow("triples-request") + if parsed.query_type == "select": + triples_client = flow("triples-request") + try: + solutions = await materialise( + parsed.algebra, + triples_client, + collection=request.collection or "default", + limit=request.limit or 10000, + ) + except EvaluationError as e: + return SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ) + return self._build_select_response(parsed, solutions) - # Evaluate the algebra + return await self._execute_non_select(parsed, request, flow) + + async def _execute_non_select(self, parsed, request, flow): + """Execute ASK, CONSTRUCT, or DESCRIBE queries.""" + triples_client = flow("triples-request") try: - solutions = await evaluate( + solutions = await materialise( parsed.algebra, triples_client, collection=request.collection or "default", @@ -152,10 +210,7 @@ class Processor(FlowProcessor): ), ) - # Build response based on query type - if parsed.query_type == "select": - return self._build_select_response(parsed, solutions) - elif parsed.query_type == "ask": + if parsed.query_type == "ask": return self._build_ask_response(solutions) elif parsed.query_type == "construct": return self._build_construct_response(parsed, solutions)