From bb1109963c1ffc5e7dc8e091251db730546748d7 Mon Sep 17 00:00:00 2001 From: cybermaggedon Date: Thu, 14 May 2026 12:03:43 +0100 Subject: [PATCH] Remove spurious workspace parameter from SPARQL algebra evaluator (#915) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix threading of workspace paramater: - The SPARQL algebra evaluator was threading a workspace parameter through every function and passing it to TriplesClient.query(), which doesn't accept it. Workspace isolation is handled by pub/sub topic routing — the TriplesClient is already scoped to a workspace-specific flow, same as GraphRAG. Passing workspace explicitly was both incorrect and unnecessary. Update tests: - tests/unit/test_query/test_sparql_algebra.py (new) — Tests _query_pattern, _eval_bgp, and evaluate() with various algebra nodes. Key tests assert workspace is never in tc.query() kwargs, plus correctness tests for BGP, JOIN, UNION, SLICE, DISTINCT, and edge cases. - tests/unit/test_retrieval/test_graph_rag.py — Added test_triples_query_never_passes_workspace (checks query()) and test_follow_edges_never_passes_workspace (checks query_stream()). --- tests/unit/test_query/test_sparql_algebra.py | 302 ++++++++++++++++++ tests/unit/test_retrieval/test_graph_rag.py | 51 +++ .../trustgraph/query/sparql/algebra.py | 84 +++-- .../trustgraph/query/sparql/service.py | 1 - 4 files changed, 394 insertions(+), 44 deletions(-) create mode 100644 tests/unit/test_query/test_sparql_algebra.py diff --git a/tests/unit/test_query/test_sparql_algebra.py b/tests/unit/test_query/test_sparql_algebra.py new file mode 100644 index 00000000..9827b2de --- /dev/null +++ b/tests/unit/test_query/test_sparql_algebra.py @@ -0,0 +1,302 @@ +""" +Tests for the SPARQL algebra evaluator. + +Verifies that evaluate() and _query_pattern() call TriplesClient.query() +with the correct arguments, and in particular that workspace is never +passed — workspace isolation is handled by pub/sub topic routing. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, call + +from rdflib.term import Variable, URIRef, Literal +from rdflib.plugins.sparql.parserutils import CompValue + +from trustgraph.schema import Term, IRI, LITERAL +from trustgraph.query.sparql.algebra import ( + evaluate, _query_pattern, _eval_bgp, +) + + +# --- Helpers --- + +def iri(v): + return Term(type=IRI, iri=v) + + +def lit(v): + return Term(type=LITERAL, value=v) + + +def make_triple(s, p, o): + t = MagicMock() + t.s = s + t.p = p + t.o = o + return t + + +def make_bgp(*patterns): + """Build a CompValue BGP node from (s, p, o) tuples of rdflib terms.""" + node = CompValue("BGP") + node.triples = list(patterns) + return node + + +def make_project(inner, variables): + node = CompValue("Project") + node.p = inner + node.PV = [Variable(v) for v in variables] + return node + + +def make_select(inner): + node = CompValue("SelectQuery") + node.p = inner + return node + + +def make_join(left, right): + node = CompValue("Join") + node.p1 = left + node.p2 = right + return node + + +def make_union(left, right): + node = CompValue("Union") + node.p1 = left + node.p2 = right + return node + + +def make_slice(inner, start, length): + node = CompValue("Slice") + node.p = inner + node.start = start + node.length = length + return node + + +def make_distinct(inner): + node = CompValue("Distinct") + node.p = inner + return node + + +class TestQueryPattern: + """Tests for _query_pattern — the leaf that calls TriplesClient.""" + + @pytest.mark.asyncio + async def test_passes_correct_args(self): + tc = AsyncMock() + tc.query.return_value = [] + + await _query_pattern( + tc, + s=iri("http://example.com/s"), + p=iri("http://example.com/p"), + o=None, + collection="my-collection", + limit=100, + ) + + tc.query.assert_called_once_with( + s=iri("http://example.com/s"), + p=iri("http://example.com/p"), + o=None, + limit=100, + collection="my-collection", + ) + + @pytest.mark.asyncio + async def test_workspace_not_passed(self): + tc = AsyncMock() + tc.query.return_value = [] + + await _query_pattern(tc, None, None, None, "default", 10) + + kwargs = tc.query.call_args.kwargs + assert "workspace" not in kwargs + + @pytest.mark.asyncio + async def test_returns_query_results(self): + tc = AsyncMock() + triple = make_triple(iri("http://a"), iri("http://b"), lit("c")) + tc.query.return_value = [triple] + + results = await _query_pattern(tc, None, None, None, "default", 10) + + assert len(results) == 1 + assert results[0].s.iri == "http://a" + + +class TestEvalBgp: + """Tests for BGP evaluation — triple pattern queries.""" + + @pytest.mark.asyncio + async def test_single_pattern_all_variables(self): + tc = AsyncMock() + triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) + tc.query.return_value = [triple] + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + + solutions = await evaluate(bgp, tc, collection="default", limit=100) + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://s" + assert solutions[0]["p"].iri == "http://p" + assert solutions[0]["o"].value == "o" + + @pytest.mark.asyncio + async def test_single_pattern_bound_subject(self): + tc = AsyncMock() + tc.query.return_value = [ + make_triple(iri("http://s"), iri("http://p"), lit("val")), + ] + + bgp = make_bgp( + (URIRef("http://s"), Variable("p"), Variable("o")), + ) + + solutions = await evaluate(bgp, tc, collection="default") + + tc.query.assert_called_once() + kwargs = tc.query.call_args.kwargs + assert "workspace" not in kwargs + assert kwargs["collection"] == "default" + + @pytest.mark.asyncio + async def test_empty_bgp_returns_empty_solution(self): + tc = AsyncMock() + + bgp = make_bgp() + + solutions = await evaluate(bgp, tc, collection="default") + + assert solutions == [{}] + tc.query.assert_not_called() + + @pytest.mark.asyncio + async def test_no_results_returns_empty(self): + tc = AsyncMock() + tc.query.return_value = [] + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + + solutions = await evaluate(bgp, tc, collection="default") + + assert solutions == [] + + +class TestEvaluate: + """Tests for the top-level evaluate() dispatcher.""" + + @pytest.mark.asyncio + async def test_select_query_node(self): + tc = AsyncMock() + tc.query.return_value = [ + make_triple(iri("http://s"), iri("http://p"), lit("o")), + ] + + bgp = make_bgp( + (Variable("s"), Variable("p"), Variable("o")), + ) + select = make_select(make_project(bgp, ["s", "p"])) + + solutions = await evaluate(select, tc, collection="default") + + assert len(solutions) == 1 + assert "s" in solutions[0] + assert "p" in solutions[0] + assert "o" not in solutions[0] + + @pytest.mark.asyncio + async def test_workspace_never_in_query_calls(self): + """Verify that no matter the algebra structure, workspace is never + passed to TriplesClient.query().""" + tc = AsyncMock() + tc.query.return_value = [ + make_triple(iri("http://s"), iri("http://p"), lit("o")), + ] + + bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c"))) + tree = make_select(make_project( + make_union(bgp1, bgp2), ["s", "p", "o"] + )) + + await evaluate(tree, tc, collection="test-coll") + + for c in tc.query.call_args_list: + assert "workspace" not in c.kwargs + + @pytest.mark.asyncio + async def test_join(self): + tc = AsyncMock() + tc.query.side_effect = [ + [make_triple(iri("http://a"), iri("http://p"), lit("v"))], + [make_triple(iri("http://a"), iri("http://q"), lit("w"))], + ] + + bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1"))) + bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2"))) + tree = make_join(bgp1, bgp2) + + solutions = await evaluate(tree, tc, collection="default") + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://a" + + @pytest.mark.asyncio + async def test_slice(self): + tc = AsyncMock() + triples = [ + make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}")) + for i in range(5) + ] + tc.query.return_value = triples + + bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + tree = make_slice(bgp, start=1, length=2) + + solutions = await evaluate(tree, tc, collection="default") + + assert len(solutions) == 2 + + @pytest.mark.asyncio + async def test_distinct(self): + tc = AsyncMock() + triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) + tc.query.return_value = [triple, triple] + + bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) + tree = make_distinct(bgp) + + solutions = await evaluate(tree, tc, collection="default") + + assert len(solutions) == 1 + + @pytest.mark.asyncio + async def test_unsupported_node_returns_empty_solution(self): + tc = AsyncMock() + + node = CompValue("SomethingUnknown") + + solutions = await evaluate(node, tc, collection="default") + + assert solutions == [{}] + tc.query.assert_not_called() + + @pytest.mark.asyncio + async def test_non_compvalue_returns_empty_solution(self): + tc = AsyncMock() + + solutions = await evaluate("not a node", tc, collection="default") + + assert solutions == [{}] diff --git a/tests/unit/test_retrieval/test_graph_rag.py b/tests/unit/test_retrieval/test_graph_rag.py index e0f41357..d1979211 100644 --- a/tests/unit/test_retrieval/test_graph_rag.py +++ b/tests/unit/test_retrieval/test_graph_rag.py @@ -337,6 +337,57 @@ class TestQuery: cache_key = "test_collection:unlabeled_entity" mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity") + @pytest.mark.asyncio + async def test_triples_query_never_passes_workspace(self): + """Workspace isolation is handled by pub/sub topic routing, not + by passing workspace to TriplesClient.query(). Verify that + GraphRAG never passes workspace as a keyword argument.""" + mock_rag = MagicMock() + mock_cache = MagicMock() + mock_cache.get.return_value = None + mock_rag.label_cache = mock_cache + mock_triples_client = AsyncMock() + mock_rag.triples_client = mock_triples_client + + mock_triple = MagicMock() + mock_triple.o = "Label" + mock_triples_client.query.return_value = [mock_triple] + + query = Query( + rag=mock_rag, + collection="test_collection", + verbose=False + ) + + await query.maybe_label("http://example.com/entity") + + for c in mock_triples_client.query.call_args_list: + assert "workspace" not in c.kwargs + + @pytest.mark.asyncio + async def test_follow_edges_never_passes_workspace(self): + """Verify follow_edges never passes workspace to query_stream.""" + mock_rag = MagicMock() + mock_triples_client = AsyncMock() + mock_rag.triples_client = mock_triples_client + + mock_triple = MagicMock() + mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1" + mock_triples_client.query_stream.return_value = [mock_triple] + + query = Query( + rag=mock_rag, + collection="test_collection", + verbose=False, + triple_limit=10 + ) + + subgraph = set() + await query.follow_edges("e1", subgraph, path_length=1) + + for c in mock_triples_client.query_stream.call_args_list: + assert "workspace" not in c.kwargs + @pytest.mark.asyncio async def test_follow_edges_basic_functionality(self): """Test Query.follow_edges method basic triple discovery""" diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index bff9a336..76b1ad8e 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -30,14 +30,13 @@ class EvaluationError(Exception): pass -async def evaluate(node, triples_client, workspace, collection, limit=10000): +async def evaluate(node, triples_client, collection, limit=10000): """ Evaluate a SPARQL algebra node. Args: node: rdflib CompValue algebra node triples_client: TriplesClient instance for triple pattern queries - workspace: workspace/keyspace identifier collection: collection identifier limit: safety limit on results @@ -55,24 +54,24 @@ async def evaluate(node, triples_client, workspace, collection, limit=10000): logger.warning(f"Unsupported algebra node: {name}") return [{}] - return await handler(node, triples_client, workspace, collection, limit) + return await handler(node, triples_client, collection, limit) # --- Node handlers --- -async def _eval_select_query(node, tc, workspace, collection, limit): +async def _eval_select_query(node, tc, collection, limit): """Evaluate a SelectQuery node.""" - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) -async def _eval_project(node, tc, workspace, collection, limit): +async def _eval_project(node, tc, collection, limit): """Evaluate a Project node (SELECT variable projection).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) variables = [str(v) for v in node.PV] return project(solutions, variables) -async def _eval_bgp(node, tc, workspace, collection, limit): +async def _eval_bgp(node, tc, collection, limit): """ Evaluate a Basic Graph Pattern. @@ -107,7 +106,7 @@ async def _eval_bgp(node, tc, workspace, collection, limit): # Query the triples store results = await _query_pattern( - tc, s_val, p_val, o_val, workspace, collection, limit + tc, s_val, p_val, o_val, collection, limit ) # Map results back to variable bindings, @@ -130,17 +129,17 @@ async def _eval_bgp(node, tc, workspace, collection, limit): return solutions[:limit] -async def _eval_join(node, tc, workspace, collection, limit): +async def _eval_join(node, tc, collection, limit): """Evaluate a Join node.""" - left = await evaluate(node.p1, tc, workspace, collection, limit) - right = await evaluate(node.p2, tc, workspace, collection, limit) + left = await evaluate(node.p1, tc, collection, limit) + right = await evaluate(node.p2, tc, collection, limit) return hash_join(left, right)[:limit] -async def _eval_left_join(node, tc, workspace, collection, limit): +async def _eval_left_join(node, tc, collection, limit): """Evaluate a LeftJoin node (OPTIONAL).""" - left_sols = await evaluate(node.p1, tc, workspace, collection, limit) - right_sols = await evaluate(node.p2, tc, workspace, collection, limit) + left_sols = await evaluate(node.p1, tc, collection, limit) + right_sols = await evaluate(node.p2, tc, collection, limit) filter_fn = None if hasattr(node, "expr") and node.expr is not None: @@ -153,16 +152,16 @@ async def _eval_left_join(node, tc, workspace, collection, limit): return left_join(left_sols, right_sols, filter_fn)[:limit] -async def _eval_union(node, tc, workspace, collection, limit): +async def _eval_union(node, tc, collection, limit): """Evaluate a Union node.""" - left = await evaluate(node.p1, tc, workspace, collection, limit) - right = await evaluate(node.p2, tc, workspace, collection, limit) + left = await evaluate(node.p1, tc, collection, limit) + right = await evaluate(node.p2, tc, collection, limit) return union(left, right)[:limit] -async def _eval_filter(node, tc, workspace, collection, limit): +async def _eval_filter(node, tc, collection, limit): """Evaluate a Filter node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) expr = node.expr return [ sol for sol in solutions @@ -170,22 +169,22 @@ async def _eval_filter(node, tc, workspace, collection, limit): ] -async def _eval_distinct(node, tc, workspace, collection, limit): +async def _eval_distinct(node, tc, collection, limit): """Evaluate a Distinct node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) return distinct(solutions) -async def _eval_reduced(node, tc, workspace, collection, limit): +async def _eval_reduced(node, tc, collection, limit): """Evaluate a Reduced node (like Distinct but implementation-defined).""" # Treat same as Distinct - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) return distinct(solutions) -async def _eval_order_by(node, tc, workspace, collection, limit): +async def _eval_order_by(node, tc, collection, limit): """Evaluate an OrderBy node.""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) key_fns = [] for cond in node.expr: @@ -206,7 +205,7 @@ async def _eval_order_by(node, tc, workspace, collection, limit): return order_by(solutions, key_fns) -async def _eval_slice(node, tc, workspace, collection, limit): +async def _eval_slice(node, tc, collection, limit): """Evaluate a Slice node (LIMIT/OFFSET).""" # Pass tighter limit downstream if possible inner_limit = limit @@ -214,13 +213,13 @@ async def _eval_slice(node, tc, workspace, collection, limit): offset = node.start or 0 inner_limit = min(limit, offset + node.length) - solutions = await evaluate(node.p, tc, workspace, collection, inner_limit) + solutions = await evaluate(node.p, tc, collection, inner_limit) return slice_solutions(solutions, node.start or 0, node.length) -async def _eval_extend(node, tc, workspace, collection, limit): +async def _eval_extend(node, tc, collection, limit): """Evaluate an Extend node (BIND).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) var_name = str(node.var) expr = node.expr @@ -246,9 +245,9 @@ async def _eval_extend(node, tc, workspace, collection, limit): return result -async def _eval_group(node, tc, workspace, collection, limit): +async def _eval_group(node, tc, collection, limit): """Evaluate a Group node (GROUP BY with aggregation).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) # Extract grouping expressions group_exprs = [] @@ -289,9 +288,9 @@ async def _eval_group(node, tc, workspace, collection, limit): return result -async def _eval_aggregate_join(node, tc, workspace, collection, limit): +async def _eval_aggregate_join(node, tc, collection, limit): """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" - solutions = await evaluate(node.p, tc, workspace, collection, limit) + solutions = await evaluate(node.p, tc, collection, limit) result = [] for sol in solutions: @@ -310,7 +309,7 @@ async def _eval_aggregate_join(node, tc, workspace, collection, limit): return result -async def _eval_graph(node, tc, workspace, collection, limit): +async def _eval_graph(node, tc, collection, limit): """Evaluate a Graph node (GRAPH clause).""" term = node.term @@ -319,16 +318,16 @@ async def _eval_graph(node, tc, workspace, collection, limit): # We'd need to pass graph to triples queries # For now, evaluate inner pattern normally logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) elif isinstance(term, Variable): # GRAPH ?g { ... } — variable graph logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) else: - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) -async def _eval_values(node, tc, workspace, collection, limit): +async def _eval_values(node, tc, collection, limit): """Evaluate a VALUES clause (inline data).""" variables = [str(v) for v in node.var] solutions = [] @@ -343,9 +342,9 @@ async def _eval_values(node, tc, workspace, collection, limit): return solutions -async def _eval_to_multiset(node, tc, workspace, collection, limit): +async def _eval_to_multiset(node, tc, collection, limit): """Evaluate a ToMultiSet node (subquery).""" - return await evaluate(node.p, tc, workspace, collection, limit) + return await evaluate(node.p, tc, collection, limit) # --- Aggregate computation --- @@ -487,7 +486,7 @@ def _resolve_term(tmpl, solution): return rdflib_term_to_term(tmpl) -async def _query_pattern(tc, s, p, o, workspace, collection, limit): +async def _query_pattern(tc, s, p, o, collection, limit): """ Issue a streaming triple pattern query via TriplesClient. @@ -496,7 +495,6 @@ async def _query_pattern(tc, s, p, o, workspace, collection, limit): results = await tc.query( s=s, p=p, o=o, limit=limit, - workspace=workspace, collection=collection, ) return results diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index 983cd4f6..75c00dba 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -141,7 +141,6 @@ class Processor(FlowProcessor): solutions = await evaluate( parsed.algebra, triples_client, - workspace=flow.workspace, collection=request.collection or "default", limit=request.limit or 10000, )