diff --git a/Makefile b/Makefile index 4d79f554..197a6c63 100644 --- a/Makefile +++ b/Makefile @@ -77,8 +77,8 @@ some-containers: -t ${CONTAINER_BASE}/trustgraph-base:${VERSION} . ${DOCKER} build -f containers/Containerfile.flow \ -t ${CONTAINER_BASE}/trustgraph-flow:${VERSION} . - ${DOCKER} build -f containers/Containerfile.unstructured \ - -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . +# ${DOCKER} build -f containers/Containerfile.unstructured \ +# -t ${CONTAINER_BASE}/trustgraph-unstructured:${VERSION} . # ${DOCKER} build -f containers/Containerfile.vertexai \ # -t ${CONTAINER_BASE}/trustgraph-vertexai:${VERSION} . # ${DOCKER} build -f containers/Containerfile.mcp \ diff --git a/docs/tech-specs/sparql-query.md b/docs/tech-specs/sparql-query.md new file mode 100644 index 00000000..97e7e115 --- /dev/null +++ b/docs/tech-specs/sparql-query.md @@ -0,0 +1,268 @@ +# SPARQL Query Service Technical Specification + +## Overview + +A pub/sub-hosted SPARQL query service that accepts SPARQL queries, decomposes +them into triple pattern lookups via the existing triples query pub/sub +interface, performs in-memory joins/filters/projections, and returns SPARQL +result bindings. + +This makes the triple store queryable using a standard graph query language +without coupling to any specific backend (Neo4j, Cassandra, FalkorDB, etc.). + +## Goals + +- **SPARQL 1.1 support**: SELECT, ASK, CONSTRUCT, DESCRIBE queries +- **Backend-agnostic**: query via the pub/sub triples interface, not direct + database access +- **Standard service pattern**: FlowProcessor with ConsumerSpec/ProducerSpec, + using TriplesClientSpec to call the triples query service +- **Correct SPARQL semantics**: proper BGP evaluation, joins, OPTIONAL, UNION, + FILTER, BIND, aggregation, solution modifiers (ORDER BY, LIMIT, OFFSET, + DISTINCT) + +## Background + +The triples query service provides a single-pattern lookup: given optional +(s, p, o) values, return matching triples. This is the equivalent of one +triple pattern in a SPARQL Basic Graph Pattern. + +To evaluate a full SPARQL query, we need to: +1. Parse the SPARQL string into an algebra tree +2. Walk the algebra tree, issuing triple pattern lookups for each BGP pattern +3. Join results across patterns (nested-loop or hash join) +4. Apply filters, optionals, unions, and aggregations in-memory +5. Project and return the requested variables + +rdflib (already a dependency) provides a SPARQL 1.1 parser and algebra +compiler. We use rdflib to parse queries into algebra trees, then evaluate +the algebra ourselves using the triples query client as the data source. + +## Technical Design + +### Architecture + +``` + pub/sub + [Client] ──request──> [SPARQL Query Service] ──triples-request──> [Triples Query Service] + [Client] <─response── [SPARQL Query Service] <─triples-response── [Triples Query Service] +``` + +The service is a FlowProcessor that: +- Consumes SPARQL query requests +- Uses TriplesClientSpec to issue triple pattern lookups +- Evaluates the SPARQL algebra in-memory +- Produces result responses + +### Components + +1. **SPARQL Query Service (FlowProcessor)** + - ConsumerSpec for incoming SPARQL requests + - ProducerSpec for outgoing results + - TriplesClientSpec for calling the triples query service + - Delegates parsing and evaluation to the components below + + Module: `trustgraph-flow/trustgraph/query/sparql/service.py` + +2. **SPARQL Parser (rdflib wrapper)** + - Uses `rdflib.plugins.sparql.prepareQuery` / `parseQuery` and + `rdflib.plugins.sparql.algebra.translateQuery` to produce an algebra tree + - Extracts PREFIX declarations, query type (SELECT/ASK/CONSTRUCT/DESCRIBE), + and the algebra root + + Module: `trustgraph-flow/trustgraph/query/sparql/parser.py` + +3. **Algebra Evaluator** + - Recursive evaluator over the rdflib algebra tree + - Each algebra node type maps to an evaluation function + - BGP nodes issue triple pattern queries via TriplesClient + - Join/Filter/Optional/Union etc. operate on in-memory solution sequences + + Module: `trustgraph-flow/trustgraph/query/sparql/algebra.py` + +4. **Solution Sequence** + - A solution is a dict mapping variable names to Term values + - Solution sequences are lists of solutions + - Join: hash join on shared variables + - LeftJoin (OPTIONAL): hash join preserving unmatched left rows + - Union: concatenation + - Filter: evaluate SPARQL expressions against each solution + - Projection/Distinct/Order/Slice: standard post-processing + + Module: `trustgraph-flow/trustgraph/query/sparql/solutions.py` + +### Data Models + +#### Request + +```python +@dataclass +class SparqlQueryRequest: + user: str = "" + collection: str = "" + query: str = "" # SPARQL query string + limit: int = 10000 # Safety limit on results +``` + +#### Response + +```python +@dataclass +class SparqlQueryResponse: + error: Error | None = None + query_type: str = "" # "select", "ask", "construct", "describe" + + # For SELECT queries + variables: list[str] = field(default_factory=list) + bindings: list[SparqlBinding] = field(default_factory=list) + + # For ASK queries + ask_result: bool = False + + # For CONSTRUCT/DESCRIBE queries + triples: list[Triple] = field(default_factory=list) + +@dataclass +class SparqlBinding: + values: list[Term | None] = field(default_factory=list) +``` + +### BGP Evaluation Strategy + +For each triple pattern in a BGP: +- Extract bound terms (concrete IRIs/literals) and variables +- Call `TriplesClient.query_stream(s, p, o)` with bound terms, None for + variables +- Map returned triples back to variable bindings + +For multi-pattern BGPs, join solutions incrementally: +- Order patterns by selectivity (patterns with more bound terms first) +- For each subsequent pattern, substitute bound variables from the current + solution sequence before querying +- This avoids full cross-products and reduces the number of triples queries + +### Streaming and Early Termination + +The triples query service supports streaming responses (batched delivery via +`TriplesClient.query_stream`). The SPARQL evaluator should use streaming +from the start, not as an optimisation. This is important because: + +- **Early termination**: when the SPARQL query has a LIMIT, or when only one + solution is needed (ASK queries), we can stop consuming triples as soon as + we have enough results. Without streaming, a wildcard pattern like + `?s ?p ?o` would fetch the entire graph before we could apply the limit. +- **Memory efficiency**: results are processed batch-by-batch rather than + materialising the full result set in memory before joining. + +The batch callback in `query_stream` returns a boolean to signal completion. +The evaluator should signal completion (return True) as soon as sufficient +solutions have been produced, allowing the underlying pub/sub consumer to +stop pulling batches. + +### Parallel BGP Execution (Phase 2 Optimisation) + +Within a BGP, patterns that share variables benefit from sequential +evaluation with bound-variable substitution (query results from earlier +patterns narrow later queries). However, patterns with no shared variables +are independent and could be issued concurrently via `asyncio.gather`. + +A practical approach for a future optimisation pass: +- Analyse BGP patterns and identify connected components (groups of + patterns linked by shared variables) +- Execute independent components in parallel +- Within each component, evaluate patterns sequentially with substitution + +This is not needed for correctness -- the sequential approach works for all +cases -- but could significantly reduce latency for queries with independent +pattern groups. Flagged as a phase 2 optimisation. + +### FILTER Expression Evaluation + +rdflib's algebra represents FILTER expressions as expression trees. We +evaluate these against each solution row, supporting: +- Comparison operators (=, !=, <, >, <=, >=) +- Logical operators (&&, ||, !) +- SPARQL built-in functions (isIRI, isLiteral, isBlank, str, lang, + datatype, bound, regex, etc.) +- Arithmetic operators (+, -, *, /) + +## Implementation Order + +1. **Schema and service skeleton** -- define SparqlQueryRequest/Response + dataclasses, create the FlowProcessor subclass with ConsumerSpec, + ProducerSpec, and TriplesClientSpec wired up. Verify it starts and + connects. + +2. **SPARQL parsing** -- wrap rdflib's parser to produce algebra trees from + SPARQL strings. Handle parse errors gracefully. Unit test with a range of + query shapes. + +3. **BGP evaluation** -- implement single-pattern and multi-pattern BGP + evaluation using TriplesClient. This is the core building block. Test + with simple SELECT WHERE { ?s ?p ?o } queries. + +4. **Joins and solution sequences** -- implement hash join, left join (for + OPTIONAL), and union. Test with multi-pattern queries. + +5. **FILTER evaluation** -- implement the expression evaluator for FILTER + clauses. Start with comparisons and logical operators, then add built-in + functions incrementally. + +6. **Solution modifiers** -- DISTINCT, ORDER BY, LIMIT, OFFSET, projection. + +7. **ASK / CONSTRUCT / DESCRIBE** -- extend beyond SELECT. ASK is trivial + (non-empty result = true). CONSTRUCT builds triples from a template. + DESCRIBE fetches all triples for matched resources. + +8. **Aggregation** -- GROUP BY, HAVING, COUNT, SUM, AVG, MIN, MAX, + GROUP_CONCAT, SAMPLE. + +9. **BIND, VALUES, subqueries** -- remaining SPARQL 1.1 features. + +10. **API gateway integration** -- add SparqlQueryRequestor dispatcher, + request/response translators, and API endpoint so that the SPARQL + service is accessible via the HTTP gateway. + +11. **SDK support** -- add `sparql_query()` method to FlowInstance in the + Python API SDK, following the same pattern as `triples_query()`. + +12. **CLI command** -- add a `tg-sparql-query` CLI command that takes a + SPARQL query string (or reads from a file/stdin), submits it via the + SDK, and prints results in a readable format (table for SELECT, + true/false for ASK, Turtle for CONSTRUCT/DESCRIBE). + +## Performance Considerations + +In-memory join over pub/sub round-trips will be slower than native SPARQL on +a graph database. Key mitigations: + +- **Streaming with early termination**: use `query_stream` so that + limit-bound queries don't fetch entire result sets. A `SELECT ... LIMIT 1` + against a wildcard pattern fetches one batch, not the whole graph. +- **Bound-variable substitution**: when evaluating BGP patterns sequentially, + substitute known bindings into subsequent patterns to issue narrow queries + rather than broad ones followed by in-memory filtering. +- **Parallel independent patterns** (phase 2): patterns with no shared + variables can be issued concurrently. +- **Query complexity limits**: may need a cap on the number of triple pattern + queries issued per SPARQL query to prevent runaway evaluation. + +### Named Graph Mapping + +SPARQL's `GRAPH ?g { ... }` and `GRAPH { ... }` clauses map to the +triples query service's graph filter parameter: + +- `GRAPH { ?s ?p ?o }` — pass `g=uri` to the triples query +- Patterns outside any GRAPH clause — pass `g=""` (default graph only) +- `GRAPH ?g { ?s ?p ?o }` — pass `g="*"` (all graphs), then bind `?g` from + the returned triple's graph field + +The triples query interface does not support a wildcard graph natively in +the SPARQL sense, but `g="*"` (all graphs) combined with client-side +filtering on the returned graph values achieves the same effect. + +## Open Questions + +- **SPARQL 1.2**: rdflib's parser support for 1.2 features (property paths + are already in 1.1; 1.2 adds lateral joins, ADJUST, etc.). Start with + 1.1 and extend as rdflib support matures. diff --git a/specs/api/paths/flow/sparql-query.yaml b/specs/api/paths/flow/sparql-query.yaml new file mode 100644 index 00000000..2f343488 --- /dev/null +++ b/specs/api/paths/flow/sparql-query.yaml @@ -0,0 +1,145 @@ +post: + tags: + - Flow Services + summary: SPARQL query - execute SPARQL 1.1 queries against the knowledge graph + description: | + Execute a SPARQL 1.1 query against the knowledge graph. + + ## Supported Query Types + + - **SELECT**: Returns variable bindings as a table of results + - **ASK**: Returns true/false for existence checks + - **CONSTRUCT**: Returns a set of triples built from a template + - **DESCRIBE**: Returns triples describing matched resources + + ## SPARQL Features + + Supports standard SPARQL 1.1 features including: + - Basic Graph Patterns (BGPs) with triple pattern matching + - OPTIONAL, UNION, FILTER + - BIND, VALUES + - ORDER BY, LIMIT, OFFSET, DISTINCT + - GROUP BY with aggregates (COUNT, SUM, AVG, MIN, MAX, GROUP_CONCAT) + - Built-in functions (isIRI, STR, REGEX, CONTAINS, etc.) + + ## Query Examples + + Find all entities of a type: + ```sparql + SELECT ?s ?label WHERE { + ?s . + ?s ?label . + } + LIMIT 10 + ``` + + Check if an entity exists: + ```sparql + ASK { ?p ?o } + ``` + + operationId: sparqlQueryService + security: + - bearerAuth: [] + parameters: + - name: flow + in: path + required: true + schema: + type: string + description: Flow instance ID + example: my-flow + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - query + properties: + query: + type: string + description: SPARQL 1.1 query string + user: + type: string + default: trustgraph + description: User/keyspace identifier + collection: + type: string + default: default + description: Collection identifier + limit: + type: integer + default: 10000 + description: Safety limit on number of results + examples: + selectQuery: + summary: SELECT query + value: + query: "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10" + user: trustgraph + collection: default + askQuery: + summary: ASK query + value: + query: "ASK { ?p ?o }" + responses: + '200': + description: Successful response + content: + application/json: + schema: + type: object + properties: + query-type: + type: string + enum: [select, ask, construct, describe] + variables: + type: array + items: + type: string + description: Variable names (SELECT only) + bindings: + type: array + items: + type: object + properties: + values: + type: array + items: + $ref: '../../components/schemas/common/RdfValue.yaml' + description: Result rows (SELECT only) + ask-result: + type: boolean + description: Boolean result (ASK only) + triples: + type: array + description: Result triples (CONSTRUCT/DESCRIBE only) + error: + type: object + properties: + type: + type: string + message: + type: string + examples: + selectResult: + summary: SELECT result + value: + query-type: select + variables: [s, p, o] + bindings: + - values: + - {t: i, i: "http://example.com/alice"} + - {t: i, i: "http://www.w3.org/1999/02/22-rdf-syntax-ns#type"} + - {t: i, i: "http://example.com/Person"} + askResult: + summary: ASK result + value: + query-type: ask + ask-result: true + '401': + $ref: '../../components/responses/Unauthorized.yaml' + '500': + $ref: '../../components/responses/Error.yaml' diff --git a/tests/unit/test_query/test_sparql_expressions.py b/tests/unit/test_query/test_sparql_expressions.py new file mode 100644 index 00000000..63e9188f --- /dev/null +++ b/tests/unit/test_query/test_sparql_expressions.py @@ -0,0 +1,424 @@ +""" +Tests for SPARQL FILTER expression evaluator. +""" + +import pytest +from trustgraph.schema import Term, IRI, LITERAL, BLANK +from trustgraph.query.sparql.expressions import ( + evaluate_expression, _effective_boolean, _to_string, _to_numeric, + _comparable_value, +) + + +# --- Helpers --- + +def iri(v): + return Term(type=IRI, iri=v) + +def lit(v, datatype="", language=""): + return Term(type=LITERAL, value=v, datatype=datatype, language=language) + +def blank(v): + return Term(type=BLANK, id=v) + +XSD = "http://www.w3.org/2001/XMLSchema#" + + +class TestEvaluateExpression: + """Test expression evaluation with rdflib algebra nodes.""" + + def test_variable_bound(self): + from rdflib.term import Variable + result = evaluate_expression(Variable("x"), {"x": lit("hello")}) + assert result.value == "hello" + + def test_variable_unbound(self): + from rdflib.term import Variable + result = evaluate_expression(Variable("x"), {}) + assert result is None + + def test_uriref_constant(self): + from rdflib import URIRef + result = evaluate_expression( + URIRef("http://example.com/a"), {} + ) + assert result.type == IRI + assert result.iri == "http://example.com/a" + + def test_literal_constant(self): + from rdflib import Literal + result = evaluate_expression(Literal("hello"), {}) + assert result.type == LITERAL + assert result.value == "hello" + + def test_boolean_constant(self): + assert evaluate_expression(True, {}) is True + assert evaluate_expression(False, {}) is False + + def test_numeric_constant(self): + assert evaluate_expression(42, {}) == 42 + assert evaluate_expression(3.14, {}) == 3.14 + + def test_none_returns_true(self): + assert evaluate_expression(None, {}) is True + + +class TestRelationalExpressions: + """Test comparison operators via CompValue nodes.""" + + def _make_relational(self, left, op, right): + from rdflib.plugins.sparql.parserutils import CompValue + return CompValue("RelationalExpression", + expr=left, op=op, other=right) + + def test_equal_literals(self): + from rdflib import Literal + expr = self._make_relational(Literal("a"), "=", Literal("a")) + assert evaluate_expression(expr, {}) is True + + def test_not_equal_literals(self): + from rdflib import Literal + expr = self._make_relational(Literal("a"), "!=", Literal("b")) + assert evaluate_expression(expr, {}) is True + + def test_less_than(self): + from rdflib import Literal + expr = self._make_relational(Literal("a"), "<", Literal("b")) + assert evaluate_expression(expr, {}) is True + + def test_greater_than(self): + from rdflib import Literal + expr = self._make_relational(Literal("b"), ">", Literal("a")) + assert evaluate_expression(expr, {}) is True + + def test_equal_with_variables(self): + from rdflib.term import Variable + expr = self._make_relational(Variable("x"), "=", Variable("y")) + sol = {"x": lit("same"), "y": lit("same")} + assert evaluate_expression(expr, sol) is True + + def test_unequal_with_variables(self): + from rdflib.term import Variable + expr = self._make_relational(Variable("x"), "=", Variable("y")) + sol = {"x": lit("one"), "y": lit("two")} + assert evaluate_expression(expr, sol) is False + + def test_none_operand_returns_false(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_relational(Variable("x"), "=", Literal("a")) + assert evaluate_expression(expr, {}) is False + + +class TestLogicalExpressions: + + def _make_and(self, exprs): + from rdflib.plugins.sparql.parserutils import CompValue + return CompValue("ConditionalAndExpression", + expr=exprs[0], other=exprs[1:]) + + def _make_or(self, exprs): + from rdflib.plugins.sparql.parserutils import CompValue + return CompValue("ConditionalOrExpression", + expr=exprs[0], other=exprs[1:]) + + def _make_not(self, expr): + from rdflib.plugins.sparql.parserutils import CompValue + return CompValue("UnaryNot", expr=expr) + + def test_and_true_true(self): + result = evaluate_expression(self._make_and([True, True]), {}) + assert result is True + + def test_and_true_false(self): + result = evaluate_expression(self._make_and([True, False]), {}) + assert result is False + + def test_or_false_true(self): + result = evaluate_expression(self._make_or([False, True]), {}) + assert result is True + + def test_or_false_false(self): + result = evaluate_expression(self._make_or([False, False]), {}) + assert result is False + + def test_not_true(self): + result = evaluate_expression(self._make_not(True), {}) + assert result is False + + def test_not_false(self): + result = evaluate_expression(self._make_not(False), {}) + assert result is True + + +class TestBuiltinFunctions: + + def _make_builtin(self, name, **kwargs): + from rdflib.plugins.sparql.parserutils import CompValue + return CompValue(f"Builtin_{name}", **kwargs) + + def test_bound_true(self): + from rdflib.term import Variable + expr = self._make_builtin("BOUND", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("hi")}) is True + + def test_bound_false(self): + from rdflib.term import Variable + expr = self._make_builtin("BOUND", arg=Variable("x")) + assert evaluate_expression(expr, {}) is False + + def test_isiri_true(self): + from rdflib.term import Variable + expr = self._make_builtin("isIRI", arg=Variable("x")) + assert evaluate_expression(expr, {"x": iri("http://x")}) is True + + def test_isiri_false(self): + from rdflib.term import Variable + expr = self._make_builtin("isIRI", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("hello")}) is False + + def test_isliteral_true(self): + from rdflib.term import Variable + expr = self._make_builtin("isLITERAL", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_isliteral_false(self): + from rdflib.term import Variable + expr = self._make_builtin("isLITERAL", arg=Variable("x")) + assert evaluate_expression(expr, {"x": iri("http://x")}) is False + + def test_isblank_true(self): + from rdflib.term import Variable + expr = self._make_builtin("isBLANK", arg=Variable("x")) + assert evaluate_expression(expr, {"x": blank("b1")}) is True + + def test_isblank_false(self): + from rdflib.term import Variable + expr = self._make_builtin("isBLANK", arg=Variable("x")) + assert evaluate_expression(expr, {"x": iri("http://x")}) is False + + def test_str(self): + from rdflib.term import Variable + expr = self._make_builtin("STR", arg=Variable("x")) + result = evaluate_expression(expr, {"x": iri("http://example.com/a")}) + assert result.type == LITERAL + assert result.value == "http://example.com/a" + + def test_lang(self): + from rdflib.term import Variable + expr = self._make_builtin("LANG", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("hello", language="en")} + ) + assert result.value == "en" + + def test_lang_no_tag(self): + from rdflib.term import Variable + expr = self._make_builtin("LANG", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.value == "" + + def test_datatype(self): + from rdflib.term import Variable + expr = self._make_builtin("DATATYPE", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("42", datatype=XSD + "integer")} + ) + assert result.type == IRI + assert result.iri == XSD + "integer" + + def test_strlen(self): + from rdflib.term import Variable + expr = self._make_builtin("STRLEN", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result == 5 + + def test_ucase(self): + from rdflib.term import Variable + expr = self._make_builtin("UCASE", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.value == "HELLO" + + def test_lcase(self): + from rdflib.term import Variable + expr = self._make_builtin("LCASE", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("HELLO")}) + assert result.value == "hello" + + def test_contains_true(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("CONTAINS", + arg1=Variable("x"), arg2=Literal("ell")) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_contains_false(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("CONTAINS", + arg1=Variable("x"), arg2=Literal("xyz")) + assert evaluate_expression(expr, {"x": lit("hello")}) is False + + def test_strstarts_true(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRSTARTS", + arg1=Variable("x"), arg2=Literal("hel")) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_strends_true(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRENDS", + arg1=Variable("x"), arg2=Literal("llo")) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_regex_match(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REGEX", + text=Variable("x"), + pattern=Literal("^hel"), + flags=None) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_regex_case_insensitive(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REGEX", + text=Variable("x"), + pattern=Literal("HELLO"), + flags=Literal("i")) + assert evaluate_expression(expr, {"x": lit("hello")}) is True + + def test_regex_no_match(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REGEX", + text=Variable("x"), + pattern=Literal("^world"), + flags=None) + assert evaluate_expression(expr, {"x": lit("hello")}) is False + + +class TestEffectiveBoolean: + + def test_true(self): + assert _effective_boolean(True) is True + + def test_false(self): + assert _effective_boolean(False) is False + + def test_none(self): + assert _effective_boolean(None) is False + + def test_nonzero_int(self): + assert _effective_boolean(42) is True + + def test_zero_int(self): + assert _effective_boolean(0) is False + + def test_nonempty_string(self): + assert _effective_boolean("hello") is True + + def test_empty_string(self): + assert _effective_boolean("") is False + + def test_iri_term(self): + assert _effective_boolean(iri("http://x")) is True + + def test_nonempty_literal(self): + assert _effective_boolean(lit("hello")) is True + + def test_empty_literal(self): + assert _effective_boolean(lit("")) is False + + def test_boolean_literal_true(self): + assert _effective_boolean( + lit("true", datatype=XSD + "boolean") + ) is True + + def test_boolean_literal_false(self): + assert _effective_boolean( + lit("false", datatype=XSD + "boolean") + ) is False + + def test_numeric_literal_nonzero(self): + assert _effective_boolean( + lit("42", datatype=XSD + "integer") + ) is True + + def test_numeric_literal_zero(self): + assert _effective_boolean( + lit("0", datatype=XSD + "integer") + ) is False + + +class TestToString: + + def test_none(self): + assert _to_string(None) == "" + + def test_string(self): + assert _to_string("hello") == "hello" + + def test_iri_term(self): + assert _to_string(iri("http://example.com")) == "http://example.com" + + def test_literal_term(self): + assert _to_string(lit("hello")) == "hello" + + def test_blank_term(self): + assert _to_string(blank("b1")) == "b1" + + +class TestToNumeric: + + def test_none(self): + assert _to_numeric(None) is None + + def test_int(self): + assert _to_numeric(42) == 42 + + def test_float(self): + assert _to_numeric(3.14) == 3.14 + + def test_integer_literal(self): + assert _to_numeric(lit("42")) == 42 + + def test_decimal_literal(self): + assert _to_numeric(lit("3.14")) == 3.14 + + def test_non_numeric_literal(self): + assert _to_numeric(lit("hello")) is None + + def test_numeric_string(self): + assert _to_numeric("42") == 42 + + def test_non_numeric_string(self): + assert _to_numeric("abc") is None + + +class TestComparableValue: + + def test_none(self): + assert _comparable_value(None) == (0, "") + + def test_int(self): + assert _comparable_value(42) == (2, 42) + + def test_iri(self): + assert _comparable_value(iri("http://x")) == (4, "http://x") + + def test_literal(self): + assert _comparable_value(lit("hello")) == (3, "hello") + + def test_numeric_literal(self): + assert _comparable_value(lit("42")) == (2, 42) + + def test_ordering(self): + vals = [lit("b"), lit("a"), lit("c")] + sorted_vals = sorted(vals, key=_comparable_value) + assert sorted_vals[0].value == "a" + assert sorted_vals[1].value == "b" + assert sorted_vals[2].value == "c" diff --git a/tests/unit/test_query/test_sparql_parser.py b/tests/unit/test_query/test_sparql_parser.py new file mode 100644 index 00000000..5ac9fad9 --- /dev/null +++ b/tests/unit/test_query/test_sparql_parser.py @@ -0,0 +1,205 @@ +""" +Tests for the SPARQL parser module. +""" + +import pytest +from trustgraph.query.sparql.parser import ( + parse_sparql, ParseError, rdflib_term_to_term, term_to_rdflib, +) +from trustgraph.schema import Term, IRI, LITERAL, BLANK + + +class TestParseSparql: + """Tests for parse_sparql function.""" + + def test_select_query_type(self): + parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }") + assert parsed.query_type == "select" + + def test_select_variables(self): + parsed = parse_sparql("SELECT ?s ?p ?o WHERE { ?s ?p ?o }") + assert parsed.variables == ["s", "p", "o"] + + def test_select_subset_variables(self): + parsed = parse_sparql("SELECT ?s ?o WHERE { ?s ?p ?o }") + assert parsed.variables == ["s", "o"] + + def test_ask_query_type(self): + parsed = parse_sparql( + "ASK { ?p ?o }" + ) + assert parsed.query_type == "ask" + + def test_ask_no_variables(self): + parsed = parse_sparql( + "ASK { ?p ?o }" + ) + assert parsed.variables == [] + + def test_construct_query_type(self): + parsed = parse_sparql( + "CONSTRUCT { ?s ?o } " + "WHERE { ?s ?o }" + ) + assert parsed.query_type == "construct" + + def test_describe_query_type(self): + parsed = parse_sparql( + "DESCRIBE " + ) + assert parsed.query_type == "describe" + + def test_select_with_limit(self): + parsed = parse_sparql( + "SELECT ?s WHERE { ?s ?p ?o } LIMIT 10" + ) + assert parsed.query_type == "select" + assert parsed.variables == ["s"] + + def test_select_with_distinct(self): + parsed = parse_sparql( + "SELECT DISTINCT ?s WHERE { ?s ?p ?o }" + ) + assert parsed.query_type == "select" + assert parsed.variables == ["s"] + + def test_select_with_filter(self): + parsed = parse_sparql( + 'SELECT ?s ?label WHERE { ' + ' ?s ?label . ' + ' FILTER(CONTAINS(STR(?label), "test")) ' + '}' + ) + assert parsed.query_type == "select" + assert parsed.variables == ["s", "label"] + + def test_select_with_optional(self): + parsed = parse_sparql( + "SELECT ?s ?p ?o ?label WHERE { " + " ?s ?p ?o . " + " OPTIONAL { ?s ?label } " + "}" + ) + assert parsed.query_type == "select" + assert set(parsed.variables) == {"s", "p", "o", "label"} + + def test_select_with_union(self): + parsed = parse_sparql( + "SELECT ?s ?label WHERE { " + " { ?s ?label } " + " UNION " + " { ?s ?label } " + "}" + ) + assert parsed.query_type == "select" + + def test_select_with_order_by(self): + parsed = parse_sparql( + "SELECT ?s ?label WHERE { ?s ?label } " + "ORDER BY ?label" + ) + assert parsed.query_type == "select" + + def test_select_with_group_by(self): + parsed = parse_sparql( + "SELECT ?p (COUNT(?o) AS ?count) WHERE { ?s ?p ?o } " + "GROUP BY ?p ORDER BY DESC(?count)" + ) + assert parsed.query_type == "select" + + def test_select_with_prefixes(self): + parsed = parse_sparql( + "PREFIX rdfs: " + "SELECT ?s ?label WHERE { ?s rdfs:label ?label }" + ) + assert parsed.query_type == "select" + assert parsed.variables == ["s", "label"] + + def test_algebra_not_none(self): + parsed = parse_sparql("SELECT ?s WHERE { ?s ?p ?o }") + assert parsed.algebra is not None + + def test_parse_error_invalid_sparql(self): + with pytest.raises(ParseError): + parse_sparql("NOT VALID SPARQL AT ALL") + + def test_parse_error_incomplete_query(self): + with pytest.raises(ParseError): + parse_sparql("SELECT ?s WHERE {") + + def test_parse_error_message(self): + with pytest.raises(ParseError, match="SPARQL parse error"): + parse_sparql("GIBBERISH") + + +class TestRdflibTermToTerm: + """Tests for rdflib-to-Term conversion.""" + + def test_uriref_to_term(self): + from rdflib import URIRef + term = rdflib_term_to_term(URIRef("http://example.com/alice")) + assert term.type == IRI + assert term.iri == "http://example.com/alice" + + def test_literal_to_term(self): + from rdflib import Literal + term = rdflib_term_to_term(Literal("hello")) + assert term.type == LITERAL + assert term.value == "hello" + + def test_typed_literal_to_term(self): + from rdflib import Literal, URIRef + term = rdflib_term_to_term( + Literal("42", datatype=URIRef("http://www.w3.org/2001/XMLSchema#integer")) + ) + assert term.type == LITERAL + assert term.value == "42" + assert term.datatype == "http://www.w3.org/2001/XMLSchema#integer" + + def test_lang_literal_to_term(self): + from rdflib import Literal + term = rdflib_term_to_term(Literal("hello", lang="en")) + assert term.type == LITERAL + assert term.value == "hello" + assert term.language == "en" + + def test_bnode_to_term(self): + from rdflib import BNode + term = rdflib_term_to_term(BNode("b1")) + assert term.type == BLANK + assert term.id == "b1" + + +class TestTermToRdflib: + """Tests for Term-to-rdflib conversion.""" + + def test_iri_term_to_uriref(self): + from rdflib import URIRef + result = term_to_rdflib(Term(type=IRI, iri="http://example.com/x")) + assert isinstance(result, URIRef) + assert str(result) == "http://example.com/x" + + def test_literal_term_to_literal(self): + from rdflib import Literal + result = term_to_rdflib(Term(type=LITERAL, value="hello")) + assert isinstance(result, Literal) + assert str(result) == "hello" + + def test_typed_literal_roundtrip(self): + from rdflib import URIRef + original = Term( + type=LITERAL, value="42", + datatype="http://www.w3.org/2001/XMLSchema#integer" + ) + rdflib_term = term_to_rdflib(original) + assert rdflib_term.datatype == URIRef("http://www.w3.org/2001/XMLSchema#integer") + + def test_lang_literal_roundtrip(self): + original = Term(type=LITERAL, value="bonjour", language="fr") + rdflib_term = term_to_rdflib(original) + assert rdflib_term.language == "fr" + + def test_blank_term_to_bnode(self): + from rdflib import BNode + result = term_to_rdflib(Term(type=BLANK, id="b1")) + assert isinstance(result, BNode) diff --git a/tests/unit/test_query/test_sparql_solutions.py b/tests/unit/test_query/test_sparql_solutions.py new file mode 100644 index 00000000..5805ca84 --- /dev/null +++ b/tests/unit/test_query/test_sparql_solutions.py @@ -0,0 +1,345 @@ +""" +Tests for SPARQL solution sequence operations. +""" + +import pytest +from trustgraph.schema import Term, IRI, LITERAL +from trustgraph.query.sparql.solutions import ( + hash_join, left_join, union, project, distinct, + order_by, slice_solutions, _terms_equal, _compatible, +) + + +# --- Test helpers --- + +def iri(v): + return Term(type=IRI, iri=v) + +def lit(v): + return Term(type=LITERAL, value=v) + + +# --- Fixtures --- + +@pytest.fixture +def alice(): + return iri("http://example.com/alice") + +@pytest.fixture +def bob(): + return iri("http://example.com/bob") + +@pytest.fixture +def carol(): + return iri("http://example.com/carol") + +@pytest.fixture +def knows(): + return iri("http://example.com/knows") + +@pytest.fixture +def name_alice(): + return lit("Alice") + +@pytest.fixture +def name_bob(): + return lit("Bob") + + +class TestTermsEqual: + + def test_equal_iris(self): + assert _terms_equal(iri("http://x.com/a"), iri("http://x.com/a")) + + def test_unequal_iris(self): + assert not _terms_equal(iri("http://x.com/a"), iri("http://x.com/b")) + + def test_equal_literals(self): + assert _terms_equal(lit("hello"), lit("hello")) + + def test_unequal_literals(self): + assert not _terms_equal(lit("hello"), lit("world")) + + def test_iri_vs_literal(self): + assert not _terms_equal(iri("hello"), lit("hello")) + + def test_none_none(self): + assert _terms_equal(None, None) + + def test_none_vs_term(self): + assert not _terms_equal(None, iri("http://x.com/a")) + + +class TestCompatible: + + def test_no_shared_variables(self): + assert _compatible({"a": iri("http://x")}, {"b": iri("http://y")}) + + def test_shared_variable_same_value(self, alice): + assert _compatible({"s": alice, "x": lit("1")}, {"s": alice, "y": lit("2")}) + + def test_shared_variable_different_value(self, alice, bob): + assert not _compatible({"s": alice}, {"s": bob}) + + def test_empty_solutions(self): + assert _compatible({}, {}) + + def test_empty_vs_nonempty(self, alice): + assert _compatible({}, {"s": alice}) + + +class TestHashJoin: + + def test_join_on_shared_variable(self, alice, bob, name_alice, name_bob): + left = [ + {"s": alice, "p": iri("http://example.com/knows"), "o": bob}, + {"s": bob, "p": iri("http://example.com/knows"), "o": alice}, + ] + right = [ + {"s": alice, "label": name_alice}, + {"s": bob, "label": name_bob}, + ] + result = hash_join(left, right) + assert len(result) == 2 + # Check that joined solutions have all variables + for sol in result: + assert "s" in sol + assert "p" in sol + assert "o" in sol + assert "label" in sol + + def test_join_no_shared_variables_cross_product(self, alice, bob): + left = [{"a": alice}] + right = [{"b": bob}, {"b": alice}] + result = hash_join(left, right) + assert len(result) == 2 + + def test_join_no_matches(self, alice, bob): + left = [{"s": alice}] + right = [{"s": bob}] + result = hash_join(left, right) + assert len(result) == 0 + + def test_join_empty_left(self, alice): + result = hash_join([], [{"s": alice}]) + assert len(result) == 0 + + def test_join_empty_right(self, alice): + result = hash_join([{"s": alice}], []) + assert len(result) == 0 + + def test_join_multiple_matches(self, alice, name_alice): + left = [ + {"s": alice, "p": iri("http://e.com/a")}, + {"s": alice, "p": iri("http://e.com/b")}, + ] + right = [{"s": alice, "label": name_alice}] + result = hash_join(left, right) + assert len(result) == 2 + + def test_join_preserves_values(self, alice, name_alice): + left = [{"s": alice, "x": lit("1")}] + right = [{"s": alice, "y": lit("2")}] + result = hash_join(left, right) + assert len(result) == 1 + assert result[0]["x"].value == "1" + assert result[0]["y"].value == "2" + + +class TestLeftJoin: + + def test_left_join_with_matches(self, alice, bob, name_alice): + left = [{"s": alice}, {"s": bob}] + right = [{"s": alice, "label": name_alice}] + result = left_join(left, right) + assert len(result) == 2 + # Alice has label + alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"] + assert len(alice_sols) == 1 + assert "label" in alice_sols[0] + # Bob preserved without label + bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"] + assert len(bob_sols) == 1 + assert "label" not in bob_sols[0] + + def test_left_join_no_matches(self, alice, bob): + left = [{"s": alice}] + right = [{"s": bob, "label": lit("Bob")}] + result = left_join(left, right) + assert len(result) == 1 + assert result[0]["s"].iri == "http://example.com/alice" + assert "label" not in result[0] + + def test_left_join_empty_right(self, alice): + left = [{"s": alice}] + result = left_join(left, []) + assert len(result) == 1 + + def test_left_join_empty_left(self): + result = left_join([], [{"s": iri("http://x")}]) + assert len(result) == 0 + + def test_left_join_with_filter(self, alice, bob): + left = [{"s": alice}, {"s": bob}] + right = [ + {"s": alice, "val": lit("yes")}, + {"s": bob, "val": lit("no")}, + ] + # Filter: only keep joins where val == "yes" + result = left_join( + left, right, + filter_fn=lambda sol: sol.get("val") and sol["val"].value == "yes" + ) + assert len(result) == 2 + # Alice matches filter + alice_sols = [s for s in result if s["s"].iri == "http://example.com/alice"] + assert "val" in alice_sols[0] + assert alice_sols[0]["val"].value == "yes" + # Bob doesn't match filter, preserved without val + bob_sols = [s for s in result if s["s"].iri == "http://example.com/bob"] + assert "val" not in bob_sols[0] + + +class TestUnion: + + def test_union_concatenates(self, alice, bob): + left = [{"s": alice}] + right = [{"s": bob}] + result = union(left, right) + assert len(result) == 2 + + def test_union_preserves_order(self, alice, bob): + left = [{"s": alice}] + right = [{"s": bob}] + result = union(left, right) + assert result[0]["s"].iri == "http://example.com/alice" + assert result[1]["s"].iri == "http://example.com/bob" + + def test_union_empty_left(self, alice): + result = union([], [{"s": alice}]) + assert len(result) == 1 + + def test_union_both_empty(self): + result = union([], []) + assert len(result) == 0 + + def test_union_allows_duplicates(self, alice): + result = union([{"s": alice}], [{"s": alice}]) + assert len(result) == 2 + + +class TestProject: + + def test_project_keeps_selected(self, alice, name_alice): + solutions = [{"s": alice, "label": name_alice, "extra": lit("x")}] + result = project(solutions, ["s", "label"]) + assert len(result) == 1 + assert "s" in result[0] + assert "label" in result[0] + assert "extra" not in result[0] + + def test_project_missing_variable(self, alice): + solutions = [{"s": alice}] + result = project(solutions, ["s", "missing"]) + assert len(result) == 1 + assert "s" in result[0] + assert "missing" not in result[0] + + def test_project_empty(self): + result = project([], ["s"]) + assert len(result) == 0 + + +class TestDistinct: + + def test_removes_duplicates(self, alice): + solutions = [{"s": alice}, {"s": alice}, {"s": alice}] + result = distinct(solutions) + assert len(result) == 1 + + def test_keeps_different(self, alice, bob): + solutions = [{"s": alice}, {"s": bob}] + result = distinct(solutions) + assert len(result) == 2 + + def test_empty(self): + result = distinct([]) + assert len(result) == 0 + + def test_multi_variable_distinct(self, alice, bob): + solutions = [ + {"s": alice, "o": bob}, + {"s": alice, "o": bob}, + {"s": alice, "o": alice}, + ] + result = distinct(solutions) + assert len(result) == 2 + + +class TestOrderBy: + + def test_order_by_ascending(self): + solutions = [ + {"label": lit("Charlie")}, + {"label": lit("Alice")}, + {"label": lit("Bob")}, + ] + key_fns = [(lambda sol: sol.get("label"), True)] + result = order_by(solutions, key_fns) + assert result[0]["label"].value == "Alice" + assert result[1]["label"].value == "Bob" + assert result[2]["label"].value == "Charlie" + + def test_order_by_descending(self): + solutions = [ + {"label": lit("Alice")}, + {"label": lit("Charlie")}, + {"label": lit("Bob")}, + ] + key_fns = [(lambda sol: sol.get("label"), False)] + result = order_by(solutions, key_fns) + assert result[0]["label"].value == "Charlie" + assert result[1]["label"].value == "Bob" + assert result[2]["label"].value == "Alice" + + def test_order_by_empty(self): + result = order_by([], [(lambda sol: sol.get("x"), True)]) + assert len(result) == 0 + + def test_order_by_no_keys(self, alice): + solutions = [{"s": alice}] + result = order_by(solutions, []) + assert len(result) == 1 + + +class TestSlice: + + def test_limit(self, alice, bob, carol): + solutions = [{"s": alice}, {"s": bob}, {"s": carol}] + result = slice_solutions(solutions, limit=2) + assert len(result) == 2 + + def test_offset(self, alice, bob, carol): + solutions = [{"s": alice}, {"s": bob}, {"s": carol}] + result = slice_solutions(solutions, offset=1) + assert len(result) == 2 + assert result[0]["s"].iri == "http://example.com/bob" + + def test_offset_and_limit(self, alice, bob, carol): + solutions = [{"s": alice}, {"s": bob}, {"s": carol}] + result = slice_solutions(solutions, offset=1, limit=1) + assert len(result) == 1 + assert result[0]["s"].iri == "http://example.com/bob" + + def test_limit_zero(self, alice): + result = slice_solutions([{"s": alice}], limit=0) + assert len(result) == 0 + + def test_offset_beyond_length(self, alice): + result = slice_solutions([{"s": alice}], offset=10) + assert len(result) == 0 + + def test_no_slice(self, alice, bob): + solutions = [{"s": alice}, {"s": bob}] + result = slice_solutions(solutions) + assert len(result) == 2 diff --git a/trustgraph-base/trustgraph/api/flow.py b/trustgraph-base/trustgraph/api/flow.py index d89e16f6..0aa55347 100644 --- a/trustgraph-base/trustgraph/api/flow.py +++ b/trustgraph-base/trustgraph/api/flow.py @@ -1122,6 +1122,45 @@ class FlowInstance: return result + def sparql_query( + self, query, user="trustgraph", collection="default", + limit=10000 + ): + """ + Execute a SPARQL query against the knowledge graph. + + Args: + query: SPARQL 1.1 query string + user: User/keyspace identifier (default: "trustgraph") + collection: Collection identifier (default: "default") + limit: Safety limit on results (default: 10000) + + Returns: + dict with query results. Structure depends on query type: + - SELECT: {"query-type": "select", "variables": [...], "bindings": [...]} + - ASK: {"query-type": "ask", "ask-result": bool} + - CONSTRUCT/DESCRIBE: {"query-type": "construct", "triples": [...]} + + Raises: + ProtocolException: If an error occurs + """ + + input = { + "query": query, + "user": user, + "collection": collection, + "limit": limit, + } + + response = self.request("service/sparql", input) + + if "error" in response and response["error"]: + error_type = response["error"].get("type", "unknown") + error_message = response["error"].get("message", "Unknown error") + raise ProtocolException(f"{error_type}: {error_message}") + + return response + def nlp_query(self, question, max_results=100): """ Convert a natural language question to a GraphQL query. diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py index 9fbcbf16..30f5061c 100644 --- a/trustgraph-base/trustgraph/messaging/__init__.py +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -27,6 +27,7 @@ from .translators.nlp_query import QuestionToStructuredQueryRequestTranslator, Q from .translators.structured_query import StructuredQueryRequestTranslator, StructuredQueryResponseTranslator from .translators.diagnosis import StructuredDataDiagnosisRequestTranslator, StructuredDataDiagnosisResponseTranslator from .translators.collection import CollectionManagementRequestTranslator, CollectionManagementResponseTranslator +from .translators.sparql_query import SparqlQueryRequestTranslator, SparqlQueryResponseTranslator # Register all service translators TranslatorRegistry.register_service( @@ -149,6 +150,12 @@ TranslatorRegistry.register_service( CollectionManagementResponseTranslator() ) +TranslatorRegistry.register_service( + "sparql-query", + SparqlQueryRequestTranslator(), + SparqlQueryResponseTranslator() +) + # Register single-direction translators for document loading TranslatorRegistry.register_request("document", DocumentTranslator()) TranslatorRegistry.register_request("text-document", TextDocumentTranslator()) diff --git a/trustgraph-base/trustgraph/messaging/translators/sparql_query.py b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py new file mode 100644 index 00000000..d1912429 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/sparql_query.py @@ -0,0 +1,111 @@ +from typing import Dict, Any, Tuple +from ...schema import ( + SparqlQueryRequest, SparqlQueryResponse, SparqlBinding, + Error, Term, Triple, IRI, LITERAL, BLANK, +) +from .base import MessageTranslator +from .primitives import TermTranslator, TripleTranslator + + +class SparqlQueryRequestTranslator(MessageTranslator): + """Translator for SparqlQueryRequest schema objects.""" + + def decode(self, data: Dict[str, Any]) -> SparqlQueryRequest: + return SparqlQueryRequest( + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default"), + query=data.get("query", ""), + limit=int(data.get("limit", 10000)), + ) + + def encode(self, obj: SparqlQueryRequest) -> Dict[str, Any]: + return { + "user": obj.user, + "collection": obj.collection, + "query": obj.query, + "limit": obj.limit, + } + + +class SparqlQueryResponseTranslator(MessageTranslator): + """Translator for SparqlQueryResponse schema objects.""" + + def __init__(self): + self.term_translator = TermTranslator() + self.triple_translator = TripleTranslator() + + def decode(self, data: Dict[str, Any]) -> SparqlQueryResponse: + raise NotImplementedError( + "Response translation to schema not typically needed" + ) + + def _encode_term(self, v): + """Encode a Term, handling both Term objects and dicts from + pub/sub deserialization.""" + if v is None: + return None + if isinstance(v, dict): + # Reconstruct Term from dict (pub/sub deserializes nested + # dataclasses as dicts) + term = Term( + type=v.get("type", ""), + iri=v.get("iri", ""), + id=v.get("id", ""), + value=v.get("value", ""), + datatype=v.get("datatype", ""), + language=v.get("language", ""), + ) + return self.term_translator.encode(term) + return self.term_translator.encode(v) + + def _encode_error(self, error): + """Encode an Error, handling both Error objects and dicts.""" + if isinstance(error, dict): + return { + "type": error.get("type", ""), + "message": error.get("message", ""), + } + return { + "type": error.type, + "message": error.message, + } + + def encode(self, obj: SparqlQueryResponse) -> Dict[str, Any]: + result = { + "query-type": obj.query_type, + } + + if obj.error: + result["error"] = self._encode_error(obj.error) + + if obj.query_type == "select": + result["variables"] = obj.variables + bindings = [] + for binding in obj.bindings: + # binding may be a SparqlBinding or a dict + if isinstance(binding, dict): + values = binding.get("values", []) + else: + values = binding.values + bindings.append({ + "values": [ + self._encode_term(v) for v in values + ] + }) + result["bindings"] = bindings + + elif obj.query_type == "ask": + result["ask-result"] = obj.ask_result + + elif obj.query_type in ("construct", "describe"): + result["triples"] = [ + self.triple_translator.encode(t) + for t in obj.triples + ] + + return result + + def encode_with_completion( + self, obj: SparqlQueryResponse + ) -> Tuple[Dict[str, Any], bool]: + return self.encode(obj), True diff --git a/trustgraph-base/trustgraph/schema/services/__init__.py b/trustgraph-base/trustgraph/schema/services/__init__.py index f246bc31..550b7d12 100644 --- a/trustgraph-base/trustgraph/schema/services/__init__.py +++ b/trustgraph-base/trustgraph/schema/services/__init__.py @@ -13,4 +13,5 @@ from .rows_query import * from .diagnosis import * from .collection import * from .storage import * -from .tool_service import * \ No newline at end of file +from .tool_service import * +from .sparql_query import * \ No newline at end of file diff --git a/trustgraph-base/trustgraph/schema/services/sparql_query.py b/trustgraph-base/trustgraph/schema/services/sparql_query.py new file mode 100644 index 00000000..105cc753 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/services/sparql_query.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass, field + +from ..core.primitives import Error, Term, Triple +from ..core.topic import queue + +############################################################################ + +# SPARQL query + +@dataclass +class SparqlBinding: + """A single row of SPARQL SELECT results. + Values are ordered to match the variables list in SparqlQueryResponse. + """ + values: list[Term | None] = field(default_factory=list) + +@dataclass +class SparqlQueryRequest: + user: str = "" + collection: str = "" + query: str = "" # SPARQL query string + limit: int = 10000 # Safety limit on results + +@dataclass +class SparqlQueryResponse: + error: Error | None = None + query_type: str = "" # "select", "ask", "construct", "describe" + + # For SELECT queries + variables: list[str] = field(default_factory=list) + bindings: list[SparqlBinding] = field(default_factory=list) + + # For ASK queries + ask_result: bool = False + + # For CONSTRUCT/DESCRIBE queries + triples: list[Triple] = field(default_factory=list) + +sparql_query_request_queue = queue('sparql-query', cls='request') +sparql_query_response_queue = queue('sparql-query', cls='response') diff --git a/trustgraph-cli/pyproject.toml b/trustgraph-cli/pyproject.toml index 9fd6bed7..2b111cae 100644 --- a/trustgraph-cli/pyproject.toml +++ b/trustgraph-cli/pyproject.toml @@ -51,6 +51,7 @@ tg-invoke-document-embeddings = "trustgraph.cli.invoke_document_embeddings:main" tg-invoke-mcp-tool = "trustgraph.cli.invoke_mcp_tool:main" tg-invoke-nlp-query = "trustgraph.cli.invoke_nlp_query:main" tg-invoke-rows-query = "trustgraph.cli.invoke_rows_query:main" +tg-invoke-sparql-query = "trustgraph.cli.invoke_sparql_query:main" tg-invoke-row-embeddings = "trustgraph.cli.invoke_row_embeddings:main" tg-invoke-prompt = "trustgraph.cli.invoke_prompt:main" tg-invoke-structured-query = "trustgraph.cli.invoke_structured_query:main" diff --git a/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py new file mode 100644 index 00000000..9547193d --- /dev/null +++ b/trustgraph-cli/trustgraph/cli/invoke_sparql_query.py @@ -0,0 +1,230 @@ +""" +Execute a SPARQL query against the TrustGraph knowledge graph. +""" + +import argparse +import os +import json +import sys +from trustgraph.api import Api + +default_url = os.getenv("TRUSTGRAPH_URL", 'http://localhost:8088/') +default_user = 'trustgraph' +default_collection = 'default' + + +def format_select(response, output_format): + """Format SELECT query results.""" + variables = response.get("variables", []) + bindings = response.get("bindings", []) + + if not bindings: + return "No results." + + if output_format == "json": + rows = [] + for binding in bindings: + row = {} + for var, val in zip(variables, binding.get("values", [])): + if val is None: + row[var] = None + elif val.get("t") == "i": + row[var] = val.get("i", "") + elif val.get("t") == "l": + row[var] = val.get("v", "") + else: + row[var] = val.get("v", val.get("i", "")) + rows.append(row) + return json.dumps(rows, indent=2) + + # Table format + col_widths = [len(v) for v in variables] + rows = [] + for binding in bindings: + row = [] + for i, val in enumerate(binding.get("values", [])): + if val is None: + cell = "" + elif val.get("t") == "i": + cell = val.get("i", "") + elif val.get("t") == "l": + cell = val.get("v", "") + else: + cell = val.get("v", val.get("i", "")) + row.append(cell) + if i < len(col_widths): + col_widths[i] = max(col_widths[i], len(cell)) + rows.append(row) + + # Build table + header = " | ".join( + v.ljust(col_widths[i]) for i, v in enumerate(variables) + ) + separator = "-+-".join("-" * w for w in col_widths) + lines = [header, separator] + for row in rows: + line = " | ".join( + cell.ljust(col_widths[i]) if i < len(col_widths) else cell + for i, cell in enumerate(row) + ) + lines.append(line) + return "\n".join(lines) + + +def format_triples(response, output_format): + """Format CONSTRUCT/DESCRIBE results.""" + triples = response.get("triples", []) + + if not triples: + return "No triples." + + if output_format == "json": + return json.dumps(triples, indent=2) + + lines = [] + for t in triples: + s = _term_str(t.get("s")) + p = _term_str(t.get("p")) + o = _term_str(t.get("o")) + lines.append(f"{s} {p} {o} .") + return "\n".join(lines) + + +def _term_str(val): + """Convert a wire-format term to a display string.""" + if val is None: + return "?" + t = val.get("t", "") + if t == "i": + return f"<{val.get('i', '')}>" + elif t == "l": + v = val.get("v", "") + dt = val.get("d", "") + lang = val.get("l", "") + if lang: + return f'"{v}"@{lang}' + elif dt: + return f'"{v}"^^<{dt}>' + return f'"{v}"' + return str(val) + + +def sparql_query(url, token, flow_id, query, user, collection, limit, + output_format): + + api = Api(url=url, token=token).flow().id(flow_id) + + resp = api.sparql_query( + query=query, + user=user, + collection=collection, + limit=limit, + ) + + query_type = resp.get("query-type", "select") + + if query_type == "select": + print(format_select(resp, output_format)) + elif query_type == "ask": + print("true" if resp.get("ask-result") else "false") + elif query_type in ("construct", "describe"): + print(format_triples(resp, output_format)) + else: + print(json.dumps(resp, indent=2)) + + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-invoke-sparql-query', + description=__doc__, + ) + + parser.add_argument( + '-u', '--url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-t', '--token', + default=os.getenv("TRUSTGRAPH_TOKEN"), + help='API bearer token (default: TRUSTGRAPH_TOKEN env var)', + ) + + parser.add_argument( + '-f', '--flow-id', + default="default", + help='Flow ID (default: default)', + ) + + parser.add_argument( + '-q', '--query', + help='SPARQL query string', + ) + + parser.add_argument( + '-i', '--input', + help='Read SPARQL query from file (use - for stdin)', + ) + + parser.add_argument( + '-U', '--user', + default=default_user, + help=f'User ID (default: {default_user})', + ) + + parser.add_argument( + '-C', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})', + ) + + parser.add_argument( + '-l', '--limit', + type=int, + default=10000, + help='Result limit (default: 10000)', + ) + + parser.add_argument( + '--format', + choices=['table', 'json'], + default='table', + help='Output format (default: table)', + ) + + args = parser.parse_args() + + # Get query from argument or file + query = args.query + if not query and args.input: + if args.input == '-': + query = sys.stdin.read() + else: + with open(args.input) as f: + query = f.read() + + if not query: + parser.error("Either -q/--query or -i/--input is required") + + try: + + sparql_query( + url=args.url, + token=args.token, + flow_id=args.flow_id, + query=query, + user=args.user, + collection=args.collection, + limit=args.limit, + output_format=args.format, + ) + + except Exception as e: + print(f"Exception: {e}", flush=True, file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/trustgraph-flow/pyproject.toml b/trustgraph-flow/pyproject.toml index 66363305..b2df4a4c 100644 --- a/trustgraph-flow/pyproject.toml +++ b/trustgraph-flow/pyproject.toml @@ -101,6 +101,7 @@ pdf-ocr-mistral = "trustgraph.decoding.mistral_ocr:run" prompt-template = "trustgraph.prompt.template:run" rev-gateway = "trustgraph.rev_gateway:run" run-processing = "trustgraph.processing:run" +sparql-query = "trustgraph.query.sparql:run" structured-query = "trustgraph.retrieval.structured_query:run" structured-diag = "trustgraph.retrieval.structured_diag:run" text-completion-azure = "trustgraph.model.text_completion.azure:run" diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index d068ecef..a4bf8de9 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -22,6 +22,7 @@ from . document_rag import DocumentRagRequestor from . triples_query import TriplesQueryRequestor from . rows_query import RowsQueryRequestor from . nlp_query import NLPQueryRequestor +from . sparql_query import SparqlQueryRequestor from . structured_query import StructuredQueryRequestor from . structured_diag import StructuredDiagRequestor from . embeddings import EmbeddingsRequestor @@ -65,6 +66,7 @@ request_response_dispatchers = { "structured-query": StructuredQueryRequestor, "structured-diag": StructuredDiagRequestor, "row-embeddings": RowEmbeddingsQueryRequestor, + "sparql": SparqlQueryRequestor, } global_dispatchers = { diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/sparql_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/sparql_query.py new file mode 100644 index 00000000..f81b9df6 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/sparql_query.py @@ -0,0 +1,30 @@ +from ... schema import SparqlQueryRequest, SparqlQueryResponse +from ... messaging import TranslatorRegistry + +from . requestor import ServiceRequestor + +class SparqlQueryRequestor(ServiceRequestor): + def __init__( + self, backend, request_queue, response_queue, timeout, + consumer, subscriber, + ): + + super(SparqlQueryRequestor, self).__init__( + backend=backend, + request_queue=request_queue, + response_queue=response_queue, + request_schema=SparqlQueryRequest, + response_schema=SparqlQueryResponse, + subscription = subscriber, + consumer_name = consumer, + timeout=timeout, + ) + + self.request_translator = TranslatorRegistry.get_request_translator("sparql-query") + self.response_translator = TranslatorRegistry.get_response_translator("sparql-query") + + def to_request(self, body): + return self.request_translator.decode(body) + + def from_response(self, message): + return self.response_translator.encode_with_completion(message) diff --git a/trustgraph-flow/trustgraph/query/sparql/__init__.py b/trustgraph-flow/trustgraph/query/sparql/__init__.py new file mode 100644 index 00000000..98f4d9da --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/__init__.py @@ -0,0 +1 @@ +from . service import * diff --git a/trustgraph-flow/trustgraph/query/sparql/__main__.py b/trustgraph-flow/trustgraph/query/sparql/__main__.py new file mode 100644 index 00000000..da5a9021 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/__main__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from . service import run + +if __name__ == '__main__': + run() diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py new file mode 100644 index 00000000..eda83efb --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -0,0 +1,541 @@ +""" +SPARQL algebra evaluator. + +Recursively evaluates an rdflib SPARQL algebra tree by issuing triple +pattern queries via TriplesClient (streaming) and performing in-memory +joins, filters, and projections. +""" + +import logging +from collections import defaultdict + +from rdflib.term import Variable, URIRef, Literal, BNode +from rdflib.plugins.sparql.parserutils import CompValue + +from ... schema import Term, Triple, IRI, LITERAL, BLANK +from ... knowledge import Uri +from ... knowledge import Literal as KgLiteral +from . parser import rdflib_term_to_term +from . solutions import ( + hash_join, left_join, union, project, distinct, + order_by, slice_solutions, _term_key, +) +from . expressions import evaluate_expression, _effective_boolean + +logger = logging.getLogger(__name__) + + +class EvaluationError(Exception): + """Raised when SPARQL evaluation fails.""" + pass + + +async def evaluate(node, triples_client, user, collection, limit=10000): + """ + Evaluate a SPARQL algebra node. + + Args: + node: rdflib CompValue algebra node + triples_client: TriplesClient instance for triple pattern queries + user: user/keyspace identifier + collection: collection identifier + limit: safety limit on results + + Returns: + list of solutions (dicts mapping variable names to Term values) + """ + if not isinstance(node, CompValue): + logger.warning(f"Expected CompValue, got {type(node)}: {node}") + return [{}] + + name = node.name + handler = _HANDLERS.get(name) + + if handler is None: + logger.warning(f"Unsupported algebra node: {name}") + return [{}] + + return await handler(node, triples_client, user, collection, limit) + + +# --- Node handlers --- + +async def _eval_select_query(node, tc, user, collection, limit): + """Evaluate a SelectQuery node.""" + return await evaluate(node.p, tc, user, collection, limit) + + +async def _eval_project(node, tc, user, collection, limit): + """Evaluate a Project node (SELECT variable projection).""" + solutions = await evaluate(node.p, tc, user, collection, limit) + variables = [str(v) for v in node.PV] + return project(solutions, variables) + + +async def _eval_bgp(node, tc, user, collection, limit): + """ + Evaluate a Basic Graph Pattern. + + Issues streaming triple pattern queries and joins results. Patterns + are ordered by selectivity (more bound terms first) and evaluated + sequentially with bound-variable substitution. + """ + triples = node.triples + if not triples: + return [{}] + + # Sort patterns by selectivity: more bound terms = more selective + def selectivity(pattern): + return sum(1 for t in pattern if not isinstance(t, Variable)) + + sorted_patterns = sorted( + enumerate(triples), key=lambda x: -selectivity(x[1]) + ) + + solutions = [{}] + + for _, pattern in sorted_patterns: + s_tmpl, p_tmpl, o_tmpl = pattern + + new_solutions = [] + + for sol in solutions: + # Substitute known bindings into the pattern + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + # Query the triples store + results = await _query_pattern( + tc, s_val, p_val, o_val, user, collection, limit + ) + + # Map results back to variable bindings, + # converting Uri/Literal to Term objects + for triple in results: + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) + + solutions = new_solutions + + if not solutions: + break + + return solutions[:limit] + + +async def _eval_join(node, tc, user, collection, limit): + """Evaluate a Join node.""" + left = await evaluate(node.p1, tc, user, collection, limit) + right = await evaluate(node.p2, tc, user, collection, limit) + return hash_join(left, right)[:limit] + + +async def _eval_left_join(node, tc, user, collection, limit): + """Evaluate a LeftJoin node (OPTIONAL).""" + left_sols = await evaluate(node.p1, tc, user, collection, limit) + right_sols = await evaluate(node.p2, tc, user, collection, limit) + + filter_fn = None + if hasattr(node, "expr") and node.expr is not None: + expr = node.expr + if not (isinstance(expr, CompValue) and expr.name == "TrueFilter"): + filter_fn = lambda sol: _effective_boolean( + evaluate_expression(expr, sol) + ) + + return left_join(left_sols, right_sols, filter_fn)[:limit] + + +async def _eval_union(node, tc, user, collection, limit): + """Evaluate a Union node.""" + left = await evaluate(node.p1, tc, user, collection, limit) + right = await evaluate(node.p2, tc, user, collection, limit) + return union(left, right)[:limit] + + +async def _eval_filter(node, tc, user, collection, limit): + """Evaluate a Filter node.""" + solutions = await evaluate(node.p, tc, user, collection, limit) + expr = node.expr + return [ + sol for sol in solutions + if _effective_boolean(evaluate_expression(expr, sol)) + ] + + +async def _eval_distinct(node, tc, user, collection, limit): + """Evaluate a Distinct node.""" + solutions = await evaluate(node.p, tc, user, collection, limit) + return distinct(solutions) + + +async def _eval_reduced(node, tc, user, collection, limit): + """Evaluate a Reduced node (like Distinct but implementation-defined).""" + # Treat same as Distinct + solutions = await evaluate(node.p, tc, user, collection, limit) + return distinct(solutions) + + +async def _eval_order_by(node, tc, user, collection, limit): + """Evaluate an OrderBy node.""" + solutions = await evaluate(node.p, tc, user, collection, limit) + + key_fns = [] + for cond in node.expr: + if isinstance(cond, CompValue) and cond.name == "OrderCondition": + ascending = cond.order != "DESC" + expr = cond.expr + key_fns.append(( + lambda sol, e=expr: evaluate_expression(e, sol), + ascending, + )) + else: + # Simple variable or expression + key_fns.append(( + lambda sol, e=cond: evaluate_expression(e, sol), + True, + )) + + return order_by(solutions, key_fns) + + +async def _eval_slice(node, tc, user, collection, limit): + """Evaluate a Slice node (LIMIT/OFFSET).""" + # Pass tighter limit downstream if possible + inner_limit = limit + if node.length is not None: + offset = node.start or 0 + inner_limit = min(limit, offset + node.length) + + solutions = await evaluate(node.p, tc, user, collection, inner_limit) + return slice_solutions(solutions, node.start or 0, node.length) + + +async def _eval_extend(node, tc, user, collection, limit): + """Evaluate an Extend node (BIND).""" + solutions = await evaluate(node.p, tc, user, collection, limit) + var_name = str(node.var) + expr = node.expr + + result = [] + for sol in solutions: + val = evaluate_expression(expr, sol) + new_sol = dict(sol) + if isinstance(val, Term): + new_sol[var_name] = val + elif isinstance(val, (int, float)): + new_sol[var_name] = Term(type=LITERAL, value=str(val)) + elif isinstance(val, str): + new_sol[var_name] = Term(type=LITERAL, value=val) + elif isinstance(val, bool): + new_sol[var_name] = Term( + type=LITERAL, value=str(val).lower(), + datatype="http://www.w3.org/2001/XMLSchema#boolean" + ) + elif val is not None: + new_sol[var_name] = Term(type=LITERAL, value=str(val)) + result.append(new_sol) + + return result + + +async def _eval_group(node, tc, user, collection, limit): + """Evaluate a Group node (GROUP BY with aggregation).""" + solutions = await evaluate(node.p, tc, user, collection, limit) + + # Extract grouping expressions + group_exprs = [] + if hasattr(node, "expr") and node.expr: + for expr in node.expr: + if isinstance(expr, CompValue) and expr.name == "GroupAs": + group_exprs.append((expr.expr, str(expr.var) if hasattr(expr, "var") and expr.var else None)) + elif isinstance(expr, Variable): + group_exprs.append((expr, str(expr))) + else: + group_exprs.append((expr, None)) + + # Group solutions + groups = defaultdict(list) + for sol in solutions: + key_parts = [] + for expr, _ in group_exprs: + val = evaluate_expression(expr, sol) + key_parts.append(_term_key(val) if isinstance(val, Term) else val) + groups[tuple(key_parts)].append(sol) + + if not group_exprs: + # No GROUP BY - entire result is one group + groups[()].extend(solutions) + + # Build grouped solutions (one per group) + result = [] + for key, group_sols in groups.items(): + sol = {} + # Include group key variables + if group_sols: + for (expr, var_name), k in zip(group_exprs, key): + if var_name and group_sols: + sol[var_name] = evaluate_expression(expr, group_sols[0]) + sol["__group__"] = group_sols + result.append(sol) + + return result + + +async def _eval_aggregate_join(node, tc, user, collection, limit): + """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" + solutions = await evaluate(node.p, tc, user, collection, limit) + + result = [] + for sol in solutions: + group = sol.get("__group__", [sol]) + new_sol = {k: v for k, v in sol.items() if k != "__group__"} + + # Apply aggregate functions + if hasattr(node, "A") and node.A: + for agg in node.A: + var_name = str(agg.res) + agg_val = _compute_aggregate(agg, group) + new_sol[var_name] = agg_val + + result.append(new_sol) + + return result + + +async def _eval_graph(node, tc, user, collection, limit): + """Evaluate a Graph node (GRAPH clause).""" + term = node.term + + if isinstance(term, URIRef): + # GRAPH { ... } — fixed graph + # We'd need to pass graph to triples queries + # For now, evaluate inner pattern normally + logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") + return await evaluate(node.p, tc, user, collection, limit) + elif isinstance(term, Variable): + # GRAPH ?g { ... } — variable graph + logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") + return await evaluate(node.p, tc, user, collection, limit) + else: + return await evaluate(node.p, tc, user, collection, limit) + + +async def _eval_values(node, tc, user, collection, limit): + """Evaluate a VALUES clause (inline data).""" + variables = [str(v) for v in node.var] + solutions = [] + + for row in node.value: + sol = {} + for var_name, val in zip(variables, row): + if val is not None and str(val) != "UNDEF": + sol[var_name] = rdflib_term_to_term(val) + solutions.append(sol) + + return solutions + + +async def _eval_to_multiset(node, tc, user, collection, limit): + """Evaluate a ToMultiSet node (subquery).""" + return await evaluate(node.p, tc, user, collection, limit) + + +# --- Aggregate computation --- + +def _compute_aggregate(agg, group): + """Compute a single aggregate function over a group of solutions.""" + agg_name = agg.name if hasattr(agg, "name") else "" + + # Get the expression to aggregate + expr = agg.vars if hasattr(agg, "vars") else None + + if agg_name == "Aggregate_Count": + if hasattr(agg, "distinct") and agg.distinct: + vals = set() + for sol in group: + if expr: + val = evaluate_expression(expr, sol) + if val is not None: + vals.add(_term_key(val) if isinstance(val, Term) else val) + else: + vals.add(id(sol)) + return Term(type=LITERAL, value=str(len(vals)), + datatype="http://www.w3.org/2001/XMLSchema#integer") + return Term(type=LITERAL, value=str(len(group)), + datatype="http://www.w3.org/2001/XMLSchema#integer") + + if agg_name == "Aggregate_Sum": + total = 0 + for sol in group: + val = evaluate_expression(expr, sol) if expr else None + num = _try_numeric(val) + if num is not None: + total += num + return Term(type=LITERAL, value=str(total), + datatype="http://www.w3.org/2001/XMLSchema#decimal") + + if agg_name == "Aggregate_Avg": + total = 0 + count = 0 + for sol in group: + val = evaluate_expression(expr, sol) if expr else None + num = _try_numeric(val) + if num is not None: + total += num + count += 1 + avg = total / count if count > 0 else 0 + return Term(type=LITERAL, value=str(avg), + datatype="http://www.w3.org/2001/XMLSchema#decimal") + + if agg_name == "Aggregate_Min": + min_val = None + for sol in group: + val = evaluate_expression(expr, sol) if expr else None + if val is not None: + cmp = _term_key(val) if isinstance(val, Term) else val + if min_val is None or cmp < min_val[0]: + min_val = (cmp, val) + if min_val: + val = min_val[1] + if isinstance(val, Term): + return val + return Term(type=LITERAL, value=str(val)) + return None + + if agg_name == "Aggregate_Max": + max_val = None + for sol in group: + val = evaluate_expression(expr, sol) if expr else None + if val is not None: + cmp = _term_key(val) if isinstance(val, Term) else val + if max_val is None or cmp > max_val[0]: + max_val = (cmp, val) + if max_val: + val = max_val[1] + if isinstance(val, Term): + return val + return Term(type=LITERAL, value=str(val)) + return None + + if agg_name == "Aggregate_GroupConcat": + separator = agg.separator if hasattr(agg, "separator") else " " + vals = [] + for sol in group: + val = evaluate_expression(expr, sol) if expr else None + if val is not None: + if isinstance(val, Term): + vals.append(val.value if val.type == LITERAL else val.iri) + else: + vals.append(str(val)) + return Term(type=LITERAL, value=separator.join(vals)) + + if agg_name == "Aggregate_Sample": + if group: + val = evaluate_expression(expr, group[0]) if expr else None + if isinstance(val, Term): + return val + if val is not None: + return Term(type=LITERAL, value=str(val)) + return None + + logger.warning(f"Unsupported aggregate: {agg_name}") + return None + + +# --- Helper functions --- + +def _to_term(val): + """ + Convert a value to a schema Term. Handles Uri and Literal from the + knowledge module (returned by TriplesClient) as well as plain strings. + """ + if val is None: + return None + if isinstance(val, Term): + return val + if isinstance(val, Uri): + return Term(type=IRI, iri=str(val)) + if isinstance(val, KgLiteral): + return Term(type=LITERAL, value=str(val)) + if isinstance(val, str): + if val.startswith("http://") or val.startswith("https://") or val.startswith("urn:"): + return Term(type=IRI, iri=val) + return Term(type=LITERAL, value=val) + return Term(type=LITERAL, value=str(val)) + + +def _resolve_term(tmpl, solution): + """ + Resolve a triple pattern term. If it's a variable and bound in the + solution, return the bound Term. Otherwise return None (wildcard) + for variables, or convert concrete terms. + """ + if isinstance(tmpl, Variable): + name = str(tmpl) + if name in solution: + return solution[name] + return None + else: + return rdflib_term_to_term(tmpl) + + +async def _query_pattern(tc, s, p, o, user, collection, limit): + """ + Issue a streaming triple pattern query via TriplesClient. + + Returns a list of Triple-like objects with s, p, o attributes. + """ + results = await tc.query( + s=s, p=p, o=o, + limit=limit, + user=user, + collection=collection, + ) + return results + + +def _try_numeric(val): + """Try to convert a value to a number, return None on failure.""" + if val is None: + return None + if isinstance(val, (int, float)): + return val + if isinstance(val, Term) and val.type == LITERAL: + try: + if "." in val.value: + return float(val.value) + return int(val.value) + except (ValueError, TypeError): + return None + return None + + +# --- Handler registry --- + +_HANDLERS = { + "SelectQuery": _eval_select_query, + "Project": _eval_project, + "BGP": _eval_bgp, + "Join": _eval_join, + "LeftJoin": _eval_left_join, + "Union": _eval_union, + "Filter": _eval_filter, + "Distinct": _eval_distinct, + "Reduced": _eval_reduced, + "OrderBy": _eval_order_by, + "Slice": _eval_slice, + "Extend": _eval_extend, + "Group": _eval_group, + "AggregateJoin": _eval_aggregate_join, + "Graph": _eval_graph, + "values": _eval_values, + "ToMultiSet": _eval_to_multiset, +} diff --git a/trustgraph-flow/trustgraph/query/sparql/expressions.py b/trustgraph-flow/trustgraph/query/sparql/expressions.py new file mode 100644 index 00000000..eac1199c --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/expressions.py @@ -0,0 +1,481 @@ +""" +SPARQL FILTER expression evaluator. + +Evaluates rdflib algebra expression nodes against a solution (variable +binding) to produce a value or boolean result. +""" + +import re +import logging +import operator + +from rdflib.term import Variable, URIRef, Literal, BNode +from rdflib.plugins.sparql.parserutils import CompValue + +from ... schema import Term, IRI, LITERAL, BLANK +from . parser import rdflib_term_to_term + +logger = logging.getLogger(__name__) + + +class ExpressionError(Exception): + """Raised when a SPARQL expression cannot be evaluated.""" + pass + + +def evaluate_expression(expr, solution): + """ + Evaluate a SPARQL expression against a solution binding. + + Args: + expr: rdflib algebra expression node + solution: dict mapping variable names to Term values + + Returns: + The result value (Term, bool, number, string, or None) + """ + if expr is None: + return True + + # rdflib Variable + if isinstance(expr, Variable): + name = str(expr) + return solution.get(name) + + # rdflib concrete terms + if isinstance(expr, URIRef): + return Term(type=IRI, iri=str(expr)) + + if isinstance(expr, Literal): + return rdflib_term_to_term(expr) + + if isinstance(expr, BNode): + return Term(type=BLANK, id=str(expr)) + + # Boolean constants + if isinstance(expr, bool): + return expr + + # Numeric constants + if isinstance(expr, (int, float)): + return expr + + # String constants + if isinstance(expr, str): + return expr + + # CompValue nodes from rdflib algebra + if isinstance(expr, CompValue): + return _evaluate_comp_value(expr, solution) + + # List/tuple (e.g. function arguments) + if isinstance(expr, (list, tuple)): + return [evaluate_expression(e, solution) for e in expr] + + logger.warning(f"Unknown expression type: {type(expr)}: {expr}") + return None + + +def _evaluate_comp_value(node, solution): + """Evaluate a CompValue expression node.""" + name = node.name + + # Relational expressions: =, !=, <, >, <=, >= + if name == "RelationalExpression": + return _eval_relational(node, solution) + + # Conditional AND / OR + if name == "ConditionalAndExpression": + return _eval_conditional_and(node, solution) + + if name == "ConditionalOrExpression": + return _eval_conditional_or(node, solution) + + # Unary NOT + if name == "UnaryNot": + val = evaluate_expression(node.expr, solution) + return not _effective_boolean(val) + + # Unary plus/minus + if name == "UnaryPlus": + return _to_numeric(evaluate_expression(node.expr, solution)) + + if name == "UnaryMinus": + val = _to_numeric(evaluate_expression(node.expr, solution)) + return -val if val is not None else None + + # Arithmetic + if name == "AdditiveExpression": + return _eval_additive(node, solution) + + if name == "MultiplicativeExpression": + return _eval_multiplicative(node, solution) + + # SPARQL built-in functions + if name.startswith("Builtin_"): + return _eval_builtin(name, node, solution) + + # Function call + if name == "Function": + return _eval_function(node, solution) + + # Exists / NotExists + if name == "Builtin_EXISTS": + # EXISTS requires graph pattern evaluation - not handled here + logger.warning("EXISTS not supported in filter expressions") + return True + + if name == "Builtin_NOTEXISTS": + logger.warning("NOT EXISTS not supported in filter expressions") + return True + + # TrueFilter (used with OPTIONAL) + if name == "TrueFilter": + return True + + # IN / NOT IN + if name == "Builtin_IN": + return _eval_in(node, solution) + + if name == "Builtin_NOTIN": + return not _eval_in(node, solution) + + logger.warning(f"Unknown CompValue expression: {name}") + return None + + +def _eval_relational(node, solution): + """Evaluate a relational expression (=, !=, <, >, <=, >=).""" + left = evaluate_expression(node.expr, solution) + right = evaluate_expression(node.other, solution) + op = node.op + + if left is None or right is None: + return False + + left_cmp = _comparable_value(left) + right_cmp = _comparable_value(right) + + ops = { + "=": operator.eq, "==": operator.eq, + "!=": operator.ne, + "<": operator.lt, + ">": operator.gt, + "<=": operator.le, + ">=": operator.ge, + } + + op_fn = ops.get(str(op)) + if op_fn is None: + logger.warning(f"Unknown relational operator: {op}") + return False + + try: + return op_fn(left_cmp, right_cmp) + except TypeError: + return False + + +def _eval_conditional_and(node, solution): + """Evaluate AND expression.""" + result = _effective_boolean(evaluate_expression(node.expr, solution)) + if not result: + return False + for other in node.other: + result = _effective_boolean(evaluate_expression(other, solution)) + if not result: + return False + return True + + +def _eval_conditional_or(node, solution): + """Evaluate OR expression.""" + result = _effective_boolean(evaluate_expression(node.expr, solution)) + if result: + return True + for other in node.other: + result = _effective_boolean(evaluate_expression(other, solution)) + if result: + return True + return False + + +def _eval_additive(node, solution): + """Evaluate additive expression (a + b - c ...).""" + result = _to_numeric(evaluate_expression(node.expr, solution)) + if result is None: + return None + for op, operand in zip(node.op, node.other): + val = _to_numeric(evaluate_expression(operand, solution)) + if val is None: + return None + if str(op) == "+": + result = result + val + elif str(op) == "-": + result = result - val + return result + + +def _eval_multiplicative(node, solution): + """Evaluate multiplicative expression (a * b / c ...).""" + result = _to_numeric(evaluate_expression(node.expr, solution)) + if result is None: + return None + for op, operand in zip(node.op, node.other): + val = _to_numeric(evaluate_expression(operand, solution)) + if val is None: + return None + if str(op) == "*": + result = result * val + elif str(op) == "/": + if val == 0: + return None + result = result / val + return result + + +def _eval_builtin(name, node, solution): + """Evaluate SPARQL built-in functions.""" + builtin = name[len("Builtin_"):] + + if builtin == "BOUND": + var_name = str(node.arg) + return var_name in solution and solution[var_name] is not None + + if builtin == "isIRI" or builtin == "isURI": + val = evaluate_expression(node.arg, solution) + return isinstance(val, Term) and val.type == IRI + + if builtin == "isLITERAL": + val = evaluate_expression(node.arg, solution) + return isinstance(val, Term) and val.type == LITERAL + + if builtin == "isBLANK": + val = evaluate_expression(node.arg, solution) + return isinstance(val, Term) and val.type == BLANK + + if builtin == "STR": + val = evaluate_expression(node.arg, solution) + return Term(type=LITERAL, value=_to_string(val)) + + if builtin == "LANG": + val = evaluate_expression(node.arg, solution) + if isinstance(val, Term) and val.type == LITERAL: + return Term(type=LITERAL, value=val.language or "") + return Term(type=LITERAL, value="") + + if builtin == "DATATYPE": + val = evaluate_expression(node.arg, solution) + if isinstance(val, Term) and val.type == LITERAL and val.datatype: + return Term(type=IRI, iri=val.datatype) + return Term(type=IRI, iri="http://www.w3.org/2001/XMLSchema#string") + + if builtin == "REGEX": + text = _to_string(evaluate_expression(node.text, solution)) + pattern = _to_string(evaluate_expression(node.pattern, solution)) + flags_str = "" + if hasattr(node, "flags") and node.flags is not None: + flags_str = _to_string(evaluate_expression(node.flags, solution)) + + re_flags = 0 + if "i" in flags_str: + re_flags |= re.IGNORECASE + if "m" in flags_str: + re_flags |= re.MULTILINE + if "s" in flags_str: + re_flags |= re.DOTALL + + try: + return bool(re.search(pattern, text, re_flags)) + except re.error: + return False + + if builtin == "STRLEN": + val = _to_string(evaluate_expression(node.arg, solution)) + return len(val) + + if builtin == "UCASE": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=LITERAL, value=val.upper()) + + if builtin == "LCASE": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=LITERAL, value=val.lower()) + + if builtin == "CONTAINS": + string = _to_string(evaluate_expression(node.arg1, solution)) + pattern = _to_string(evaluate_expression(node.arg2, solution)) + return pattern in string + + if builtin == "STRSTARTS": + string = _to_string(evaluate_expression(node.arg1, solution)) + prefix = _to_string(evaluate_expression(node.arg2, solution)) + return string.startswith(prefix) + + if builtin == "STRENDS": + string = _to_string(evaluate_expression(node.arg1, solution)) + suffix = _to_string(evaluate_expression(node.arg2, solution)) + return string.endswith(suffix) + + if builtin == "CONCAT": + args = [_to_string(evaluate_expression(a, solution)) for a in node.arg] + return Term(type=LITERAL, value="".join(args)) + + if builtin == "IF": + cond = _effective_boolean(evaluate_expression(node.arg1, solution)) + if cond: + return evaluate_expression(node.arg2, solution) + else: + return evaluate_expression(node.arg3, solution) + + if builtin == "COALESCE": + for arg in node.arg: + val = evaluate_expression(arg, solution) + if val is not None: + return val + return None + + if builtin == "sameTerm": + left = evaluate_expression(node.arg1, solution) + right = evaluate_expression(node.arg2, solution) + if not isinstance(left, Term) or not isinstance(right, Term): + return False + from . solutions import _term_key + return _term_key(left) == _term_key(right) + + logger.warning(f"Unsupported built-in function: {builtin}") + return None + + +def _eval_function(node, solution): + """Evaluate a SPARQL function call.""" + # Cast functions (xsd:integer, xsd:string, etc.) + iri = str(node.iri) if hasattr(node, "iri") else "" + args = [evaluate_expression(a, solution) for a in node.expr] + + xsd = "http://www.w3.org/2001/XMLSchema#" + if iri == xsd + "integer": + try: + return int(_to_numeric(args[0])) + except (TypeError, ValueError): + return None + elif iri == xsd + "decimal" or iri == xsd + "double" or iri == xsd + "float": + try: + return float(_to_numeric(args[0])) + except (TypeError, ValueError): + return None + elif iri == xsd + "string": + return Term(type=LITERAL, value=_to_string(args[0])) + elif iri == xsd + "boolean": + return _effective_boolean(args[0]) + + logger.warning(f"Unsupported function: {iri}") + return None + + +def _eval_in(node, solution): + """Evaluate IN expression.""" + val = evaluate_expression(node.expr, solution) + for item in node.other: + other = evaluate_expression(item, solution) + if _comparable_value(val) == _comparable_value(other): + return True + return False + + +# --- Value conversion helpers --- + +def _effective_boolean(val): + """Convert a value to its effective boolean value (EBV).""" + if isinstance(val, bool): + return val + if val is None: + return False + if isinstance(val, (int, float)): + return val != 0 + if isinstance(val, str): + return len(val) > 0 + if isinstance(val, Term): + if val.type == LITERAL: + v = val.value + if val.datatype == "http://www.w3.org/2001/XMLSchema#boolean": + return v.lower() in ("true", "1") + if val.datatype in ( + "http://www.w3.org/2001/XMLSchema#integer", + "http://www.w3.org/2001/XMLSchema#decimal", + "http://www.w3.org/2001/XMLSchema#double", + "http://www.w3.org/2001/XMLSchema#float", + ): + try: + return float(v) != 0 + except ValueError: + return False + return len(v) > 0 + return True + return bool(val) + + +def _to_string(val): + """Convert a value to a string.""" + if val is None: + return "" + if isinstance(val, str): + return val + if isinstance(val, Term): + if val.type == IRI: + return val.iri + elif val.type == LITERAL: + return val.value + elif val.type == BLANK: + return val.id + return str(val) + + +def _to_numeric(val): + """Convert a value to a number.""" + if val is None: + return None + if isinstance(val, (int, float)): + return val + if isinstance(val, Term) and val.type == LITERAL: + try: + if "." in val.value: + return float(val.value) + return int(val.value) + except (ValueError, TypeError): + return None + if isinstance(val, str): + try: + if "." in val: + return float(val) + return int(val) + except (ValueError, TypeError): + return None + return None + + +def _comparable_value(val): + """ + Convert a value to a form suitable for comparison. + Returns a tuple (type, value) for consistent ordering. + """ + if val is None: + return (0, "") + if isinstance(val, bool): + return (1, val) + if isinstance(val, (int, float)): + return (2, val) + if isinstance(val, str): + return (3, val) + if isinstance(val, Term): + if val.type == IRI: + return (4, val.iri) + elif val.type == LITERAL: + # Try numeric comparison for numeric types + num = _to_numeric(val) + if num is not None: + return (2, num) + return (3, val.value) + elif val.type == BLANK: + return (5, val.id) + return (6, str(val)) diff --git a/trustgraph-flow/trustgraph/query/sparql/parser.py b/trustgraph-flow/trustgraph/query/sparql/parser.py new file mode 100644 index 00000000..7de18460 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/parser.py @@ -0,0 +1,139 @@ +""" +SPARQL parser wrapping rdflib's SPARQL 1.1 parser and algebra compiler. +Parses a SPARQL query string into an algebra tree for evaluation. +""" + +import logging + +from rdflib.plugins.sparql import prepareQuery +from rdflib.plugins.sparql.algebra import translateQuery +from rdflib.plugins.sparql.parserutils import CompValue +from rdflib.term import Variable, URIRef, Literal, BNode + +from ... schema import Term, Triple, IRI, LITERAL, BLANK + +logger = logging.getLogger(__name__) + + +class ParseError(Exception): + """Raised when a SPARQL query cannot be parsed.""" + pass + + +class ParsedQuery: + """Result of parsing a SPARQL query string.""" + + def __init__(self, algebra, query_type, variables=None): + self.algebra = algebra + self.query_type = query_type # "select", "ask", "construct", "describe" + self.variables = variables or [] # projected variable names (SELECT) + + +def rdflib_term_to_term(t): + """Convert an rdflib term (URIRef, Literal, BNode) to a schema Term.""" + if isinstance(t, URIRef): + return Term(type=IRI, iri=str(t)) + elif isinstance(t, Literal): + term = Term(type=LITERAL, value=str(t)) + if t.datatype: + term.datatype = str(t.datatype) + if t.language: + term.language = t.language + return term + elif isinstance(t, BNode): + return Term(type=BLANK, id=str(t)) + else: + return Term(type=LITERAL, value=str(t)) + + +def term_to_rdflib(t): + """Convert a schema Term to an rdflib term.""" + if t.type == IRI: + return URIRef(t.iri) + elif t.type == LITERAL: + kwargs = {} + if t.datatype: + kwargs["datatype"] = URIRef(t.datatype) + if t.language: + kwargs["lang"] = t.language + return Literal(t.value, **kwargs) + elif t.type == BLANK: + return BNode(t.id) + else: + return Literal(t.value) + + +def parse_sparql(query_string): + """ + Parse a SPARQL query string into a ParsedQuery. + + Args: + query_string: SPARQL 1.1 query string + + Returns: + ParsedQuery with algebra tree, query type, and projected variables + + Raises: + ParseError: if the query cannot be parsed + """ + try: + prepared = prepareQuery(query_string) + except Exception as e: + raise ParseError(f"SPARQL parse error: {e}") from e + + algebra = prepared.algebra + + # Determine query type and extract variables + query_type = _detect_query_type(algebra) + variables = _extract_variables(algebra, query_type) + + return ParsedQuery( + algebra=algebra, + query_type=query_type, + variables=variables, + ) + + +def _detect_query_type(algebra): + """Detect the SPARQL query type from the algebra root.""" + name = algebra.name + + if name == "SelectQuery": + return "select" + elif name == "AskQuery": + return "ask" + elif name == "ConstructQuery": + return "construct" + elif name == "DescribeQuery": + return "describe" + + # The top-level algebra node may be a modifier (Project, Slice, etc.) + # wrapping the actual query. Check for common patterns. + if name in ("Project", "Distinct", "Reduced", "OrderBy", "Slice"): + return "select" + + logger.warning(f"Unknown algebra root type: {name}, assuming select") + return "select" + + +def _extract_variables(algebra, query_type): + """Extract projected variable names from the algebra.""" + if query_type != "select": + return [] + + # For SELECT queries, the Project node has PV (projected variables) + if hasattr(algebra, "PV"): + return [str(v) for v in algebra.PV] + + # Walk down through modifiers to find Project + node = algebra + while hasattr(node, "p"): + node = node.p + if hasattr(node, "PV"): + return [str(v) for v in node.PV] + + # Fallback: collect all variables from the algebra + if hasattr(algebra, "_vars"): + return [str(v) for v in algebra._vars] + + return [] diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py new file mode 100644 index 00000000..e815540f --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -0,0 +1,230 @@ +""" +SPARQL query service. Accepts SPARQL queries, decomposes them into triple +pattern lookups via the triples query pub/sub interface, performs in-memory +joins/filters/projections, and returns SPARQL result bindings. +""" + +import logging + +from ... schema import SparqlQueryRequest, SparqlQueryResponse +from ... schema import SparqlBinding, Error, Term, Triple +from ... base import FlowProcessor, ConsumerSpec, ProducerSpec +from ... base import TriplesClientSpec + +from . parser import parse_sparql, ParseError +from . algebra import evaluate, EvaluationError + +logger = logging.getLogger(__name__) + +default_ident = "sparql-query" +default_concurrency = 10 + + +class Processor(FlowProcessor): + + def __init__(self, **params): + + id = params.get("id", default_ident) + concurrency = params.get("concurrency", default_concurrency) + + super(Processor, self).__init__( + **params | { + "id": id, + "concurrency": concurrency, + } + ) + + self.register_specification( + ConsumerSpec( + name="request", + schema=SparqlQueryRequest, + handler=self.on_message, + concurrency=concurrency, + ) + ) + + self.register_specification( + ProducerSpec( + name="response", + schema=SparqlQueryResponse, + ) + ) + + self.register_specification( + TriplesClientSpec( + request_name="triples-request", + response_name="triples-response", + ) + ) + + async def on_message(self, msg, consumer, flow): + + try: + + request = msg.value() + id = msg.properties()["id"] + + logger.debug(f"Handling SPARQL query request {id}...") + + response = await self.execute_sparql(request, flow) + + await flow("response").send(response, properties={"id": id}) + + logger.debug("SPARQL query request completed") + + except Exception as e: + + logger.error( + f"Exception in SPARQL query service: {e}", exc_info=True + ) + + r = SparqlQueryResponse( + error=Error( + type="sparql-query-error", + message=str(e), + ), + ) + + await flow("response").send(r, properties={"id": id}) + + async def execute_sparql(self, request, flow): + """Parse and evaluate a SPARQL query.""" + + # Parse the SPARQL query + try: + parsed = parse_sparql(request.query) + except ParseError as e: + return SparqlQueryResponse( + error=Error( + type="sparql-parse-error", + message=str(e), + ), + ) + + # Get the triples client from the flow + triples_client = flow("triples-request") + + # Evaluate the algebra + try: + solutions = await evaluate( + parsed.algebra, + triples_client, + user=request.user or "trustgraph", + collection=request.collection or "default", + limit=request.limit or 10000, + ) + except EvaluationError as e: + return SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ) + + # Build response based on query type + if parsed.query_type == "select": + return self._build_select_response(parsed, solutions) + elif parsed.query_type == "ask": + return self._build_ask_response(solutions) + elif parsed.query_type == "construct": + return self._build_construct_response(parsed, solutions) + elif parsed.query_type == "describe": + return self._build_describe_response(parsed, solutions) + else: + return SparqlQueryResponse( + error=Error( + type="sparql-unsupported", + message=f"Unsupported query type: {parsed.query_type}", + ), + ) + + def _build_select_response(self, parsed, solutions): + """Build response for SELECT queries.""" + variables = parsed.variables + + bindings = [] + for sol in solutions: + values = [sol.get(v) for v in variables] + bindings.append(SparqlBinding(values=values)) + + return SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=bindings, + ) + + def _build_ask_response(self, solutions): + """Build response for ASK queries.""" + return SparqlQueryResponse( + query_type="ask", + ask_result=len(solutions) > 0, + ) + + def _build_construct_response(self, parsed, solutions): + """Build response for CONSTRUCT queries.""" + # CONSTRUCT template is in the algebra + template = [] + if hasattr(parsed.algebra, "template"): + template = parsed.algebra.template + + triples = [] + seen = set() + + for sol in solutions: + for s_tmpl, p_tmpl, o_tmpl in template: + from rdflib.term import Variable + from . parser import rdflib_term_to_term + + s = self._resolve_construct_term(s_tmpl, sol) + p = self._resolve_construct_term(p_tmpl, sol) + o = self._resolve_construct_term(o_tmpl, sol) + + if s is not None and p is not None and o is not None: + key = ( + s.type, s.iri or s.value, + p.type, p.iri or p.value, + o.type, o.iri or o.value, + ) + if key not in seen: + seen.add(key) + triples.append(Triple(s=s, p=p, o=o)) + + return SparqlQueryResponse( + query_type="construct", + triples=triples, + ) + + def _build_describe_response(self, parsed, solutions): + """Build response for DESCRIBE queries.""" + # DESCRIBE returns all triples about the described resources + # For now, return empty - would need additional triples queries + return SparqlQueryResponse( + query_type="describe", + triples=[], + ) + + def _resolve_construct_term(self, tmpl, solution): + """Resolve a CONSTRUCT template term.""" + from rdflib.term import Variable + from . parser import rdflib_term_to_term + + if isinstance(tmpl, Variable): + return solution.get(str(tmpl)) + else: + return rdflib_term_to_term(tmpl) + + @staticmethod + def add_args(parser): + FlowProcessor.add_args(parser) + + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Number of concurrent requests ' + f'(default: {default_concurrency})' + ) + + +def run(): + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/query/sparql/solutions.py b/trustgraph-flow/trustgraph/query/sparql/solutions.py new file mode 100644 index 00000000..d1ea8373 --- /dev/null +++ b/trustgraph-flow/trustgraph/query/sparql/solutions.py @@ -0,0 +1,248 @@ +""" +Solution sequence operations for SPARQL evaluation. + +A solution is a dict mapping variable names (str) to Term values. +A solution sequence is a list of solutions. +""" + +import logging +from collections import defaultdict + +from ... schema import Term, IRI, LITERAL, BLANK + +logger = logging.getLogger(__name__) + + +def _term_key(term): + """Create a hashable key from a Term for join/distinct operations.""" + if term is None: + return None + if term.type == IRI: + return ("i", term.iri) + elif term.type == LITERAL: + return ("l", term.value, term.datatype, term.language) + elif term.type == BLANK: + return ("b", term.id) + else: + return ("?", str(term)) + + +def _solution_key(solution, variables): + """Create a hashable key from a solution for the given variables.""" + return tuple(_term_key(solution.get(v)) for v in variables) + + +def _terms_equal(a, b): + """Check if two Terms are equal.""" + if a is None and b is None: + return True + if a is None or b is None: + return False + return _term_key(a) == _term_key(b) + + +def _compatible(sol_a, sol_b): + """Check if two solutions are compatible (agree on shared variables).""" + shared = set(sol_a.keys()) & set(sol_b.keys()) + return all(_terms_equal(sol_a[v], sol_b[v]) for v in shared) + + +def _merge(sol_a, sol_b): + """Merge two compatible solutions into one.""" + result = dict(sol_a) + result.update(sol_b) + return result + + +def hash_join(left, right): + """ + Inner join two solution sequences on shared variables. + Uses hash join for efficiency. + """ + if not left or not right: + return [] + + left_vars = set() + for sol in left: + left_vars.update(sol.keys()) + + right_vars = set() + for sol in right: + right_vars.update(sol.keys()) + + shared = sorted(left_vars & right_vars) + + if not shared: + # Cross product + return [_merge(l, r) for l in left for r in right] + + # Build hash table on the smaller side + if len(left) <= len(right): + index = defaultdict(list) + for sol in left: + key = _solution_key(sol, shared) + index[key].append(sol) + + results = [] + for sol_r in right: + key = _solution_key(sol_r, shared) + for sol_l in index.get(key, []): + results.append(_merge(sol_l, sol_r)) + return results + else: + index = defaultdict(list) + for sol in right: + key = _solution_key(sol, shared) + index[key].append(sol) + + results = [] + for sol_l in left: + key = _solution_key(sol_l, shared) + for sol_r in index.get(key, []): + results.append(_merge(sol_l, sol_r)) + return results + + +def left_join(left, right, filter_fn=None): + """ + Left outer join (OPTIONAL semantics). + Every left solution is preserved. If it joins with right solutions + (and passes the optional filter), the merged solutions are included. + Otherwise the original left solution is kept. + """ + if not left: + return [] + + if not right: + return list(left) + + right_vars = set() + for sol in right: + right_vars.update(sol.keys()) + + left_vars = set() + for sol in left: + left_vars.update(sol.keys()) + + shared = sorted(left_vars & right_vars) + + # Build hash table on right side + index = defaultdict(list) + for sol in right: + key = _solution_key(sol, shared) if shared else () + index[key].append(sol) + + results = [] + for sol_l in left: + key = _solution_key(sol_l, shared) if shared else () + matches = index.get(key, []) + + matched = False + for sol_r in matches: + merged = _merge(sol_l, sol_r) + if filter_fn is None or filter_fn(merged): + results.append(merged) + matched = True + + if not matched: + results.append(dict(sol_l)) + + return results + + +def union(left, right): + """Union two solution sequences (concatenation).""" + return list(left) + list(right) + + +def project(solutions, variables): + """Keep only the specified variables in each solution.""" + return [ + {v: sol[v] for v in variables if v in sol} + for sol in solutions + ] + + +def distinct(solutions): + """Remove duplicate solutions.""" + seen = set() + results = [] + for sol in solutions: + key = tuple(sorted( + (k, _term_key(v)) for k, v in sol.items() + )) + if key not in seen: + seen.add(key) + results.append(sol) + return results + + +def order_by(solutions, key_fns): + """ + Sort solutions by the given key functions. + + key_fns is a list of (fn, ascending) tuples where fn extracts + a comparable value from a solution. + """ + if not key_fns: + return solutions + + def sort_key(sol): + keys = [] + for fn, ascending in key_fns: + val = fn(sol) + # Convert to comparable form + if val is None: + comparable = ("", "") + elif isinstance(val, Term): + comparable = _term_key(val) + else: + comparable = ("v", str(val)) + keys.append(comparable) + return keys + + # Handle ascending/descending + # For simplicity, sort ascending then reverse individual keys + # This works for single sort keys; for multiple mixed keys we + # need a wrapper + result = sorted(solutions, key=sort_key) + + # If any key is descending, we need a more complex approach. + # Check if all are same direction for the simple case. + if key_fns and all(not asc for _, asc in key_fns): + result.reverse() + elif key_fns and not all(asc for _, asc in key_fns): + # Mixed ascending/descending - use negation wrapper + result = _mixed_sort(solutions, key_fns) + + return result + + +def _mixed_sort(solutions, key_fns): + """Sort with mixed ascending/descending keys.""" + import functools + + def compare(a, b): + for fn, ascending in key_fns: + va = fn(a) + vb = fn(b) + ka = _term_key(va) if isinstance(va, Term) else ("v", str(va)) if va is not None else ("", "") + kb = _term_key(vb) if isinstance(vb, Term) else ("v", str(vb)) if vb is not None else ("", "") + + if ka < kb: + return -1 if ascending else 1 + elif ka > kb: + return 1 if ascending else -1 + + return 0 + + return sorted(solutions, key=functools.cmp_to_key(compare)) + + +def slice_solutions(solutions, offset=0, limit=None): + """Apply OFFSET and LIMIT to a solution sequence.""" + if offset: + solutions = solutions[offset:] + if limit is not None: + solutions = solutions[:limit] + return solutions