Remove spurious workspace parameter from SPARQL algebra evaluator (#915)

Fix threading of workspace paramater:
- The SPARQL algebra evaluator was threading a workspace parameter
  through every function and passing it to TriplesClient.query(),
  which doesn't accept it. Workspace isolation is handled by pub/sub
  topic routing — the TriplesClient is already scoped to a
  workspace-specific flow, same as GraphRAG. Passing workspace
  explicitly was both incorrect and unnecessary.

Update tests:
- tests/unit/test_query/test_sparql_algebra.py (new) — Tests
  _query_pattern, _eval_bgp, and evaluate() with various algebra
  nodes. Key tests assert workspace is never in tc.query() kwargs,
  plus correctness tests for BGP, JOIN, UNION, SLICE, DISTINCT, and
  edge cases.
- tests/unit/test_retrieval/test_graph_rag.py — Added
  test_triples_query_never_passes_workspace (checks query()) and
  test_follow_edges_never_passes_workspace (checks query_stream()).
This commit is contained in:
cybermaggedon 2026-05-14 12:03:43 +01:00 committed by GitHub
parent f0ad282708
commit bb1109963c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 394 additions and 44 deletions

View file

@ -0,0 +1,302 @@
"""
Tests for the SPARQL algebra evaluator.
Verifies that evaluate() and _query_pattern() call TriplesClient.query()
with the correct arguments, and in particular that workspace is never
passed workspace isolation is handled by pub/sub topic routing.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, call
from rdflib.term import Variable, URIRef, Literal
from rdflib.plugins.sparql.parserutils import CompValue
from trustgraph.schema import Term, IRI, LITERAL
from trustgraph.query.sparql.algebra import (
evaluate, _query_pattern, _eval_bgp,
)
# --- Helpers ---
def iri(v):
return Term(type=IRI, iri=v)
def lit(v):
return Term(type=LITERAL, value=v)
def make_triple(s, p, o):
t = MagicMock()
t.s = s
t.p = p
t.o = o
return t
def make_bgp(*patterns):
"""Build a CompValue BGP node from (s, p, o) tuples of rdflib terms."""
node = CompValue("BGP")
node.triples = list(patterns)
return node
def make_project(inner, variables):
node = CompValue("Project")
node.p = inner
node.PV = [Variable(v) for v in variables]
return node
def make_select(inner):
node = CompValue("SelectQuery")
node.p = inner
return node
def make_join(left, right):
node = CompValue("Join")
node.p1 = left
node.p2 = right
return node
def make_union(left, right):
node = CompValue("Union")
node.p1 = left
node.p2 = right
return node
def make_slice(inner, start, length):
node = CompValue("Slice")
node.p = inner
node.start = start
node.length = length
return node
def make_distinct(inner):
node = CompValue("Distinct")
node.p = inner
return node
class TestQueryPattern:
"""Tests for _query_pattern — the leaf that calls TriplesClient."""
@pytest.mark.asyncio
async def test_passes_correct_args(self):
tc = AsyncMock()
tc.query.return_value = []
await _query_pattern(
tc,
s=iri("http://example.com/s"),
p=iri("http://example.com/p"),
o=None,
collection="my-collection",
limit=100,
)
tc.query.assert_called_once_with(
s=iri("http://example.com/s"),
p=iri("http://example.com/p"),
o=None,
limit=100,
collection="my-collection",
)
@pytest.mark.asyncio
async def test_workspace_not_passed(self):
tc = AsyncMock()
tc.query.return_value = []
await _query_pattern(tc, None, None, None, "default", 10)
kwargs = tc.query.call_args.kwargs
assert "workspace" not in kwargs
@pytest.mark.asyncio
async def test_returns_query_results(self):
tc = AsyncMock()
triple = make_triple(iri("http://a"), iri("http://b"), lit("c"))
tc.query.return_value = [triple]
results = await _query_pattern(tc, None, None, None, "default", 10)
assert len(results) == 1
assert results[0].s.iri == "http://a"
class TestEvalBgp:
"""Tests for BGP evaluation — triple pattern queries."""
@pytest.mark.asyncio
async def test_single_pattern_all_variables(self):
tc = AsyncMock()
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
tc.query.return_value = [triple]
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
solutions = await evaluate(bgp, tc, collection="default", limit=100)
assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://s"
assert solutions[0]["p"].iri == "http://p"
assert solutions[0]["o"].value == "o"
@pytest.mark.asyncio
async def test_single_pattern_bound_subject(self):
tc = AsyncMock()
tc.query.return_value = [
make_triple(iri("http://s"), iri("http://p"), lit("val")),
]
bgp = make_bgp(
(URIRef("http://s"), Variable("p"), Variable("o")),
)
solutions = await evaluate(bgp, tc, collection="default")
tc.query.assert_called_once()
kwargs = tc.query.call_args.kwargs
assert "workspace" not in kwargs
assert kwargs["collection"] == "default"
@pytest.mark.asyncio
async def test_empty_bgp_returns_empty_solution(self):
tc = AsyncMock()
bgp = make_bgp()
solutions = await evaluate(bgp, tc, collection="default")
assert solutions == [{}]
tc.query.assert_not_called()
@pytest.mark.asyncio
async def test_no_results_returns_empty(self):
tc = AsyncMock()
tc.query.return_value = []
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
solutions = await evaluate(bgp, tc, collection="default")
assert solutions == []
class TestEvaluate:
"""Tests for the top-level evaluate() dispatcher."""
@pytest.mark.asyncio
async def test_select_query_node(self):
tc = AsyncMock()
tc.query.return_value = [
make_triple(iri("http://s"), iri("http://p"), lit("o")),
]
bgp = make_bgp(
(Variable("s"), Variable("p"), Variable("o")),
)
select = make_select(make_project(bgp, ["s", "p"]))
solutions = await evaluate(select, tc, collection="default")
assert len(solutions) == 1
assert "s" in solutions[0]
assert "p" in solutions[0]
assert "o" not in solutions[0]
@pytest.mark.asyncio
async def test_workspace_never_in_query_calls(self):
"""Verify that no matter the algebra structure, workspace is never
passed to TriplesClient.query()."""
tc = AsyncMock()
tc.query.return_value = [
make_triple(iri("http://s"), iri("http://p"), lit("o")),
]
bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o")))
bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c")))
tree = make_select(make_project(
make_union(bgp1, bgp2), ["s", "p", "o"]
))
await evaluate(tree, tc, collection="test-coll")
for c in tc.query.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_join(self):
tc = AsyncMock()
tc.query.side_effect = [
[make_triple(iri("http://a"), iri("http://p"), lit("v"))],
[make_triple(iri("http://a"), iri("http://q"), lit("w"))],
]
bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1")))
bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2")))
tree = make_join(bgp1, bgp2)
solutions = await evaluate(tree, tc, collection="default")
assert len(solutions) == 1
assert solutions[0]["s"].iri == "http://a"
@pytest.mark.asyncio
async def test_slice(self):
tc = AsyncMock()
triples = [
make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}"))
for i in range(5)
]
tc.query.return_value = triples
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
tree = make_slice(bgp, start=1, length=2)
solutions = await evaluate(tree, tc, collection="default")
assert len(solutions) == 2
@pytest.mark.asyncio
async def test_distinct(self):
tc = AsyncMock()
triple = make_triple(iri("http://s"), iri("http://p"), lit("o"))
tc.query.return_value = [triple, triple]
bgp = make_bgp((Variable("s"), Variable("p"), Variable("o")))
tree = make_distinct(bgp)
solutions = await evaluate(tree, tc, collection="default")
assert len(solutions) == 1
@pytest.mark.asyncio
async def test_unsupported_node_returns_empty_solution(self):
tc = AsyncMock()
node = CompValue("SomethingUnknown")
solutions = await evaluate(node, tc, collection="default")
assert solutions == [{}]
tc.query.assert_not_called()
@pytest.mark.asyncio
async def test_non_compvalue_returns_empty_solution(self):
tc = AsyncMock()
solutions = await evaluate("not a node", tc, collection="default")
assert solutions == [{}]

View file

@ -337,6 +337,57 @@ class TestQuery:
cache_key = "test_collection:unlabeled_entity" cache_key = "test_collection:unlabeled_entity"
mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity") mock_cache.put.assert_called_once_with(cache_key, "unlabeled_entity")
@pytest.mark.asyncio
async def test_triples_query_never_passes_workspace(self):
"""Workspace isolation is handled by pub/sub topic routing, not
by passing workspace to TriplesClient.query(). Verify that
GraphRAG never passes workspace as a keyword argument."""
mock_rag = MagicMock()
mock_cache = MagicMock()
mock_cache.get.return_value = None
mock_rag.label_cache = mock_cache
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triple = MagicMock()
mock_triple.o = "Label"
mock_triples_client.query.return_value = [mock_triple]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False
)
await query.maybe_label("http://example.com/entity")
for c in mock_triples_client.query.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio
async def test_follow_edges_never_passes_workspace(self):
"""Verify follow_edges never passes workspace to query_stream."""
mock_rag = MagicMock()
mock_triples_client = AsyncMock()
mock_rag.triples_client = mock_triples_client
mock_triple = MagicMock()
mock_triple.s, mock_triple.p, mock_triple.o = "e1", "p1", "o1"
mock_triples_client.query_stream.return_value = [mock_triple]
query = Query(
rag=mock_rag,
collection="test_collection",
verbose=False,
triple_limit=10
)
subgraph = set()
await query.follow_edges("e1", subgraph, path_length=1)
for c in mock_triples_client.query_stream.call_args_list:
assert "workspace" not in c.kwargs
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_follow_edges_basic_functionality(self): async def test_follow_edges_basic_functionality(self):
"""Test Query.follow_edges method basic triple discovery""" """Test Query.follow_edges method basic triple discovery"""

View file

@ -30,14 +30,13 @@ class EvaluationError(Exception):
pass pass
async def evaluate(node, triples_client, workspace, collection, limit=10000): async def evaluate(node, triples_client, collection, limit=10000):
""" """
Evaluate a SPARQL algebra node. Evaluate a SPARQL algebra node.
Args: Args:
node: rdflib CompValue algebra node node: rdflib CompValue algebra node
triples_client: TriplesClient instance for triple pattern queries triples_client: TriplesClient instance for triple pattern queries
workspace: workspace/keyspace identifier
collection: collection identifier collection: collection identifier
limit: safety limit on results limit: safety limit on results
@ -55,24 +54,24 @@ async def evaluate(node, triples_client, workspace, collection, limit=10000):
logger.warning(f"Unsupported algebra node: {name}") logger.warning(f"Unsupported algebra node: {name}")
return [{}] return [{}]
return await handler(node, triples_client, workspace, collection, limit) return await handler(node, triples_client, collection, limit)
# --- Node handlers --- # --- Node handlers ---
async def _eval_select_query(node, tc, workspace, collection, limit): async def _eval_select_query(node, tc, collection, limit):
"""Evaluate a SelectQuery node.""" """Evaluate a SelectQuery node."""
return await evaluate(node.p, tc, workspace, collection, limit) return await evaluate(node.p, tc, collection, limit)
async def _eval_project(node, tc, workspace, collection, limit): async def _eval_project(node, tc, collection, limit):
"""Evaluate a Project node (SELECT variable projection).""" """Evaluate a Project node (SELECT variable projection)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit) solutions = await evaluate(node.p, tc, collection, limit)
variables = [str(v) for v in node.PV] variables = [str(v) for v in node.PV]
return project(solutions, variables) return project(solutions, variables)
async def _eval_bgp(node, tc, workspace, collection, limit): async def _eval_bgp(node, tc, collection, limit):
""" """
Evaluate a Basic Graph Pattern. Evaluate a Basic Graph Pattern.
@ -107,7 +106,7 @@ async def _eval_bgp(node, tc, workspace, collection, limit):
# Query the triples store # Query the triples store
results = await _query_pattern( results = await _query_pattern(
tc, s_val, p_val, o_val, workspace, collection, limit tc, s_val, p_val, o_val, collection, limit
) )
# Map results back to variable bindings, # Map results back to variable bindings,
@ -130,17 +129,17 @@ async def _eval_bgp(node, tc, workspace, collection, limit):
return solutions[:limit] return solutions[:limit]
async def _eval_join(node, tc, workspace, collection, limit): async def _eval_join(node, tc, collection, limit):
"""Evaluate a Join node.""" """Evaluate a Join node."""
left = await evaluate(node.p1, tc, workspace, collection, limit) left = await evaluate(node.p1, tc, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit) right = await evaluate(node.p2, tc, collection, limit)
return hash_join(left, right)[:limit] return hash_join(left, right)[:limit]
async def _eval_left_join(node, tc, workspace, collection, limit): async def _eval_left_join(node, tc, collection, limit):
"""Evaluate a LeftJoin node (OPTIONAL).""" """Evaluate a LeftJoin node (OPTIONAL)."""
left_sols = await evaluate(node.p1, tc, workspace, collection, limit) left_sols = await evaluate(node.p1, tc, collection, limit)
right_sols = await evaluate(node.p2, tc, workspace, collection, limit) right_sols = await evaluate(node.p2, tc, collection, limit)
filter_fn = None filter_fn = None
if hasattr(node, "expr") and node.expr is not None: if hasattr(node, "expr") and node.expr is not None:
@ -153,16 +152,16 @@ async def _eval_left_join(node, tc, workspace, collection, limit):
return left_join(left_sols, right_sols, filter_fn)[:limit] return left_join(left_sols, right_sols, filter_fn)[:limit]
async def _eval_union(node, tc, workspace, collection, limit): async def _eval_union(node, tc, collection, limit):
"""Evaluate a Union node.""" """Evaluate a Union node."""
left = await evaluate(node.p1, tc, workspace, collection, limit) left = await evaluate(node.p1, tc, collection, limit)
right = await evaluate(node.p2, tc, workspace, collection, limit) right = await evaluate(node.p2, tc, collection, limit)
return union(left, right)[:limit] return union(left, right)[:limit]
async def _eval_filter(node, tc, workspace, collection, limit): async def _eval_filter(node, tc, collection, limit):
"""Evaluate a Filter node.""" """Evaluate a Filter node."""
solutions = await evaluate(node.p, tc, workspace, collection, limit) solutions = await evaluate(node.p, tc, collection, limit)
expr = node.expr expr = node.expr
return [ return [
sol for sol in solutions sol for sol in solutions
@ -170,22 +169,22 @@ async def _eval_filter(node, tc, workspace, collection, limit):
] ]
async def _eval_distinct(node, tc, workspace, collection, limit): async def _eval_distinct(node, tc, collection, limit):
"""Evaluate a Distinct node.""" """Evaluate a Distinct node."""
solutions = await evaluate(node.p, tc, workspace, collection, limit) solutions = await evaluate(node.p, tc, collection, limit)
return distinct(solutions) return distinct(solutions)
async def _eval_reduced(node, tc, workspace, collection, limit): async def _eval_reduced(node, tc, collection, limit):
"""Evaluate a Reduced node (like Distinct but implementation-defined).""" """Evaluate a Reduced node (like Distinct but implementation-defined)."""
# Treat same as Distinct # Treat same as Distinct
solutions = await evaluate(node.p, tc, workspace, collection, limit) solutions = await evaluate(node.p, tc, collection, limit)
return distinct(solutions) return distinct(solutions)
async def _eval_order_by(node, tc, workspace, collection, limit): async def _eval_order_by(node, tc, collection, limit):
"""Evaluate an OrderBy node.""" """Evaluate an OrderBy node."""
solutions = await evaluate(node.p, tc, workspace, collection, limit) solutions = await evaluate(node.p, tc, collection, limit)
key_fns = [] key_fns = []
for cond in node.expr: for cond in node.expr:
@ -206,7 +205,7 @@ async def _eval_order_by(node, tc, workspace, collection, limit):
return order_by(solutions, key_fns) return order_by(solutions, key_fns)
async def _eval_slice(node, tc, workspace, collection, limit): async def _eval_slice(node, tc, collection, limit):
"""Evaluate a Slice node (LIMIT/OFFSET).""" """Evaluate a Slice node (LIMIT/OFFSET)."""
# Pass tighter limit downstream if possible # Pass tighter limit downstream if possible
inner_limit = limit inner_limit = limit
@ -214,13 +213,13 @@ async def _eval_slice(node, tc, workspace, collection, limit):
offset = node.start or 0 offset = node.start or 0
inner_limit = min(limit, offset + node.length) inner_limit = min(limit, offset + node.length)
solutions = await evaluate(node.p, tc, workspace, collection, inner_limit) solutions = await evaluate(node.p, tc, collection, inner_limit)
return slice_solutions(solutions, node.start or 0, node.length) return slice_solutions(solutions, node.start or 0, node.length)
async def _eval_extend(node, tc, workspace, collection, limit): async def _eval_extend(node, tc, collection, limit):
"""Evaluate an Extend node (BIND).""" """Evaluate an Extend node (BIND)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit) solutions = await evaluate(node.p, tc, collection, limit)
var_name = str(node.var) var_name = str(node.var)
expr = node.expr expr = node.expr
@ -246,9 +245,9 @@ async def _eval_extend(node, tc, workspace, collection, limit):
return result return result
async def _eval_group(node, tc, workspace, collection, limit): async def _eval_group(node, tc, collection, limit):
"""Evaluate a Group node (GROUP BY with aggregation).""" """Evaluate a Group node (GROUP BY with aggregation)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit) solutions = await evaluate(node.p, tc, collection, limit)
# Extract grouping expressions # Extract grouping expressions
group_exprs = [] group_exprs = []
@ -289,9 +288,9 @@ async def _eval_group(node, tc, workspace, collection, limit):
return result return result
async def _eval_aggregate_join(node, tc, workspace, collection, limit): async def _eval_aggregate_join(node, tc, collection, limit):
"""Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" """Evaluate an AggregateJoin (aggregation functions after GROUP BY)."""
solutions = await evaluate(node.p, tc, workspace, collection, limit) solutions = await evaluate(node.p, tc, collection, limit)
result = [] result = []
for sol in solutions: for sol in solutions:
@ -310,7 +309,7 @@ async def _eval_aggregate_join(node, tc, workspace, collection, limit):
return result return result
async def _eval_graph(node, tc, workspace, collection, limit): async def _eval_graph(node, tc, collection, limit):
"""Evaluate a Graph node (GRAPH clause).""" """Evaluate a Graph node (GRAPH clause)."""
term = node.term term = node.term
@ -319,16 +318,16 @@ async def _eval_graph(node, tc, workspace, collection, limit):
# We'd need to pass graph to triples queries # We'd need to pass graph to triples queries
# For now, evaluate inner pattern normally # For now, evaluate inner pattern normally
logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired")
return await evaluate(node.p, tc, workspace, collection, limit) return await evaluate(node.p, tc, collection, limit)
elif isinstance(term, Variable): elif isinstance(term, Variable):
# GRAPH ?g { ... } — variable graph # GRAPH ?g { ... } — variable graph
logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") logger.info(f"GRAPH ?{term} clause - variable graph not yet wired")
return await evaluate(node.p, tc, workspace, collection, limit) return await evaluate(node.p, tc, collection, limit)
else: else:
return await evaluate(node.p, tc, workspace, collection, limit) return await evaluate(node.p, tc, collection, limit)
async def _eval_values(node, tc, workspace, collection, limit): async def _eval_values(node, tc, collection, limit):
"""Evaluate a VALUES clause (inline data).""" """Evaluate a VALUES clause (inline data)."""
variables = [str(v) for v in node.var] variables = [str(v) for v in node.var]
solutions = [] solutions = []
@ -343,9 +342,9 @@ async def _eval_values(node, tc, workspace, collection, limit):
return solutions return solutions
async def _eval_to_multiset(node, tc, workspace, collection, limit): async def _eval_to_multiset(node, tc, collection, limit):
"""Evaluate a ToMultiSet node (subquery).""" """Evaluate a ToMultiSet node (subquery)."""
return await evaluate(node.p, tc, workspace, collection, limit) return await evaluate(node.p, tc, collection, limit)
# --- Aggregate computation --- # --- Aggregate computation ---
@ -487,7 +486,7 @@ def _resolve_term(tmpl, solution):
return rdflib_term_to_term(tmpl) return rdflib_term_to_term(tmpl)
async def _query_pattern(tc, s, p, o, workspace, collection, limit): async def _query_pattern(tc, s, p, o, collection, limit):
""" """
Issue a streaming triple pattern query via TriplesClient. Issue a streaming triple pattern query via TriplesClient.
@ -496,7 +495,6 @@ async def _query_pattern(tc, s, p, o, workspace, collection, limit):
results = await tc.query( results = await tc.query(
s=s, p=p, o=o, s=s, p=p, o=o,
limit=limit, limit=limit,
workspace=workspace,
collection=collection, collection=collection,
) )
return results return results

View file

@ -141,7 +141,6 @@ class Processor(FlowProcessor):
solutions = await evaluate( solutions = await evaluate(
parsed.algebra, parsed.algebra,
triples_client, triples_client,
workspace=flow.workspace,
collection=request.collection or "default", collection=request.collection or "default",
limit=request.limit or 10000, limit=request.limit or 10000,
) )