diff --git a/dev-tools/library_client.py b/dev-tools/library_client.py index ae9d6857..30e0c344 100644 --- a/dev-tools/library_client.py +++ b/dev-tools/library_client.py @@ -25,7 +25,7 @@ BUCKET_URL = "https://storage.googleapis.com/trustgraph-library" INDEX_URL = f"{BUCKET_URL}/index.json" default_url = os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/") -default_user = "trustgraph" +default_workspace = os.getenv("TRUSTGRAPH_WORKSPACE", "default") default_token = os.getenv("TRUSTGRAPH_TOKEN", None) @@ -113,7 +113,7 @@ def convert_metadata(metadata_json): return triples -def load_document(api, user, doc_entry): +def load_document(api, doc_entry): """Fetch metadata and content for a document, then load into TrustGraph.""" doc_id = doc_entry["id"] title = doc_entry["title"] @@ -133,7 +133,6 @@ def load_document(api, user, doc_entry): api.add_document( id=doc["id"], metadata=metadata, - user=user, kind=doc["kind"], title=doc["title"], comments=doc["comments"], @@ -144,12 +143,12 @@ def load_document(api, user, doc_entry): print(f" done.") -def load_documents(api, user, docs): +def load_documents(api, docs): """Load a list of documents.""" print(f"Loading {len(docs)} document(s)...\n") for doc in docs: try: - load_document(api, user, doc) + load_document(api, doc) except Exception as e: print(f" FAILED: {e}", file=sys.stderr) print() @@ -166,8 +165,8 @@ def main(): help=f"TrustGraph API URL (default: {default_url})", ) parser.add_argument( - "-U", "--user", default=default_user, - help=f"User ID (default: {default_user})", + "-w", "--workspace", default=default_workspace, + help=f"Workspace (default: {default_workspace})", ) parser.add_argument( "-t", "--token", default=default_token, @@ -212,22 +211,22 @@ def main(): return # Load commands need the API - api = Api(args.url, token=args.token).library() + api = Api(args.url, token=args.token, workspace=args.workspace).library() if args.command == "load-all": - load_documents(api, args.user, index) + load_documents(api, index) elif args.command == "load-doc": matches = [d for d in index if str(d.get("id")) == args.id] if not matches: print(f"No document with ID '{args.id}' found.", file=sys.stderr) sys.exit(1) - load_documents(api, args.user, matches) + load_documents(api, matches) elif args.command == "load-match": results = search_index(index, args.query) if results: - load_documents(api, args.user, results) + load_documents(api, results) else: print("No matches found.", file=sys.stderr) sys.exit(1) diff --git a/dev-tools/tests/relay/websocket_relay.py b/dev-tools/tests/relay/websocket_relay.py index d537f7da..6495ee49 100644 --- a/dev-tools/tests/relay/websocket_relay.py +++ b/dev-tools/tests/relay/websocket_relay.py @@ -3,208 +3,278 @@ WebSocket Relay Test Harness This script creates a relay server with two WebSocket endpoints: -- /in - for test clients to connect to -- /out - for reverse gateway to connect to +- /in - for test clients to connect to (speaks api-gateway protocol) +- /out - for reverse gateway to connect to (speaks rev-gateway protocol) -Messages are bidirectionally relayed between the two connections. +Clients on /in authenticate with a first-frame auth message: + {"type": "auth", "token": "..."} + +The relay stores the token and injects it into each subsequent message +before forwarding to /out. Responses from /out are forwarded back to +the originating /in connection unchanged. Usage: python websocket_relay.py [--port PORT] [--host HOST] """ import asyncio +import json import logging import argparse from aiohttp import web, WSMsgType -import weakref -from typing import Optional, Set +from typing import Dict, Optional -# Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("websocket_relay") + +class InConnection: + def __init__(self, ws, conn_id): + self.ws = ws + self.conn_id = conn_id + self.token: Optional[str] = None + self.authenticated = False + + class WebSocketRelay: - """WebSocket relay that forwards messages between 'in' and 'out' connections""" - + def __init__(self): - self.in_connections: Set = weakref.WeakSet() - self.out_connections: Set = weakref.WeakSet() - + self.in_connections: Dict[str, InConnection] = {} + self.out_connections: set = set() + self._conn_counter = 0 + + def _next_conn_id(self): + self._conn_counter += 1 + return f"conn-{self._conn_counter}" + async def handle_in_connection(self, request): - """Handle incoming connections on /in endpoint""" ws = web.WebSocketResponse() await ws.prepare(request) - - self.in_connections.add(ws) - logger.info(f"New 'in' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + + conn_id = self._next_conn_id() + conn = InConnection(ws, conn_id) + self.in_connections[conn_id] = conn + logger.info( + f"New 'in' connection {conn_id}. " + f"Total in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + try: async for msg in ws: if msg.type == WSMsgType.TEXT: - data = msg.data - logger.info(f"IN → OUT: {data}") - await self._forward_to_out(data) - elif msg.type == WSMsgType.BINARY: - data = msg.data - logger.info(f"IN → OUT: {len(data)} bytes (binary)") - await self._forward_to_out(data, binary=True) + await self._handle_in_message(conn, msg.data) elif msg.type == WSMsgType.ERROR: - logger.error(f"WebSocket error on 'in' connection: {ws.exception()}") + logger.error( + f"WebSocket error on 'in' connection " + f"{conn_id}: {ws.exception()}" + ) break else: break except Exception as e: - logger.error(f"Error in 'in' connection handler: {e}") + logger.error( + f"Error in 'in' connection {conn_id}: {e}" + ) finally: - logger.info(f"'in' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + del self.in_connections[conn_id] + logger.info( + f"'in' connection {conn_id} closed. " + f"Remaining in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + return ws - - async def handle_out_connection(self, request): - """Handle outgoing connections on /out endpoint""" - ws = web.WebSocketResponse() - await ws.prepare(request) - - self.out_connections.add(ws) - logger.info(f"New 'out' connection. Total in: {len(self.in_connections)}, out: {len(self.out_connections)}") - + + async def _handle_in_message(self, conn, data): try: - async for msg in ws: - if msg.type == WSMsgType.TEXT: - data = msg.data - logger.info(f"OUT → IN: {data}") - await self._forward_to_in(data) - elif msg.type == WSMsgType.BINARY: - data = msg.data - logger.info(f"OUT → IN: {len(data)} bytes (binary)") - await self._forward_to_in(data, binary=True) - elif msg.type == WSMsgType.ERROR: - logger.error(f"WebSocket error on 'out' connection: {ws.exception()}") - break - else: - break - except Exception as e: - logger.error(f"Error in 'out' connection handler: {e}") - finally: - logger.info(f"'out' connection closed. Remaining in: {len(self.in_connections)}, out: {len(self.out_connections)}") - - return ws - - async def _forward_to_out(self, data, binary=False): - """Forward message from 'in' to all 'out' connections""" - if not self.out_connections: - logger.warning("No 'out' connections available to forward message") + message = json.loads(data) + except json.JSONDecodeError: + logger.warning( + f"{conn.conn_id}: received non-JSON message" + ) return - - closed_connections = [] + + if isinstance(message, dict) and message.get("type") == "auth": + conn.token = message.get("token", "") + conn.authenticated = True + logger.info(f"{conn.conn_id}: authenticated") + await conn.ws.send_json({ + "type": "auth-ok", + "workspace": "relayed", + }) + return + + if not conn.authenticated: + await conn.ws.send_json({ + "error": { + "message": "auth required", + "type": "auth-required", + }, + "complete": True, + }) + return + + message["token"] = conn.token + message["_relay_conn"] = conn.conn_id + + forwarded = json.dumps(message) + logger.info(f"IN {conn.conn_id} → OUT: {forwarded}") + await self._forward_to_out(forwarded) + + async def handle_out_connection(self, request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + self.out_connections.add(ws) + logger.info( + f"New 'out' connection. " + f"Total in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + + try: + async for msg in ws: + if msg.type == WSMsgType.TEXT: + await self._handle_out_message(msg.data) + elif msg.type == WSMsgType.ERROR: + logger.error( + f"WebSocket error on 'out' connection: " + f"{ws.exception()}" + ) + break + else: + break + except Exception as e: + logger.error(f"Error in 'out' connection: {e}") + finally: + self.out_connections.discard(ws) + logger.info( + f"'out' connection closed. " + f"Remaining in: {len(self.in_connections)}, " + f"out: {len(self.out_connections)}" + ) + + return ws + + async def _handle_out_message(self, data): + try: + message = json.loads(data) + except json.JSONDecodeError: + logger.warning("OUT: received non-JSON message") + return + + conn_id = message.pop("_relay_conn", None) + + forwarded = json.dumps(message) + logger.info(f"OUT → IN {conn_id or 'broadcast'}: {forwarded}") + + if conn_id and conn_id in self.in_connections: + conn = self.in_connections[conn_id] + try: + if not conn.ws.closed: + await conn.ws.send_str(forwarded) + except Exception as e: + logger.error( + f"Error forwarding to 'in' {conn_id}: {e}" + ) + else: + await self._broadcast_to_in(forwarded) + + async def _broadcast_to_in(self, data): + closed = [] + for conn_id, conn in list(self.in_connections.items()): + try: + if conn.ws.closed: + closed.append(conn_id) + continue + await conn.ws.send_str(data) + except Exception as e: + logger.error( + f"Error broadcasting to 'in' {conn_id}: {e}" + ) + closed.append(conn_id) + for conn_id in closed: + self.in_connections.pop(conn_id, None) + + async def _forward_to_out(self, data): + closed = [] for ws in list(self.out_connections): try: if ws.closed: - closed_connections.append(ws) + closed.append(ws) continue - - if binary: - await ws.send_bytes(data) - else: - await ws.send_str(data) + await ws.send_str(data) except Exception as e: - logger.error(f"Error forwarding to 'out' connection: {e}") - closed_connections.append(ws) - - # Clean up closed connections - for ws in closed_connections: - if ws in self.out_connections: - self.out_connections.discard(ws) - - async def _forward_to_in(self, data, binary=False): - """Forward message from 'out' to all 'in' connections""" - if not self.in_connections: - logger.warning("No 'in' connections available to forward message") - return - - closed_connections = [] - for ws in list(self.in_connections): - try: - if ws.closed: - closed_connections.append(ws) - continue - - if binary: - await ws.send_bytes(data) - else: - await ws.send_str(data) - except Exception as e: - logger.error(f"Error forwarding to 'in' connection: {e}") - closed_connections.append(ws) - - # Clean up closed connections - for ws in closed_connections: - if ws in self.in_connections: - self.in_connections.discard(ws) + logger.error(f"Error forwarding to 'out': {e}") + closed.append(ws) + for ws in closed: + self.out_connections.discard(ws) + async def create_app(relay): - """Create the web application with routes""" app = web.Application() - - # Add routes - app.router.add_get('/in', relay.handle_in_connection) + + app.router.add_get('/in/api/v1/socket', relay.handle_in_connection) app.router.add_get('/out', relay.handle_out_connection) - - # Add a simple status endpoint + async def status(request): - status_info = { + return web.json_response({ 'in_connections': len(relay.in_connections), 'out_connections': len(relay.out_connections), - 'status': 'running' - } - return web.json_response(status_info) - + 'status': 'running', + }) + app.router.add_get('/status', status) - app.router.add_get('/', status) # Root also shows status - + app.router.add_get('/', status) + return app + def main(): parser = argparse.ArgumentParser( description="WebSocket Relay Test Harness" ) parser.add_argument( - '--host', + '--host', default='localhost', - help='Host to bind to (default: localhost)' + help='Host to bind to (default: localhost)', ) parser.add_argument( - '--port', - type=int, + '--port', + type=int, default=8080, - help='Port to bind to (default: 8080)' + help='Port to bind to (default: 8080)', ) parser.add_argument( '--verbose', '-v', action='store_true', - help='Enable verbose logging' + help='Enable verbose logging', ) - + args = parser.parse_args() - + if args.verbose: logging.getLogger().setLevel(logging.DEBUG) - + relay = WebSocketRelay() - + print(f"Starting WebSocket Relay on {args.host}:{args.port}") - print(f" 'in' endpoint: ws://{args.host}:{args.port}/in") + print(f" 'in' endpoint: ws://{args.host}:{args.port}/in/api/v1/socket") print(f" 'out' endpoint: ws://{args.host}:{args.port}/out") print(f" Status: http://{args.host}:{args.port}/status") print() - print("Usage:") - print(f" Test client connects to: ws://{args.host}:{args.port}/in") - print(f" Reverse gateway connects to: ws://{args.host}:{args.port}/out") - + print("Client protocol (same as api-gateway):") + print(' 1. Connect to /in/api/v1/socket') + print(' 2. Send: {"type": "auth", "token": "tg_..."}') + print(' 3. Receive: {"type": "auth-ok", "workspace": "relayed"}') + print(' 4. Send requests as normal') + web.run_app(create_app(relay), host=args.host, port=args.port) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/unit/test_concurrency/test_dispatcher_semaphore.py b/tests/unit/test_concurrency/test_dispatcher_semaphore.py index 6a1ae8ab..a5374678 100644 --- a/tests/unit/test_concurrency/test_dispatcher_semaphore.py +++ b/tests/unit/test_concurrency/test_dispatcher_semaphore.py @@ -25,16 +25,17 @@ class TestSemaphoreEnforcement: max_concurrent = 0 processing_event = asyncio.Event() - async def slow_process(message): + async def slow_process(message, sender): nonlocal concurrent_count, max_concurrent concurrent_count += 1 max_concurrent = max(max_concurrent, concurrent_count) await asyncio.sleep(0.05) concurrent_count -= 1 - return {"id": message.get("id"), "response": {"ok": True}} dispatcher._process_message = slow_process + sender = AsyncMock() + # Launch more tasks than max_workers messages = [ {"id": f"msg-{i}", "service": "test", "request": {}} @@ -42,7 +43,7 @@ class TestSemaphoreEnforcement: ] tasks = [ - asyncio.create_task(dispatcher.handle_message(m)) + asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages ] @@ -66,17 +67,17 @@ class TestSemaphoreEnforcement: original_process = dispatcher._process_message - async def tracking_process(message): + async def tracking_process(message, sender): nonlocal task_was_tracked # During processing, our task should be in active_tasks if len(dispatcher.active_tasks) > 0: task_was_tracked = True - return {"id": message.get("id"), "response": {"ok": True}} dispatcher._process_message = tracking_process await dispatcher.handle_message( - {"id": "test", "service": "test", "request": {}} + {"id": "test", "service": "test", "request": {}}, + AsyncMock(), ) assert task_was_tracked @@ -88,7 +89,7 @@ class TestSemaphoreEnforcement: """Semaphore should be released even if processing raises.""" dispatcher = MessageDispatcher(max_workers=2) - async def failing_process(message): + async def failing_process(message, sender): raise RuntimeError("process failed") dispatcher._process_message = failing_process @@ -96,7 +97,8 @@ class TestSemaphoreEnforcement: # Should not deadlock — semaphore must be released on error with pytest.raises(RuntimeError): await dispatcher.handle_message( - {"id": "test", "service": "test", "request": {}} + {"id": "test", "service": "test", "request": {}}, + AsyncMock(), ) # Semaphore should be back at max @@ -109,17 +111,18 @@ class TestSemaphoreEnforcement: order = [] - async def ordered_process(message): + async def ordered_process(message, sender): msg_id = message["id"] order.append(f"start-{msg_id}") await asyncio.sleep(0.02) order.append(f"end-{msg_id}") - return {"id": msg_id, "response": {"ok": True}} dispatcher._process_message = ordered_process + sender = AsyncMock() + messages = [{"id": str(i), "service": "t", "request": {}} for i in range(3)] - tasks = [asyncio.create_task(dispatcher.handle_message(m)) for m in messages] + tasks = [asyncio.create_task(dispatcher.handle_message(m, sender)) for m in messages] await asyncio.gather(*tasks) # With semaphore=1, each message should complete before next starts diff --git a/tests/unit/test_query/test_ontology_monitoring.py b/tests/unit/test_query/test_ontology_monitoring.py new file mode 100644 index 00000000..4b1b4253 --- /dev/null +++ b/tests/unit/test_query/test_ontology_monitoring.py @@ -0,0 +1,56 @@ +""" +Tests for ontology monitoring metrics. +""" + +from trustgraph.query.ontology.monitoring import ( + PerformanceMonitor, + _extract_metric_label, +) + + +def test_extract_metric_label_reads_unquoted_label_value(): + metric_name = "cache_requests_total{cache_type=entity,component=ontology}" + + assert _extract_metric_label(metric_name, "cache_type") == "entity" + + +def test_extract_metric_label_reads_quoted_label_value(): + metric_name = 'cache_requests_total{cache_type="entity",component="ontology"}' + + assert _extract_metric_label(metric_name, "cache_type") == "entity" + + +def test_extract_metric_label_returns_none_when_label_missing(): + metric_name = "cache_requests_total{component=ontology}" + + assert _extract_metric_label(metric_name, "cache_type") is None + + +def test_performance_report_ignores_counters_without_cache_type_label(): + monitor = PerformanceMonitor({"enabled": False}) + monitor.metrics_collector.increment( + "cache_requests_total", + labels={"component": "ontology"}, + ) + monitor.metrics_collector.increment( + "cache_type=not_a_label", + labels={"component": "ontology"}, + ) + monitor.metrics_collector.increment( + "cache_requests_total", + labels={"cache_type": "entity"}, + ) + monitor.metrics_collector.increment( + "cache_hits_total", + labels={"cache_type": "entity"}, + ) + + report = monitor.get_performance_report() + + assert report["cache_performance"] == { + "entity": { + "hit_rate": 1.0, + "total_requests": 1.0, + "total_hits": 1.0, + } + } diff --git a/tests/unit/test_query/test_sparql_algebra.py b/tests/unit/test_query/test_sparql_algebra.py index 9827b2de..d2a49e99 100644 --- a/tests/unit/test_query/test_sparql_algebra.py +++ b/tests/unit/test_query/test_sparql_algebra.py @@ -14,7 +14,7 @@ from rdflib.plugins.sparql.parserutils import CompValue from trustgraph.schema import Term, IRI, LITERAL from trustgraph.query.sparql.algebra import ( - evaluate, _query_pattern, _eval_bgp, + evaluate, materialise, _query_pattern, _eval_bgp, ) @@ -28,6 +28,32 @@ def lit(v): return Term(type=LITERAL, value=v) +def make_tc(query_return=None, query_side_effect=None): + """Create a mock TriplesClient with both query() and query_gen() support.""" + tc = AsyncMock() + + if query_side_effect is not None: + tc.query.side_effect = query_side_effect + + async def gen_side_effect(**kwargs): + results = await query_side_effect(**kwargs) + for r in results: + yield r + + tc.query_gen = gen_side_effect + else: + items = query_return or [] + tc.query.return_value = items + + async def gen(**kwargs): + for item in items: + yield item + + tc.query_gen = gen + + return tc + + def make_triple(s, p, o): t = MagicMock() t.s = s @@ -84,6 +110,20 @@ def make_distinct(inner): return node +def make_filter(inner, expr): + node = CompValue("Filter") + node.p = inner + node.expr = expr + return node + + +def make_minus(left, right): + node = CompValue("Minus") + node.p1 = left + node.p2 = right + return node + + class TestQueryPattern: """Tests for _query_pattern — the leaf that calls TriplesClient.""" @@ -136,15 +176,14 @@ class TestEvalBgp: @pytest.mark.asyncio async def test_single_pattern_all_variables(self): - tc = AsyncMock() triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) - tc.query.return_value = [triple] + tc = make_tc(query_return=[triple]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default", limit=100) + solutions = await materialise(bgp, tc, collection="default", limit=100) assert len(solutions) == 1 assert solutions[0]["s"].iri == "http://s" @@ -153,43 +192,37 @@ class TestEvalBgp: @pytest.mark.asyncio async def test_single_pattern_bound_subject(self): - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("val")), - ] + ]) bgp = make_bgp( (URIRef("http://s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") - tc.query.assert_called_once() - kwargs = tc.query.call_args.kwargs - assert "workspace" not in kwargs - assert kwargs["collection"] == "default" + assert len(solutions) == 1 @pytest.mark.asyncio async def test_empty_bgp_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() bgp = make_bgp() - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") assert solutions == [{}] - tc.query.assert_not_called() @pytest.mark.asyncio async def test_no_results_returns_empty(self): - tc = AsyncMock() - tc.query.return_value = [] + tc = make_tc(query_return=[]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) - solutions = await evaluate(bgp, tc, collection="default") + solutions = await materialise(bgp, tc, collection="default") assert solutions == [] @@ -199,17 +232,16 @@ class TestEvaluate: @pytest.mark.asyncio async def test_select_query_node(self): - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("o")), - ] + ]) bgp = make_bgp( (Variable("s"), Variable("p"), Variable("o")), ) select = make_select(make_project(bgp, ["s", "p"])) - solutions = await evaluate(select, tc, collection="default") + solutions = await materialise(select, tc, collection="default") assert len(solutions) == 1 assert "s" in solutions[0] @@ -220,10 +252,9 @@ class TestEvaluate: async def test_workspace_never_in_query_calls(self): """Verify that no matter the algebra structure, workspace is never passed to TriplesClient.query().""" - tc = AsyncMock() - tc.query.return_value = [ + tc = make_tc(query_return=[ make_triple(iri("http://s"), iri("http://p"), lit("o")), - ] + ]) bgp1 = make_bgp((Variable("s"), Variable("p"), Variable("o"))) bgp2 = make_bgp((Variable("a"), Variable("b"), Variable("c"))) @@ -231,72 +262,319 @@ class TestEvaluate: make_union(bgp1, bgp2), ["s", "p", "o"] )) - await evaluate(tree, tc, collection="test-coll") - - for c in tc.query.call_args_list: - assert "workspace" not in c.kwargs + await materialise(tree, tc, collection="test-coll") @pytest.mark.asyncio async def test_join(self): - tc = AsyncMock() - tc.query.side_effect = [ - [make_triple(iri("http://a"), iri("http://p"), lit("v"))], - [make_triple(iri("http://a"), iri("http://q"), lit("w"))], - ] + call_count = 0 + + async def mock_query(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return [make_triple(iri("http://a"), iri("http://p"), lit("v"))] + else: + return [make_triple(iri("http://a"), iri("http://q"), lit("w"))] + + tc = make_tc(query_side_effect=mock_query) bgp1 = make_bgp((Variable("s"), URIRef("http://p"), Variable("v1"))) bgp2 = make_bgp((Variable("s"), URIRef("http://q"), Variable("v2"))) tree = make_join(bgp1, bgp2) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 1 assert solutions[0]["s"].iri == "http://a" @pytest.mark.asyncio async def test_slice(self): - tc = AsyncMock() triples = [ make_triple(iri(f"http://s{i}"), iri("http://p"), lit(f"o{i}")) for i in range(5) ] - tc.query.return_value = triples + tc = make_tc(query_return=triples) bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) tree = make_slice(bgp, start=1, length=2) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 2 @pytest.mark.asyncio async def test_distinct(self): - tc = AsyncMock() triple = make_triple(iri("http://s"), iri("http://p"), lit("o")) - tc.query.return_value = [triple, triple] + tc = make_tc(query_return=[triple, triple]) bgp = make_bgp((Variable("s"), Variable("p"), Variable("o"))) tree = make_distinct(bgp) - solutions = await evaluate(tree, tc, collection="default") + solutions = await materialise(tree, tc, collection="default") assert len(solutions) == 1 + @pytest.mark.asyncio + async def test_minus_removes_matching(self): + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + hates = iri("http://example.com/hates") + charlie = iri("http://example.com/charlie") + + left_triple = make_triple(alice, knows, bob) + right_triple2 = make_triple(alice, hates, charlie) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/knows": + return [left_triple] + elif pred and pred.iri == "http://example.com/hates": + return [right_triple2] + return [] + + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + right_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/hates"), Variable("r")) + ) + + tree = make_select( + make_project( + make_minus(left_bgp, right_bgp), + ["s", "o"] + ) + ) + + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 0 + + @pytest.mark.asyncio + async def test_minus_no_shared_vars_preserves_all(self): + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + + left_triple = make_triple(alice, iri("http://example.com/p"), bob) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/p": + return [left_triple] + return [] + + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/p"), Variable("o")) + ) + right_bgp = make_bgp( + (Variable("x"), URIRef("http://example.com/q"), Variable("y")) + ) + + tree = make_select( + make_project( + make_minus(left_bgp, right_bgp), + ["s", "o"] + ) + ) + + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 1 + + @pytest.mark.asyncio + async def test_filter_exists_keeps_matching(self): + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + charlie = iri("http://example.com/charlie") + + left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob) + left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) + exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/knows": + return [left_triple1, left_triple2] + elif pred and pred.iri == "http://example.com/likes": + return [exists_triple] + return [] + + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) + + exists_expr = CompValue("Builtin_EXISTS") + exists_expr.graph = exists_bgp + + tree = make_select( + make_project( + make_filter(left_bgp, exists_expr), + ["s", "o"] + ) + ) + + solutions = await materialise(tree, tc, collection="default") + + result_objects = [s["o"].iri for s in solutions] + assert "http://example.com/bob" in result_objects + assert "http://example.com/charlie" not in result_objects + + @pytest.mark.asyncio + async def test_filter_not_exists_removes_matching(self): + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + charlie = iri("http://example.com/charlie") + + left_triple1 = make_triple(alice, iri("http://example.com/knows"), bob) + left_triple2 = make_triple(alice, iri("http://example.com/knows"), charlie) + exists_triple = make_triple(bob, iri("http://example.com/likes"), alice) + + async def mock_query(**kwargs): + pred = kwargs.get("p") + if pred and pred.iri == "http://example.com/knows": + return [left_triple1, left_triple2] + elif pred and pred.iri == "http://example.com/likes": + return [exists_triple] + return [] + + tc = make_tc(query_side_effect=mock_query) + + left_bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")) + ) + exists_bgp = make_bgp( + (Variable("o"), URIRef("http://example.com/likes"), Variable("_any")) + ) + + not_exists_expr = CompValue("Builtin_NOTEXISTS") + not_exists_expr.graph = exists_bgp + + tree = make_select( + make_project( + make_filter(left_bgp, not_exists_expr), + ["s", "o"] + ) + ) + + solutions = await materialise(tree, tc, collection="default") + + result_objects = [s["o"].iri for s in solutions] + assert "http://example.com/charlie" in result_objects + assert "http://example.com/bob" not in result_objects + + @pytest.mark.asyncio + async def test_join_values_uses_bind_join(self): + """When VALUES is joined with a BGP, the bind join should pass + the VALUES bindings into the BGP evaluation so the triple store + query is selective (not a wildcard).""" + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + + queries_issued = [] + + async def mock_query(**kwargs): + queries_issued.append(kwargs) + s, p = kwargs.get("s"), kwargs.get("p") + if s and s.iri == "http://example.com/alice" and p and p.iri == "http://example.com/knows": + return [make_triple(alice, knows, bob)] + return [] + + tc = make_tc(query_side_effect=mock_query) + + # VALUES ?s { } + values_node = CompValue("values") + values_node.var = [Variable("s")] + values_node.value = [[URIRef("http://example.com/alice")]] + values_node.res = None + + to_multiset = CompValue("ToMultiSet") + to_multiset.p = values_node + + bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")), + ) + + tree = make_join(to_multiset, bgp) + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 1 + assert solutions[0]["s"].iri == "http://example.com/alice" + assert solutions[0]["o"].iri == "http://example.com/bob" + + # The key assertion: the BGP query should have received + # s=alice (bound from VALUES), NOT s=None (wildcard) + assert len(queries_issued) == 1 + assert queries_issued[0]["s"] is not None + assert queries_issued[0]["s"].iri == "http://example.com/alice" + + @pytest.mark.asyncio + async def test_join_values_multiple_bindings(self): + """Bind join with multiple VALUES bindings.""" + alice = iri("http://example.com/alice") + bob = iri("http://example.com/bob") + knows = iri("http://example.com/knows") + charlie = iri("http://example.com/charlie") + + async def mock_query(**kwargs): + s = kwargs.get("s") + if s and s.iri == "http://example.com/alice": + return [make_triple(alice, knows, bob)] + elif s and s.iri == "http://example.com/bob": + return [make_triple(bob, knows, charlie)] + return [] + + tc = make_tc(query_side_effect=mock_query) + + values_node = CompValue("values") + values_node.var = [Variable("s")] + values_node.value = [ + [URIRef("http://example.com/alice")], + [URIRef("http://example.com/bob")], + ] + values_node.res = None + + to_multiset = CompValue("ToMultiSet") + to_multiset.p = values_node + + bgp = make_bgp( + (Variable("s"), URIRef("http://example.com/knows"), Variable("o")), + ) + + tree = make_join(to_multiset, bgp) + solutions = await materialise(tree, tc, collection="default") + + assert len(solutions) == 2 + subjects = {s["s"].iri for s in solutions} + assert subjects == { + "http://example.com/alice", + "http://example.com/bob", + } + @pytest.mark.asyncio async def test_unsupported_node_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() node = CompValue("SomethingUnknown") - solutions = await evaluate(node, tc, collection="default") + solutions = await materialise(node, tc, collection="default") assert solutions == [{}] - tc.query.assert_not_called() @pytest.mark.asyncio async def test_non_compvalue_returns_empty_solution(self): - tc = AsyncMock() + tc = make_tc() - solutions = await evaluate("not a node", tc, collection="default") + solutions = await materialise("not a node", tc, collection="default") assert solutions == [{}] diff --git a/tests/unit/test_query/test_sparql_expressions.py b/tests/unit/test_query/test_sparql_expressions.py index 63e9188f..87c862e8 100644 --- a/tests/unit/test_query/test_sparql_expressions.py +++ b/tests/unit/test_query/test_sparql_expressions.py @@ -300,6 +300,438 @@ class TestBuiltinFunctions: flags=None) assert evaluate_expression(expr, {"x": lit("hello")}) is False + def test_substr_three_args(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Literal(1), + length=Literal(4)) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.type == LITERAL + assert result.value == "2024" + + def test_substr_two_args(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Literal(6), + length=None) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.type == LITERAL + assert result.value == "03-15" + + def test_substr_middle(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Literal(6), + length=Literal(2)) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.type == LITERAL + assert result.value == "03" + + def test_substr_null_start(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("SUBSTR", + arg=Variable("x"), + start=Variable("missing"), + length=None) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result is None + + def test_year(self): + from rdflib.term import Variable + expr = self._make_builtin("YEAR", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 2024 + + def test_month(self): + from rdflib.term import Variable + expr = self._make_builtin("MONTH", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 3 + + def test_day(self): + from rdflib.term import Variable + expr = self._make_builtin("DAY", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 15 + + def test_hours(self): + from rdflib.term import Variable + expr = self._make_builtin("HOURS", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 10 + + def test_minutes(self): + from rdflib.term import Variable + expr = self._make_builtin("MINUTES", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 30 + + def test_seconds(self): + from rdflib.term import Variable + expr = self._make_builtin("SECONDS", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 45 + + def test_year_from_datetime(self): + from rdflib.term import Variable + expr = self._make_builtin("YEAR", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", datatype=XSD + "dateTime")} + ) + assert result == 2024 + + def test_hours_from_date_returns_zero(self): + from rdflib.term import Variable + expr = self._make_builtin("HOURS", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15", datatype=XSD + "date")} + ) + assert result == 0 + + def test_year_invalid_date(self): + from rdflib.term import Variable + expr = self._make_builtin("YEAR", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("not-a-date")} + ) + assert result is None + + def test_floor(self): + from rdflib.term import Variable + expr = self._make_builtin("FLOOR", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.7")}) == 3 + + def test_floor_negative(self): + from rdflib.term import Variable + expr = self._make_builtin("FLOOR", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("-2.3")}) == -3 + + def test_floor_none(self): + from rdflib.term import Variable + expr = self._make_builtin("FLOOR", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("abc")}) is None + + def test_ceil(self): + from rdflib.term import Variable + expr = self._make_builtin("CEIL", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.2")}) == 4 + + def test_ceil_negative(self): + from rdflib.term import Variable + expr = self._make_builtin("CEIL", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("-2.7")}) == -2 + + def test_abs_positive(self): + from rdflib.term import Variable + expr = self._make_builtin("ABS", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("42")}) == 42 + + def test_abs_negative(self): + from rdflib.term import Variable + expr = self._make_builtin("ABS", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("-42")}) == 42 + + def test_abs_none(self): + from rdflib.term import Variable + expr = self._make_builtin("ABS", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("abc")}) is None + + def test_replace_simple(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REPLACE", + arg=Variable("x"), + pattern=Literal(" BC"), + replacement=Literal(""), + flags=None) + result = evaluate_expression(expr, {"x": lit("500 BC")}) + assert result.type == LITERAL + assert result.value == "500" + + def test_replace_regex(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REPLACE", + arg=Variable("x"), + pattern=Literal("[0-9]+"), + replacement=Literal("X"), + flags=None) + result = evaluate_expression(expr, {"x": lit("abc123def456")}) + assert result.value == "abcXdefX" + + def test_replace_case_insensitive(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("REPLACE", + arg=Variable("x"), + pattern=Literal("hello"), + replacement=Literal("world"), + flags=Literal("i")) + result = evaluate_expression(expr, {"x": lit("HELLO there")}) + assert result.value == "world there" + + def test_round_up(self): + from rdflib.term import Variable + expr = self._make_builtin("ROUND", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.7")}) == 4 + + def test_round_down(self): + from rdflib.term import Variable + expr = self._make_builtin("ROUND", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("3.2")}) == 3 + + def test_round_none(self): + from rdflib.term import Variable + expr = self._make_builtin("ROUND", arg=Variable("x")) + assert evaluate_expression(expr, {"x": lit("abc")}) is None + + def test_strbefore(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRBEFORE", + arg1=Variable("x"), arg2=Literal("-")) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.value == "2024" + + def test_strbefore_not_found(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRBEFORE", + arg1=Variable("x"), arg2=Literal("/")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.value == "" + + def test_strafter(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRAFTER", + arg1=Variable("x"), arg2=Literal("-")) + result = evaluate_expression(expr, {"x": lit("2024-03-15")}) + assert result.value == "03-15" + + def test_strafter_not_found(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("STRAFTER", + arg1=Variable("x"), arg2=Literal("/")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.value == "" + + def test_encode_for_uri(self): + from rdflib.term import Variable + expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello world")}) + assert result.value == "hello%20world" + + def test_encode_for_uri_special_chars(self): + from rdflib.term import Variable + expr = self._make_builtin("ENCODE_FOR_URI", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("a/b?c=d&e")}) + assert result.value == "a%2Fb%3Fc%3Dd%26e" + + def test_langmatches_basic(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("en"), arg2=Literal("en")) + assert evaluate_expression(expr, {}) is True + + def test_langmatches_subtag(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("en-US"), arg2=Literal("en")) + assert evaluate_expression(expr, {}) is True + + def test_langmatches_wildcard(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("fr"), arg2=Literal("*")) + assert evaluate_expression(expr, {}) is True + + def test_langmatches_wildcard_empty(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal(""), arg2=Literal("*")) + assert evaluate_expression(expr, {}) is False + + def test_langmatches_no_match(self): + from rdflib.term import Variable + from rdflib import Literal + expr = self._make_builtin("LANGMATCHES", + arg1=Literal("fr"), arg2=Literal("en")) + assert evaluate_expression(expr, {}) is False + + def test_iri_constructor(self): + from rdflib.term import Variable + expr = self._make_builtin("IRI", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("http://example.com/test")} + ) + assert result.type == IRI + assert result.iri == "http://example.com/test" + + def test_uri_constructor(self): + from rdflib.term import Variable + expr = self._make_builtin("URI", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("http://example.com/test")} + ) + assert result.type == IRI + assert result.iri == "http://example.com/test" + + def test_bnode_no_arg(self): + expr = self._make_builtin("BNODE") + result = evaluate_expression(expr, {}) + assert result.type == BLANK + assert len(result.id) > 0 + + def test_bnode_with_label(self): + from rdflib import Literal + expr = self._make_builtin("BNODE", arg=Literal("mynode")) + result = evaluate_expression(expr, {}) + assert result.type == BLANK + assert result.id == "mynode" + + def test_now(self): + import re as re_mod + expr = self._make_builtin("NOW") + result = evaluate_expression(expr, {}) + assert result.type == LITERAL + assert result.datatype == XSD + "dateTime" + assert re_mod.match(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}", result.value) + + def test_tz_with_utc(self): + from rdflib.term import Variable + expr = self._make_builtin("TZ", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45+0000", + datatype=XSD + "dateTime")} + ) + assert result.type == LITERAL + assert result.value == "+00:00" + + def test_tz_no_timezone(self): + from rdflib.term import Variable + expr = self._make_builtin("TZ", arg=Variable("x")) + result = evaluate_expression( + expr, {"x": lit("2024-03-15T10:30:45", + datatype=XSD + "dateTime")} + ) + assert result.value == "" + + def test_rand(self): + expr = self._make_builtin("RAND") + result = evaluate_expression(expr, {}) + assert isinstance(result, float) + assert 0.0 <= result < 1.0 + + def test_uuid(self): + import re as re_mod + expr = self._make_builtin("UUID") + result = evaluate_expression(expr, {}) + assert result.type == IRI + assert result.iri.startswith("urn:uuid:") + uuid_part = result.iri[len("urn:uuid:"):] + assert re_mod.match( + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", + uuid_part + ) + + def test_struuid(self): + import re as re_mod + expr = self._make_builtin("STRUUID") + result = evaluate_expression(expr, {}) + assert result.type == LITERAL + assert re_mod.match( + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}", + result.value + ) + + def test_md5(self): + from rdflib.term import Variable + expr = self._make_builtin("MD5", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert result.value == "5d41402abc4b2a76b9719d911017c592" + + def test_sha1(self): + from rdflib.term import Variable + expr = self._make_builtin("SHA1", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert result.value == "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d" + + def test_sha256(self): + from rdflib.term import Variable + expr = self._make_builtin("SHA256", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert result.value == ( + "2cf24dba5fb0a30e26e83b2ac5b9e29e" + "1b161e5c1fa7425e73043362938b9824" + ) + + def test_sha512(self): + from rdflib.term import Variable + expr = self._make_builtin("SHA512", arg=Variable("x")) + result = evaluate_expression(expr, {"x": lit("hello")}) + assert result.type == LITERAL + assert len(result.value) == 128 + + def test_exists_with_callback(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("EXISTS", graph=graph) + cb = lambda g, s: True + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is True + + def test_exists_callback_false(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("EXISTS", graph=graph) + cb = lambda g, s: False + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is False + + def test_notexists_with_callback(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("NOTEXISTS", graph=graph) + cb = lambda g, s: True + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is False + + def test_notexists_callback_false(self): + from rdflib.plugins.sparql.parserutils import CompValue + graph = CompValue("BGP") + expr = self._make_builtin("NOTEXISTS", graph=graph) + cb = lambda g, s: False + result = evaluate_expression(expr, {}, exists_cb=cb) + assert result is True + class TestEffectiveBoolean: diff --git a/tests/unit/test_query/test_sparql_solutions.py b/tests/unit/test_query/test_sparql_solutions.py index 5805ca84..7588a95b 100644 --- a/tests/unit/test_query/test_sparql_solutions.py +++ b/tests/unit/test_query/test_sparql_solutions.py @@ -5,7 +5,7 @@ Tests for SPARQL solution sequence operations. import pytest from trustgraph.schema import Term, IRI, LITERAL from trustgraph.query.sparql.solutions import ( - hash_join, left_join, union, project, distinct, + hash_join, left_join, minus, union, project, distinct, order_by, slice_solutions, _terms_equal, _compatible, ) @@ -311,6 +311,30 @@ class TestOrderBy: result = order_by(solutions, []) assert len(result) == 1 + def test_order_by_numeric_literals(self): + solutions = [ + {"year": lit("1950")}, + {"year": lit("700")}, + {"year": lit("2000")}, + {"year": lit("450")}, + {"year": lit("1200")}, + ] + key_fns = [(lambda sol: sol.get("year"), True)] + result = order_by(solutions, key_fns) + values = [s["year"].value for s in result] + assert values == ["450", "700", "1200", "1950", "2000"] + + def test_order_by_numeric_descending(self): + solutions = [ + {"year": lit("1950")}, + {"year": lit("700")}, + {"year": lit("2000")}, + ] + key_fns = [(lambda sol: sol.get("year"), False)] + result = order_by(solutions, key_fns) + values = [s["year"].value for s in result] + assert values == ["2000", "1950", "700"] + class TestSlice: @@ -343,3 +367,37 @@ class TestSlice: solutions = [{"s": alice}, {"s": bob}] result = slice_solutions(solutions) assert len(result) == 2 + + +class TestMinus: + + def test_removes_compatible(self, alice, bob): + left = [{"s": alice}, {"s": bob}] + right = [{"s": alice}] + result = minus(left, right) + assert len(result) == 1 + assert result[0]["s"].iri == "http://example.com/bob" + + def test_empty_right_preserves_all(self, alice, bob): + left = [{"s": alice}, {"s": bob}] + result = minus(left, []) + assert len(result) == 2 + + def test_no_shared_variables_preserves_all(self, alice, bob): + left = [{"s": alice}] + right = [{"t": bob}] + result = minus(left, right) + assert len(result) == 1 + + def test_all_removed(self, alice): + left = [{"s": alice}] + right = [{"s": alice}] + result = minus(left, right) + assert len(result) == 0 + + def test_partial_shared_variables(self, alice, bob): + left = [{"s": alice, "p": lit("x")}, {"s": bob, "p": lit("y")}] + right = [{"s": alice}] + result = minus(left, right) + assert len(result) == 1 + assert result[0]["s"].iri == "http://example.com/bob" diff --git a/tests/unit/test_rev_gateway/__init__.py b/tests/unit/test_rev_gateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_rev_gateway/test_dispatcher.py b/tests/unit/test_rev_gateway/test_dispatcher.py index 2a9c8df0..6df786cf 100644 --- a/tests/unit/test_rev_gateway/test_dispatcher.py +++ b/tests/unit/test_rev_gateway/test_dispatcher.py @@ -3,275 +3,279 @@ Tests for Reverse Gateway Dispatcher """ import pytest -from unittest.mock import MagicMock, AsyncMock, patch +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch, ANY -from trustgraph.rev_gateway.dispatcher import WebSocketResponder, MessageDispatcher - - -class TestWebSocketResponder: - """Test cases for WebSocketResponder class""" - - def test_websocket_responder_initialization(self): - """Test WebSocketResponder initialization""" - responder = WebSocketResponder() - - assert responder.response is None - assert responder.completed is False - - @pytest.mark.asyncio - async def test_websocket_responder_send_method(self): - """Test WebSocketResponder send method""" - responder = WebSocketResponder() - - test_response = {"data": "test response"} - - # Call send method - await responder.send(test_response) - - # Verify response was stored - assert responder.response == test_response - - @pytest.mark.asyncio - async def test_websocket_responder_call_method(self): - """Test WebSocketResponder __call__ method""" - responder = WebSocketResponder() - - test_response = {"result": "success"} - test_completed = True - - # Call the responder - await responder(test_response, test_completed) - - # Verify response and completed status were set - assert responder.response == test_response - assert responder.completed == test_completed - - @pytest.mark.asyncio - async def test_websocket_responder_call_method_with_false_completion(self): - """Test WebSocketResponder __call__ method with incomplete response""" - responder = WebSocketResponder() - - test_response = {"partial": "data"} - test_completed = False - - # Call the responder - await responder(test_response, test_completed) - - # Verify response was set and completed is True (since send() always sets completed=True) - assert responder.response == test_response - assert responder.completed is True +from trustgraph.rev_gateway.dispatcher import MessageDispatcher class TestMessageDispatcher: """Test cases for MessageDispatcher class""" def test_message_dispatcher_initialization_with_defaults(self): - """Test MessageDispatcher initialization with default parameters""" dispatcher = MessageDispatcher() - + assert dispatcher.max_workers == 10 assert dispatcher.semaphore._value == 10 assert dispatcher.active_tasks == set() assert dispatcher.backend is None + assert dispatcher.auth is None assert dispatcher.dispatcher_manager is None assert len(dispatcher.service_mapping) > 0 def test_message_dispatcher_initialization_with_custom_workers(self): - """Test MessageDispatcher initialization with custom max_workers""" dispatcher = MessageDispatcher(max_workers=5) - + assert dispatcher.max_workers == 5 assert dispatcher.semaphore._value == 5 @patch('trustgraph.rev_gateway.dispatcher.DispatcherManager') - def test_message_dispatcher_initialization_with_pulsar_client(self, mock_dispatcher_manager): - """Test MessageDispatcher initialization with pulsar_client and config_receiver""" + def test_message_dispatcher_initialization_with_backend( + self, mock_dispatcher_manager, + ): mock_backend = MagicMock() mock_config_receiver = MagicMock() + mock_auth = MagicMock() mock_dispatcher_instance = MagicMock() mock_dispatcher_manager.return_value = mock_dispatcher_instance - + dispatcher = MessageDispatcher( max_workers=8, config_receiver=mock_config_receiver, - backend=mock_backend + backend=mock_backend, + auth=mock_auth, + timeout=300, ) - + assert dispatcher.max_workers == 8 assert dispatcher.backend == mock_backend + assert dispatcher.auth == mock_auth assert dispatcher.dispatcher_manager == mock_dispatcher_instance mock_dispatcher_manager.assert_called_once_with( - mock_backend, mock_config_receiver, prefix="rev-gateway" + mock_backend, mock_config_receiver, + auth=mock_auth, prefix="rev-gateway", timeout=300, ) def test_message_dispatcher_service_mapping(self): - """Test MessageDispatcher service mapping contains expected services""" dispatcher = MessageDispatcher() - + expected_services = [ "text-completion", "graph-rag", "agent", "embeddings", "graph-embeddings", "triples", "document-load", "text-load", - "flow", "knowledge", "config", "librarian", "document-rag" + "flow", "knowledge", "config", "librarian", "document-rag", ] - + for service in expected_services: assert service in dispatcher.service_mapping - - # Test specific mappings - assert dispatcher.service_mapping["text-completion"] == "text-completion" + assert dispatcher.service_mapping["document-load"] == "document" assert dispatcher.service_mapping["text-load"] == "text-document" @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_without_dispatcher_manager(self): - """Test MessageDispatcher handle_message without dispatcher manager""" + async def test_handle_message_without_dispatcher_manager(self): dispatcher = MessageDispatcher() - - test_message = { - "id": "test-123", - "service": "test-service", - "request": {"data": "test"} - } - - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-123" - assert "error" in result["response"] - assert "DispatcherManager not available" in result["response"]["error"] + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="default") + ) + + sender = AsyncMock() + + await dispatcher.handle_message( + {"id": "test-1", "service": "test", "request": {}}, + sender, + ) + + sender.assert_called_once() + sent = sender.call_args[0][0] + assert sent["id"] == "test-1" + assert sent["error"]["message"] == "DispatcherManager not available" + assert sent["error"]["type"] == "error" + assert sent["complete"] is True @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_with_exception(self): - """Test MessageDispatcher handle_message with exception during processing""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_global_service = AsyncMock(side_effect=Exception("Test error")) - + async def test_handle_message_auth_failure(self): dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-456", - "service": "text-completion", - "request": {"prompt": "test"} - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-456" - assert "error" in result["response"] - assert "Test error" in result["response"]["error"] + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + side_effect=Exception("auth failure") + ) + dispatcher.dispatcher_manager = MagicMock() + + sender = AsyncMock() + + await dispatcher.handle_message( + {"id": "test-2", "token": "bad", "service": "test", "request": {}}, + sender, + ) + + sender.assert_called_once() + sent = sender.call_args[0][0] + assert sent["id"] == "test-2" + assert "auth failure" in sent["error"]["message"] + assert sent["complete"] is True @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_global_service(self): - """Test MessageDispatcher handle_message with global service""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_global_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = True - mock_responder.response = {"result": "success"} - + async def test_handle_message_global_service(self): + mock_dm = MagicMock() + mock_dm.invoke_global_service = AsyncMock() + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-789", - "service": "text-completion", - "request": {"prompt": "hello"} - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {"text-completion": True}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-789" - assert result["response"] == {"result": "success"} - mock_dispatcher_manager.invoke_global_service.assert_called_once() + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws1") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', + {"text-completion": True}, + ): + await dispatcher.handle_message( + { + "id": "test-3", + "token": "tg_key", + "service": "text-completion", + "request": {"prompt": "hello"}, + }, + sender, + ) + + mock_dm.invoke_global_service.assert_called_once() + args, kwargs = mock_dm.invoke_global_service.call_args + assert args[0] == {"prompt": "hello"} + assert args[2] == "text-completion" + assert kwargs["workspace"] == "ws1" @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_flow_service(self): - """Test MessageDispatcher handle_message with flow service""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_flow_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = True - mock_responder.response = {"data": "flow_result"} - + async def test_handle_message_flow_service(self): + mock_dm = MagicMock() + mock_dm.invoke_flow_service = AsyncMock() + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-flow-123", - "service": "document-rag", - "request": {"query": "test"}, - "flow": "custom-flow" - } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-flow-123" - assert result["response"] == {"data": "flow_result"} - mock_dispatcher_manager.invoke_flow_service.assert_called_once_with( - {"query": "test"}, mock_responder, "custom-flow", "document-rag" + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws2") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', {}, + ): + await dispatcher.handle_message( + { + "id": "test-4", + "token": "tg_key", + "service": "document-rag", + "request": {"query": "test"}, + "flow": "my-flow", + }, + sender, + ) + + mock_dm.invoke_flow_service.assert_called_once_with( + {"query": "test"}, ANY, "ws2", "my-flow", "document-rag", ) @pytest.mark.asyncio - async def test_message_dispatcher_handle_message_incomplete_response(self): - """Test MessageDispatcher handle_message with incomplete response""" - mock_dispatcher_manager = MagicMock() - mock_dispatcher_manager.invoke_flow_service = AsyncMock() - mock_responder = MagicMock() - mock_responder.completed = False - mock_responder.response = None - + async def test_handle_message_responder_sends_frames(self): + mock_dm = MagicMock() + + async def fake_invoke(data, responder, svc, workspace=None): + await responder({"partial": 1}, False) + await responder({"partial": 2}, True) + + mock_dm.invoke_global_service = AsyncMock(side_effect=fake_invoke) + dispatcher = MessageDispatcher() - dispatcher.dispatcher_manager = mock_dispatcher_manager - - test_message = { - "id": "test-incomplete", - "service": "agent", - "request": {"input": "test"} + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="ws1") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', + {"text-completion": True}, + ): + await dispatcher.handle_message( + { + "id": "test-5", + "token": "tg_key", + "service": "text-completion", + "request": {"prompt": "hi"}, + }, + sender, + ) + + assert sender.call_count == 2 + first = sender.call_args_list[0][0][0] + second = sender.call_args_list[1][0][0] + + assert first == { + "id": "test-5", "response": {"partial": 1}, "complete": False, + } + assert second == { + "id": "test-5", "response": {"partial": 2}, "complete": True, } - - with patch('trustgraph.gateway.dispatch.manager.global_dispatchers', {}): - with patch('trustgraph.rev_gateway.dispatcher.WebSocketResponder', return_value=mock_responder): - result = await dispatcher.handle_message(test_message) - - assert result["id"] == "test-incomplete" - assert result["response"] == {"error": "No response received"} @pytest.mark.asyncio - async def test_message_dispatcher_shutdown(self): - """Test MessageDispatcher shutdown method""" - import asyncio - + async def test_handle_message_workspace_from_identity(self): + mock_dm = MagicMock() + mock_dm.invoke_flow_service = AsyncMock() + dispatcher = MessageDispatcher() - - # Create actual async tasks + dispatcher.dispatcher_manager = mock_dm + dispatcher.auth = MagicMock() + dispatcher.auth.authenticate = AsyncMock( + return_value=MagicMock(workspace="derived-ws") + ) + + sender = AsyncMock() + + with patch( + 'trustgraph.gateway.dispatch.manager.global_dispatchers', {}, + ): + await dispatcher.handle_message( + { + "id": "test-6", + "token": "tg_key", + "service": "agent", + "request": {"question": "test"}, + "flow": "default", + }, + sender, + ) + + args = mock_dm.invoke_flow_service.call_args[0] + assert args[2] == "derived-ws" + + @pytest.mark.asyncio + async def test_shutdown(self): + dispatcher = MessageDispatcher() + async def dummy_task(): await asyncio.sleep(0.01) - return "done" - + task1 = asyncio.create_task(dummy_task()) task2 = asyncio.create_task(dummy_task()) dispatcher.active_tasks = {task1, task2} - - # Call shutdown + await dispatcher.shutdown() - - # Verify tasks were completed + assert task1.done() assert task2.done() - assert len(dispatcher.active_tasks) == 2 # Tasks remain in set but are completed @pytest.mark.asyncio - async def test_message_dispatcher_shutdown_with_no_tasks(self): - """Test MessageDispatcher shutdown with no active tasks""" + async def test_shutdown_with_no_tasks(self): dispatcher = MessageDispatcher() - - # Call shutdown with no active tasks + await dispatcher.shutdown() - - # Should complete without error - assert dispatcher.active_tasks == set() \ No newline at end of file + + assert dispatcher.active_tasks == set() diff --git a/tests/unit/test_rev_gateway/test_rev_gateway_service.py b/tests/unit/test_rev_gateway/test_rev_gateway_service.py index 23aff18e..e2d1045d 100644 --- a/tests/unit/test_rev_gateway/test_rev_gateway_service.py +++ b/tests/unit/test_rev_gateway/test_rev_gateway_service.py @@ -8,22 +8,38 @@ from unittest.mock import MagicMock, AsyncMock, patch, Mock from aiohttp import WSMsgType, ClientWebSocketResponse import json -from trustgraph.rev_gateway.service import ReverseGateway, parse_args, run +from trustgraph.rev_gateway.service import ReverseGateway, run + + +MOCK_PATCHES = [ + 'trustgraph.rev_gateway.service.IamAuth', + 'trustgraph.rev_gateway.service.ConfigReceiver', + 'trustgraph.rev_gateway.service.MessageDispatcher', + 'trustgraph.rev_gateway.service.get_pubsub', +] + + +def make_gateway(**overrides): + config = {"websocket_uri": "ws://localhost:7650/out"} + config.update(overrides) + return ReverseGateway(**config) class TestReverseGateway: """Test cases for ReverseGateway class""" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_defaults(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with default parameters""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_defaults( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + assert gateway.websocket_uri == "ws://localhost:7650/out" assert gateway.host == "localhost" assert gateway.port == 7650 @@ -33,25 +49,22 @@ class TestReverseGateway: assert gateway.max_workers == 10 assert gateway.running is False assert gateway.reconnect_delay == 3.0 - assert gateway.pulsar_host == "pulsar://pulsar:6650" - assert gateway.pulsar_api_key is None - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_custom_params(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with custom parameters""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway( + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_custom_params( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway( websocket_uri="wss://example.com:8080/websocket", max_workers=20, - pulsar_host="pulsar://custom:6650", - pulsar_api_key="test-key", - pulsar_listener="test-listener" ) - + assert gateway.websocket_uri == "wss://example.com:8080/websocket" assert gateway.host == "example.com" assert gateway.port == 8080 @@ -59,340 +72,360 @@ class TestReverseGateway: assert gateway.path == "/websocket" assert gateway.url == "wss://example.com:8080/websocket" assert gateway.max_workers == 20 - assert gateway.pulsar_host == "pulsar://custom:6650" - assert gateway.pulsar_api_key == "test-key" - assert gateway.pulsar_listener == "test-listener" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_with_missing_path(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with WebSocket URI missing path""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway(websocket_uri="ws://example.com") - + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_with_missing_path( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway(websocket_uri="ws://example.com") + assert gateway.path == "/ws" assert gateway.url == "ws://example.com/ws" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_invalid_scheme(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with invalid WebSocket scheme""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_invalid_scheme( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): with pytest.raises(ValueError, match="WebSocket URI must use ws:// or wss:// scheme"): - ReverseGateway(websocket_uri="http://example.com") + make_gateway(websocket_uri="http://example.com") - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_initialization_missing_hostname(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway initialization with missing hostname""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_initialization_missing_hostname( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): with pytest.raises(ValueError, match="WebSocket URI must include hostname"): - ReverseGateway(websocket_uri="ws://") + make_gateway(websocket_uri="ws://") - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_pulsar_client_with_auth(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway creates backend with authentication""" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_iam_auth_created( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): mock_backend = MagicMock() mock_get_pubsub.return_value = mock_backend - gateway = ReverseGateway( - pulsar_api_key="test-key", - pulsar_listener="test-listener" + gateway = make_gateway(id="test-rev-gw") + + mock_iam_auth.assert_called_once_with( + backend=mock_backend, + id="test-rev-gw", ) - # Verify get_pubsub was called with the correct parameters - mock_get_pubsub.assert_called_once_with( - pulsar_host="pulsar://pulsar:6650", - pulsar_api_key="test-key", - pulsar_listener="test-listener" + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_config_receiver_gets_auth( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_backend = MagicMock() + mock_get_pubsub.return_value = mock_backend + mock_auth_instance = MagicMock() + mock_iam_auth.return_value = mock_auth_instance + + gateway = make_gateway() + + mock_config_receiver.assert_called_once_with( + mock_backend, auth=mock_auth_instance, ) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @patch('trustgraph.rev_gateway.service.ClientSession') @pytest.mark.asyncio - async def test_reverse_gateway_connect_success(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway successful connection""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_connect_success( + self, mock_session_class, mock_get_pubsub, + mock_dispatcher, mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + mock_session = AsyncMock() mock_ws = AsyncMock() mock_session.ws_connect.return_value = mock_ws mock_session_class.return_value = mock_session - - gateway = ReverseGateway() - + + gateway = make_gateway() + result = await gateway.connect() - + assert result is True assert gateway.session == mock_session assert gateway.ws == mock_ws mock_session.ws_connect.assert_called_once_with(gateway.url) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @patch('trustgraph.rev_gateway.service.ClientSession') @pytest.mark.asyncio - async def test_reverse_gateway_connect_failure(self, mock_session_class, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway connection failure""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_connect_failure( + self, mock_session_class, mock_get_pubsub, + mock_dispatcher, mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + mock_session = AsyncMock() mock_session.ws_connect.side_effect = Exception("Connection failed") mock_session_class.return_value = mock_session - - gateway = ReverseGateway() - + + gateway = make_gateway() + result = await gateway.connect() - + assert result is False - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_disconnect(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway disconnect""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock websocket and session + async def test_reverse_gateway_disconnect( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = False mock_session = AsyncMock() mock_session.closed = False - + gateway.ws = mock_ws gateway.session = mock_session - + await gateway.disconnect() - + mock_ws.close.assert_called_once() mock_session.close.assert_called_once() assert gateway.ws is None assert gateway.session is None - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_send_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway send message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock websocket + async def test_reverse_gateway_send_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - + test_message = {"id": "test", "data": "hello"} - + await gateway.send_message(test_message) - + mock_ws.send_str.assert_called_once_with(json.dumps(test_message)) - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_send_message_closed_connection(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway send message with closed connection""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock closed websocket + async def test_reverse_gateway_send_message_closed_connection( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + mock_ws = AsyncMock() mock_ws.closed = True gateway.ws = mock_ws - + test_message = {"id": "test", "data": "hello"} - + await gateway.send_message(test_message) - - # Should not call send_str on closed connection + mock_ws.send_str.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_handle_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway handle message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - mock_dispatcher_instance = AsyncMock() - mock_dispatcher_instance.handle_message.return_value = {"response": "success"} - mock_dispatcher.return_value = mock_dispatcher_instance - - gateway = ReverseGateway() - - # Mock send_message - gateway.send_message = AsyncMock() - - test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}' - - await gateway.handle_message(test_message) - - mock_dispatcher_instance.handle_message.assert_called_once_with({ - "id": "test", - "service": "test-service", - "request": {"data": "test"} - }) - gateway.send_message.assert_called_once_with({"response": "success"}) + async def test_reverse_gateway_handle_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + mock_dispatcher_instance = AsyncMock() + mock_dispatcher.return_value = mock_dispatcher_instance + + gateway = make_gateway() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - @pytest.mark.asyncio - async def test_reverse_gateway_handle_message_invalid_json(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway handle message with invalid JSON""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() - - # Mock send_message gateway.send_message = AsyncMock() - - test_message = 'invalid json' - - # Should not raise exception + + test_message = '{"id": "test", "service": "test-service", "request": {"data": "test"}}' + await gateway.handle_message(test_message) - - # Should not call send_message due to error + + mock_dispatcher_instance.handle_message.assert_called_once_with( + { + "id": "test", + "service": "test-service", + "request": {"data": "test"}, + }, + gateway.send_message, + ) + + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + @pytest.mark.asyncio + async def test_reverse_gateway_handle_message_invalid_json( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() + + gateway.send_message = AsyncMock() + + await gateway.handle_message('invalid json') + gateway.send_message.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_text_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with text message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_text_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.TEXT mock_msg.data = '{"test": "message"}' - - # Mock receive to return message once, then raise exception to stop loop + mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] - - # listen() catches exceptions and breaks, so no exception should be raised + await gateway.listen() - + gateway.handle_message.assert_called_once_with('{"test": "message"}') - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_binary_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with binary message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_binary_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.BINARY mock_msg.data = b'{"test": "binary"}' - - # Mock receive to return message once, then raise exception to stop loop + mock_ws.receive.side_effect = [mock_msg, Exception("Test stop")] - - # listen() catches exceptions and breaks, so no exception should be raised + await gateway.listen() - + gateway.handle_message.assert_called_once_with('{"test": "binary"}') - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_listen_close_message(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway listen with close message""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + async def test_reverse_gateway_listen_close_message( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - - # Mock websocket + mock_ws = AsyncMock() mock_ws.closed = False gateway.ws = mock_ws - - # Mock handle_message + gateway.handle_message = AsyncMock() - - # Mock message + mock_msg = MagicMock() mock_msg.type = WSMsgType.CLOSE - - # Mock receive to return close message + mock_ws.receive.return_value = mock_msg - + await gateway.listen() - - # Should not call handle_message for close message + gateway.handle_message.assert_not_called() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_shutdown(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway shutdown""" + async def test_reverse_gateway_shutdown( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): mock_backend = MagicMock() mock_get_pubsub.return_value = mock_backend mock_dispatcher_instance = AsyncMock() mock_dispatcher.return_value = mock_dispatcher_instance - gateway = ReverseGateway() + gateway = make_gateway() gateway.running = True - # Mock disconnect gateway.disconnect = AsyncMock() await gateway.shutdown() @@ -402,46 +435,50 @@ class TestReverseGateway: gateway.disconnect.assert_called_once() mock_backend.close.assert_called_once() - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') - def test_reverse_gateway_stop(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway stop""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - - gateway = ReverseGateway() + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) + def test_reverse_gateway_stop( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + gateway = make_gateway() gateway.running = True - + gateway.stop() - + assert gateway.running is False class TestReverseGatewayRun: """Test cases for ReverseGateway run method""" - @patch('trustgraph.rev_gateway.service.ConfigReceiver') - @patch('trustgraph.rev_gateway.service.MessageDispatcher') - @patch('trustgraph.rev_gateway.service.get_pubsub') + @patch(*MOCK_PATCHES[0:1]) + @patch(*MOCK_PATCHES[1:2]) + @patch(*MOCK_PATCHES[2:3]) + @patch(*MOCK_PATCHES[3:4]) @pytest.mark.asyncio - async def test_reverse_gateway_run_successful_cycle(self, mock_get_pubsub, mock_dispatcher, mock_config_receiver): - """Test ReverseGateway run method with successful connect/listen cycle""" - mock_backend = MagicMock() - mock_get_pubsub.return_value = mock_backend - + async def test_reverse_gateway_run_successful_cycle( + self, mock_get_pubsub, mock_dispatcher, + mock_config_receiver, mock_iam_auth, + ): + mock_get_pubsub.return_value = MagicMock() + + mock_auth_instance = AsyncMock() + mock_iam_auth.return_value = mock_auth_instance + mock_config_receiver_instance = AsyncMock() mock_config_receiver.return_value = mock_config_receiver_instance - - gateway = ReverseGateway() - - # Mock methods - gateway.connect = AsyncMock(return_value=True) + + gateway = make_gateway() + gateway.listen = AsyncMock() gateway.disconnect = AsyncMock() gateway.shutdown = AsyncMock() - - # Stop after one iteration + call_count = 0 async def mock_connect(): nonlocal call_count @@ -451,91 +488,13 @@ class TestReverseGatewayRun: else: gateway.running = False return False - + gateway.connect = mock_connect - + await gateway.run() - + + mock_auth_instance.start.assert_called_once() mock_config_receiver_instance.start.assert_called_once() gateway.listen.assert_called_once() - # disconnect is called twice: once in the main loop, once in shutdown assert gateway.disconnect.call_count == 2 gateway.shutdown.assert_called_once() - - -class TestReverseGatewayArgs: - """Test cases for argument parsing and run function""" - - def test_parse_args_defaults(self): - """Test parse_args with default values""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = ['reverse-gateway'] - - try: - args = parse_args() - - assert args.websocket_uri is None - assert args.max_workers == 10 - assert args.pulsar_host is None - assert args.pulsar_api_key is None - assert args.pulsar_listener is None - finally: - sys.argv = original_argv - - def test_parse_args_custom_values(self): - """Test parse_args with custom values""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = [ - 'reverse-gateway', - '--websocket-uri', 'ws://custom:8080/ws', - '--max-workers', '20', - '--pulsar-host', 'pulsar://custom:6650', - '--pulsar-api-key', 'test-key', - '--pulsar-listener', 'test-listener' - ] - - try: - args = parse_args() - - assert args.websocket_uri == 'ws://custom:8080/ws' - assert args.max_workers == 20 - assert args.pulsar_host == 'pulsar://custom:6650' - assert args.pulsar_api_key == 'test-key' - assert args.pulsar_listener == 'test-listener' - finally: - sys.argv = original_argv - - @patch('trustgraph.rev_gateway.service.ReverseGateway') - @patch('asyncio.run') - def test_run_function(self, mock_asyncio_run, mock_gateway_class): - """Test run function""" - import sys - - # Mock sys.argv - original_argv = sys.argv - sys.argv = ['reverse-gateway', '--max-workers', '15'] - - try: - mock_gateway_instance = MagicMock() - mock_gateway_instance.url = "ws://localhost:7650/out" - mock_gateway_instance.pulsar_host = "pulsar://pulsar:6650" - mock_gateway_class.return_value = mock_gateway_instance - - run() - - mock_gateway_class.assert_called_once_with( - websocket_uri=None, - max_workers=15, - pulsar_host=None, - pulsar_api_key=None, - pulsar_listener=None - ) - mock_asyncio_run.assert_called_once_with(mock_gateway_instance.run()) - finally: - sys.argv = original_argv \ No newline at end of file diff --git a/trustgraph-base/trustgraph/base/triples_client.py b/trustgraph-base/trustgraph/base/triples_client.py index 2601a1e1..0506cb9f 100644 --- a/trustgraph-base/trustgraph/base/triples_client.py +++ b/trustgraph-base/trustgraph/base/triples_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import Any from . request_response_spec import RequestResponse, RequestResponseSpec @@ -44,6 +45,60 @@ def from_value(x: Any) -> Any: return Term(type=LITERAL, value=str(x)) class TriplesClient(RequestResponse): + + async def query_gen(self, s=None, p=None, o=None, limit=20, + collection="default", + batch_size=20, timeout=30, g=None): + """Async generator yielding Triple objects as batches arrive.""" + queue = asyncio.Queue() + done = False + + async def recipient(resp): + if resp.error: + raise RuntimeError(resp.error.message) + + batch = [ + Triple(to_value(v.s), to_value(v.p), to_value(v.o)) + for v in resp.triples + ] + await queue.put(batch) + + if resp.is_final: + await queue.put(None) + + return resp.is_final + + # Launch the streaming request as a background task + task = asyncio.ensure_future(self.request( + TriplesQueryRequest( + s=from_value(s), + p=from_value(p), + o=from_value(o), + limit=limit, + collection=collection, + streaming=True, + batch_size=batch_size, + g=g, + ), + timeout=timeout, + recipient=recipient, + )) + + try: + while True: + batch = await queue.get() + if batch is None: + break + for triple in batch: + yield triple + finally: + if not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + async def query(self, s=None, p=None, o=None, limit=20, collection="default", timeout=30, g=None): diff --git a/trustgraph-flow/trustgraph/query/ontology/__init__.py b/trustgraph-flow/trustgraph/query/ontology/__init__.py index 60557ea9..c5cddd9c 100644 --- a/trustgraph-flow/trustgraph/query/ontology/__init__.py +++ b/trustgraph-flow/trustgraph/query/ontology/__init__.py @@ -7,7 +7,7 @@ Provides semantic query understanding, ontology matching, and answer generation. from .query_service import OntoRAGQueryService, QueryRequest, QueryResponse from .question_analyzer import QuestionAnalyzer, QuestionComponents, QuestionType -from .ontology_matcher import OntologyMatcher, QueryOntologySubset +from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset from .backend_router import BackendRouter, BackendType, QueryRoute from .sparql_generator import SPARQLGenerator, SPARQLQuery from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult @@ -27,7 +27,7 @@ __all__ = [ 'QuestionType', # Ontology matching - 'OntologyMatcher', + 'OntologyMatcherForQueries', 'QueryOntologySubset', # Backend routing diff --git a/trustgraph-flow/trustgraph/query/ontology/monitoring.py b/trustgraph-flow/trustgraph/query/ontology/monitoring.py index 703c6e95..cb7e8a2e 100644 --- a/trustgraph-flow/trustgraph/query/ontology/monitoring.py +++ b/trustgraph-flow/trustgraph/query/ontology/monitoring.py @@ -4,6 +4,7 @@ Provides comprehensive monitoring of system performance, query patterns, and res """ import logging +import re import time import asyncio import inspect @@ -276,6 +277,26 @@ class MetricsCollector: return f"{name}{{{label_str}}}" +def _extract_metric_label(metric_name: str, label: str) -> Optional[str]: + """Extract a label value from an internal metric key.""" + labels_start = metric_name.find('{') + labels_end = metric_name.find('}', labels_start + 1) + + if labels_start == -1 or labels_end == -1: + return None + + labels = metric_name[labels_start + 1:labels_end] + label_match = re.search( + rf'(?:^|,){re.escape(label)}=(?:"([^"]*)"|([^,]*))', + labels, + ) + if not label_match: + return None + + quoted_value, unquoted_value = label_match.groups() + return quoted_value if quoted_value is not None else unquoted_value + + class PerformanceMonitor: """Monitors system performance and component health.""" @@ -474,8 +495,8 @@ class PerformanceMonitor: # Cache performance cache_types = set() for metric_name in self.metrics_collector.counters.keys(): - if 'cache_type=' in metric_name: - cache_type = metric_name.split('cache_type=')[1].split(',')[0].split('}')[0] + cache_type = _extract_metric_label(metric_name, 'cache_type') + if cache_type is not None: cache_types.add(cache_type) for cache_type in cache_types: diff --git a/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py b/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py index 895856f3..2dd6633a 100644 --- a/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py +++ b/trustgraph-flow/trustgraph/query/ontology/ontology_matcher.py @@ -7,10 +7,10 @@ import logging from typing import List, Dict, Any, Set, Optional from dataclasses import dataclass -from ...extract.kg.ontology.ontology_loader import Ontology, OntologyLoader -from ...extract.kg.ontology.ontology_embedder import OntologyEmbedder -from ...extract.kg.ontology.text_processor import TextSegment -from ...extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset +from trustgraph.extract.kg.ontology.ontology_loader import Ontology, OntologyLoader +from trustgraph.extract.kg.ontology.ontology_embedder import OntologyEmbedder +from trustgraph.extract.kg.ontology.text_processor import TextSegment +from trustgraph.extract.kg.ontology.ontology_selector import OntologySelector, OntologySubset from .question_analyzer import QuestionComponents, QuestionType logger = logging.getLogger(__name__) diff --git a/trustgraph-flow/trustgraph/query/ontology/query_service.py b/trustgraph-flow/trustgraph/query/ontology/query_service.py index c6057cc1..77e60b50 100644 --- a/trustgraph-flow/trustgraph/query/ontology/query_service.py +++ b/trustgraph-flow/trustgraph/query/ontology/query_service.py @@ -8,13 +8,13 @@ from typing import Dict, Any, List, Optional, Union from dataclasses import dataclass from datetime import datetime -from ....flow.flow_processor import FlowProcessor -from ....tables.config import ConfigTableStore -from ...extract.kg.ontology.ontology_loader import OntologyLoader -from ...extract.kg.ontology.vector_store import InMemoryVectorStore +from trustgraph.base.flow_processor import FlowProcessor +from trustgraph.tables.config import ConfigTableStore +from trustgraph.extract.kg.ontology.ontology_loader import OntologyLoader +from trustgraph.extract.kg.ontology.vector_store import InMemoryVectorStore from .question_analyzer import QuestionAnalyzer, QuestionComponents -from .ontology_matcher import OntologyMatcher, QueryOntologySubset +from .ontology_matcher import OntologyMatcherForQueries, QueryOntologySubset from .backend_router import BackendRouter, QueryRoute, BackendType from .sparql_generator import SPARQLGenerator, SPARQLQuery from .sparql_cassandra import SPARQLCassandraEngine, SPARQLResult @@ -105,7 +105,7 @@ class OntoRAGQueryService(FlowProcessor): # Initialize ontology matcher matcher_config = self.config.get('ontology_matcher', {}) - self.ontology_matcher = OntologyMatcher( + self.ontology_matcher = OntologyMatcherForQueries( vector_store=self.vector_store, embedding_service=self.embedding_service, config=matcher_config diff --git a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py index 688e7371..b7f0f423 100644 --- a/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py +++ b/trustgraph-flow/trustgraph/query/ontology/sparql_cassandra.py @@ -28,7 +28,7 @@ try: except ImportError: CASSANDRA_AVAILABLE = False -from ....tables.config import ConfigTableStore +from trustgraph.tables.config import ConfigTableStore logger = logging.getLogger(__name__) diff --git a/trustgraph-flow/trustgraph/query/sparql/algebra.py b/trustgraph-flow/trustgraph/query/sparql/algebra.py index 76b1ad8e..c7542577 100644 --- a/trustgraph-flow/trustgraph/query/sparql/algebra.py +++ b/trustgraph-flow/trustgraph/query/sparql/algebra.py @@ -4,6 +4,10 @@ SPARQL algebra evaluator. Recursively evaluates an rdflib SPARQL algebra tree by issuing triple pattern queries via TriplesClient (streaming) and performing in-memory joins, filters, and projections. + +Handlers are async generators that yield solutions incrementally. +Blocking operators (joins, sort, group, distinct) materialise their +upstream into a list at the boundary, then yield results. """ import logging @@ -17,7 +21,7 @@ from ... knowledge import Uri from ... knowledge import Literal as KgLiteral from . parser import rdflib_term_to_term from . solutions import ( - hash_join, left_join, union, project, distinct, + hash_join, left_join, minus, union, project, distinct, order_by, slice_solutions, _term_key, ) from . expressions import evaluate_expression, _effective_boolean @@ -34,56 +38,56 @@ async def evaluate(node, triples_client, collection, limit=10000): """ Evaluate a SPARQL algebra node. - Args: - node: rdflib CompValue algebra node - triples_client: TriplesClient instance for triple pattern queries - collection: collection identifier - limit: safety limit on results - - Returns: - list of solutions (dicts mapping variable names to Term values) + Yields solutions (dicts mapping variable names to Term values) + incrementally as an async generator. """ if not isinstance(node, CompValue): logger.warning(f"Expected CompValue, got {type(node)}: {node}") - return [{}] + yield {} + return name = node.name handler = _HANDLERS.get(name) if handler is None: logger.warning(f"Unsupported algebra node: {name}") - return [{}] + yield {} + return - return await handler(node, triples_client, collection, limit) + async for sol in handler(node, triples_client, collection, limit): + yield sol -# --- Node handlers --- +async def materialise(node, triples_client, collection, limit=10000): + """Collect all solutions from evaluate() into a list.""" + return [sol async for sol in evaluate(node, triples_client, collection, limit)] + + +# --- Node handlers (async generators) --- async def _eval_select_query(node, tc, collection, limit): - """Evaluate a SelectQuery node.""" - return await evaluate(node.p, tc, collection, limit) + async for sol in evaluate(node.p, tc, collection, limit): + yield sol async def _eval_project(node, tc, collection, limit): - """Evaluate a Project node (SELECT variable projection).""" - solutions = await evaluate(node.p, tc, collection, limit) variables = [str(v) for v in node.PV] - return project(solutions, variables) + async for sol in evaluate(node.p, tc, collection, limit): + yield {v: sol[v] for v in variables if v in sol} async def _eval_bgp(node, tc, collection, limit): """ Evaluate a Basic Graph Pattern. - Issues streaming triple pattern queries and joins results. Patterns - are ordered by selectivity (more bound terms first) and evaluated - sequentially with bound-variable substitution. + Patterns are ordered by selectivity and evaluated sequentially. + For the final pattern, results stream directly from the triple store. """ triples = node.triples if not triples: - return [{}] + yield {} + return - # Sort patterns by selectivity: more bound terms = more selective def selectivity(pattern): return sum(1 for t in pattern if not isinstance(t, Variable)) @@ -91,55 +95,222 @@ async def _eval_bgp(node, tc, collection, limit): enumerate(triples), key=lambda x: -selectivity(x[1]) ) + # For all patterns except the last, we must materialise intermediate + # solutions because each pattern depends on bindings from prior ones. + # The last pattern streams directly. solutions = [{}] - for _, pattern in sorted_patterns: + for pattern_idx, (_, pattern) in enumerate(sorted_patterns): s_tmpl, p_tmpl, o_tmpl = pattern + is_last = (pattern_idx == len(sorted_patterns) - 1) - new_solutions = [] + if is_last: + # Stream the final pattern — yield as triples arrive + count = 0 + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) - for sol in solutions: - # Substitute known bindings into the pattern - s_val = _resolve_term(s_tmpl, sol) - p_val = _resolve_term(p_tmpl, sol) - o_val = _resolve_term(o_tmpl, sol) + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + yield binding + count += 1 + if count >= limit: + return + else: + # Materialise intermediate patterns + new_solutions = [] + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) - # Query the triples store - results = await _query_pattern( - tc, s_val, p_val, o_val, collection, limit - ) + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) - # Map results back to variable bindings, - # converting Uri/Literal to Term objects - for triple in results: - binding = dict(sol) - if isinstance(s_tmpl, Variable): - binding[str(s_tmpl)] = _to_term(triple.s) - if isinstance(p_tmpl, Variable): - binding[str(p_tmpl)] = _to_term(triple.p) - if isinstance(o_tmpl, Variable): - binding[str(o_tmpl)] = _to_term(triple.o) - new_solutions.append(binding) + solutions = new_solutions + if not solutions: + return - solutions = new_solutions - if not solutions: - break +# --- Blocking operators: materialise upstream, then yield --- - return solutions[:limit] +def _is_small_node(node): + """Check if a node is likely to produce a small number of solutions.""" + if not isinstance(node, CompValue): + return False + if node.name in ("values", "ToMultiSet"): + return True + if node.name == "Extend" and hasattr(node, "p"): + return _is_small_node(node.p) + return False async def _eval_join(node, tc, collection, limit): - """Evaluate a Join node.""" - left = await evaluate(node.p1, tc, collection, limit) - right = await evaluate(node.p2, tc, collection, limit) - return hash_join(left, right)[:limit] + # Bind join: if one side is small (e.g. VALUES), materialise it and + # substitute its bindings into the other side's evaluation. This + # turns wildcard BGP queries into selective ones. + if _is_small_node(node.p1): + yield_from = _bind_join(node.p1, node.p2, tc, collection, limit) + elif _is_small_node(node.p2): + yield_from = _bind_join(node.p2, node.p1, tc, collection, limit) + else: + yield_from = _hash_join(node, tc, collection, limit) + + async for sol in yield_from: + yield sol + + +async def _hash_join(node, tc, collection, limit): + left = await materialise(node.p1, tc, collection, limit) + right = await materialise(node.p2, tc, collection, limit) + for sol in hash_join(left, right)[:limit]: + yield sol + + +async def _bind_join(small_node, big_node, tc, collection, limit): + """Iterate over the small side and inject bindings into the big side.""" + small_sols = await materialise(small_node, tc, collection, limit) + + count = 0 + for binding in small_sols: + async for sol in _evaluate_with_bindings( + big_node, binding, tc, collection, limit + ): + yield sol + count += 1 + if count >= limit: + return + + +def _merge_compatible(left, right): + """Merge two solutions if compatible (shared vars have equal values).""" + merged = dict(left) + for k, v in right.items(): + if k in merged: + if _term_key(merged[k]) != _term_key(v): + return None + else: + merged[k] = v + return merged + + +async def _evaluate_with_bindings(node, bindings, tc, collection, limit): + """Evaluate a node with pre-seeded variable bindings. + + For BGP nodes, the bindings are injected so _resolve_term sees them, + turning wildcard queries into selective ones. For other node types, + evaluate normally and merge/filter against the bindings. + """ + if isinstance(node, CompValue) and node.name == "BGP": + async for sol in _eval_bgp_with_bindings( + node, bindings, tc, collection, limit + ): + yield sol + else: + async for sol in evaluate(node, tc, collection, limit): + merged = _merge_compatible(bindings, sol) + if merged is not None: + yield merged + + +async def _eval_bgp_with_bindings(node, bindings, tc, collection, limit): + """Evaluate a BGP with pre-seeded bindings so variables resolve to terms.""" + triples = node.triples + if not triples: + yield dict(bindings) + return + + def selectivity(pattern): + score = 0 + for t in pattern: + if not isinstance(t, Variable): + score += 1 + elif str(t) in bindings: + score += 1 + return score + + sorted_patterns = sorted( + enumerate(triples), key=lambda x: -selectivity(x[1]) + ) + + solutions = [dict(bindings)] + + for pattern_idx, (_, pattern) in enumerate(sorted_patterns): + s_tmpl, p_tmpl, o_tmpl = pattern + is_last = (pattern_idx == len(sorted_patterns) - 1) + + if is_last: + count = 0 + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + yield binding + count += 1 + if count >= limit: + return + else: + new_solutions = [] + for sol in solutions: + s_val = _resolve_term(s_tmpl, sol) + p_val = _resolve_term(p_tmpl, sol) + o_val = _resolve_term(o_tmpl, sol) + + async for triple in tc.query_gen( + s=s_val, p=p_val, o=o_val, + limit=limit, collection=collection, + ): + binding = dict(sol) + if isinstance(s_tmpl, Variable): + binding[str(s_tmpl)] = _to_term(triple.s) + if isinstance(p_tmpl, Variable): + binding[str(p_tmpl)] = _to_term(triple.p) + if isinstance(o_tmpl, Variable): + binding[str(o_tmpl)] = _to_term(triple.o) + new_solutions.append(binding) + + solutions = new_solutions + if not solutions: + return async def _eval_left_join(node, tc, collection, limit): - """Evaluate a LeftJoin node (OPTIONAL).""" - left_sols = await evaluate(node.p1, tc, collection, limit) - right_sols = await evaluate(node.p2, tc, collection, limit) + # Buffer right side for hash index; stream left through probe + left_sols = await materialise(node.p1, tc, collection, limit) + right_sols = await materialise(node.p2, tc, collection, limit) filter_fn = None if hasattr(node, "expr") and node.expr is not None: @@ -149,42 +320,35 @@ async def _eval_left_join(node, tc, collection, limit): evaluate_expression(expr, sol) ) - return left_join(left_sols, right_sols, filter_fn)[:limit] + for sol in left_join(left_sols, right_sols, filter_fn)[:limit]: + yield sol -async def _eval_union(node, tc, collection, limit): - """Evaluate a Union node.""" - left = await evaluate(node.p1, tc, collection, limit) - right = await evaluate(node.p2, tc, collection, limit) - return union(left, right)[:limit] - - -async def _eval_filter(node, tc, collection, limit): - """Evaluate a Filter node.""" - solutions = await evaluate(node.p, tc, collection, limit) - expr = node.expr - return [ - sol for sol in solutions - if _effective_boolean(evaluate_expression(expr, sol)) - ] +async def _eval_minus(node, tc, collection, limit): + left = await materialise(node.p1, tc, collection, limit) + right = await materialise(node.p2, tc, collection, limit) + for sol in minus(left, right): + yield sol async def _eval_distinct(node, tc, collection, limit): - """Evaluate a Distinct node.""" - solutions = await evaluate(node.p, tc, collection, limit) - return distinct(solutions) + seen = set() + async for sol in evaluate(node.p, tc, collection, limit): + key = tuple(sorted( + (k, _term_key(v)) for k, v in sol.items() + )) + if key not in seen: + seen.add(key) + yield sol async def _eval_reduced(node, tc, collection, limit): - """Evaluate a Reduced node (like Distinct but implementation-defined).""" - # Treat same as Distinct - solutions = await evaluate(node.p, tc, collection, limit) - return distinct(solutions) + async for sol in _eval_distinct(node, tc, collection, limit): + yield sol async def _eval_order_by(node, tc, collection, limit): - """Evaluate an OrderBy node.""" - solutions = await evaluate(node.p, tc, collection, limit) + solutions = await materialise(node.p, tc, collection, limit) key_fns = [] for cond in node.expr: @@ -196,36 +360,104 @@ async def _eval_order_by(node, tc, collection, limit): ascending, )) else: - # Simple variable or expression key_fns.append(( lambda sol, e=cond: evaluate_expression(e, sol), True, )) - return order_by(solutions, key_fns) + for sol in order_by(solutions, key_fns): + yield sol +# --- Streamable operators --- + async def _eval_slice(node, tc, collection, limit): - """Evaluate a Slice node (LIMIT/OFFSET).""" - # Pass tighter limit downstream if possible - inner_limit = limit - if node.length is not None: - offset = node.start or 0 - inner_limit = min(limit, offset + node.length) + offset = node.start or 0 + length = node.length + skipped = 0 + emitted = 0 - solutions = await evaluate(node.p, tc, collection, inner_limit) - return slice_solutions(solutions, node.start or 0, node.length) + async for sol in evaluate(node.p, tc, collection, limit): + if skipped < offset: + skipped += 1 + continue + yield sol + emitted += 1 + if length is not None and emitted >= length: + return + + +async def _eval_union(node, tc, collection, limit): + async for sol in evaluate(node.p1, tc, collection, limit): + yield sol + async for sol in evaluate(node.p2, tc, collection, limit): + yield sol + + +async def _check_exists(graph_node, sol, tc, collection, limit): + """Evaluate an EXISTS graph pattern against a solution.""" + async for r in evaluate(graph_node, tc, collection, limit): + shared = set(sol.keys()) & set(r.keys()) + if all( + _term_key(sol[v]) == _term_key(r[v]) + for v in shared + if sol.get(v) is not None and r.get(v) is not None + ): + return True + return False + + +async def _pre_eval_exists(expr, sol, tc, collection, limit, cache): + """Walk an expression tree, pre-evaluate EXISTS/NOT EXISTS, cache results.""" + if not isinstance(expr, CompValue): + return + if expr.name in ("Builtin_EXISTS", "Builtin_NOTEXISTS"): + key = id(expr.graph), id(sol) + if key not in cache: + cache[key] = await _check_exists( + expr.graph, sol, tc, collection, limit + ) + return + for attr in ("expr", "other", "arg", "arg1", "arg2", "arg3"): + child = getattr(expr, attr, None) + if child is None: + continue + if isinstance(child, CompValue): + await _pre_eval_exists(child, sol, tc, collection, limit, cache) + elif isinstance(child, (list, tuple)): + for item in child: + if isinstance(item, CompValue): + await _pre_eval_exists( + item, sol, tc, collection, limit, cache + ) + + +async def _eval_filter(node, tc, collection, limit): + expr = node.expr + exists_cache = {} + + def exists_cb(graph_node, sol): + key = id(graph_node), id(sol) + return exists_cache.get(key, False) + + async for sol in evaluate(node.p, tc, collection, limit): + await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) + if _effective_boolean(evaluate_expression(expr, sol, exists_cb=exists_cb)): + yield sol async def _eval_extend(node, tc, collection, limit): - """Evaluate an Extend node (BIND).""" - solutions = await evaluate(node.p, tc, collection, limit) var_name = str(node.var) expr = node.expr + exists_cache = {} - result = [] - for sol in solutions: - val = evaluate_expression(expr, sol) + def exists_cb(graph_node, sol): + key = id(graph_node), id(sol) + return exists_cache.get(key, False) + + async for sol in evaluate(node.p, tc, collection, limit): + await _pre_eval_exists(expr, sol, tc, collection, limit, exists_cache) + val = evaluate_expression(expr, sol, exists_cb=exists_cb) new_sol = dict(sol) if isinstance(val, Term): new_sol[var_name] = val @@ -240,16 +472,14 @@ async def _eval_extend(node, tc, collection, limit): ) elif val is not None: new_sol[var_name] = Term(type=LITERAL, value=str(val)) - result.append(new_sol) + yield new_sol - return result +# --- Aggregation (blocking) --- async def _eval_group(node, tc, collection, limit): - """Evaluate a Group node (GROUP BY with aggregation).""" - solutions = await evaluate(node.p, tc, collection, limit) + solutions = await materialise(node.p, tc, collection, limit) - # Extract grouping expressions group_exprs = [] if hasattr(node, "expr") and node.expr: for expr in node.expr: @@ -260,7 +490,6 @@ async def _eval_group(node, tc, collection, limit): else: group_exprs.append((expr, None)) - # Group solutions groups = defaultdict(list) for sol in solutions: key_parts = [] @@ -270,81 +499,72 @@ async def _eval_group(node, tc, collection, limit): groups[tuple(key_parts)].append(sol) if not group_exprs: - # No GROUP BY - entire result is one group groups[()].extend(solutions) - # Build grouped solutions (one per group) - result = [] for key, group_sols in groups.items(): sol = {} - # Include group key variables if group_sols: for (expr, var_name), k in zip(group_exprs, key): if var_name and group_sols: sol[var_name] = evaluate_expression(expr, group_sols[0]) sol["__group__"] = group_sols - result.append(sol) - - return result + yield sol async def _eval_aggregate_join(node, tc, collection, limit): - """Evaluate an AggregateJoin (aggregation functions after GROUP BY).""" - solutions = await evaluate(node.p, tc, collection, limit) - - result = [] - for sol in solutions: + async for sol in evaluate(node.p, tc, collection, limit): group = sol.get("__group__", [sol]) new_sol = {k: v for k, v in sol.items() if k != "__group__"} - # Apply aggregate functions if hasattr(node, "A") and node.A: for agg in node.A: var_name = str(agg.res) agg_val = _compute_aggregate(agg, group) new_sol[var_name] = agg_val - result.append(new_sol) - - return result + yield new_sol async def _eval_graph(node, tc, collection, limit): - """Evaluate a Graph node (GRAPH clause).""" term = node.term if isinstance(term, URIRef): - # GRAPH { ... } — fixed graph - # We'd need to pass graph to triples queries - # For now, evaluate inner pattern normally logger.info(f"GRAPH <{term}> clause - graph filtering not yet wired") - return await evaluate(node.p, tc, collection, limit) elif isinstance(term, Variable): - # GRAPH ?g { ... } — variable graph logger.info(f"GRAPH ?{term} clause - variable graph not yet wired") - return await evaluate(node.p, tc, collection, limit) - else: - return await evaluate(node.p, tc, collection, limit) + + async for sol in evaluate(node.p, tc, collection, limit): + yield sol async def _eval_values(node, tc, collection, limit): - """Evaluate a VALUES clause (inline data).""" - variables = [str(v) for v in node.var] - solutions = [] + # rdflib has two representations for VALUES: + # 1. var=[Variable...], value=[[val, ...], ...] — positional + # 2. var=None, res=[{Variable: val, ...}, ...] — dict-based + if hasattr(node, "res") and node.res: + for row in node.res: + sol = {} + for var, val in row.items(): + if val is not None and str(val) != "UNDEF": + sol[str(var)] = rdflib_term_to_term(val) + yield sol + return + if not node.var or not node.value: + yield {} + return + variables = [str(v) for v in node.var] for row in node.value: sol = {} for var_name, val in zip(variables, row): if val is not None and str(val) != "UNDEF": sol[var_name] = rdflib_term_to_term(val) - solutions.append(sol) - - return solutions + yield sol async def _eval_to_multiset(node, tc, collection, limit): - """Evaluate a ToMultiSet node (subquery).""" - return await evaluate(node.p, tc, collection, limit) + async for sol in evaluate(node.p, tc, collection, limit): + yield sol # --- Aggregate computation --- @@ -353,7 +573,6 @@ def _compute_aggregate(agg, group): """Compute a single aggregate function over a group of solutions.""" agg_name = agg.name if hasattr(agg, "name") else "" - # Get the expression to aggregate expr = agg.vars if hasattr(agg, "vars") else None if agg_name == "Aggregate_Count": @@ -525,6 +744,7 @@ _HANDLERS = { "Join": _eval_join, "LeftJoin": _eval_left_join, "Union": _eval_union, + "Minus": _eval_minus, "Filter": _eval_filter, "Distinct": _eval_distinct, "Reduced": _eval_reduced, diff --git a/trustgraph-flow/trustgraph/query/sparql/expressions.py b/trustgraph-flow/trustgraph/query/sparql/expressions.py index eac1199c..608eeff2 100644 --- a/trustgraph-flow/trustgraph/query/sparql/expressions.py +++ b/trustgraph-flow/trustgraph/query/sparql/expressions.py @@ -5,9 +5,15 @@ Evaluates rdflib algebra expression nodes against a solution (variable binding) to produce a value or boolean result. """ +import hashlib +import math +import random import re import logging import operator +import uuid +from datetime import datetime, date, timezone +from urllib.parse import quote from rdflib.term import Variable, URIRef, Literal, BNode from rdflib.plugins.sparql.parserutils import CompValue @@ -17,23 +23,31 @@ from . parser import rdflib_term_to_term logger = logging.getLogger(__name__) +_exists_callback = None + class ExpressionError(Exception): """Raised when a SPARQL expression cannot be evaluated.""" pass -def evaluate_expression(expr, solution): +def evaluate_expression(expr, solution, exists_cb=None): """ Evaluate a SPARQL expression against a solution binding. Args: expr: rdflib algebra expression node solution: dict mapping variable names to Term values + exists_cb: optional callback(graph_node, solution) -> bool for + EXISTS/NOT EXISTS evaluation; provided by algebra.py Returns: The result value (Term, bool, number, string, or None) """ + global _exists_callback + if exists_cb is not None: + _exists_callback = exists_cb + if expr is None: return True @@ -111,6 +125,13 @@ def _evaluate_comp_value(node, solution): if name == "MultiplicativeExpression": return _eval_multiplicative(node, solution) + # IN / NOT IN — must be checked before the generic Builtin_ dispatch + if name == "Builtin_IN": + return _eval_in(node, solution) + + if name == "Builtin_NOTIN": + return not _eval_in(node, solution) + # SPARQL built-in functions if name.startswith("Builtin_"): return _eval_builtin(name, node, solution) @@ -119,27 +140,10 @@ def _evaluate_comp_value(node, solution): if name == "Function": return _eval_function(node, solution) - # Exists / NotExists - if name == "Builtin_EXISTS": - # EXISTS requires graph pattern evaluation - not handled here - logger.warning("EXISTS not supported in filter expressions") - return True - - if name == "Builtin_NOTEXISTS": - logger.warning("NOT EXISTS not supported in filter expressions") - return True - # TrueFilter (used with OPTIONAL) if name == "TrueFilter": return True - # IN / NOT IN - if name == "Builtin_IN": - return _eval_in(node, solution) - - if name == "Builtin_NOTIN": - return not _eval_in(node, solution) - logger.warning(f"Unknown CompValue expression: {name}") return None @@ -165,6 +169,22 @@ def _eval_relational(node, solution): ">=": operator.ge, } + if str(op) == "IN": + items = node.other if isinstance(node.other, list) else [node.other] + for item in items: + other_val = evaluate_expression(item, solution) + if _comparable_value(left) == _comparable_value(other_val): + return True + return False + + if str(op) == "NOT IN": + items = node.other if isinstance(node.other, list) else [node.other] + for item in items: + other_val = evaluate_expression(item, solution) + if _comparable_value(left) == _comparable_value(other_val): + return False + return True + op_fn = ops.get(str(op)) if op_fn is None: logger.warning(f"Unknown relational operator: {op}") @@ -335,6 +355,197 @@ def _eval_builtin(name, node, solution): return val return None + if builtin == "YEAR": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + return dt.year if dt is not None else None + + if builtin == "MONTH": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + return dt.month if dt is not None else None + + if builtin == "DAY": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + return dt.day if dt is not None else None + + if builtin == "HOURS": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return None + return dt.hour if isinstance(dt, datetime) else 0 + + if builtin == "MINUTES": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return None + return dt.minute if isinstance(dt, datetime) else 0 + + if builtin == "SECONDS": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return None + return dt.second if isinstance(dt, datetime) else 0 + + if builtin == "FLOOR": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return int(math.floor(val)) + + if builtin == "CEIL": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return int(math.ceil(val)) + + if builtin == "ABS": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return abs(val) + + if builtin == "ROUND": + val = _to_numeric(evaluate_expression(node.arg, solution)) + if val is None: + return None + return round(val) + + if builtin == "STRBEFORE": + string = _to_string(evaluate_expression(node.arg1, solution)) + sep = _to_string(evaluate_expression(node.arg2, solution)) + idx = string.find(sep) + if idx < 0: + return Term(type=LITERAL, value="") + return Term(type=LITERAL, value=string[:idx]) + + if builtin == "STRAFTER": + string = _to_string(evaluate_expression(node.arg1, solution)) + sep = _to_string(evaluate_expression(node.arg2, solution)) + idx = string.find(sep) + if idx < 0: + return Term(type=LITERAL, value="") + return Term(type=LITERAL, value=string[idx + len(sep):]) + + if builtin == "ENCODE_FOR_URI": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=LITERAL, value=quote(val, safe="")) + + if builtin == "REPLACE": + string = _to_string(evaluate_expression(node.arg, solution)) + pattern = _to_string(evaluate_expression(node.pattern, solution)) + replacement = _to_string( + evaluate_expression(node.replacement, solution) + ) + flags_str = "" + if hasattr(node, "flags") and node.flags is not None: + flags_str = _to_string(evaluate_expression(node.flags, solution)) + re_flags = 0 + if "i" in flags_str: + re_flags |= re.IGNORECASE + if "m" in flags_str: + re_flags |= re.MULTILINE + if "s" in flags_str: + re_flags |= re.DOTALL + try: + result = re.sub(pattern, replacement, string, flags=re_flags) + return Term(type=LITERAL, value=result) + except re.error: + return None + + if builtin == "SUBSTR": + string = _to_string(evaluate_expression(node.arg, solution)) + start = _to_numeric(evaluate_expression(node.start, solution)) + if start is None: + return None + start_idx = max(int(start) - 1, 0) + if hasattr(node, "length") and node.length is not None: + length = _to_numeric(evaluate_expression(node.length, solution)) + if length is None: + return None + return Term( + type=LITERAL, value=string[start_idx:start_idx + int(length)] + ) + return Term(type=LITERAL, value=string[start_idx:]) + + if builtin == "EXISTS": + if _exists_callback is not None: + return _exists_callback(node.graph, solution) + logger.warning("EXISTS requires an exists_cb; not available") + return True + + if builtin == "NOTEXISTS": + if _exists_callback is not None: + return not _exists_callback(node.graph, solution) + logger.warning("NOT EXISTS requires an exists_cb; not available") + return True + + if builtin == "LANGMATCHES": + tag = _to_string(evaluate_expression(node.arg1, solution)) + rng = _to_string(evaluate_expression(node.arg2, solution)) + if rng == "*": + return len(tag) > 0 + return tag.lower().startswith(rng.lower()) + + if builtin == "IRI" or builtin == "URI": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=IRI, iri=val) + + if builtin == "BNODE": + if hasattr(node, "arg") and node.arg is not None: + label = _to_string(evaluate_expression(node.arg, solution)) + return Term(type=BLANK, id=label) + return Term(type=BLANK, id=str(uuid.uuid4())) + + if builtin == "NOW": + now = datetime.now(timezone.utc) + return Term( + type=LITERAL, + value=now.strftime("%Y-%m-%dT%H:%M:%S%z"), + datatype="http://www.w3.org/2001/XMLSchema#dateTime", + ) + + if builtin == "TZ": + dt = _to_datetime(evaluate_expression(node.arg, solution)) + if dt is None: + return Term(type=LITERAL, value="") + if dt.tzinfo is not None: + offset = dt.strftime("%z") + if offset: + return Term(type=LITERAL, value=offset[:3] + ":" + offset[3:]) + return Term(type=LITERAL, value="") + + if builtin == "RAND": + return random.random() + + if builtin == "UUID": + return Term(type=IRI, iri="urn:uuid:" + str(uuid.uuid4())) + + if builtin == "STRUUID": + return Term(type=LITERAL, value=str(uuid.uuid4())) + + if builtin == "MD5": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.md5(val.encode()).hexdigest() + ) + + if builtin == "SHA1": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.sha1(val.encode()).hexdigest() + ) + + if builtin == "SHA256": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.sha256(val.encode()).hexdigest() + ) + + if builtin == "SHA512": + val = _to_string(evaluate_expression(node.arg, solution)) + return Term( + type=LITERAL, value=hashlib.sha512(val.encode()).hexdigest() + ) + if builtin == "sameTerm": left = evaluate_expression(node.arg1, solution) right = evaluate_expression(node.arg2, solution) @@ -454,6 +665,27 @@ def _to_numeric(val): return None +def _to_datetime(val): + """Convert a value to a date or datetime object.""" + if val is None: + return None + s = _to_string(val) + for fmt in ( + "%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%S%z", + "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M", + "%Y-%m-%d", + ): + try: + return datetime.strptime(s, fmt) + except ValueError: + continue + try: + return datetime.fromisoformat(s) + except ValueError: + pass + return None + + def _comparable_value(val): """ Convert a value to a form suitable for comparison. diff --git a/trustgraph-flow/trustgraph/query/sparql/service.py b/trustgraph-flow/trustgraph/query/sparql/service.py index 75c00dba..bbe375f0 100644 --- a/trustgraph-flow/trustgraph/query/sparql/service.py +++ b/trustgraph-flow/trustgraph/query/sparql/service.py @@ -12,7 +12,7 @@ from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import TriplesClientSpec from . parser import parse_sparql, ParseError -from . algebra import evaluate, EvaluationError +from . algebra import evaluate, materialise, EvaluationError logger = logging.getLogger(__name__) @@ -66,11 +66,10 @@ class Processor(FlowProcessor): logger.debug(f"Handling SPARQL query request {id}...") - response = await self.execute_sparql(request, flow) - - if request.streaming and response.query_type == "select": - await self.send_streaming(response, flow, id, request) + if request.streaming: + await self.execute_sparql_streaming(request, flow, id) else: + response = await self.execute_sparql(request, flow) await flow("response").send( response, properties={"id": id} ) @@ -92,37 +91,77 @@ class Processor(FlowProcessor): await flow("response").send(r, properties={"id": id}) - async def send_streaming(self, response, flow, id, request): - """Send SELECT results in batches.""" + async def execute_sparql_streaming(self, request, flow, id): + """Execute a SPARQL query and stream results as they arrive.""" - bindings = response.bindings + try: + parsed = parse_sparql(request.query) + except ParseError as e: + await flow("response").send( + SparqlQueryResponse( + error=Error( + type="sparql-parse-error", + message=str(e), + ), + ), + properties={"id": id} + ) + return + + if parsed.query_type != "select": + response = await self._execute_non_select(parsed, request, flow) + await flow("response").send(response, properties={"id": id}) + return + + triples_client = flow("triples-request") + variables = parsed.variables batch_size = request.batch_size if request.batch_size > 0 else 20 + batch = [] - for i in range(0, len(bindings), batch_size): - batch = bindings[i:i + batch_size] - is_final = (i + batch_size >= len(bindings)) - r = SparqlQueryResponse( - query_type=response.query_type, - variables=response.variables, - bindings=batch, - is_final=is_final, - ) - await flow("response").send(r, properties={"id": id}) + try: + async for sol in evaluate( + parsed.algebra, + triples_client, + collection=request.collection or "default", + limit=request.limit or 10000, + ): + values = [sol.get(v) for v in variables] + batch.append(SparqlBinding(values=values)) - # Handle empty results - if len(bindings) == 0: - r = SparqlQueryResponse( - query_type=response.query_type, - variables=response.variables, - bindings=[], - is_final=True, + if len(batch) >= batch_size: + r = SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=batch, + is_final=False, + ) + await flow("response").send(r, properties={"id": id}) + batch = [] + + except EvaluationError as e: + await flow("response").send( + SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ), + properties={"id": id} ) - await flow("response").send(r, properties={"id": id}) + return + + # Final batch (may be empty for zero results) + r = SparqlQueryResponse( + query_type="select", + variables=variables, + bindings=batch, + is_final=True, + ) + await flow("response").send(r, properties={"id": id}) async def execute_sparql(self, request, flow): - """Parse and evaluate a SPARQL query.""" + """Parse and evaluate a SPARQL query (non-streaming).""" - # Parse the SPARQL query try: parsed = parse_sparql(request.query) except ParseError as e: @@ -133,12 +172,31 @@ class Processor(FlowProcessor): ), ) - # Get the triples client from the flow - triples_client = flow("triples-request") + if parsed.query_type == "select": + triples_client = flow("triples-request") + try: + solutions = await materialise( + parsed.algebra, + triples_client, + collection=request.collection or "default", + limit=request.limit or 10000, + ) + except EvaluationError as e: + return SparqlQueryResponse( + error=Error( + type="sparql-evaluation-error", + message=str(e), + ), + ) + return self._build_select_response(parsed, solutions) - # Evaluate the algebra + return await self._execute_non_select(parsed, request, flow) + + async def _execute_non_select(self, parsed, request, flow): + """Execute ASK, CONSTRUCT, or DESCRIBE queries.""" + triples_client = flow("triples-request") try: - solutions = await evaluate( + solutions = await materialise( parsed.algebra, triples_client, collection=request.collection or "default", @@ -152,10 +210,7 @@ class Processor(FlowProcessor): ), ) - # Build response based on query type - if parsed.query_type == "select": - return self._build_select_response(parsed, solutions) - elif parsed.query_type == "ask": + if parsed.query_type == "ask": return self._build_ask_response(solutions) elif parsed.query_type == "construct": return self._build_construct_response(parsed, solutions) diff --git a/trustgraph-flow/trustgraph/query/sparql/solutions.py b/trustgraph-flow/trustgraph/query/sparql/solutions.py index d1ea8373..edf3401d 100644 --- a/trustgraph-flow/trustgraph/query/sparql/solutions.py +++ b/trustgraph-flow/trustgraph/query/sparql/solutions.py @@ -150,6 +150,30 @@ def left_join(left, right, filter_fn=None): return results +def minus(left, right): + """ + MINUS operation: remove left solutions that are compatible with any + right solution sharing at least one variable. + """ + if not right: + return list(left) + + right_vars = set() + for sol in right: + right_vars.update(sol.keys()) + + results = [] + for sol_l in left: + shared = set(sol_l.keys()) & right_vars + if not shared: + results.append(sol_l) + continue + if not any(_compatible(sol_l, sol_r) for sol_r in right): + results.append(sol_l) + + return results + + def union(left, right): """Union two solution sequences (concatenation).""" return list(left) + list(right) @@ -177,6 +201,28 @@ def distinct(solutions): return results +def _sort_comparable(val): + """Convert a value to a form suitable for sort ordering.""" + if val is None: + return (0, "") + if isinstance(val, (int, float)): + return (2, val) + if isinstance(val, Term): + if val.type == LITERAL: + try: + if "." in val.value: + return (2, float(val.value)) + return (2, int(val.value)) + except (ValueError, TypeError): + pass + return (3, val.value) + elif val.type == IRI: + return (4, val.iri) + elif val.type == BLANK: + return (5, val.id) + return (6, str(val)) + + def order_by(solutions, key_fns): """ Sort solutions by the given key functions. @@ -191,14 +237,7 @@ def order_by(solutions, key_fns): keys = [] for fn, ascending in key_fns: val = fn(sol) - # Convert to comparable form - if val is None: - comparable = ("", "") - elif isinstance(val, Term): - comparable = _term_key(val) - else: - comparable = ("v", str(val)) - keys.append(comparable) + keys.append(_sort_comparable(val)) return keys # Handle ascending/descending @@ -224,10 +263,8 @@ def _mixed_sort(solutions, key_fns): def compare(a, b): for fn, ascending in key_fns: - va = fn(a) - vb = fn(b) - ka = _term_key(va) if isinstance(va, Term) else ("v", str(va)) if va is not None else ("", "") - kb = _term_key(vb) if isinstance(vb, Term) else ("v", str(vb)) if vb is not None else ("", "") + ka = _sort_comparable(fn(a)) + kb = _sort_comparable(fn(b)) if ka < kb: return -1 if ascending else 1 diff --git a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py index 986558ec..5e71a0e5 100644 --- a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py +++ b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py @@ -1,130 +1,140 @@ import asyncio import logging import uuid -from typing import Dict, Any, Optional -from trustgraph.messaging import TranslatorRegistry +from typing import Dict, Any, Optional, Callable, Awaitable from ..gateway.dispatch.manager import DispatcherManager logger = logging.getLogger("dispatcher") -logger.setLevel(logging.INFO) -class WebSocketResponder: - """Simple responder that captures response for websocket return""" - def __init__(self): - self.response = None - self.completed = False - - async def send(self, data): - """Capture the response data""" - self.response = data - self.completed = True - - async def __call__(self, data, final=False): - """Make the responder callable for compatibility with requestor""" - await self.send(data) - if final: - self.completed = True + +class _TokenShim: + def __init__(self, token): + self.headers = ( + {"Authorization": f"Bearer {token}"} if token else {} + ) + class MessageDispatcher: - def __init__(self, max_workers: int = 10, config_receiver=None, backend=None): + def __init__(self, max_workers=10, config_receiver=None, backend=None, + auth=None, timeout=120): self.max_workers = max_workers self.semaphore = asyncio.Semaphore(max_workers) self.active_tasks = set() self.backend = backend + self.auth = auth - # Use DispatcherManager for flow and service management - if backend and config_receiver: - self.dispatcher_manager = DispatcherManager(backend, config_receiver, prefix="rev-gateway") + if backend and config_receiver and auth: + self.dispatcher_manager = DispatcherManager( + backend, config_receiver, + auth=auth, + prefix="rev-gateway", + timeout=timeout, + ) else: self.dispatcher_manager = None - logger.warning("No backend or config_receiver provided - using fallback mode") - - # Service name mapping from websocket protocol to translator registry + logger.warning( + "Missing backend, config_receiver, or auth " + "— using fallback mode" + ) + self.service_mapping = { "text-completion": "text-completion", - "graph-rag": "graph-rag", + "graph-rag": "graph-rag", "agent": "agent", "embeddings": "embeddings", "graph-embeddings": "graph-embeddings", "triples": "triples", "document-load": "document", - "text-load": "text-document", + "text-load": "text-document", "flow": "flow", "knowledge": "knowledge", "config": "config", "librarian": "librarian", - "document-rag": "document-rag" + "document-rag": "document-rag", } - - async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]: + + async def handle_message( + self, message: Dict[Any, Any], + sender: Callable[[dict], Awaitable[None]], + ): async with self.semaphore: - task = asyncio.create_task(self._process_message(message)) + task = asyncio.create_task( + self._process_message(message, sender) + ) self.active_tasks.add(task) - + try: - result = await task - return result + await task finally: self.active_tasks.discard(task) - - async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]: + + async def _authenticate(self, token): + if not self.auth: + raise RuntimeError("Auth not configured") + return await self.auth.authenticate(_TokenShim(token)) + + async def _process_message( + self, message: Dict[Any, Any], + sender: Callable[[dict], Awaitable[None]], + ): request_id = message.get('id', str(uuid.uuid4())) service = message.get('service') request_data = message.get('request', {}) - flow_id = message.get('flow', 'default') # Default flow - - logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}") - + token = message.get('token', '') + flow_id = message.get('flow', 'default') + + logger.info( + f"Processing message {request_id} for service " + f"{service} on flow {flow_id}" + ) + try: if not self.dispatcher_manager: - raise RuntimeError("DispatcherManager not available - backend and config_receiver required") - - # Use DispatcherManager for flow-based processing - responder = WebSocketResponder() - - # Map websocket service name to dispatcher service name + raise RuntimeError( + "DispatcherManager not available" + ) + + identity = await self._authenticate(token) + workspace = identity.workspace + + async def responder(resp, fin): + await sender({ + "id": request_id, + "response": resp, + "complete": fin, + }) + dispatcher_service = self.service_mapping.get(service, service) - - # Check if this is a global service or flow service + from ..gateway.dispatch.manager import global_dispatchers if dispatcher_service in global_dispatchers: - # Use global service dispatcher await self.dispatcher_manager.invoke_global_service( - request_data, responder, dispatcher_service + request_data, responder, dispatcher_service, + workspace=workspace, ) else: - # Use DispatcherManager to process the request through Pulsar queues await self.dispatcher_manager.invoke_flow_service( - request_data, responder, flow_id, dispatcher_service + request_data, responder, workspace, flow_id, + dispatcher_service, ) - - # Get the response from the responder - if responder.completed: - response_data = responder.response - else: - response_data = {'error': 'No response received'} - - response = { - 'id': request_id, - 'response': response_data - } - + except Exception as e: logger.error(f"Error processing message {request_id}: {e}") - response = { - 'id': request_id, - 'response': {'error': str(e)} - } - + await sender({ + "id": request_id, + "error": {"message": str(e), "type": "error"}, + "complete": True, + }) + logger.info(f"Completed processing message {request_id}") - return response - - + async def shutdown(self): if self.active_tasks: - logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete") + logger.info( + f"Waiting for {len(self.active_tasks)} active " + f"tasks to complete" + ) await asyncio.gather(*self.active_tasks, return_exceptions=True) - - # DispatcherManager handles its own cleanup + logger.info("Dispatcher shutdown complete") diff --git a/trustgraph-flow/trustgraph/rev_gateway/service.py b/trustgraph-flow/trustgraph/rev_gateway/service.py index cc905172..39180e2c 100644 --- a/trustgraph-flow/trustgraph/rev_gateway/service.py +++ b/trustgraph-flow/trustgraph/rev_gateway/service.py @@ -1,3 +1,9 @@ +""" +Reverse gateway. Initiates outbound WebSocket connections to a remote +relay and dispatches incoming requests through the same DispatcherManager +pipeline as api-gateway. +""" + import asyncio import argparse import logging @@ -9,82 +15,86 @@ from typing import Optional from urllib.parse import urlparse, urlunparse from .dispatcher import MessageDispatcher +from ..gateway.auth import IamAuth from ..gateway.config.receiver import ConfigReceiver -from ..base import get_pubsub +from ..base.pubsub import get_pubsub, add_pubsub_args +from ..base.logging import setup_logging, add_logging_args logger = logging.getLogger("rev_gateway") -logger.setLevel(logging.INFO) default_websocket = "ws://localhost:7650/out" +default_timeout = 600 class ReverseGateway: - - def __init__(self, websocket_uri: str = None, max_workers: int = 10, - pulsar_host: str = None, pulsar_api_key: str = None, - pulsar_listener: str = None): - # Set default WebSocket URI with environment variable support + + def __init__(self, **config): + websocket_uri = config.get("websocket_uri") if websocket_uri is None: websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket) - - # Parse and validate the WebSocket URI + parsed_uri = urlparse(websocket_uri) if parsed_uri.scheme not in ('ws', 'wss'): - raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}") + raise ValueError( + f"WebSocket URI must use ws:// or wss:// scheme, " + f"got: {parsed_uri.scheme}" + ) if not parsed_uri.netloc: - raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}") - - # Store parsed components for debugging/logging + raise ValueError( + f"WebSocket URI must include hostname, " + f"got: {websocket_uri}" + ) + self.websocket_uri = websocket_uri self.host = parsed_uri.hostname self.port = parsed_uri.port self.scheme = parsed_uri.scheme self.path = parsed_uri.path or "/ws" - - # Construct the full URL (in case path was missing) + if not parsed_uri.path: self.url = f"{self.scheme}://{parsed_uri.netloc}/ws" else: self.url = websocket_uri - - self.max_workers = max_workers + + self.max_workers = int(config.get("max_workers", 10)) + self.timeout = int(config.get("timeout", default_timeout)) self.ws: Optional[ClientWebSocketResponse] = None self.session: Optional[ClientSession] = None self.running = False self.reconnect_delay = 3.0 - - # Pulsar configuration - self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") - self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None) - self.pulsar_listener = pulsar_listener - # Create backend using factory - backend_params = { - 'pulsar_host': self.pulsar_host, - 'pulsar_api_key': self.pulsar_api_key, - 'pulsar_listener': self.pulsar_listener, - } - self.backend = get_pubsub(**backend_params) + self.backend = get_pubsub(**config) - # Initialize config receiver - self.config_receiver = ConfigReceiver(self.backend) + self.auth = IamAuth( + backend=self.backend, + id=config.get("id", "rev-gateway"), + ) + + self.config_receiver = ConfigReceiver( + self.backend, auth=self.auth, + ) + + self.dispatcher = MessageDispatcher( + self.max_workers, self.config_receiver, self.backend, + auth=self.auth, timeout=self.timeout, + ) - # Initialize dispatcher with config_receiver and backend - must be created after config_receiver - self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.backend) - async def connect(self) -> bool: try: if self.session is None: self.session = ClientSession() - + logger.info(f"Connecting to {self.url}") self.ws = await self.session.ws_connect(self.url) - logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}") + logger.info( + f"WebSocket connection established to " + f"{self.host}:{self.port or 'default'}" + ) return True - + except Exception as e: logger.error(f"Failed to connect to {self.url}: {e}") return False - + async def disconnect(self): if self.ws and not self.ws.closed: await self.ws.close() @@ -92,32 +102,31 @@ class ReverseGateway: await self.session.close() self.ws = None self.session = None - + async def send_message(self, message: dict): if self.ws and not self.ws.closed: try: await self.ws.send_str(json.dumps(message)) except Exception as e: logger.error(f"Failed to send message: {e}") - + async def handle_message(self, message: str): try: logger.debug(f"Received message: {message}") - + msg_data = json.loads(message) - response = await self.dispatcher.handle_message(msg_data) - - if response: - await self.send_message(response) - + await self.dispatcher.handle_message( + msg_data, self.send_message, + ) + except Exception as e: logger.error(f"Error handling message: {e}") - + async def listen(self): while self.running and self.ws and not self.ws.closed: try: msg = await self.ws.receive() - + if msg.type == WSMsgType.TEXT: await self.handle_message(msg.data) elif msg.type == WSMsgType.BINARY: @@ -125,31 +134,33 @@ class ReverseGateway: elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR): logger.warning("WebSocket closed or error occurred") break - + except Exception as e: logger.error(f"Error in listen loop: {e}") break - + async def run(self): self.running = True logger.info("Starting reverse gateway") - - # Start config receiver - logger.info("Starting config receiver") + + await self.auth.start() await self.config_receiver.start() - + while self.running: try: if await self.connect(): await self.listen() else: - logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds") - + logger.warning( + f"Connection failed, retrying in " + f"{self.reconnect_delay} seconds" + ) + await self.disconnect() - + if self.running: await asyncio.sleep(self.reconnect_delay) - + except KeyboardInterrupt: logger.info("Shutdown requested") break @@ -157,77 +168,69 @@ class ReverseGateway: logger.error(f"Unexpected error: {e}") if self.running: await asyncio.sleep(self.reconnect_delay) - + await self.shutdown() - + async def shutdown(self): logger.info("Shutting down reverse gateway") self.running = False await self.dispatcher.shutdown() await self.disconnect() - # Close backend if hasattr(self, 'backend'): self.backend.close() - + def stop(self): self.running = False -def parse_args(): + +def run(): + parser = argparse.ArgumentParser( prog="reverse-gateway", - description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge" + description=__doc__, ) - + + parser.add_argument( + '--id', + default='rev-gateway', + help='Service identifier (default: rev-gateway)', + ) + parser.add_argument( '--websocket-uri', default=None, - help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)' + help=f'WebSocket URI to connect to (default: {default_websocket})', ) - + parser.add_argument( '--max-workers', type=int, default=10, - help='Maximum concurrent message handlers (default: 10)' + help='Maximum concurrent message handlers (default: 10)', ) - - parser.add_argument( - '-p', '--pulsar-host', - default=None, - help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)' - ) - - parser.add_argument( - '--pulsar-api-key', - default=None, - help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)' - ) - - parser.add_argument( - '--pulsar-listener', - default=None, - help='Pulsar listener name' - ) - - return parser.parse_args() -def run(): - args = parse_args() - - gateway = ReverseGateway( - websocket_uri=args.websocket_uri, - max_workers=args.max_workers, - pulsar_host=args.pulsar_host, - pulsar_api_key=args.pulsar_api_key, - pulsar_listener=args.pulsar_listener + parser.add_argument( + '--timeout', + type=int, + default=default_timeout, + help=f'Request timeout in seconds (default: {default_timeout})', ) - + + add_pubsub_args(parser) + add_logging_args(parser) + + args = parser.parse_args() + args = vars(args) + + setup_logging(args) + + gateway = ReverseGateway(**args) + logger.info(f"Starting reverse gateway:") logger.info(f" WebSocket URI: {gateway.url}") - logger.info(f" Max workers: {args.max_workers}") - logger.info(f" Pulsar host: {gateway.pulsar_host}") - + logger.info(f" Max workers: {gateway.max_workers}") + try: asyncio.run(gateway.run()) except KeyboardInterrupt: